Skip to main content

lean_ctx/core/deep_queries/
mod.rs

1//! Tree-sitter deep queries for extracting imports, call sites, and type definitions.
2//!
3//! Replaces regex-based extraction in `deps.rs` with precise AST parsing.
4//! Supports: TypeScript/JavaScript, Python, Rust, Go, Java.
5
6mod calls;
7mod imports;
8mod type_defs;
9mod types;
10
11pub use types::*;
12
13#[cfg(feature = "tree-sitter")]
14use tree_sitter::{Language, Node, Parser};
15
16pub fn analyze(content: &str, ext: &str) -> DeepAnalysis {
17    #[cfg(feature = "tree-sitter")]
18    {
19        if let Some(result) = analyze_with_tree_sitter(content, ext) {
20            return result;
21        }
22    }
23
24    let _ = (content, ext);
25    DeepAnalysis::empty()
26}
27
28#[cfg(feature = "tree-sitter")]
29fn analyze_with_tree_sitter(content: &str, ext: &str) -> Option<DeepAnalysis> {
30    let language = get_language(ext)?;
31    let mut parser = Parser::new();
32    parser.set_language(&language).ok()?;
33    let tree = parser.parse(content.as_bytes(), None)?;
34    let root = tree.root_node();
35
36    let imports = imports::extract_imports(root, content, ext);
37    let calls = calls::extract_calls(root, content, ext);
38    let types = type_defs::extract_types(root, content, ext);
39    let exports = type_defs::extract_exports(root, content, ext);
40
41    Some(DeepAnalysis {
42        imports,
43        calls,
44        types,
45        exports,
46    })
47}
48
49#[cfg(feature = "tree-sitter")]
50fn get_language(ext: &str) -> Option<Language> {
51    match ext {
52        "rs" => Some(tree_sitter_rust::LANGUAGE.into()),
53        "ts" | "tsx" => Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()),
54        "js" | "jsx" => Some(tree_sitter_javascript::LANGUAGE.into()),
55        "py" => Some(tree_sitter_python::LANGUAGE.into()),
56        "go" => Some(tree_sitter_go::LANGUAGE.into()),
57        "java" => Some(tree_sitter_java::LANGUAGE.into()),
58        "c" | "h" => Some(tree_sitter_c::LANGUAGE.into()),
59        "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "hh" => Some(tree_sitter_cpp::LANGUAGE.into()),
60        "rb" => Some(tree_sitter_ruby::LANGUAGE.into()),
61        "cs" => Some(tree_sitter_c_sharp::LANGUAGE.into()),
62        "kt" | "kts" => Some(tree_sitter_kotlin_ng::LANGUAGE.into()),
63        "swift" => Some(tree_sitter_swift::LANGUAGE.into()),
64        "php" => Some(tree_sitter_php::LANGUAGE_PHP.into()),
65        "sh" | "bash" => Some(tree_sitter_bash::LANGUAGE.into()),
66        "dart" => Some(tree_sitter_dart::LANGUAGE.into()),
67        "scala" | "sc" => Some(tree_sitter_scala::LANGUAGE.into()),
68        "ex" | "exs" => Some(tree_sitter_elixir::LANGUAGE.into()),
69        "zig" => Some(tree_sitter_zig::LANGUAGE.into()),
70        _ => None,
71    }
72}
73
74// ---------------------------------------------------------------------------
75// Shared helpers (accessible by child modules via `super::`)
76// ---------------------------------------------------------------------------
77
78#[cfg(feature = "tree-sitter")]
79fn node_text<'a>(node: Node, src: &'a str) -> &'a str {
80    &src[node.byte_range()]
81}
82
83#[cfg(feature = "tree-sitter")]
84fn find_child_by_kind<'a>(node: Node<'a>, kind: &str) -> Option<Node<'a>> {
85    let mut cursor = node.walk();
86    let result = node.children(&mut cursor).find(|c| c.kind() == kind);
87    result
88}
89
90#[cfg(feature = "tree-sitter")]
91fn find_descendant_by_kind<'a>(node: Node<'a>, kind: &str) -> Option<Node<'a>> {
92    if node.kind() == kind {
93        return Some(node);
94    }
95    let mut cursor = node.walk();
96    for child in node.children(&mut cursor) {
97        if let Some(found) = find_descendant_by_kind(child, kind) {
98            return Some(found);
99        }
100    }
101    None
102}
103
104// ---------------------------------------------------------------------------
105// Tests
106// ---------------------------------------------------------------------------
107
108#[cfg(test)]
109#[cfg(feature = "tree-sitter")]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn ts_named_import() {
115        let src = r"import { useState, useEffect } from 'react';";
116        let analysis = analyze(src, "ts");
117        assert_eq!(analysis.imports.len(), 1);
118        assert_eq!(analysis.imports[0].source, "react");
119        assert_eq!(analysis.imports[0].names, vec!["useState", "useEffect"]);
120    }
121
122    #[test]
123    fn ts_default_import() {
124        let src = r"import React from 'react';";
125        let analysis = analyze(src, "ts");
126        assert_eq!(analysis.imports.len(), 1);
127        assert_eq!(analysis.imports[0].kind, ImportKind::Default);
128        assert_eq!(analysis.imports[0].names, vec!["React"]);
129    }
130
131    #[test]
132    fn ts_star_import() {
133        let src = r"import * as path from 'path';";
134        let analysis = analyze(src, "ts");
135        assert_eq!(analysis.imports.len(), 1);
136        assert_eq!(analysis.imports[0].kind, ImportKind::Star);
137    }
138
139    #[test]
140    fn ts_side_effect_import() {
141        let src = r"import './styles.css';";
142        let analysis = analyze(src, "ts");
143        assert_eq!(analysis.imports.len(), 1);
144        assert_eq!(analysis.imports[0].kind, ImportKind::SideEffect);
145        assert_eq!(analysis.imports[0].source, "./styles.css");
146    }
147
148    #[test]
149    fn ts_type_only_import() {
150        let src = r"import type { User } from './types';";
151        let analysis = analyze(src, "ts");
152        assert_eq!(analysis.imports.len(), 1);
153        assert!(analysis.imports[0].is_type_only);
154    }
155
156    #[test]
157    fn ts_reexport() {
158        let src = r"export { foo, bar } from './utils';";
159        let analysis = analyze(src, "ts");
160        assert_eq!(analysis.imports.len(), 1);
161        assert_eq!(analysis.imports[0].kind, ImportKind::Reexport);
162    }
163
164    #[test]
165    fn ts_call_sites() {
166        let src = r"
167const x = foo(1);
168const y = obj.method(2);
169";
170        let analysis = analyze(src, "ts");
171        assert!(analysis.calls.len() >= 2);
172        let fns: Vec<&str> = analysis.calls.iter().map(|c| c.callee.as_str()).collect();
173        assert!(fns.contains(&"foo"));
174        assert!(fns.contains(&"method"));
175    }
176
177    #[test]
178    fn ts_interface() {
179        let src = r"
180export interface User {
181    name: string;
182    age: number;
183}
184";
185        let analysis = analyze(src, "ts");
186        assert_eq!(analysis.types.len(), 1);
187        assert_eq!(analysis.types[0].name, "User");
188        assert_eq!(analysis.types[0].kind, TypeDefKind::Interface);
189    }
190
191    #[test]
192    fn ts_type_alias_union() {
193        let src = r"type Result = Success | Error;";
194        let analysis = analyze(src, "ts");
195        assert_eq!(analysis.types.len(), 1);
196        assert_eq!(analysis.types[0].kind, TypeDefKind::Union);
197    }
198
199    #[test]
200    fn rust_use_statements() {
201        let src = r"
202use crate::core::session;
203use anyhow::Result;
204use std::collections::HashMap;
205";
206        let analysis = analyze(src, "rs");
207        assert_eq!(analysis.imports.len(), 2);
208        let sources: Vec<&str> = analysis.imports.iter().map(|i| i.source.as_str()).collect();
209        assert!(sources.contains(&"crate::core::session"));
210        assert!(sources.contains(&"anyhow::Result"));
211    }
212
213    #[test]
214    fn rust_pub_use_reexport() {
215        let src = r"pub use crate::tools::ctx_read;";
216        let analysis = analyze(src, "rs");
217        assert_eq!(analysis.imports.len(), 1);
218        assert_eq!(analysis.imports[0].kind, ImportKind::Reexport);
219    }
220
221    #[test]
222    fn rust_struct_and_trait() {
223        let src = r"
224pub struct Config {
225    pub name: String,
226}
227
228pub trait Service {
229    fn run(&self);
230}
231";
232        let analysis = analyze(src, "rs");
233        assert_eq!(analysis.types.len(), 2);
234        let names: Vec<&str> = analysis.types.iter().map(|t| t.name.as_str()).collect();
235        assert!(names.contains(&"Config"));
236        assert!(names.contains(&"Service"));
237    }
238
239    #[test]
240    fn rust_call_sites() {
241        let src = r"
242fn main() {
243    let x = calculate(42);
244    let y = self.process();
245    Vec::new();
246}
247";
248        let analysis = analyze(src, "rs");
249        assert!(analysis.calls.len() >= 2);
250        let fns: Vec<&str> = analysis.calls.iter().map(|c| c.callee.as_str()).collect();
251        assert!(fns.contains(&"calculate"));
252    }
253
254    #[test]
255    fn python_imports() {
256        let src = r"
257import os
258from pathlib import Path
259from . import utils
260from ..models import User, Role
261";
262        let analysis = analyze(src, "py");
263        assert!(analysis.imports.len() >= 3);
264    }
265
266    #[test]
267    fn python_class_protocol() {
268        let src = r"
269class MyProtocol(Protocol):
270    def method(self) -> None: ...
271
272class User:
273    name: str
274";
275        let analysis = analyze(src, "py");
276        assert_eq!(analysis.types.len(), 2);
277        assert_eq!(analysis.types[0].kind, TypeDefKind::Protocol);
278        assert_eq!(analysis.types[1].kind, TypeDefKind::Class);
279    }
280
281    #[test]
282    fn go_imports() {
283        let src = r#"
284package main
285
286import (
287    "fmt"
288    "net/http"
289    _ "github.com/lib/pq"
290)
291"#;
292        let analysis = analyze(src, "go");
293        assert!(analysis.imports.len() >= 3);
294        let side_effect = analysis.imports.iter().find(|i| i.source.contains("pq"));
295        assert!(side_effect.is_some());
296        assert_eq!(side_effect.unwrap().kind, ImportKind::SideEffect);
297    }
298
299    #[test]
300    fn go_struct_and_interface() {
301        let src = r"
302package main
303
304type Server struct {
305    Port int
306}
307
308type Handler interface {
309    Handle(r *Request)
310}
311";
312        let analysis = analyze(src, "go");
313        assert_eq!(analysis.types.len(), 2);
314        let kinds: Vec<&TypeDefKind> = analysis.types.iter().map(|t| &t.kind).collect();
315        assert!(kinds.contains(&&TypeDefKind::Struct));
316        assert!(kinds.contains(&&TypeDefKind::Interface));
317    }
318
319    #[test]
320    fn java_imports() {
321        let src = r"
322import java.util.List;
323import java.util.Map;
324import static org.junit.Assert.*;
325";
326        let analysis = analyze(src, "java");
327        assert!(analysis.imports.len() >= 2);
328    }
329
330    #[test]
331    fn java_class_and_interface() {
332        let src = r"
333public class UserService {
334    public void save(User u) {}
335}
336
337public interface Repository<T> {
338    T findById(int id);
339}
340
341public enum Status { ACTIVE, INACTIVE }
342
343public record Point(int x, int y) {}
344";
345        let analysis = analyze(src, "java");
346        assert!(analysis.types.len() >= 3);
347        let kinds: Vec<&TypeDefKind> = analysis.types.iter().map(|t| &t.kind).collect();
348        assert!(kinds.contains(&&TypeDefKind::Class));
349        assert!(kinds.contains(&&TypeDefKind::Interface));
350        assert!(kinds.contains(&&TypeDefKind::Enum));
351    }
352
353    #[test]
354    fn kotlin_imports_and_aliases() {
355        let src = r"
356package com.example.app
357
358import com.example.services.UserService
359import com.example.factories.WidgetFactory as Factory
360import com.example.shared.*
361";
362        let analysis = analyze(src, "kt");
363        assert_eq!(analysis.imports.len(), 3);
364        assert_eq!(
365            analysis.imports[0].source,
366            "com.example.services.UserService"
367        );
368        assert_eq!(analysis.imports[1].names, vec!["Factory"]);
369        assert_eq!(analysis.imports[2].kind, ImportKind::Star);
370    }
371
372    #[test]
373    fn kotlin_call_sites() {
374        let src = r"
375class UserService {
376    fun run() {
377        prepare()
378        repository.save(user)
379        Factory.create()
380    }
381}
382";
383        let analysis = analyze(src, "kt");
384        let callees: Vec<&str> = analysis.calls.iter().map(|c| c.callee.as_str()).collect();
385        assert!(callees.contains(&"prepare"));
386        assert!(callees.contains(&"save"));
387        assert!(callees.contains(&"create"));
388    }
389
390    #[test]
391    fn kotlin_types_and_visibility() {
392        let src = r"
393sealed interface Handler
394data class User(val id: String)
395enum class Status { ACTIVE, INACTIVE }
396object Registry
397private typealias UserId = String
398";
399        let analysis = analyze(src, "kt");
400        let names: Vec<&str> = analysis.types.iter().map(|t| t.name.as_str()).collect();
401        assert!(names.contains(&"Handler"));
402        assert!(names.contains(&"User"));
403        assert!(names.contains(&"Status"));
404        assert!(names.contains(&"Registry"));
405        assert!(names.contains(&"UserId"));
406        let handler = analysis.types.iter().find(|t| t.name == "Handler").unwrap();
407        assert_eq!(handler.kind, TypeDefKind::Interface);
408        let alias = analysis.types.iter().find(|t| t.name == "UserId").unwrap();
409        assert!(!alias.is_exported);
410    }
411
412    #[test]
413    fn ts_generics_extracted() {
414        let src = r"interface Result<T, E> { ok: T; err: E; }";
415        let analysis = analyze(src, "ts");
416        assert_eq!(analysis.types.len(), 1);
417        assert!(!analysis.types[0].generics.is_empty());
418    }
419
420    #[test]
421    fn mixed_analysis_ts() {
422        let src = r"
423import { Request, Response } from 'express';
424import type { User } from './models';
425
426export interface Handler {
427    handle(req: Request): Response;
428}
429
430export class Router {
431    register(path: string, handler: Handler) {
432        this.handlers.set(path, handler);
433    }
434}
435
436const app = express();
437app.listen(3000);
438";
439        let analysis = analyze(src, "ts");
440        assert!(analysis.imports.len() >= 2, "Should find imports");
441        assert!(!analysis.types.is_empty(), "Should find types");
442        assert!(!analysis.calls.is_empty(), "Should find calls");
443    }
444
445    #[test]
446    fn empty_file() {
447        let analysis = analyze("", "ts");
448        assert!(analysis.imports.is_empty());
449        assert!(analysis.calls.is_empty());
450        assert!(analysis.types.is_empty());
451    }
452
453    #[test]
454    fn unsupported_extension() {
455        let analysis = analyze("some content", "txt");
456        assert!(analysis.imports.is_empty());
457    }
458
459    #[test]
460    fn c_include_import() {
461        let src = r#"
462#include "foo/bar.h"
463#include <stdio.h>
464"#;
465        let analysis = analyze(src, "c");
466        assert!(analysis.imports.iter().any(|i| i.source == "foo/bar.h"));
467    }
468
469    #[test]
470    fn bash_source_import() {
471        let src = r#"
472source "./scripts/env.sh"
473. ../common.sh
474"#;
475        let analysis = analyze(src, "sh");
476        assert!(
477            analysis
478                .imports
479                .iter()
480                .any(|i| i.source.contains("scripts/env.sh")),
481            "expected source import"
482        );
483    }
484
485    #[test]
486    fn zig_at_import() {
487        let src = r#"
488const m = @import("lib/math.zig");
489const std = @import("std");
490"#;
491        let analysis = analyze(src, "zig");
492        assert!(analysis.imports.iter().any(|i| i.source == "lib/math.zig"));
493    }
494}