Skip to main content

lean_ctx/core/deep_queries/
mod.rs

1//! Tree-sitter deep queries for extracting imports, call sites, and type definitions.
2//!
3//! Replaces regex-based extraction in `deps.rs` with precise AST parsing.
4//! Supports: TypeScript/JavaScript, Python, Rust, Go, Java.
5
6mod calls;
7mod imports;
8mod type_defs;
9mod types;
10
11pub use types::*;
12
13#[cfg(feature = "tree-sitter")]
14use tree_sitter::{Language, Node, Parser};
15
16pub fn analyze(content: &str, ext: &str) -> DeepAnalysis {
17    #[cfg(feature = "tree-sitter")]
18    {
19        if let Some(result) = analyze_with_tree_sitter(content, ext) {
20            return result;
21        }
22    }
23
24    let _ = (content, ext);
25    DeepAnalysis::empty()
26}
27
28#[cfg(feature = "tree-sitter")]
29fn analyze_with_tree_sitter(content: &str, ext: &str) -> Option<DeepAnalysis> {
30    let language = get_language(ext)?;
31
32    thread_local! {
33        static PARSER: std::cell::RefCell<Parser> = std::cell::RefCell::new(Parser::new());
34    }
35
36    let tree = PARSER.with(|p| {
37        let mut parser = p.borrow_mut();
38        let _ = parser.set_language(&language);
39        parser.parse(content.as_bytes(), None)
40    })?;
41    let root = tree.root_node();
42
43    let imports = imports::extract_imports(root, content, ext);
44    let calls = calls::extract_calls(root, content, ext);
45    let types = type_defs::extract_types(root, content, ext);
46    let exports = type_defs::extract_exports(root, content, ext);
47
48    Some(DeepAnalysis {
49        imports,
50        calls,
51        types,
52        exports,
53    })
54}
55
56#[cfg(feature = "tree-sitter")]
57fn get_language(ext: &str) -> Option<Language> {
58    match ext {
59        "rs" => Some(tree_sitter_rust::LANGUAGE.into()),
60        "ts" | "tsx" => Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()),
61        "js" | "jsx" => Some(tree_sitter_javascript::LANGUAGE.into()),
62        "py" => Some(tree_sitter_python::LANGUAGE.into()),
63        "go" => Some(tree_sitter_go::LANGUAGE.into()),
64        "java" => Some(tree_sitter_java::LANGUAGE.into()),
65        "c" | "h" => Some(tree_sitter_c::LANGUAGE.into()),
66        "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "hh" => Some(tree_sitter_cpp::LANGUAGE.into()),
67        "rb" => Some(tree_sitter_ruby::LANGUAGE.into()),
68        "cs" => Some(tree_sitter_c_sharp::LANGUAGE.into()),
69        "kt" | "kts" => Some(tree_sitter_kotlin_ng::LANGUAGE.into()),
70        "swift" => Some(tree_sitter_swift::LANGUAGE.into()),
71        "php" => Some(tree_sitter_php::LANGUAGE_PHP.into()),
72        "sh" | "bash" => Some(tree_sitter_bash::LANGUAGE.into()),
73        "dart" => Some(tree_sitter_dart::LANGUAGE.into()),
74        "scala" | "sc" => Some(tree_sitter_scala::LANGUAGE.into()),
75        "ex" | "exs" => Some(tree_sitter_elixir::LANGUAGE.into()),
76        "zig" => Some(tree_sitter_zig::LANGUAGE.into()),
77        _ => None,
78    }
79}
80
81// ---------------------------------------------------------------------------
82// Shared helpers (accessible by child modules via `super::`)
83// ---------------------------------------------------------------------------
84
85#[cfg(feature = "tree-sitter")]
86fn node_text<'a>(node: Node, src: &'a str) -> &'a str {
87    &src[node.byte_range()]
88}
89
90#[cfg(feature = "tree-sitter")]
91fn find_child_by_kind<'a>(node: Node<'a>, kind: &str) -> Option<Node<'a>> {
92    let mut cursor = node.walk();
93    let result = node.children(&mut cursor).find(|c| c.kind() == kind);
94    result
95}
96
97#[cfg(feature = "tree-sitter")]
98fn find_descendant_by_kind<'a>(node: Node<'a>, kind: &str) -> Option<Node<'a>> {
99    if node.kind() == kind {
100        return Some(node);
101    }
102    let mut cursor = node.walk();
103    for child in node.children(&mut cursor) {
104        if let Some(found) = find_descendant_by_kind(child, kind) {
105            return Some(found);
106        }
107    }
108    None
109}
110
111// ---------------------------------------------------------------------------
112// Tests
113// ---------------------------------------------------------------------------
114
115#[cfg(test)]
116#[cfg(feature = "tree-sitter")]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn ts_named_import() {
122        let src = r"import { useState, useEffect } from 'react';";
123        let analysis = analyze(src, "ts");
124        assert_eq!(analysis.imports.len(), 1);
125        assert_eq!(analysis.imports[0].source, "react");
126        assert_eq!(analysis.imports[0].names, vec!["useState", "useEffect"]);
127    }
128
129    #[test]
130    fn ts_default_import() {
131        let src = r"import React from 'react';";
132        let analysis = analyze(src, "ts");
133        assert_eq!(analysis.imports.len(), 1);
134        assert_eq!(analysis.imports[0].kind, ImportKind::Default);
135        assert_eq!(analysis.imports[0].names, vec!["React"]);
136    }
137
138    #[test]
139    fn ts_star_import() {
140        let src = r"import * as path from 'path';";
141        let analysis = analyze(src, "ts");
142        assert_eq!(analysis.imports.len(), 1);
143        assert_eq!(analysis.imports[0].kind, ImportKind::Star);
144    }
145
146    #[test]
147    fn ts_side_effect_import() {
148        let src = r"import './styles.css';";
149        let analysis = analyze(src, "ts");
150        assert_eq!(analysis.imports.len(), 1);
151        assert_eq!(analysis.imports[0].kind, ImportKind::SideEffect);
152        assert_eq!(analysis.imports[0].source, "./styles.css");
153    }
154
155    #[test]
156    fn ts_type_only_import() {
157        let src = r"import type { User } from './types';";
158        let analysis = analyze(src, "ts");
159        assert_eq!(analysis.imports.len(), 1);
160        assert!(analysis.imports[0].is_type_only);
161    }
162
163    #[test]
164    fn ts_reexport() {
165        let src = r"export { foo, bar } from './utils';";
166        let analysis = analyze(src, "ts");
167        assert_eq!(analysis.imports.len(), 1);
168        assert_eq!(analysis.imports[0].kind, ImportKind::Reexport);
169    }
170
171    #[test]
172    fn ts_call_sites() {
173        let src = r"
174const x = foo(1);
175const y = obj.method(2);
176";
177        let analysis = analyze(src, "ts");
178        assert!(analysis.calls.len() >= 2);
179        let fns: Vec<&str> = analysis.calls.iter().map(|c| c.callee.as_str()).collect();
180        assert!(fns.contains(&"foo"));
181        assert!(fns.contains(&"method"));
182    }
183
184    #[test]
185    fn ts_interface() {
186        let src = r"
187export interface User {
188    name: string;
189    age: number;
190}
191";
192        let analysis = analyze(src, "ts");
193        assert_eq!(analysis.types.len(), 1);
194        assert_eq!(analysis.types[0].name, "User");
195        assert_eq!(analysis.types[0].kind, TypeDefKind::Interface);
196    }
197
198    #[test]
199    fn ts_type_alias_union() {
200        let src = r"type Result = Success | Error;";
201        let analysis = analyze(src, "ts");
202        assert_eq!(analysis.types.len(), 1);
203        assert_eq!(analysis.types[0].kind, TypeDefKind::Union);
204    }
205
206    #[test]
207    fn rust_use_statements() {
208        let src = r"
209use crate::core::session;
210use anyhow::Result;
211use std::collections::HashMap;
212";
213        let analysis = analyze(src, "rs");
214        assert_eq!(analysis.imports.len(), 2);
215        let sources: Vec<&str> = analysis.imports.iter().map(|i| i.source.as_str()).collect();
216        assert!(sources.contains(&"crate::core::session"));
217        assert!(sources.contains(&"anyhow::Result"));
218    }
219
220    #[test]
221    fn rust_pub_use_reexport() {
222        let src = r"pub use crate::tools::ctx_read;";
223        let analysis = analyze(src, "rs");
224        assert_eq!(analysis.imports.len(), 1);
225        assert_eq!(analysis.imports[0].kind, ImportKind::Reexport);
226    }
227
228    #[test]
229    fn rust_struct_and_trait() {
230        let src = r"
231pub struct Config {
232    pub name: String,
233}
234
235pub trait Service {
236    fn run(&self);
237}
238";
239        let analysis = analyze(src, "rs");
240        assert_eq!(analysis.types.len(), 2);
241        let names: Vec<&str> = analysis.types.iter().map(|t| t.name.as_str()).collect();
242        assert!(names.contains(&"Config"));
243        assert!(names.contains(&"Service"));
244    }
245
246    #[test]
247    fn rust_call_sites() {
248        let src = r"
249fn main() {
250    let x = calculate(42);
251    let y = self.process();
252    Vec::new();
253}
254";
255        let analysis = analyze(src, "rs");
256        assert!(analysis.calls.len() >= 2);
257        let fns: Vec<&str> = analysis.calls.iter().map(|c| c.callee.as_str()).collect();
258        assert!(fns.contains(&"calculate"));
259    }
260
261    #[test]
262    fn python_imports() {
263        let src = r"
264import os
265from pathlib import Path
266from . import utils
267from ..models import User, Role
268";
269        let analysis = analyze(src, "py");
270        assert!(analysis.imports.len() >= 3);
271    }
272
273    #[test]
274    fn python_class_protocol() {
275        let src = r"
276class MyProtocol(Protocol):
277    def method(self) -> None: ...
278
279class User:
280    name: str
281";
282        let analysis = analyze(src, "py");
283        assert_eq!(analysis.types.len(), 2);
284        assert_eq!(analysis.types[0].kind, TypeDefKind::Protocol);
285        assert_eq!(analysis.types[1].kind, TypeDefKind::Class);
286    }
287
288    #[test]
289    fn go_imports() {
290        let src = r#"
291package main
292
293import (
294    "fmt"
295    "net/http"
296    _ "github.com/lib/pq"
297)
298"#;
299        let analysis = analyze(src, "go");
300        assert!(analysis.imports.len() >= 3);
301        let side_effect = analysis.imports.iter().find(|i| i.source.contains("pq"));
302        assert!(side_effect.is_some());
303        assert_eq!(side_effect.unwrap().kind, ImportKind::SideEffect);
304    }
305
306    #[test]
307    fn go_struct_and_interface() {
308        let src = r"
309package main
310
311type Server struct {
312    Port int
313}
314
315type Handler interface {
316    Handle(r *Request)
317}
318";
319        let analysis = analyze(src, "go");
320        assert_eq!(analysis.types.len(), 2);
321        let kinds: Vec<&TypeDefKind> = analysis.types.iter().map(|t| &t.kind).collect();
322        assert!(kinds.contains(&&TypeDefKind::Struct));
323        assert!(kinds.contains(&&TypeDefKind::Interface));
324    }
325
326    #[test]
327    fn java_imports() {
328        let src = r"
329import java.util.List;
330import java.util.Map;
331import static org.junit.Assert.*;
332";
333        let analysis = analyze(src, "java");
334        assert!(analysis.imports.len() >= 2);
335    }
336
337    #[test]
338    fn java_class_and_interface() {
339        let src = r"
340public class UserService {
341    public void save(User u) {}
342}
343
344public interface Repository<T> {
345    T findById(int id);
346}
347
348public enum Status { ACTIVE, INACTIVE }
349
350public record Point(int x, int y) {}
351";
352        let analysis = analyze(src, "java");
353        assert!(analysis.types.len() >= 3);
354        let kinds: Vec<&TypeDefKind> = analysis.types.iter().map(|t| &t.kind).collect();
355        assert!(kinds.contains(&&TypeDefKind::Class));
356        assert!(kinds.contains(&&TypeDefKind::Interface));
357        assert!(kinds.contains(&&TypeDefKind::Enum));
358    }
359
360    #[test]
361    fn kotlin_imports_and_aliases() {
362        let src = r"
363package com.example.app
364
365import com.example.services.UserService
366import com.example.factories.WidgetFactory as Factory
367import com.example.shared.*
368";
369        let analysis = analyze(src, "kt");
370        assert_eq!(analysis.imports.len(), 3);
371        assert_eq!(
372            analysis.imports[0].source,
373            "com.example.services.UserService"
374        );
375        assert_eq!(analysis.imports[1].names, vec!["Factory"]);
376        assert_eq!(analysis.imports[2].kind, ImportKind::Star);
377    }
378
379    #[test]
380    fn kotlin_call_sites() {
381        let src = r"
382class UserService {
383    fun run() {
384        prepare()
385        repository.save(user)
386        Factory.create()
387    }
388}
389";
390        let analysis = analyze(src, "kt");
391        let callees: Vec<&str> = analysis.calls.iter().map(|c| c.callee.as_str()).collect();
392        assert!(callees.contains(&"prepare"));
393        assert!(callees.contains(&"save"));
394        assert!(callees.contains(&"create"));
395    }
396
397    #[test]
398    fn kotlin_types_and_visibility() {
399        let src = r"
400sealed interface Handler
401data class User(val id: String)
402enum class Status { ACTIVE, INACTIVE }
403object Registry
404private typealias UserId = String
405";
406        let analysis = analyze(src, "kt");
407        let names: Vec<&str> = analysis.types.iter().map(|t| t.name.as_str()).collect();
408        assert!(names.contains(&"Handler"));
409        assert!(names.contains(&"User"));
410        assert!(names.contains(&"Status"));
411        assert!(names.contains(&"Registry"));
412        assert!(names.contains(&"UserId"));
413        let handler = analysis.types.iter().find(|t| t.name == "Handler").unwrap();
414        assert_eq!(handler.kind, TypeDefKind::Interface);
415        let alias = analysis.types.iter().find(|t| t.name == "UserId").unwrap();
416        assert!(!alias.is_exported);
417    }
418
419    #[test]
420    fn ts_generics_extracted() {
421        let src = r"interface Result<T, E> { ok: T; err: E; }";
422        let analysis = analyze(src, "ts");
423        assert_eq!(analysis.types.len(), 1);
424        assert!(!analysis.types[0].generics.is_empty());
425    }
426
427    #[test]
428    fn mixed_analysis_ts() {
429        let src = r"
430import { Request, Response } from 'express';
431import type { User } from './models';
432
433export interface Handler {
434    handle(req: Request): Response;
435}
436
437export class Router {
438    register(path: string, handler: Handler) {
439        this.handlers.set(path, handler);
440    }
441}
442
443const app = express();
444app.listen(3000);
445";
446        let analysis = analyze(src, "ts");
447        assert!(analysis.imports.len() >= 2, "Should find imports");
448        assert!(!analysis.types.is_empty(), "Should find types");
449        assert!(!analysis.calls.is_empty(), "Should find calls");
450    }
451
452    #[test]
453    fn empty_file() {
454        let analysis = analyze("", "ts");
455        assert!(analysis.imports.is_empty());
456        assert!(analysis.calls.is_empty());
457        assert!(analysis.types.is_empty());
458    }
459
460    #[test]
461    fn unsupported_extension() {
462        let analysis = analyze("some content", "txt");
463        assert!(analysis.imports.is_empty());
464    }
465
466    #[test]
467    fn c_include_import() {
468        let src = r#"
469#include "foo/bar.h"
470#include <stdio.h>
471"#;
472        let analysis = analyze(src, "c");
473        assert!(analysis.imports.iter().any(|i| i.source == "foo/bar.h"));
474    }
475
476    #[test]
477    fn bash_source_import() {
478        let src = r#"
479source "./scripts/env.sh"
480. ../common.sh
481"#;
482        let analysis = analyze(src, "sh");
483        assert!(
484            analysis
485                .imports
486                .iter()
487                .any(|i| i.source.contains("scripts/env.sh")),
488            "expected source import"
489        );
490    }
491
492    #[test]
493    fn zig_at_import() {
494        let src = r#"
495const m = @import("lib/math.zig");
496const std = @import("std");
497"#;
498        let analysis = analyze(src, "zig");
499        assert!(analysis.imports.iter().any(|i| i.source == "lib/math.zig"));
500    }
501}