Skip to main content

lean_ctx/core/
chunks_ts.rs

1//! Tree-sitter AST-aware code chunking for semantic search.
2//!
3//! Replaces heuristic line-prefix matching with proper AST parsing.
4//! Extracts function bodies, struct definitions, class declarations etc.
5//! as complete, self-contained chunks with accurate boundaries.
6//!
7//! Falls back to heuristic chunking for unsupported languages.
8
9#[cfg(feature = "tree-sitter")]
10use tree_sitter::{Language, Node, Parser, Query, QueryCursor, StreamingIterator};
11
12use super::bm25_index::{ChunkKind, CodeChunk};
13
14#[cfg(feature = "tree-sitter")]
15const CHUNK_QUERY_RUST: &str = r"
16(function_item name: (identifier) @name) @chunk
17(struct_item name: (type_identifier) @name) @chunk
18(enum_item name: (type_identifier) @name) @chunk
19(trait_item name: (type_identifier) @name) @chunk
20(impl_item type: (type_identifier) @name) @chunk
21(const_item name: (identifier) @name) @chunk
22";
23
24#[cfg(feature = "tree-sitter")]
25const CHUNK_QUERY_TYPESCRIPT: &str = r"
26(function_declaration name: (identifier) @name) @chunk
27(class_declaration name: (type_identifier) @name) @chunk
28(abstract_class_declaration name: (type_identifier) @name) @chunk
29(interface_declaration name: (type_identifier) @name) @chunk
30(type_alias_declaration name: (type_identifier) @name) @chunk
31(method_definition name: (property_identifier) @name) @chunk
32(variable_declarator name: (identifier) @name value: (arrow_function)) @chunk
33";
34
35#[cfg(feature = "tree-sitter")]
36const CHUNK_QUERY_JAVASCRIPT: &str = r"
37(function_declaration name: (identifier) @name) @chunk
38(class_declaration name: (identifier) @name) @chunk
39(method_definition name: (property_identifier) @name) @chunk
40(variable_declarator name: (identifier) @name value: (arrow_function)) @chunk
41";
42
43#[cfg(feature = "tree-sitter")]
44const CHUNK_QUERY_PYTHON: &str = r"
45(function_definition name: (identifier) @name) @chunk
46(class_definition name: (identifier) @name) @chunk
47";
48
49#[cfg(feature = "tree-sitter")]
50const CHUNK_QUERY_GO: &str = r"
51(function_declaration name: (identifier) @name) @chunk
52(method_declaration name: (field_identifier) @name) @chunk
53(type_spec name: (type_identifier) @name) @chunk
54";
55
56#[cfg(feature = "tree-sitter")]
57const CHUNK_QUERY_JAVA: &str = r"
58(method_declaration name: (identifier) @name) @chunk
59(class_declaration name: (identifier) @name) @chunk
60(interface_declaration name: (identifier) @name) @chunk
61(enum_declaration name: (identifier) @name) @chunk
62(constructor_declaration name: (identifier) @name) @chunk
63";
64
65#[cfg(feature = "tree-sitter")]
66const CHUNK_QUERY_C: &str = r"
67(function_definition
68  declarator: (function_declarator
69    declarator: (identifier) @name)) @chunk
70(struct_specifier name: (type_identifier) @name) @chunk
71(enum_specifier name: (type_identifier) @name) @chunk
72";
73
74#[cfg(feature = "tree-sitter")]
75const CHUNK_QUERY_CPP: &str = r"
76(function_definition
77  declarator: (function_declarator
78    declarator: (_) @name)) @chunk
79(struct_specifier name: (type_identifier) @name) @chunk
80(class_specifier name: (type_identifier) @name) @chunk
81(enum_specifier name: (type_identifier) @name) @chunk
82(namespace_definition name: (identifier) @name) @chunk
83";
84
85/// Extract code chunks from a file using tree-sitter AST parsing.
86///
87/// Returns `None` if the language is unsupported, allowing callers to fall back
88/// to heuristic-based chunking.
89#[cfg(feature = "tree-sitter")]
90fn get_cached_query(file_ext: &str) -> Option<&'static Query> {
91    use std::collections::HashMap;
92    use std::sync::OnceLock;
93
94    static QUERY_CACHE: OnceLock<HashMap<&'static str, Query>> = OnceLock::new();
95
96    let cache = QUERY_CACHE.get_or_init(|| {
97        let mut map = HashMap::new();
98        let exts: &[&str] = &[
99            "rs", "ts", "tsx", "js", "jsx", "py", "go", "java", "c", "h", "cpp", "cc", "cxx", "hpp",
100        ];
101        for &ext in exts {
102            if let (Some(lang), Some(src)) = (get_language(ext), get_chunk_query(ext)) {
103                if let Ok(q) = Query::new(&lang, src) {
104                    map.insert(ext, q);
105                }
106            }
107        }
108        map
109    });
110
111    cache.get(file_ext)
112}
113
114/// Visit each structural chunk root (`@chunk` capture) once per tree-sitter query match.
115///
116/// `start_line` / `end_line` are **1-based** inclusive line numbers (matching [`CodeChunk`]).
117/// Returns `None` if the extension is unsupported or parsing fails.
118#[cfg(feature = "tree-sitter")]
119pub(crate) fn for_each_chunk_node(
120    content: &str,
121    file_ext: &str,
122    mut visitor: impl FnMut(Node, &str, ChunkKind, usize, usize),
123) -> Option<()> {
124    let language = get_language(file_ext)?;
125
126    thread_local! {
127        static PARSER: std::cell::RefCell<Parser> = std::cell::RefCell::new(Parser::new());
128    }
129
130    let tree = PARSER.with(|p| {
131        let mut parser = p.borrow_mut();
132        let _ = parser.set_language(&language);
133        parser.parse(content, None)
134    })?;
135
136    let query = get_cached_query(file_ext)?;
137    let chunk_idx = find_capture_index(query, "chunk")?;
138    let name_idx = find_capture_index(query, "name")?;
139
140    let source = content.as_bytes();
141    let mut cursor = QueryCursor::new();
142    let mut matches = cursor.matches(query, tree.root_node(), source);
143    let mut seen_ranges = Vec::new();
144
145    while let Some(m) = matches.next() {
146        let mut chunk_node: Option<Node> = None;
147        let mut name_text = String::new();
148
149        for cap in m.captures {
150            if cap.index == chunk_idx {
151                chunk_node = Some(cap.node);
152            } else if cap.index == name_idx {
153                if let Ok(text) = cap.node.utf8_text(source) {
154                    name_text = text.to_string();
155                }
156            }
157        }
158
159        if let Some(node) = chunk_node {
160            if name_text.is_empty() {
161                continue;
162            }
163
164            let start_row0 = node.start_position().row;
165            let end_row0 = node.end_position().row;
166
167            let range = (start_row0, end_row0);
168            if seen_ranges
169                .iter()
170                .any(|&(s, e)| s <= start_row0 && end_row0 <= e && range != (s, e))
171            {
172                continue;
173            }
174            seen_ranges.push(range);
175
176            let kind = node_kind_to_chunk_kind(node.kind());
177            visitor(node, name_text.as_str(), kind, start_row0 + 1, end_row0 + 1);
178        }
179    }
180
181    Some(())
182}
183
184#[cfg(feature = "tree-sitter")]
185pub fn extract_chunks_ts(file_path: &str, content: &str, file_ext: &str) -> Option<Vec<CodeChunk>> {
186    let lines: Vec<&str> = content.lines().collect();
187    let mut chunks = Vec::new();
188
189    for_each_chunk_node(
190        content,
191        file_ext,
192        |node, name_text, kind, start_line, end_line| {
193            let start_row0 = node.start_position().row;
194            let end_row0 = node.end_position().row;
195            let block: String = lines[start_row0..=end_row0.min(lines.len().saturating_sub(1))]
196                .to_vec()
197                .join("\n");
198            let token_count = super::bm25_index::tokenize_for_index(&block).len();
199
200            chunks.push(CodeChunk {
201                file_path: file_path.to_string(),
202                symbol_name: name_text.to_string(),
203                kind,
204                start_line,
205                end_line,
206                content: block,
207                tokens: Vec::new(),
208                token_count,
209            });
210        },
211    )?;
212
213    if chunks.is_empty() {
214        return None;
215    }
216
217    chunks.sort_by_key(|c| c.start_line);
218    Some(chunks)
219}
220
221#[cfg(not(feature = "tree-sitter"))]
222pub fn extract_chunks_ts(
223    _file_path: &str,
224    _content: &str,
225    _file_ext: &str,
226) -> Option<Vec<CodeChunk>> {
227    None
228}
229
230#[cfg(feature = "tree-sitter")]
231fn get_language(ext: &str) -> Option<Language> {
232    Some(match ext {
233        "rs" => tree_sitter_rust::LANGUAGE.into(),
234        "ts" => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
235        "tsx" => tree_sitter_typescript::LANGUAGE_TSX.into(),
236        "js" | "jsx" => tree_sitter_javascript::LANGUAGE.into(),
237        "py" => tree_sitter_python::LANGUAGE.into(),
238        "go" => tree_sitter_go::LANGUAGE.into(),
239        "java" => tree_sitter_java::LANGUAGE.into(),
240        "c" | "h" => tree_sitter_c::LANGUAGE.into(),
241        "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "hh" => tree_sitter_cpp::LANGUAGE.into(),
242        _ => return None,
243    })
244}
245
246#[cfg(feature = "tree-sitter")]
247fn get_chunk_query(ext: &str) -> Option<&'static str> {
248    Some(match ext {
249        "rs" => CHUNK_QUERY_RUST,
250        "ts" | "tsx" => CHUNK_QUERY_TYPESCRIPT,
251        "js" | "jsx" => CHUNK_QUERY_JAVASCRIPT,
252        "py" => CHUNK_QUERY_PYTHON,
253        "go" => CHUNK_QUERY_GO,
254        "java" => CHUNK_QUERY_JAVA,
255        "c" | "h" => CHUNK_QUERY_C,
256        "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "hh" => CHUNK_QUERY_CPP,
257        _ => return None,
258    })
259}
260
261#[cfg(feature = "tree-sitter")]
262fn find_capture_index(query: &Query, name: &str) -> Option<u32> {
263    query
264        .capture_names()
265        .iter()
266        .position(|n| *n == name)
267        .map(|i| i as u32)
268}
269
270fn node_kind_to_chunk_kind(kind: &str) -> ChunkKind {
271    match kind {
272        "function_item"
273        | "function_declaration"
274        | "function_definition"
275        | "method_declaration"
276        | "method_definition"
277        | "constructor_declaration"
278        | "variable_declarator" => ChunkKind::Function,
279
280        "struct_item"
281        | "struct_specifier"
282        | "struct_declaration"
283        | "enum_item"
284        | "enum_specifier"
285        | "enum_declaration"
286        | "trait_item"
287        | "interface_declaration"
288        | "type_alias_declaration"
289        | "type_spec" => ChunkKind::Struct,
290
291        "impl_item" => ChunkKind::Impl,
292
293        "class_declaration"
294        | "abstract_class_declaration"
295        | "class_specifier"
296        | "class_definition" => ChunkKind::Class,
297
298        "namespace_definition" | "namespace_declaration" => ChunkKind::Module,
299
300        _ => ChunkKind::Other,
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn extract_rust_chunks() {
310        let src = r#"use std::io;
311
312pub fn process(input: &str) -> String {
313    input.to_uppercase()
314}
315
316pub struct Config {
317    pub name: String,
318    pub port: u16,
319}
320
321impl Config {
322    pub fn new() -> Self {
323        Self { name: "default".into(), port: 8080 }
324    }
325}
326
327fn helper() -> bool {
328    true
329}
330"#;
331        let chunks = extract_chunks_ts("main.rs", src, "rs").unwrap();
332        assert!(
333            chunks.len() >= 4,
334            "expected >=4 chunks, got {}",
335            chunks.len()
336        );
337
338        let names: Vec<&str> = chunks.iter().map(|c| c.symbol_name.as_str()).collect();
339        assert!(names.contains(&"process"), "got {names:?}");
340        assert!(names.contains(&"Config"), "got {names:?}");
341        assert!(names.contains(&"helper"), "got {names:?}");
342
343        let process = chunks.iter().find(|c| c.symbol_name == "process").unwrap();
344        assert!(matches!(process.kind, ChunkKind::Function));
345        assert!(process.content.contains("to_uppercase"));
346    }
347
348    #[test]
349    fn extract_typescript_chunks() {
350        let src = r"
351export function greet(name: string): string {
352    return `Hello ${name}`;
353}
354
355export class UserService {
356    findUser(id: number): User {
357        return db.find(id);
358    }
359}
360
361const handler = async (req: Request): Promise<Response> => {
362    return new Response();
363};
364";
365        let chunks = extract_chunks_ts("app.ts", src, "ts").unwrap();
366        assert!(
367            chunks.len() >= 3,
368            "expected >=3 chunks, got {}",
369            chunks.len()
370        );
371
372        let names: Vec<&str> = chunks.iter().map(|c| c.symbol_name.as_str()).collect();
373        assert!(names.contains(&"greet"), "got {names:?}");
374        assert!(names.contains(&"UserService"), "got {names:?}");
375    }
376
377    #[test]
378    fn extract_python_chunks() {
379        let src = r"
380class AuthService:
381    def __init__(self, db):
382        self.db = db
383
384    def authenticate(self, email: str) -> bool:
385        user = self.db.find(email)
386        return user is not None
387
388def create_app():
389    return Flask(__name__)
390";
391        let chunks = extract_chunks_ts("app.py", src, "py").unwrap();
392        assert!(
393            chunks.len() >= 2,
394            "expected >=2 chunks, got {}",
395            chunks.len()
396        );
397
398        let names: Vec<&str> = chunks.iter().map(|c| c.symbol_name.as_str()).collect();
399        assert!(names.contains(&"AuthService"), "got {names:?}");
400        assert!(names.contains(&"create_app"), "got {names:?}");
401
402        let auth = chunks
403            .iter()
404            .find(|c| c.symbol_name == "AuthService")
405            .unwrap();
406        assert!(auth.content.contains("authenticate"));
407    }
408
409    #[test]
410    fn chunks_contain_full_body() {
411        let src = r#"
412pub fn complex(x: i32, y: i32) -> Result<String, Error> {
413    let sum = x + y;
414    let result = format!("Sum: {}", sum);
415    if sum > 100 {
416        return Err(Error::new("too large"));
417    }
418    Ok(result)
419}
420"#;
421        let chunks = extract_chunks_ts("math.rs", src, "rs").unwrap();
422        let complex = chunks.iter().find(|c| c.symbol_name == "complex").unwrap();
423        assert!(complex.content.contains("sum > 100"));
424        assert!(complex.content.contains("Ok(result)"));
425    }
426
427    #[test]
428    fn unsupported_language_returns_none() {
429        assert!(extract_chunks_ts("file.xyz", "content", "xyz").is_none());
430    }
431
432    #[test]
433    fn empty_file_returns_none() {
434        assert!(extract_chunks_ts("empty.rs", "", "rs").is_none());
435    }
436
437    #[test]
438    fn chunks_sorted_by_line() {
439        let src = r"
440fn b_func() {}
441fn a_func() {}
442";
443        let chunks = extract_chunks_ts("sort.rs", src, "rs").unwrap();
444        assert!(chunks[0].start_line <= chunks[1].start_line);
445    }
446}