Skip to main content

lean_ctx/core/
chunks_ts.rs

1//! Tree-sitter AST-aware code chunking for semantic search.
2//!
3//! Replaces heuristic line-prefix matching with proper AST parsing.
4//! Extracts function bodies, struct definitions, class declarations etc.
5//! as complete, self-contained chunks with accurate boundaries.
6//!
7//! Falls back to heuristic chunking for unsupported languages.
8
9#[cfg(feature = "tree-sitter")]
10use tree_sitter::{Language, Node, Parser, Query, QueryCursor, StreamingIterator};
11
12use super::vector_index::{ChunkKind, CodeChunk};
13
14#[cfg(feature = "tree-sitter")]
15const CHUNK_QUERY_RUST: &str = r"
16(function_item name: (identifier) @name) @chunk
17(struct_item name: (type_identifier) @name) @chunk
18(enum_item name: (type_identifier) @name) @chunk
19(trait_item name: (type_identifier) @name) @chunk
20(impl_item type: (type_identifier) @name) @chunk
21(const_item name: (identifier) @name) @chunk
22";
23
24#[cfg(feature = "tree-sitter")]
25const CHUNK_QUERY_TYPESCRIPT: &str = r"
26(function_declaration name: (identifier) @name) @chunk
27(class_declaration name: (type_identifier) @name) @chunk
28(abstract_class_declaration name: (type_identifier) @name) @chunk
29(interface_declaration name: (type_identifier) @name) @chunk
30(type_alias_declaration name: (type_identifier) @name) @chunk
31(method_definition name: (property_identifier) @name) @chunk
32(variable_declarator name: (identifier) @name value: (arrow_function)) @chunk
33";
34
35#[cfg(feature = "tree-sitter")]
36const CHUNK_QUERY_JAVASCRIPT: &str = r"
37(function_declaration name: (identifier) @name) @chunk
38(class_declaration name: (identifier) @name) @chunk
39(method_definition name: (property_identifier) @name) @chunk
40(variable_declarator name: (identifier) @name value: (arrow_function)) @chunk
41";
42
43#[cfg(feature = "tree-sitter")]
44const CHUNK_QUERY_PYTHON: &str = r"
45(function_definition name: (identifier) @name) @chunk
46(class_definition name: (identifier) @name) @chunk
47";
48
49#[cfg(feature = "tree-sitter")]
50const CHUNK_QUERY_GO: &str = r"
51(function_declaration name: (identifier) @name) @chunk
52(method_declaration name: (field_identifier) @name) @chunk
53(type_spec name: (type_identifier) @name) @chunk
54";
55
56#[cfg(feature = "tree-sitter")]
57const CHUNK_QUERY_JAVA: &str = r"
58(method_declaration name: (identifier) @name) @chunk
59(class_declaration name: (identifier) @name) @chunk
60(interface_declaration name: (identifier) @name) @chunk
61(enum_declaration name: (identifier) @name) @chunk
62(constructor_declaration name: (identifier) @name) @chunk
63";
64
65#[cfg(feature = "tree-sitter")]
66const CHUNK_QUERY_C: &str = r"
67(function_definition
68  declarator: (function_declarator
69    declarator: (identifier) @name)) @chunk
70(struct_specifier name: (type_identifier) @name) @chunk
71(enum_specifier name: (type_identifier) @name) @chunk
72";
73
74#[cfg(feature = "tree-sitter")]
75const CHUNK_QUERY_CPP: &str = r"
76(function_definition
77  declarator: (function_declarator
78    declarator: (_) @name)) @chunk
79(struct_specifier name: (type_identifier) @name) @chunk
80(class_specifier name: (type_identifier) @name) @chunk
81(enum_specifier name: (type_identifier) @name) @chunk
82(namespace_definition name: (identifier) @name) @chunk
83";
84
85/// Extract code chunks from a file using tree-sitter AST parsing.
86///
87/// Returns `None` if the language is unsupported, allowing callers to fall back
88/// to heuristic-based chunking.
89#[cfg(feature = "tree-sitter")]
90pub fn extract_chunks_ts(file_path: &str, content: &str, file_ext: &str) -> Option<Vec<CodeChunk>> {
91    let language = get_language(file_ext)?;
92    let query_src = get_chunk_query(file_ext)?;
93
94    thread_local! {
95        static PARSER: std::cell::RefCell<Parser> = std::cell::RefCell::new(Parser::new());
96    }
97
98    let tree = PARSER.with(|p| {
99        let mut parser = p.borrow_mut();
100        let _ = parser.set_language(&language);
101        parser.parse(content, None)
102    })?;
103
104    let query = Query::new(&language, query_src).ok()?;
105    let chunk_idx = find_capture_index(&query, "chunk")?;
106    let name_idx = find_capture_index(&query, "name")?;
107
108    let source = content.as_bytes();
109    let lines: Vec<&str> = content.lines().collect();
110    let mut chunks = Vec::new();
111    let mut cursor = QueryCursor::new();
112    let mut matches = cursor.matches(&query, tree.root_node(), source);
113    let mut seen_ranges = Vec::new();
114
115    while let Some(m) = matches.next() {
116        let mut chunk_node: Option<Node> = None;
117        let mut name_text = String::new();
118
119        for cap in m.captures {
120            if cap.index == chunk_idx {
121                chunk_node = Some(cap.node);
122            } else if cap.index == name_idx {
123                if let Ok(text) = cap.node.utf8_text(source) {
124                    name_text = text.to_string();
125                }
126            }
127        }
128
129        if let Some(node) = chunk_node {
130            if name_text.is_empty() {
131                continue;
132            }
133
134            let start_line = node.start_position().row;
135            let end_line = node.end_position().row;
136
137            let range = (start_line, end_line);
138            if seen_ranges
139                .iter()
140                .any(|&(s, e)| s <= start_line && end_line <= e && range != (s, e))
141            {
142                continue;
143            }
144            seen_ranges.push(range);
145
146            let block: String = lines[start_line..=end_line.min(lines.len() - 1)]
147                .to_vec()
148                .join("\n");
149
150            let kind = node_kind_to_chunk_kind(node.kind());
151            let tokens = super::vector_index::tokenize_for_index(&block);
152            let token_count = tokens.len();
153
154            chunks.push(CodeChunk {
155                file_path: file_path.to_string(),
156                symbol_name: name_text,
157                kind,
158                start_line: start_line + 1,
159                end_line: end_line + 1,
160                content: block,
161                tokens,
162                token_count,
163            });
164        }
165    }
166
167    if chunks.is_empty() {
168        return None;
169    }
170
171    chunks.sort_by_key(|c| c.start_line);
172    Some(chunks)
173}
174
175#[cfg(not(feature = "tree-sitter"))]
176pub fn extract_chunks_ts(
177    _file_path: &str,
178    _content: &str,
179    _file_ext: &str,
180) -> Option<Vec<CodeChunk>> {
181    None
182}
183
184#[cfg(feature = "tree-sitter")]
185fn get_language(ext: &str) -> Option<Language> {
186    Some(match ext {
187        "rs" => tree_sitter_rust::LANGUAGE.into(),
188        "ts" => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
189        "tsx" => tree_sitter_typescript::LANGUAGE_TSX.into(),
190        "js" | "jsx" => tree_sitter_javascript::LANGUAGE.into(),
191        "py" => tree_sitter_python::LANGUAGE.into(),
192        "go" => tree_sitter_go::LANGUAGE.into(),
193        "java" => tree_sitter_java::LANGUAGE.into(),
194        "c" | "h" => tree_sitter_c::LANGUAGE.into(),
195        "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "hh" => tree_sitter_cpp::LANGUAGE.into(),
196        _ => return None,
197    })
198}
199
200#[cfg(feature = "tree-sitter")]
201fn get_chunk_query(ext: &str) -> Option<&'static str> {
202    Some(match ext {
203        "rs" => CHUNK_QUERY_RUST,
204        "ts" | "tsx" => CHUNK_QUERY_TYPESCRIPT,
205        "js" | "jsx" => CHUNK_QUERY_JAVASCRIPT,
206        "py" => CHUNK_QUERY_PYTHON,
207        "go" => CHUNK_QUERY_GO,
208        "java" => CHUNK_QUERY_JAVA,
209        "c" | "h" => CHUNK_QUERY_C,
210        "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "hh" => CHUNK_QUERY_CPP,
211        _ => return None,
212    })
213}
214
215#[cfg(feature = "tree-sitter")]
216fn find_capture_index(query: &Query, name: &str) -> Option<u32> {
217    query
218        .capture_names()
219        .iter()
220        .position(|n| *n == name)
221        .map(|i| i as u32)
222}
223
224fn node_kind_to_chunk_kind(kind: &str) -> ChunkKind {
225    match kind {
226        "function_item"
227        | "function_declaration"
228        | "function_definition"
229        | "method_declaration"
230        | "method_definition"
231        | "constructor_declaration"
232        | "variable_declarator" => ChunkKind::Function,
233
234        "struct_item"
235        | "struct_specifier"
236        | "struct_declaration"
237        | "enum_item"
238        | "enum_specifier"
239        | "enum_declaration"
240        | "trait_item"
241        | "interface_declaration"
242        | "type_alias_declaration"
243        | "type_spec" => ChunkKind::Struct,
244
245        "impl_item" => ChunkKind::Impl,
246
247        "class_declaration"
248        | "abstract_class_declaration"
249        | "class_specifier"
250        | "class_definition" => ChunkKind::Class,
251
252        "namespace_definition" | "namespace_declaration" => ChunkKind::Module,
253
254        _ => ChunkKind::Other,
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn extract_rust_chunks() {
264        let src = r#"use std::io;
265
266pub fn process(input: &str) -> String {
267    input.to_uppercase()
268}
269
270pub struct Config {
271    pub name: String,
272    pub port: u16,
273}
274
275impl Config {
276    pub fn new() -> Self {
277        Self { name: "default".into(), port: 8080 }
278    }
279}
280
281fn helper() -> bool {
282    true
283}
284"#;
285        let chunks = extract_chunks_ts("main.rs", src, "rs").unwrap();
286        assert!(
287            chunks.len() >= 4,
288            "expected >=4 chunks, got {}",
289            chunks.len()
290        );
291
292        let names: Vec<&str> = chunks.iter().map(|c| c.symbol_name.as_str()).collect();
293        assert!(names.contains(&"process"), "got {names:?}");
294        assert!(names.contains(&"Config"), "got {names:?}");
295        assert!(names.contains(&"helper"), "got {names:?}");
296
297        let process = chunks.iter().find(|c| c.symbol_name == "process").unwrap();
298        assert!(matches!(process.kind, ChunkKind::Function));
299        assert!(process.content.contains("to_uppercase"));
300    }
301
302    #[test]
303    fn extract_typescript_chunks() {
304        let src = r"
305export function greet(name: string): string {
306    return `Hello ${name}`;
307}
308
309export class UserService {
310    findUser(id: number): User {
311        return db.find(id);
312    }
313}
314
315const handler = async (req: Request): Promise<Response> => {
316    return new Response();
317};
318";
319        let chunks = extract_chunks_ts("app.ts", src, "ts").unwrap();
320        assert!(
321            chunks.len() >= 3,
322            "expected >=3 chunks, got {}",
323            chunks.len()
324        );
325
326        let names: Vec<&str> = chunks.iter().map(|c| c.symbol_name.as_str()).collect();
327        assert!(names.contains(&"greet"), "got {names:?}");
328        assert!(names.contains(&"UserService"), "got {names:?}");
329    }
330
331    #[test]
332    fn extract_python_chunks() {
333        let src = r"
334class AuthService:
335    def __init__(self, db):
336        self.db = db
337
338    def authenticate(self, email: str) -> bool:
339        user = self.db.find(email)
340        return user is not None
341
342def create_app():
343    return Flask(__name__)
344";
345        let chunks = extract_chunks_ts("app.py", src, "py").unwrap();
346        assert!(
347            chunks.len() >= 2,
348            "expected >=2 chunks, got {}",
349            chunks.len()
350        );
351
352        let names: Vec<&str> = chunks.iter().map(|c| c.symbol_name.as_str()).collect();
353        assert!(names.contains(&"AuthService"), "got {names:?}");
354        assert!(names.contains(&"create_app"), "got {names:?}");
355
356        let auth = chunks
357            .iter()
358            .find(|c| c.symbol_name == "AuthService")
359            .unwrap();
360        assert!(auth.content.contains("authenticate"));
361    }
362
363    #[test]
364    fn chunks_contain_full_body() {
365        let src = r#"
366pub fn complex(x: i32, y: i32) -> Result<String, Error> {
367    let sum = x + y;
368    let result = format!("Sum: {}", sum);
369    if sum > 100 {
370        return Err(Error::new("too large"));
371    }
372    Ok(result)
373}
374"#;
375        let chunks = extract_chunks_ts("math.rs", src, "rs").unwrap();
376        let complex = chunks.iter().find(|c| c.symbol_name == "complex").unwrap();
377        assert!(complex.content.contains("sum > 100"));
378        assert!(complex.content.contains("Ok(result)"));
379    }
380
381    #[test]
382    fn unsupported_language_returns_none() {
383        assert!(extract_chunks_ts("file.xyz", "content", "xyz").is_none());
384    }
385
386    #[test]
387    fn empty_file_returns_none() {
388        assert!(extract_chunks_ts("empty.rs", "", "rs").is_none());
389    }
390
391    #[test]
392    fn chunks_sorted_by_line() {
393        let src = r"
394fn b_func() {}
395fn a_func() {}
396";
397        let chunks = extract_chunks_ts("sort.rs", src, "rs").unwrap();
398        assert!(chunks[0].start_line <= chunks[1].start_line);
399    }
400}