Skip to main content

graphify_extract/treesitter/
mod.rs

1//! Tree-sitter based AST extraction engine.
2//!
3//! Provides accurate structural extraction using native tree-sitter grammars
4//! for Python, JavaScript, TypeScript, Rust, Go, Java, C, C++, Ruby, C#, and Dart.
5//! Falls back gracefully to the regex-based extractor for unsupported languages.
6
7mod handlers;
8mod imports;
9mod treesitter_config;
10
11pub use treesitter_config::TsConfig;
12
13use std::collections::{HashMap, HashSet};
14use std::path::Path;
15
16use graphify_core::confidence::Confidence;
17use graphify_core::id::make_id;
18use graphify_core::model::{ExtractionResult, GraphEdge, GraphNode, NodeType};
19use tracing::trace;
20use tree_sitter::{Language, Node, Parser};
21
22pub fn try_extract(path: &Path, source: &[u8], lang: &str) -> Option<ExtractionResult> {
23    let (language, config) = treesitter_config::resolve_language(lang)?;
24    extract_with_treesitter(path, source, language, &config, lang)
25}
26
27fn extract_with_treesitter(
28    path: &Path,
29    source: &[u8],
30    language: Language,
31    config: &TsConfig,
32    lang: &str,
33) -> Option<ExtractionResult> {
34    let mut parser = Parser::new();
35    parser.set_language(&language).ok()?;
36    let tree = parser.parse(source, None)?;
37    let root = tree.root_node();
38
39    let stem = path
40        .file_stem()
41        .and_then(|s| s.to_str())
42        .unwrap_or("unknown");
43    let str_path = path.to_string_lossy();
44
45    let mut nodes = Vec::new();
46    let mut edges = Vec::new();
47    let mut seen_ids = HashSet::new();
48    let mut raw_calls: Vec<(String, String)> = Vec::new();
49    let mut ruby_bodies: Vec<(String, usize, usize)> = Vec::new();
50
51    let file_nid = make_id(&[&str_path]);
52    seen_ids.insert(file_nid.clone());
53    nodes.push(GraphNode {
54        id: file_nid.clone(),
55        label: stem.to_string(),
56        source_file: str_path.to_string(),
57        source_location: None,
58        node_type: NodeType::File,
59        community: None,
60        extra: HashMap::new(),
61    });
62
63    {
64        let mut ctx = WalkContext {
65            lang,
66            file_nid: &file_nid,
67            str_path: &str_path,
68            nodes: &mut nodes,
69            edges: &mut edges,
70            seen_ids: &mut seen_ids,
71            raw_calls: &mut raw_calls,
72            ruby_bodies: &mut ruby_bodies,
73        };
74        walk_node(root, source, config, &mut ctx, None);
75    }
76
77    let label_to_nid: HashMap<String, String> = nodes
78        .iter()
79        .filter(|n| matches!(n.node_type, NodeType::Function | NodeType::Method))
80        .map(|n| {
81            let normalized = n
82                .label
83                .trim_end_matches("()")
84                .trim_start_matches('.')
85                .to_lowercase();
86            (normalized, n.id.clone())
87        })
88        .collect();
89
90    let mut seen_calls: HashSet<(String, String)> = HashSet::new();
91    for (caller_nid, callee_name) in &raw_calls {
92        let name_lower = callee_name.to_lowercase();
93        if let Some(callee_nid) = label_to_nid.get(&name_lower) {
94            if callee_nid == caller_nid {
95                continue;
96            }
97            let key = (caller_nid.clone(), callee_nid.clone());
98            if seen_calls.insert(key) {
99                edges.push(GraphEdge {
100                    source: caller_nid.clone(),
101                    target: callee_nid.clone(),
102                    relation: "calls".to_string(),
103                    confidence: Confidence::Inferred,
104                    confidence_score: Confidence::Inferred.default_score(),
105                    source_file: str_path.to_string(),
106                    source_location: None,
107                    weight: 1.0,
108                    extra: HashMap::new(),
109                });
110            }
111        }
112    }
113
114    if lang == "ruby" {
115        for (caller_nid, body_start, body_end) in &ruby_bodies {
116            let body_text = &source[*body_start..*body_end];
117            let body_str = String::from_utf8_lossy(body_text);
118            let body_lower = body_str.to_lowercase();
119            for (func_label, callee_nid) in &label_to_nid {
120                if callee_nid == caller_nid {
121                    continue;
122                }
123                let found = body_lower
124                    .find(func_label.as_str())
125                    .is_some_and(|pos| {
126                        let after = pos + func_label.len();
127                        if after >= body_lower.len() {
128                            true
129                        } else {
130                            let next_ch = body_lower.as_bytes()[after];
131                            !next_ch.is_ascii_alphanumeric() && next_ch != b'_'
132                        }
133                    });
134                if found {
135                    let key = (caller_nid.clone(), callee_nid.clone());
136                    if seen_calls.insert(key) {
137                        edges.push(GraphEdge {
138                            source: caller_nid.clone(),
139                            target: callee_nid.clone(),
140                            relation: "calls".to_string(),
141                            confidence: Confidence::Inferred,
142                            confidence_score: Confidence::Inferred.default_score(),
143                            source_file: str_path.to_string(),
144                            source_location: None,
145                            weight: 1.0,
146                            extra: HashMap::new(),
147                        });
148                    }
149                }
150            }
151        }
152    }
153
154    trace!(
155        "treesitter({}): {} nodes, {} edges from {}",
156        lang,
157        nodes.len(),
158        edges.len(),
159        str_path
160    );
161
162    Some(ExtractionResult {
163        nodes,
164        edges,
165        hyperedges: vec![],
166    })
167}
168
169fn collect_callees(body: Node, source: &[u8], config: &TsConfig) -> Vec<String> {
170    let mut callees = Vec::new();
171    collect_callees_recursive(body, source, config, &mut callees);
172    callees
173}
174
175fn collect_callees_recursive(
176    node: Node,
177    source: &[u8],
178    config: &TsConfig,
179    callees: &mut Vec<String>,
180) {
181    if config.call_types.contains(node.kind())
182        && let Some(name) = extract_callee_name(node, source, config)
183    {
184        callees.push(name);
185    }
186
187    let mut cursor = node.walk();
188    for child in node.children(&mut cursor) {
189        collect_callees_recursive(child, source, config, callees);
190    }
191}
192
193fn extract_callee_name(call_node: Node, source: &[u8], config: &TsConfig) -> Option<String> {
194    let func_node = call_node.child_by_field_name(config.call_function_field)?;
195    extract_name_from_callee(func_node, source)
196}
197
198fn extract_name_from_callee(node: Node, source: &[u8]) -> Option<String> {
199    match node.kind() {
200        "identifier" | "field_identifier" => Some(node_text(node, source)),
201        "attribute" => node
202            .child_by_field_name("attribute")
203            .map(|n| node_text(n, source)),
204        "field_expression" | "member_expression" => node
205            .child_by_field_name("field")
206            .or_else(|| node.child_by_field_name("property"))
207            .map(|n| node_text(n, source)),
208        "scoped_identifier" | "qualified_identifier" => node
209            .child_by_field_name("name")
210            .map(|n| node_text(n, source)),
211        "selector_expression" => node
212            .child_by_field_name("field")
213            .map(|n| node_text(n, source)),
214        _ => None,
215    }
216}
217
218enum ElixirCallKind {
219    Import,
220    Class,
221    Function,
222    Other,
223}
224
225fn classify_elixir_call(node: Node, source: &[u8], config: &TsConfig) -> ElixirCallKind {
226    let target = node
227        .child_by_field_name(config.name_field)
228        .map(|n| node_text(n, source))
229        .unwrap_or_default();
230    match target.as_str() {
231        "import" | "use" | "require" | "alias" => ElixirCallKind::Import,
232        "defmodule" | "defprotocol" | "defimpl" => ElixirCallKind::Class,
233        "def" | "defp" | "defmacro" | "defmacrop" | "defguard" | "defguardp" | "defdelegate" => {
234            ElixirCallKind::Function
235        }
236        _ => ElixirCallKind::Other,
237    }
238}
239
240pub(crate) fn walk_node(
241    node: Node,
242    source: &[u8],
243    config: &TsConfig,
244    ctx: &mut WalkContext,
245    parent_class_nid: Option<&str>,
246) {
247    let kind = node.kind();
248
249    if ctx.lang == "elixir" && kind == "call" {
250        match classify_elixir_call(node, source, config) {
251            ElixirCallKind::Import => {
252                imports::extract_import(
253                    node,
254                    source,
255                    ctx.file_nid,
256                    ctx.str_path,
257                    ctx.lang,
258                    ctx.edges,
259                    ctx.nodes,
260                );
261                return;
262            }
263            ElixirCallKind::Class => {
264                handlers::handle_class_like(node, source, config, ctx);
265                return;
266            }
267            ElixirCallKind::Function => {
268                handlers::handle_function(node, source, config, ctx, parent_class_nid);
269                return;
270            }
271            ElixirCallKind::Other => {}
272        }
273    } else if config.import_types.contains(kind) {
274        if ctx.lang == "ruby" && kind == "call" {
275            let method_name = node
276                .child_by_field_name("method")
277                .map(|n| node_text(n, source))
278                .unwrap_or_default();
279            if method_name == "require" || method_name == "require_relative" {
280                imports::extract_import(
281                    node,
282                    source,
283                    ctx.file_nid,
284                    ctx.str_path,
285                    ctx.lang,
286                    ctx.edges,
287                    ctx.nodes,
288                );
289                return;
290            }
291        } else {
292            imports::extract_import(
293                node,
294                source,
295                ctx.file_nid,
296                ctx.str_path,
297                ctx.lang,
298                ctx.edges,
299                ctx.nodes,
300            );
301            return;
302        }
303    } else if config.class_types.contains(kind) {
304        handlers::handle_class_like(node, source, config, ctx);
305        return;
306    } else if config.function_types.contains(kind) {
307        handlers::handle_function(node, source, config, ctx, parent_class_nid);
308        return;
309    }
310
311    let mut cursor = node.walk();
312    for child in node.children(&mut cursor) {
313        walk_node(child, source, config, ctx, parent_class_nid);
314    }
315}
316
317pub(crate) struct WalkContext<'a> {
318    pub lang: &'a str,
319    pub file_nid: &'a str,
320    pub str_path: &'a str,
321    pub nodes: &'a mut Vec<GraphNode>,
322    pub edges: &'a mut Vec<GraphEdge>,
323    pub seen_ids: &'a mut HashSet<String>,
324    pub raw_calls: &'a mut Vec<(String, String)>,
325    pub ruby_bodies: &'a mut Vec<(String, usize, usize)>,
326}
327
328pub(crate) fn node_text(node: Node, source: &[u8]) -> String {
329    node.utf8_text(source).unwrap_or("").to_string()
330}
331
332pub(crate) fn get_name(node: Node, source: &[u8], field: &str) -> Option<String> {
333    let name_node = node.child_by_field_name(field)?;
334    let text = unwrap_declarator_name(name_node, source);
335    if text.is_empty() { None } else { Some(text) }
336}
337
338pub(crate) fn unwrap_declarator_name(node: Node, source: &[u8]) -> String {
339    match node.kind() {
340        "function_declarator"
341        | "pointer_declarator"
342        | "reference_declarator"
343        | "parenthesized_declarator" => {
344            if let Some(inner) = node.child_by_field_name("declarator") {
345                return unwrap_declarator_name(inner, source);
346            }
347            let mut cursor = node.walk();
348            for child in node.children(&mut cursor) {
349                if child.kind() == "identifier" || child.kind() == "field_identifier" {
350                    return node_text(child, source);
351                }
352            }
353            node_text(node, source)
354        }
355        "qualified_identifier" | "scoped_identifier" => {
356            if let Some(name) = node.child_by_field_name("name") {
357                return node_text(name, source);
358            }
359            node_text(node, source)
360        }
361        _ => node_text(node, source),
362    }
363}
364
365pub(crate) fn make_edge(
366    source_id: &str,
367    target_id: &str,
368    relation: &str,
369    source_file: &str,
370    line: usize,
371) -> GraphEdge {
372    GraphEdge {
373        source: source_id.to_string(),
374        target: target_id.to_string(),
375        relation: relation.to_string(),
376        confidence: Confidence::Extracted,
377        confidence_score: Confidence::Extracted.default_score(),
378        source_file: source_file.to_string(),
379        source_location: Some(format!("L{line}")),
380        weight: 1.0,
381        extra: HashMap::new(),
382    }
383}