Skip to main content

mimir_graph/
extract.rs

1//! Tree-sitter symbol/call/import extraction. No LLM, no type checker —
2//! honest static extraction with explicit confidence tiers downstream.
3
4use tree_sitter::{Node, Parser};
5
6use crate::languages::Lang;
7
8#[derive(Debug, Clone, PartialEq)]
9pub struct SymbolDef {
10    /// Bare name (resolution bucket), e.g. "resolve_ref".
11    pub name: String,
12    /// Nesting-qualified, e.g. "MatrixCache::ensure" / "ClassName.method".
13    pub qualified: String,
14    /// function | method | struct | class | trait | enum | interface | type
15    pub kind: &'static str,
16    /// Signature line(s) — what gets embedded alongside the doc comment.
17    pub signature: String,
18    pub doc: Option<String>,
19    /// 1-based, inclusive.
20    pub start_line: usize,
21    pub end_line: usize,
22}
23
24#[derive(Debug, Clone, PartialEq)]
25pub struct CallSite {
26    /// Qualified name of the enclosing definition ("" = file top level).
27    pub caller: String,
28    /// Bare callee name as written (rightmost path segment).
29    pub callee: String,
30}
31
32#[derive(Debug, Clone, PartialEq)]
33pub struct ImportRef {
34    /// Name bound locally (rightmost segment or alias).
35    pub local: String,
36    /// Module/path text as written ("./util", "foo::bar", "pkg.mod").
37    pub source: String,
38}
39
40#[derive(Debug, Default)]
41pub struct FileExtract {
42    pub symbols: Vec<SymbolDef>,
43    pub calls: Vec<CallSite>,
44    pub imports: Vec<ImportRef>,
45}
46
47pub fn extract(lang: Lang, source: &str) -> FileExtract {
48    let mut parser = Parser::new();
49    if parser.set_language(&lang.language()).is_err() {
50        return FileExtract::default();
51    }
52    let Some(tree) = parser.parse(source, None) else {
53        return FileExtract::default();
54    };
55    let mut out = FileExtract::default();
56    walk(lang, tree.root_node(), source, &mut Vec::new(), &mut out);
57    out
58}
59
60/// Recursive walk keeping a stack of enclosing definition names.
61fn walk(lang: Lang, node: Node, src: &str, scope: &mut Vec<String>, out: &mut FileExtract) {
62    let mut pushed = false;
63
64    if let Some((name, kind)) = lang.definition(node, src) {
65        let qualified = qualify(scope, &name);
66        // Methods: a function nested inside a type/class scope.
67        let kind = if kind == "function" && !scope.is_empty() {
68            "method"
69        } else {
70            kind
71        };
72        out.symbols.push(SymbolDef {
73            signature: signature_text(lang, node, src),
74            doc: lang.doc_comment(node, src),
75            start_line: node.start_position().row + 1,
76            end_line: node.end_position().row + 1,
77            name: qualified_tail(&name),
78            qualified: qualified.clone(),
79            kind,
80        });
81        // Push the name as returned (it may carry a `::` receiver prefix,
82        // e.g. Go methods) so scope.join("::") == qualified for children.
83        scope.push(name);
84        pushed = true;
85    } else if let Some(scope_name) = lang.scope_only(node, src) {
86        // Containers that qualify children but aren't symbols themselves
87        // (Rust impl blocks, modules).
88        scope.push(scope_name);
89        pushed = true;
90    }
91
92    if let Some(callee) = lang.call(node, src) {
93        out.calls.push(CallSite {
94            caller: scope.join(&lang.separator()),
95            callee,
96        });
97    }
98    lang.imports(node, src, &mut out.imports);
99
100    let mut cursor = node.walk();
101    for child in node.children(&mut cursor) {
102        walk(lang, child, src, scope, out);
103    }
104    if pushed {
105        scope.pop();
106    }
107}
108
109fn qualify(scope: &[String], name: &str) -> String {
110    if scope.is_empty() {
111        name.to_string()
112    } else {
113        format!("{}::{name}", scope.join("::"))
114    }
115}
116
117fn qualified_tail(qualified: &str) -> String {
118    qualified
119        .rsplit("::")
120        .next()
121        .unwrap_or(qualified)
122        .to_string()
123}
124
125/// Text from the definition start to its body — the signature.
126fn signature_text(lang: Lang, node: Node, src: &str) -> String {
127    let full = &src[node.byte_range()];
128    let cut = lang
129        .body_field()
130        .and_then(|f| node.child_by_field_name(f))
131        .map(|b| b.start_byte().saturating_sub(node.start_byte()))
132        .unwrap_or(full.len());
133    let sig: String = full[..cut].split_whitespace().collect::<Vec<_>>().join(" ");
134    // Defensive cap: pathological one-line definitions.
135    sig.chars().take(300).collect()
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    fn names(fx: &FileExtract) -> Vec<(&str, &str)> {
143        fx.symbols
144            .iter()
145            .map(|s| (s.qualified.as_str(), s.kind))
146            .collect()
147    }
148
149    #[test]
150    fn rust_extraction() {
151        let src = r#"
152//! module docs
153
154/// Adds things.
155pub fn add(a: i32, b: i32) -> i32 { helper(a) + b }
156
157fn helper(x: i32) -> i32 { x }
158
159pub struct Counter { n: u64 }
160
161impl Counter {
162    /// Bump it.
163    pub fn bump(&mut self) { self.n += 1; validate(self.n); }
164}
165
166pub trait Resettable { fn reset(&mut self); }
167
168pub enum Mode { A, B }
169
170use std::collections::HashMap;
171use crate::store::resolve_ref as rr;
172"#;
173        let fx = extract(Lang::Rust, src);
174        let n = names(&fx);
175        assert!(n.contains(&("add", "function")), "{n:?}");
176        assert!(n.contains(&("helper", "function")), "{n:?}");
177        assert!(n.contains(&("Counter", "struct")), "{n:?}");
178        assert!(n.contains(&("Counter::bump", "method")), "{n:?}");
179        assert!(n.contains(&("Resettable", "trait")), "{n:?}");
180        assert!(n.contains(&("Mode", "enum")), "{n:?}");
181
182        let add = fx.symbols.iter().find(|s| s.qualified == "add").unwrap();
183        assert_eq!(add.doc.as_deref(), Some("Adds things."));
184        assert!(add.signature.contains("pub fn add(a: i32, b: i32) -> i32"));
185
186        let calls: Vec<(&str, &str)> = fx
187            .calls
188            .iter()
189            .map(|c| (c.caller.as_str(), c.callee.as_str()))
190            .collect();
191        assert!(calls.contains(&("add", "helper")), "{calls:?}");
192        assert!(calls.contains(&("Counter::bump", "validate")), "{calls:?}");
193
194        let imports: Vec<(&str, &str)> = fx
195            .imports
196            .iter()
197            .map(|i| (i.local.as_str(), i.source.as_str()))
198            .collect();
199        assert!(
200            imports.contains(&("HashMap", "std::collections::HashMap")),
201            "{imports:?}"
202        );
203        assert!(
204            imports.contains(&("rr", "crate::store::resolve_ref")),
205            "{imports:?}"
206        );
207    }
208
209    #[test]
210    fn typescript_extraction() {
211        let src = r#"
212import { fetchUser, postUser as pu } from "./api";
213import db from "../db";
214
215/** Greets. */
216export function greet(name: string): string { return hello(name); }
217
218const shout = (s: string) => s.toUpperCase();
219
220export class UserService {
221    find(id: number) { return fetchUser(id); }
222}
223
224interface Shape { area(): number; }
225"#;
226        let fx = extract(Lang::TypeScript, src);
227        let n = names(&fx);
228        assert!(n.contains(&("greet", "function")), "{n:?}");
229        assert!(n.contains(&("shout", "function")), "{n:?}");
230        assert!(n.contains(&("UserService", "class")), "{n:?}");
231        assert!(n.contains(&("UserService::find", "method")), "{n:?}");
232        assert!(n.contains(&("Shape", "interface")), "{n:?}");
233
234        let calls: Vec<(&str, &str)> = fx
235            .calls
236            .iter()
237            .map(|c| (c.caller.as_str(), c.callee.as_str()))
238            .collect();
239        assert!(calls.contains(&("greet", "hello")), "{calls:?}");
240        assert!(
241            calls.contains(&("UserService::find", "fetchUser")),
242            "{calls:?}"
243        );
244
245        let imports: Vec<(&str, &str)> = fx
246            .imports
247            .iter()
248            .map(|i| (i.local.as_str(), i.source.as_str()))
249            .collect();
250        assert!(imports.contains(&("fetchUser", "./api")), "{imports:?}");
251        assert!(imports.contains(&("pu", "./api")), "{imports:?}");
252        assert!(imports.contains(&("db", "../db")), "{imports:?}");
253    }
254
255    #[test]
256    fn python_extraction() {
257        let src = r#"
258import os
259from collections import OrderedDict as OD
260from .util import slugify
261
262def top(x):
263    """Top-level docstring."""
264    return slugify(x)
265
266class Repo:
267    def save(self, item):
268        validate(item)
269        return persist(item)
270"#;
271        let fx = extract(Lang::Python, src);
272        let n = names(&fx);
273        assert!(n.contains(&("top", "function")), "{n:?}");
274        assert!(n.contains(&("Repo", "class")), "{n:?}");
275        assert!(n.contains(&("Repo::save", "method")), "{n:?}");
276
277        let top_sym = fx.symbols.iter().find(|s| s.qualified == "top").unwrap();
278        assert_eq!(top_sym.doc.as_deref(), Some("Top-level docstring."));
279
280        let calls: Vec<(&str, &str)> = fx
281            .calls
282            .iter()
283            .map(|c| (c.caller.as_str(), c.callee.as_str()))
284            .collect();
285        assert!(calls.contains(&("top", "slugify")), "{calls:?}");
286        assert!(calls.contains(&("Repo::save", "validate")), "{calls:?}");
287
288        let imports: Vec<(&str, &str)> = fx
289            .imports
290            .iter()
291            .map(|i| (i.local.as_str(), i.source.as_str()))
292            .collect();
293        assert!(imports.contains(&("os", "os")), "{imports:?}");
294        assert!(imports.contains(&("OD", "collections")), "{imports:?}");
295        assert!(imports.contains(&("slugify", ".util")), "{imports:?}");
296    }
297
298    #[test]
299    fn go_extraction() {
300        let src = r#"
301package main
302
303import (
304    "fmt"
305    alias "net/http"
306)
307
308// Greet says hi.
309func Greet(name string) string { return fmt.Sprintf("hi %s", name) }
310
311type Server struct{ port int }
312
313func (s *Server) Start() error { return listen(s.port) }
314"#;
315        let fx = extract(Lang::Go, src);
316        let n = names(&fx);
317        assert!(n.contains(&("Greet", "function")), "{n:?}");
318        assert!(n.contains(&("Server", "struct")), "{n:?}");
319        assert!(n.contains(&("Server::Start", "method")), "{n:?}");
320
321        let greet = fx.symbols.iter().find(|s| s.qualified == "Greet").unwrap();
322        assert_eq!(greet.doc.as_deref(), Some("Greet says hi."));
323
324        let calls: Vec<(&str, &str)> = fx
325            .calls
326            .iter()
327            .map(|c| (c.caller.as_str(), c.callee.as_str()))
328            .collect();
329        assert!(calls.contains(&("Greet", "Sprintf")), "{calls:?}");
330        assert!(calls.contains(&("Server::Start", "listen")), "{calls:?}");
331
332        let imports: Vec<(&str, &str)> = fx
333            .imports
334            .iter()
335            .map(|i| (i.local.as_str(), i.source.as_str()))
336            .collect();
337        assert!(imports.contains(&("fmt", "fmt")), "{imports:?}");
338        assert!(imports.contains(&("alias", "net/http")), "{imports:?}");
339    }
340
341    #[test]
342    fn broken_source_does_not_panic() {
343        for lang in [Lang::Rust, Lang::TypeScript, Lang::Python, Lang::Go] {
344            extract(lang, "fn class def func ((((");
345            extract(lang, "");
346        }
347    }
348}