Skip to main content

graphify_extract/treesitter/
mod.rs

1//! Tree-sitter based AST extraction engine.
2//!
3//! Provides accurate structural extraction using native tree-sitter grammars
4//! for Python, JavaScript, TypeScript, Rust, Go, Java, C, C++, Ruby, C#, and Dart.
5//! Falls back gracefully to the regex-based extractor for unsupported languages.
6
7mod handlers;
8mod imports;
9mod treesitter_config;
10
11pub use treesitter_config::TsConfig;
12
13use std::collections::{HashMap, HashSet};
14use std::path::Path;
15
16use graphify_core::confidence::Confidence;
17use graphify_core::id::make_id;
18use graphify_core::model::{ExtractionResult, GraphEdge, GraphNode, NodeType};
19use tracing::trace;
20use tree_sitter::{Language, Node, Parser};
21
22pub fn try_extract(path: &Path, source: &[u8], lang: &str) -> Option<ExtractionResult> {
23    let (language, config) = treesitter_config::resolve_language(lang)?;
24    extract_with_treesitter(path, source, language, &config, lang)
25}
26
27fn extract_with_treesitter(
28    path: &Path,
29    source: &[u8],
30    language: Language,
31    config: &TsConfig,
32    lang: &str,
33) -> Option<ExtractionResult> {
34    let mut parser = Parser::new();
35    parser.set_language(&language).ok()?;
36    let tree = parser.parse(source, None)?;
37    let root = tree.root_node();
38
39    let stem = path
40        .file_stem()
41        .and_then(|s| s.to_str())
42        .unwrap_or("unknown");
43    let str_path = path.to_string_lossy();
44
45    let mut nodes = Vec::new();
46    let mut edges = Vec::new();
47    let mut seen_ids = HashSet::new();
48    let mut raw_calls: Vec<(String, String)> = Vec::new();
49    let mut ruby_bodies: Vec<(String, usize, usize)> = Vec::new();
50
51    let file_nid = make_id(&[&str_path]);
52    seen_ids.insert(file_nid.clone());
53    nodes.push(GraphNode {
54        id: file_nid.clone(),
55        label: stem.to_string(),
56        source_file: str_path.to_string(),
57        source_location: None,
58        node_type: NodeType::File,
59        community: None,
60        extra: HashMap::new(),
61    });
62
63    {
64        let mut ctx = WalkContext {
65            lang,
66            file_nid: &file_nid,
67            str_path: &str_path,
68            nodes: &mut nodes,
69            edges: &mut edges,
70            seen_ids: &mut seen_ids,
71            raw_calls: &mut raw_calls,
72            ruby_bodies: &mut ruby_bodies,
73        };
74        walk_node(root, source, config, &mut ctx, None);
75    }
76
77    let label_to_nid: HashMap<String, String> = nodes
78        .iter()
79        .filter(|n| matches!(n.node_type, NodeType::Function | NodeType::Method))
80        .map(|n| {
81            let normalized = n
82                .label
83                .trim_end_matches("()")
84                .trim_start_matches('.')
85                .to_lowercase();
86            (normalized, n.id.clone())
87        })
88        .collect();
89
90    let mut seen_calls: HashSet<(String, String)> = HashSet::new();
91    for (caller_nid, callee_name) in &raw_calls {
92        let name_lower = callee_name.to_lowercase();
93        if let Some(callee_nid) = label_to_nid.get(&name_lower) {
94            if callee_nid == caller_nid {
95                continue;
96            }
97            let key = (caller_nid.clone(), callee_nid.clone());
98            if seen_calls.insert(key) {
99                edges.push(GraphEdge {
100                    source: caller_nid.clone(),
101                    target: callee_nid.clone(),
102                    relation: "calls".to_string(),
103                    confidence: Confidence::Inferred,
104                    confidence_score: Confidence::Inferred.default_score(),
105                    source_file: str_path.to_string(),
106                    source_location: None,
107                    weight: 1.0,
108                    extra: HashMap::new(),
109                });
110            }
111        }
112    }
113
114    if lang == "ruby" {
115        for (caller_nid, body_start, body_end) in &ruby_bodies {
116            let body_text = &source[*body_start..*body_end];
117            let body_str = String::from_utf8_lossy(body_text);
118            let body_lower = body_str.to_lowercase();
119            for (func_label, callee_nid) in &label_to_nid {
120                if callee_nid == caller_nid {
121                    continue;
122                }
123                let found = body_lower.find(func_label.as_str()).is_some_and(|pos| {
124                    let after = pos + func_label.len();
125                    if after >= body_lower.len() {
126                        true
127                    } else {
128                        let next_ch = body_lower.as_bytes()[after];
129                        !next_ch.is_ascii_alphanumeric() && next_ch != b'_'
130                    }
131                });
132                if found {
133                    let key = (caller_nid.clone(), callee_nid.clone());
134                    if seen_calls.insert(key) {
135                        edges.push(GraphEdge {
136                            source: caller_nid.clone(),
137                            target: callee_nid.clone(),
138                            relation: "calls".to_string(),
139                            confidence: Confidence::Inferred,
140                            confidence_score: Confidence::Inferred.default_score(),
141                            source_file: str_path.to_string(),
142                            source_location: None,
143                            weight: 1.0,
144                            extra: HashMap::new(),
145                        });
146                    }
147                }
148            }
149        }
150    }
151
152    trace!(
153        "treesitter({}): {} nodes, {} edges from {}",
154        lang,
155        nodes.len(),
156        edges.len(),
157        str_path
158    );
159
160    Some(ExtractionResult {
161        nodes,
162        edges,
163        hyperedges: vec![],
164    })
165}
166
167fn collect_callees(body: Node, source: &[u8], config: &TsConfig) -> Vec<String> {
168    let mut callees = Vec::new();
169    collect_callees_recursive(body, source, config, &mut callees);
170    callees
171}
172
173fn collect_callees_recursive(
174    node: Node,
175    source: &[u8],
176    config: &TsConfig,
177    callees: &mut Vec<String>,
178) {
179    if config.call_types.contains(node.kind())
180        && let Some(name) = extract_callee_name(node, source, config)
181    {
182        callees.push(name);
183    }
184
185    let mut cursor = node.walk();
186    for child in node.children(&mut cursor) {
187        collect_callees_recursive(child, source, config, callees);
188    }
189}
190
191fn extract_callee_name(call_node: Node, source: &[u8], config: &TsConfig) -> Option<String> {
192    let func_node = call_node.child_by_field_name(config.call_function_field)?;
193    extract_name_from_callee(func_node, source)
194}
195
196fn extract_name_from_callee(node: Node, source: &[u8]) -> Option<String> {
197    match node.kind() {
198        "identifier" | "field_identifier" => Some(node_text(node, source)),
199        "attribute" => node
200            .child_by_field_name("attribute")
201            .map(|n| node_text(n, source)),
202        "field_expression" | "member_expression" => node
203            .child_by_field_name("field")
204            .or_else(|| node.child_by_field_name("property"))
205            .map(|n| node_text(n, source)),
206        "scoped_identifier" | "qualified_identifier" => node
207            .child_by_field_name("name")
208            .map(|n| node_text(n, source)),
209        "selector_expression" => node
210            .child_by_field_name("field")
211            .map(|n| node_text(n, source)),
212        _ => None,
213    }
214}
215
216enum ElixirCallKind {
217    Import,
218    Class,
219    Function,
220    Other,
221}
222
223fn classify_elixir_call(node: Node, source: &[u8], config: &TsConfig) -> ElixirCallKind {
224    let target = node
225        .child_by_field_name(config.name_field)
226        .map(|n| node_text(n, source))
227        .unwrap_or_default();
228    match target.as_str() {
229        "import" | "use" | "require" | "alias" => ElixirCallKind::Import,
230        "defmodule" | "defprotocol" | "defimpl" => ElixirCallKind::Class,
231        "def" | "defp" | "defmacro" | "defmacrop" | "defguard" | "defguardp" | "defdelegate" => {
232            ElixirCallKind::Function
233        }
234        _ => ElixirCallKind::Other,
235    }
236}
237
238pub(crate) fn walk_node(
239    node: Node,
240    source: &[u8],
241    config: &TsConfig,
242    ctx: &mut WalkContext,
243    parent_class_nid: Option<&str>,
244) {
245    let kind = node.kind();
246
247    if ctx.lang == "elixir" && kind == "call" {
248        match classify_elixir_call(node, source, config) {
249            ElixirCallKind::Import => {
250                imports::extract_import(
251                    node,
252                    source,
253                    ctx.file_nid,
254                    ctx.str_path,
255                    ctx.lang,
256                    ctx.edges,
257                    ctx.nodes,
258                );
259                return;
260            }
261            ElixirCallKind::Class => {
262                handlers::handle_class_like(node, source, config, ctx);
263                return;
264            }
265            ElixirCallKind::Function => {
266                handlers::handle_function(node, source, config, ctx, parent_class_nid);
267                return;
268            }
269            ElixirCallKind::Other => {}
270        }
271    } else if config.import_types.contains(kind) {
272        if ctx.lang == "ruby" && kind == "call" {
273            let method_name = node
274                .child_by_field_name("method")
275                .map(|n| node_text(n, source))
276                .unwrap_or_default();
277            if method_name == "require" || method_name == "require_relative" {
278                imports::extract_import(
279                    node,
280                    source,
281                    ctx.file_nid,
282                    ctx.str_path,
283                    ctx.lang,
284                    ctx.edges,
285                    ctx.nodes,
286                );
287                return;
288            }
289        } else {
290            imports::extract_import(
291                node,
292                source,
293                ctx.file_nid,
294                ctx.str_path,
295                ctx.lang,
296                ctx.edges,
297                ctx.nodes,
298            );
299            return;
300        }
301    } else if config.class_types.contains(kind) {
302        handlers::handle_class_like(node, source, config, ctx);
303        return;
304    } else if config.function_types.contains(kind) {
305        handlers::handle_function(node, source, config, ctx, parent_class_nid);
306        return;
307    }
308
309    let mut cursor = node.walk();
310    for child in node.children(&mut cursor) {
311        walk_node(child, source, config, ctx, parent_class_nid);
312    }
313}
314
315pub(crate) struct WalkContext<'a> {
316    pub lang: &'a str,
317    pub file_nid: &'a str,
318    pub str_path: &'a str,
319    pub nodes: &'a mut Vec<GraphNode>,
320    pub edges: &'a mut Vec<GraphEdge>,
321    pub seen_ids: &'a mut HashSet<String>,
322    pub raw_calls: &'a mut Vec<(String, String)>,
323    pub ruby_bodies: &'a mut Vec<(String, usize, usize)>,
324}
325
326pub(crate) fn node_text(node: Node, source: &[u8]) -> String {
327    node.utf8_text(source).unwrap_or("").to_string()
328}
329
330pub(crate) fn get_name(node: Node, source: &[u8], field: &str) -> Option<String> {
331    let name_node = node.child_by_field_name(field)?;
332    let text = unwrap_declarator_name(name_node, source);
333    if text.is_empty() { None } else { Some(text) }
334}
335
336pub(crate) fn unwrap_declarator_name(node: Node, source: &[u8]) -> String {
337    match node.kind() {
338        "function_declarator"
339        | "pointer_declarator"
340        | "reference_declarator"
341        | "parenthesized_declarator" => {
342            if let Some(inner) = node.child_by_field_name("declarator") {
343                return unwrap_declarator_name(inner, source);
344            }
345            let mut cursor = node.walk();
346            for child in node.children(&mut cursor) {
347                if child.kind() == "identifier" || child.kind() == "field_identifier" {
348                    return node_text(child, source);
349                }
350            }
351            node_text(node, source)
352        }
353        "qualified_identifier" | "scoped_identifier" => {
354            if let Some(name) = node.child_by_field_name("name") {
355                return node_text(name, source);
356            }
357            node_text(node, source)
358        }
359        _ => node_text(node, source),
360    }
361}
362
363pub(crate) fn make_edge(
364    source_id: &str,
365    target_id: &str,
366    relation: &str,
367    source_file: &str,
368    line: usize,
369) -> GraphEdge {
370    GraphEdge {
371        source: source_id.to_string(),
372        target: target_id.to_string(),
373        relation: relation.to_string(),
374        confidence: Confidence::Extracted,
375        confidence_score: Confidence::Extracted.default_score(),
376        source_file: source_file.to_string(),
377        source_location: Some(format!("L{line}")),
378        weight: 1.0,
379        extra: HashMap::new(),
380    }
381}