Skip to main content

mimir_graph/
languages.rs

1//! Per-language tree-sitter adapters: which AST nodes are definitions,
2//! scopes, calls, and imports — and how to read docs/signatures off them.
3
4use tree_sitter::Node;
5
6use crate::extract::ImportRef;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum Lang {
10    Rust,
11    TypeScript,
12    Tsx,
13    Python,
14    Go,
15    Java,
16    Ruby,
17    C,
18}
19
20impl Lang {
21    pub fn from_path(path: &str) -> Option<Lang> {
22        let ext = path.rsplit('.').next()?;
23        Some(match ext {
24            "rs" => Lang::Rust,
25            "ts" | "mts" | "cts" => Lang::TypeScript,
26            "tsx" | "jsx" | "js" | "mjs" | "cjs" => Lang::Tsx,
27            "py" | "pyi" => Lang::Python,
28            "go" => Lang::Go,
29            "java" => Lang::Java,
30            "rb" | "rake" => Lang::Ruby,
31            "c" | "h" => Lang::C,
32            _ => return None,
33        })
34    }
35
36    pub fn name(&self) -> &'static str {
37        match self {
38            Lang::Rust => "rust",
39            Lang::TypeScript | Lang::Tsx => "typescript",
40            Lang::Python => "python",
41            Lang::Go => "go",
42            Lang::Java => "java",
43            Lang::Ruby => "ruby",
44            Lang::C => "c",
45        }
46    }
47
48    pub fn language(&self) -> tree_sitter::Language {
49        match self {
50            Lang::Rust => tree_sitter_rust::LANGUAGE.into(),
51            Lang::TypeScript => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
52            Lang::Tsx => tree_sitter_typescript::LANGUAGE_TSX.into(),
53            Lang::Python => tree_sitter_python::LANGUAGE.into(),
54            Lang::Go => tree_sitter_go::LANGUAGE.into(),
55            Lang::Java => tree_sitter_java::LANGUAGE.into(),
56            Lang::Ruby => tree_sitter_ruby::LANGUAGE.into(),
57            Lang::C => tree_sitter_c::LANGUAGE.into(),
58        }
59    }
60
61    pub fn separator(&self) -> String {
62        "::".into()
63    }
64
65    /// Is this node a symbol definition? Returns (name, kind). The name may
66    /// be pre-qualified with `::` (Go methods carry their receiver).
67    pub fn definition(&self, node: Node, src: &str) -> Option<(String, &'static str)> {
68        let text = |n: Node| src[n.byte_range()].to_string();
69        match self {
70            Lang::Rust => match node.kind() {
71                "function_item" => Some((text(node.child_by_field_name("name")?), "function")),
72                "struct_item" => Some((text(node.child_by_field_name("name")?), "struct")),
73                "enum_item" => Some((text(node.child_by_field_name("name")?), "enum")),
74                "trait_item" => Some((text(node.child_by_field_name("name")?), "trait")),
75                "union_item" => Some((text(node.child_by_field_name("name")?), "struct")),
76                _ => None,
77            },
78            Lang::TypeScript | Lang::Tsx => match node.kind() {
79                "function_declaration" | "generator_function_declaration" => {
80                    Some((text(node.child_by_field_name("name")?), "function"))
81                }
82                "class_declaration" => Some((text(node.child_by_field_name("name")?), "class")),
83                "method_definition" => {
84                    let name = text(node.child_by_field_name("name")?);
85                    if name == "constructor" {
86                        return None;
87                    }
88                    Some((name, "method"))
89                }
90                "interface_declaration" => {
91                    Some((text(node.child_by_field_name("name")?), "interface"))
92                }
93                "enum_declaration" => Some((text(node.child_by_field_name("name")?), "enum")),
94                "type_alias_declaration" => Some((text(node.child_by_field_name("name")?), "type")),
95                // const f = (..) => ..  /  const f = function(..) {..}
96                "variable_declarator" => {
97                    let value = node.child_by_field_name("value")?;
98                    if matches!(value.kind(), "arrow_function" | "function_expression") {
99                        let name = node.child_by_field_name("name")?;
100                        if name.kind() == "identifier" {
101                            return Some((text(name), "function"));
102                        }
103                    }
104                    None
105                }
106                _ => None,
107            },
108            Lang::Python => match node.kind() {
109                "function_definition" => {
110                    Some((text(node.child_by_field_name("name")?), "function"))
111                }
112                "class_definition" => Some((text(node.child_by_field_name("name")?), "class")),
113                _ => None,
114            },
115            Lang::Go => match node.kind() {
116                "function_declaration" => {
117                    Some((text(node.child_by_field_name("name")?), "function"))
118                }
119                "method_declaration" => {
120                    let name = text(node.child_by_field_name("name")?);
121                    let recv = node
122                        .child_by_field_name("receiver")
123                        .and_then(|r| receiver_type(r, src));
124                    Some((
125                        match recv {
126                            Some(t) => format!("{t}::{name}"),
127                            None => name,
128                        },
129                        "method",
130                    ))
131                }
132                "type_spec" => {
133                    let name = text(node.child_by_field_name("name")?);
134                    let kind = match node.child_by_field_name("type").map(|t| t.kind()) {
135                        Some("struct_type") => "struct",
136                        Some("interface_type") => "interface",
137                        _ => "type",
138                    };
139                    Some((name, kind))
140                }
141                _ => None,
142            },
143            Lang::Java => match node.kind() {
144                "class_declaration" => Some((text(node.child_by_field_name("name")?), "class")),
145                "interface_declaration" => {
146                    Some((text(node.child_by_field_name("name")?), "interface"))
147                }
148                "enum_declaration" => Some((text(node.child_by_field_name("name")?), "enum")),
149                "record_declaration" => Some((text(node.child_by_field_name("name")?), "class")),
150                "method_declaration" => Some((text(node.child_by_field_name("name")?), "method")),
151                "constructor_declaration" => {
152                    Some((text(node.child_by_field_name("name")?), "method"))
153                }
154                _ => None,
155            },
156            Lang::Ruby => match node.kind() {
157                "class" => Some((const_name(node.child_by_field_name("name")?, src), "class")),
158                "module" => Some((const_name(node.child_by_field_name("name")?, src), "module")),
159                "method" => Some((text(node.child_by_field_name("name")?), "method")),
160                // def self.foo — a class method.
161                "singleton_method" => Some((text(node.child_by_field_name("name")?), "method")),
162                _ => None,
163            },
164            Lang::C => match node.kind() {
165                "function_definition" => Some((
166                    c_declarator_name(node.child_by_field_name("declarator")?, src)?,
167                    "function",
168                )),
169                "struct_specifier" => Some((text(node.child_by_field_name("name")?), "struct")),
170                "union_specifier" => Some((text(node.child_by_field_name("name")?), "struct")),
171                "enum_specifier" => Some((text(node.child_by_field_name("name")?), "enum")),
172                _ => None,
173            },
174        }
175    }
176
177    /// Containers that qualify children without being symbols themselves.
178    pub fn scope_only(&self, node: Node, src: &str) -> Option<String> {
179        match self {
180            Lang::Rust => match node.kind() {
181                // impl Foo { .. } / impl Trait for Foo { .. } → scope "Foo"
182                "impl_item" => {
183                    let ty = node.child_by_field_name("type")?;
184                    Some(base_type_name(ty, src))
185                }
186                "mod_item" => Some(src[node.child_by_field_name("name")?.byte_range()].to_string()),
187                _ => None,
188            },
189            _ => None,
190        }
191    }
192
193    /// Field holding the body (cut point for signatures), per node kind.
194    pub fn body_field(&self) -> Option<&'static str> {
195        // All supported definition kinds use "body" except TS declarators,
196        // which signature_text handles via the generic fallback.
197        Some("body")
198    }
199
200    /// If this node is a call, return the bare callee name.
201    pub fn call(&self, node: Node, src: &str) -> Option<String> {
202        let text = |n: Node| src[n.byte_range()].to_string();
203        match self {
204            Lang::Rust => {
205                if node.kind() != "call_expression" {
206                    return None;
207                }
208                let f = node.child_by_field_name("function")?;
209                match f.kind() {
210                    "identifier" => Some(text(f)),
211                    "field_expression" => f.child_by_field_name("field").map(text),
212                    "scoped_identifier" => f.child_by_field_name("name").map(text),
213                    "generic_function" => {
214                        let inner = f.child_by_field_name("function")?;
215                        match inner.kind() {
216                            "identifier" => Some(text(inner)),
217                            "scoped_identifier" => inner.child_by_field_name("name").map(text),
218                            _ => None,
219                        }
220                    }
221                    _ => None,
222                }
223            }
224            Lang::TypeScript | Lang::Tsx => {
225                if node.kind() != "call_expression" {
226                    return None;
227                }
228                let f = node.child_by_field_name("function")?;
229                match f.kind() {
230                    "identifier" => Some(text(f)),
231                    "member_expression" => f.child_by_field_name("property").map(text),
232                    _ => None,
233                }
234            }
235            Lang::Python => {
236                if node.kind() != "call" {
237                    return None;
238                }
239                let f = node.child_by_field_name("function")?;
240                match f.kind() {
241                    "identifier" => Some(text(f)),
242                    "attribute" => f.child_by_field_name("attribute").map(text),
243                    _ => None,
244                }
245            }
246            Lang::Go => {
247                if node.kind() != "call_expression" {
248                    return None;
249                }
250                let f = node.child_by_field_name("function")?;
251                match f.kind() {
252                    "identifier" => Some(text(f)),
253                    "selector_expression" => f.child_by_field_name("field").map(text),
254                    _ => None,
255                }
256            }
257            Lang::Java => {
258                if node.kind() != "method_invocation" {
259                    return None;
260                }
261                node.child_by_field_name("name").map(text)
262            }
263            Lang::Ruby => {
264                // `foo(...)`, `obj.foo(...)`, `obj.foo` — the method name.
265                if node.kind() != "call" {
266                    return None;
267                }
268                node.child_by_field_name("method").map(text)
269            }
270            Lang::C => {
271                if node.kind() != "call_expression" {
272                    return None;
273                }
274                let f = node.child_by_field_name("function")?;
275                match f.kind() {
276                    "identifier" => Some(text(f)),
277                    _ => None,
278                }
279            }
280        }
281    }
282
283    /// Collect imports declared by this node.
284    pub fn imports(&self, node: Node, src: &str, out: &mut Vec<ImportRef>) {
285        let text = |n: Node| src[n.byte_range()].to_string();
286        match self {
287            Lang::Rust => {
288                if node.kind() == "use_declaration" {
289                    if let Some(arg) = node.child_by_field_name("argument") {
290                        rust_use_tree(arg, src, "", out);
291                    }
292                }
293            }
294            Lang::TypeScript | Lang::Tsx => {
295                if node.kind() != "import_statement" {
296                    return;
297                }
298                let Some(source) = node
299                    .child_by_field_name("source")
300                    .map(|s| text(s).trim_matches(['"', '\'']).to_string())
301                else {
302                    return;
303                };
304                let mut cursor = node.walk();
305                for child in node.children(&mut cursor) {
306                    if child.kind() != "import_clause" {
307                        continue;
308                    }
309                    let mut c2 = child.walk();
310                    for part in child.children(&mut c2) {
311                        match part.kind() {
312                            "identifier" => out.push(ImportRef {
313                                local: text(part),
314                                source: source.clone(),
315                            }),
316                            "named_imports" => {
317                                let mut c3 = part.walk();
318                                for spec in part.children(&mut c3) {
319                                    if spec.kind() != "import_specifier" {
320                                        continue;
321                                    }
322                                    let local = spec
323                                        .child_by_field_name("alias")
324                                        .or_else(|| spec.child_by_field_name("name"))
325                                        .map(text);
326                                    if let Some(local) = local {
327                                        out.push(ImportRef {
328                                            local,
329                                            source: source.clone(),
330                                        });
331                                    }
332                                }
333                            }
334                            "namespace_import" => {
335                                // import * as ns from "x"
336                                let mut c3 = part.walk();
337                                for id in part.children(&mut c3) {
338                                    if id.kind() == "identifier" {
339                                        out.push(ImportRef {
340                                            local: text(id),
341                                            source: source.clone(),
342                                        });
343                                    }
344                                }
345                            }
346                            _ => {}
347                        }
348                    }
349                }
350            }
351            Lang::Python => match node.kind() {
352                "import_statement" => {
353                    let mut cursor = node.walk();
354                    for child in node.children(&mut cursor) {
355                        match child.kind() {
356                            "dotted_name" => out.push(ImportRef {
357                                local: text(child)
358                                    .rsplit('.')
359                                    .next()
360                                    .unwrap_or_default()
361                                    .to_string(),
362                                source: text(child),
363                            }),
364                            "aliased_import" => {
365                                let name = child.child_by_field_name("name").map(text);
366                                let alias = child.child_by_field_name("alias").map(text);
367                                if let (Some(name), Some(alias)) = (name, alias) {
368                                    out.push(ImportRef {
369                                        local: alias,
370                                        source: name,
371                                    });
372                                }
373                            }
374                            _ => {}
375                        }
376                    }
377                }
378                "import_from_statement" => {
379                    let Some(module) = node.child_by_field_name("module_name").map(text) else {
380                        return;
381                    };
382                    let mut cursor = node.walk();
383                    let mut past_import = false;
384                    for child in node.children(&mut cursor) {
385                        if child.kind() == "import" {
386                            past_import = true;
387                            continue;
388                        }
389                        if !past_import {
390                            continue;
391                        }
392                        match child.kind() {
393                            "dotted_name" => out.push(ImportRef {
394                                local: text(child),
395                                source: module.clone(),
396                            }),
397                            "aliased_import" => {
398                                if let Some(alias) = child.child_by_field_name("alias").map(text) {
399                                    out.push(ImportRef {
400                                        local: alias,
401                                        source: module.clone(),
402                                    });
403                                }
404                            }
405                            _ => {}
406                        }
407                    }
408                }
409                _ => {}
410            },
411            Lang::Go => {
412                if node.kind() != "import_spec" {
413                    return;
414                }
415                let Some(path) = node
416                    .child_by_field_name("path")
417                    .map(|p| text(p).trim_matches('"').to_string())
418                else {
419                    return;
420                };
421                let local = node
422                    .child_by_field_name("name")
423                    .map(text)
424                    .unwrap_or_else(|| path.rsplit('/').next().unwrap_or(&path).to_string());
425                out.push(ImportRef {
426                    local,
427                    source: path,
428                });
429            }
430            Lang::Java => {
431                // import a.b.C;  /  import static a.b.C.m;  → bind the last segment.
432                if node.kind() != "import_declaration" {
433                    return;
434                }
435                let mut cursor = node.walk();
436                let Some(scoped) = node
437                    .children(&mut cursor)
438                    .find(|c| c.kind() == "scoped_identifier")
439                else {
440                    return;
441                };
442                let source = text(scoped);
443                let local = source.rsplit('.').next().unwrap_or(&source).to_string();
444                out.push(ImportRef { local, source });
445            }
446            Lang::C => {
447                // #include "foo.h" / <foo.h> → a file→file edge by path.
448                if node.kind() != "preproc_include" {
449                    return;
450                }
451                let Some(path_node) = node.child_by_field_name("path") else {
452                    return;
453                };
454                let source = text(path_node).trim_matches(['"', '<', '>']).to_string();
455                let local = source
456                    .rsplit('/')
457                    .next()
458                    .unwrap_or(&source)
459                    .trim_end_matches(".h")
460                    .to_string();
461                out.push(ImportRef { local, source });
462            }
463            // Ruby's `require` is a method call, not an import node; calls
464            // still resolve same-file (tier 1) and globally by name (tier 3).
465            Lang::Ruby => {}
466        }
467    }
468
469    /// Doc comment attached to a definition node.
470    pub fn doc_comment(&self, node: Node, src: &str) -> Option<String> {
471        match self {
472            Lang::Python => {
473                // Docstring: first statement of the body is a string literal.
474                let body = node.child_by_field_name("body")?;
475                let first = body.named_child(0)?;
476                if first.kind() != "expression_statement" {
477                    return None;
478                }
479                let s = first.named_child(0)?;
480                if s.kind() != "string" {
481                    return None;
482                }
483                let raw = &src[s.byte_range()];
484                let cleaned = raw
485                    .trim_start_matches(['r', 'b', 'f', 'u', 'R', 'B', 'F', 'U'])
486                    .trim_matches(['"', '\''])
487                    .trim();
488                Some(cleaned.lines().next().unwrap_or("").trim().to_string())
489                    .filter(|s| !s.is_empty())
490            }
491            Lang::Rust
492            | Lang::Go
493            | Lang::TypeScript
494            | Lang::Tsx
495            | Lang::Java
496            | Lang::C
497            | Lang::Ruby => {
498                // Contiguous comment siblings directly above the node
499                // (a blank line breaks the chain; `//!` belongs to the
500                // module, not this item).
501                let mut lines: Vec<String> = Vec::new();
502                let mut expect_row = node.start_position().row;
503                let mut prev = node.prev_sibling();
504                while let Some(p) = prev {
505                    if !p.kind().contains("comment")
506                        || expect_row.saturating_sub(p.end_position().row) > 1
507                        || src[p.byte_range()].starts_with("//!")
508                    {
509                        break;
510                    }
511                    lines.push(src[p.byte_range()].to_string());
512                    expect_row = p.start_position().row;
513                    prev = p.prev_sibling();
514                }
515                if lines.is_empty() {
516                    return None;
517                }
518                lines.reverse();
519                let cleaned: Vec<String> = lines
520                    .iter()
521                    .flat_map(|c| c.lines())
522                    .map(|l| {
523                        l.trim()
524                            .trim_start_matches("///")
525                            .trim_start_matches("//!")
526                            .trim_start_matches("//")
527                            .trim_start_matches("/**")
528                            .trim_start_matches("/*")
529                            .trim_end_matches("*/")
530                            .trim_start_matches('*')
531                            .trim_start_matches('#') // Ruby line comments
532                            .trim()
533                            .to_string()
534                    })
535                    .filter(|l| !l.is_empty())
536                    .collect();
537                if cleaned.is_empty() {
538                    None
539                } else {
540                    Some(cleaned.join(" ").chars().take(300).collect())
541                }
542            }
543        }
544    }
545}
546
547/// `impl Foo`, `impl Foo<T>`, `impl Trait for Foo<T>` → "Foo".
548fn base_type_name(ty: Node, src: &str) -> String {
549    match ty.kind() {
550        "generic_type" => ty
551            .child_by_field_name("type")
552            .map(|t| src[t.byte_range()].to_string())
553            .unwrap_or_else(|| src[ty.byte_range()].to_string()),
554        _ => src[ty.byte_range()].to_string(),
555    }
556}
557
558/// Ruby class/module name: a bare `constant` or `A::B` scope_resolution —
559/// keep the last segment as the local name.
560fn const_name(node: Node, src: &str) -> String {
561    let full = src[node.byte_range()].to_string();
562    full.rsplit("::").next().unwrap_or(&full).to_string()
563}
564
565/// C function name: unwrap pointer/function declarators down to the
566/// identifier — `*foo(...)`, `foo(...)`, `(*foo)(...)` all yield "foo".
567fn c_declarator_name(node: Node, src: &str) -> Option<String> {
568    match node.kind() {
569        "identifier" => Some(src[node.byte_range()].to_string()),
570        "function_declarator" | "pointer_declarator" | "parenthesized_declarator" => {
571            c_declarator_name(node.child_by_field_name("declarator")?, src)
572        }
573        _ => {
574            // Fall back to the first identifier descendant.
575            let mut cursor = node.walk();
576            for child in node.children(&mut cursor) {
577                if let Some(name) = c_declarator_name(child, src) {
578                    return Some(name);
579                }
580            }
581            None
582        }
583    }
584}
585
586/// Go receiver `(s *Server)` → "Server".
587fn receiver_type(receiver: Node, src: &str) -> Option<String> {
588    let mut cursor = receiver.walk();
589    for child in receiver.children(&mut cursor) {
590        if child.kind() == "parameter_declaration" {
591            let ty = child.child_by_field_name("type")?;
592            let base = match ty.kind() {
593                "pointer_type" => ty.named_child(0)?,
594                _ => ty,
595            };
596            return Some(src[base.byte_range()].to_string());
597        }
598    }
599    None
600}
601
602/// Rust use-tree walker: `use a::{b::C, d as E};` → C←a::b::C, E←a::d.
603fn rust_use_tree(node: Node, src: &str, prefix: &str, out: &mut Vec<ImportRef>) {
604    let text = |n: Node| src[n.byte_range()].to_string();
605    let join = |prefix: &str, seg: &str| {
606        if prefix.is_empty() {
607            seg.to_string()
608        } else {
609            format!("{prefix}::{seg}")
610        }
611    };
612    match node.kind() {
613        "identifier" | "crate" | "self" | "super" => {
614            let seg = text(node);
615            out.push(ImportRef {
616                local: seg.clone(),
617                source: join(prefix, &seg),
618            });
619        }
620        "scoped_identifier" => {
621            let full = join(prefix, &text(node));
622            let local = node
623                .child_by_field_name("name")
624                .map(text)
625                .unwrap_or_default();
626            if !local.is_empty() {
627                out.push(ImportRef {
628                    local,
629                    source: full,
630                });
631            }
632        }
633        "use_as_clause" => {
634            let alias = node.child_by_field_name("alias").map(text);
635            let path = node.child_by_field_name("path").map(text);
636            if let (Some(alias), Some(path)) = (alias, path) {
637                out.push(ImportRef {
638                    local: alias,
639                    source: join(prefix, &path),
640                });
641            }
642        }
643        "scoped_use_list" => {
644            let new_prefix = node
645                .child_by_field_name("path")
646                .map(|p| join(prefix, &text(p)))
647                .unwrap_or_else(|| prefix.to_string());
648            if let Some(list) = node.child_by_field_name("list") {
649                let mut cursor = list.walk();
650                for child in list.named_children(&mut cursor) {
651                    rust_use_tree(child, src, &new_prefix, out);
652                }
653            }
654        }
655        "use_list" => {
656            let mut cursor = node.walk();
657            for child in node.named_children(&mut cursor) {
658                rust_use_tree(child, src, prefix, out);
659            }
660        }
661        // use_wildcard and attributes: nothing useful to bind.
662        _ => {}
663    }
664}