1#[cfg(feature = "tree-sitter")]
10use tree_sitter::{Language, Node, Parser, Query, QueryCursor, StreamingIterator};
11
12use super::bm25_index::{ChunkKind, CodeChunk};
13
14#[cfg(feature = "tree-sitter")]
15const CHUNK_QUERY_RUST: &str = r"
16(function_item name: (identifier) @name) @chunk
17(struct_item name: (type_identifier) @name) @chunk
18(enum_item name: (type_identifier) @name) @chunk
19(trait_item name: (type_identifier) @name) @chunk
20(impl_item type: (type_identifier) @name) @chunk
21(const_item name: (identifier) @name) @chunk
22";
23
24#[cfg(feature = "tree-sitter")]
25const CHUNK_QUERY_TYPESCRIPT: &str = r"
26(function_declaration name: (identifier) @name) @chunk
27(class_declaration name: (type_identifier) @name) @chunk
28(abstract_class_declaration name: (type_identifier) @name) @chunk
29(interface_declaration name: (type_identifier) @name) @chunk
30(type_alias_declaration name: (type_identifier) @name) @chunk
31(method_definition name: (property_identifier) @name) @chunk
32(variable_declarator name: (identifier) @name value: (arrow_function)) @chunk
33";
34
35#[cfg(feature = "tree-sitter")]
36const CHUNK_QUERY_JAVASCRIPT: &str = r"
37(function_declaration name: (identifier) @name) @chunk
38(class_declaration name: (identifier) @name) @chunk
39(method_definition name: (property_identifier) @name) @chunk
40(variable_declarator name: (identifier) @name value: (arrow_function)) @chunk
41";
42
43#[cfg(feature = "tree-sitter")]
44const CHUNK_QUERY_PYTHON: &str = r"
45(function_definition name: (identifier) @name) @chunk
46(class_definition name: (identifier) @name) @chunk
47";
48
49#[cfg(feature = "tree-sitter")]
50const CHUNK_QUERY_GO: &str = r"
51(function_declaration name: (identifier) @name) @chunk
52(method_declaration name: (field_identifier) @name) @chunk
53(type_spec name: (type_identifier) @name) @chunk
54";
55
56#[cfg(feature = "tree-sitter")]
57const CHUNK_QUERY_JAVA: &str = r"
58(method_declaration name: (identifier) @name) @chunk
59(class_declaration name: (identifier) @name) @chunk
60(interface_declaration name: (identifier) @name) @chunk
61(enum_declaration name: (identifier) @name) @chunk
62(constructor_declaration name: (identifier) @name) @chunk
63";
64
65#[cfg(feature = "tree-sitter")]
66const CHUNK_QUERY_C: &str = r"
67(function_definition
68 declarator: (function_declarator
69 declarator: (identifier) @name)) @chunk
70(struct_specifier name: (type_identifier) @name) @chunk
71(enum_specifier name: (type_identifier) @name) @chunk
72";
73
74#[cfg(feature = "tree-sitter")]
75const CHUNK_QUERY_CPP: &str = r"
76(function_definition
77 declarator: (function_declarator
78 declarator: (_) @name)) @chunk
79(struct_specifier name: (type_identifier) @name) @chunk
80(class_specifier name: (type_identifier) @name) @chunk
81(enum_specifier name: (type_identifier) @name) @chunk
82(namespace_definition name: (identifier) @name) @chunk
83";
84
85#[cfg(feature = "tree-sitter")]
90fn get_cached_query(file_ext: &str) -> Option<&'static Query> {
91 use std::collections::HashMap;
92 use std::sync::OnceLock;
93
94 static QUERY_CACHE: OnceLock<HashMap<&'static str, Query>> = OnceLock::new();
95
96 let cache = QUERY_CACHE.get_or_init(|| {
97 let mut map = HashMap::new();
98 let exts: &[&str] = &[
99 "rs", "ts", "tsx", "js", "jsx", "py", "go", "java", "c", "h", "cpp", "cc", "cxx", "hpp",
100 ];
101 for &ext in exts {
102 if let (Some(lang), Some(src)) = (get_language(ext), get_chunk_query(ext)) {
103 if let Ok(q) = Query::new(&lang, src) {
104 map.insert(ext, q);
105 }
106 }
107 }
108 map
109 });
110
111 cache.get(file_ext)
112}
113
114#[cfg(feature = "tree-sitter")]
119pub(crate) fn for_each_chunk_node(
120 content: &str,
121 file_ext: &str,
122 mut visitor: impl FnMut(Node, &str, ChunkKind, usize, usize),
123) -> Option<()> {
124 let language = get_language(file_ext)?;
125
126 thread_local! {
127 static PARSER: std::cell::RefCell<Parser> = std::cell::RefCell::new(Parser::new());
128 }
129
130 let tree = PARSER.with(|p| {
131 let mut parser = p.borrow_mut();
132 let _ = parser.set_language(&language);
133 parser.parse(content, None)
134 })?;
135
136 let query = get_cached_query(file_ext)?;
137 let chunk_idx = find_capture_index(query, "chunk")?;
138 let name_idx = find_capture_index(query, "name")?;
139
140 let source = content.as_bytes();
141 let mut cursor = QueryCursor::new();
142 let mut matches = cursor.matches(query, tree.root_node(), source);
143 let mut seen_ranges = Vec::new();
144
145 while let Some(m) = matches.next() {
146 let mut chunk_node: Option<Node> = None;
147 let mut name_text = String::new();
148
149 for cap in m.captures {
150 if cap.index == chunk_idx {
151 chunk_node = Some(cap.node);
152 } else if cap.index == name_idx {
153 if let Ok(text) = cap.node.utf8_text(source) {
154 name_text = text.to_string();
155 }
156 }
157 }
158
159 if let Some(node) = chunk_node {
160 if name_text.is_empty() {
161 continue;
162 }
163
164 let start_row0 = node.start_position().row;
165 let end_row0 = node.end_position().row;
166
167 let range = (start_row0, end_row0);
168 if seen_ranges
169 .iter()
170 .any(|&(s, e)| s <= start_row0 && end_row0 <= e && range != (s, e))
171 {
172 continue;
173 }
174 seen_ranges.push(range);
175
176 let kind = node_kind_to_chunk_kind(node.kind());
177 visitor(node, name_text.as_str(), kind, start_row0 + 1, end_row0 + 1);
178 }
179 }
180
181 Some(())
182}
183
184#[cfg(feature = "tree-sitter")]
185pub fn extract_chunks_ts(file_path: &str, content: &str, file_ext: &str) -> Option<Vec<CodeChunk>> {
186 let lines: Vec<&str> = content.lines().collect();
187 let mut chunks = Vec::new();
188
189 for_each_chunk_node(
190 content,
191 file_ext,
192 |node, name_text, kind, start_line, end_line| {
193 let start_row0 = node.start_position().row;
194 let end_row0 = node.end_position().row;
195 let block: String = lines[start_row0..=end_row0.min(lines.len().saturating_sub(1))]
196 .to_vec()
197 .join("\n");
198 let token_count = super::bm25_index::tokenize_for_index(&block).len();
199
200 chunks.push(CodeChunk {
201 file_path: file_path.to_string(),
202 symbol_name: name_text.to_string(),
203 kind,
204 start_line,
205 end_line,
206 content: block,
207 tokens: Vec::new(),
208 token_count,
209 });
210 },
211 )?;
212
213 if chunks.is_empty() {
214 return None;
215 }
216
217 chunks.sort_by_key(|c| c.start_line);
218 Some(chunks)
219}
220
221#[cfg(not(feature = "tree-sitter"))]
222pub fn extract_chunks_ts(
223 _file_path: &str,
224 _content: &str,
225 _file_ext: &str,
226) -> Option<Vec<CodeChunk>> {
227 None
228}
229
230#[cfg(feature = "tree-sitter")]
231fn get_language(ext: &str) -> Option<Language> {
232 Some(match ext {
233 "rs" => tree_sitter_rust::LANGUAGE.into(),
234 "ts" => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
235 "tsx" => tree_sitter_typescript::LANGUAGE_TSX.into(),
236 "js" | "jsx" => tree_sitter_javascript::LANGUAGE.into(),
237 "py" => tree_sitter_python::LANGUAGE.into(),
238 "go" => tree_sitter_go::LANGUAGE.into(),
239 "java" => tree_sitter_java::LANGUAGE.into(),
240 "c" | "h" => tree_sitter_c::LANGUAGE.into(),
241 "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "hh" => tree_sitter_cpp::LANGUAGE.into(),
242 _ => return None,
243 })
244}
245
246#[cfg(feature = "tree-sitter")]
247fn get_chunk_query(ext: &str) -> Option<&'static str> {
248 Some(match ext {
249 "rs" => CHUNK_QUERY_RUST,
250 "ts" | "tsx" => CHUNK_QUERY_TYPESCRIPT,
251 "js" | "jsx" => CHUNK_QUERY_JAVASCRIPT,
252 "py" => CHUNK_QUERY_PYTHON,
253 "go" => CHUNK_QUERY_GO,
254 "java" => CHUNK_QUERY_JAVA,
255 "c" | "h" => CHUNK_QUERY_C,
256 "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "hh" => CHUNK_QUERY_CPP,
257 _ => return None,
258 })
259}
260
261#[cfg(feature = "tree-sitter")]
262fn find_capture_index(query: &Query, name: &str) -> Option<u32> {
263 query
264 .capture_names()
265 .iter()
266 .position(|n| *n == name)
267 .map(|i| i as u32)
268}
269
270fn node_kind_to_chunk_kind(kind: &str) -> ChunkKind {
271 match kind {
272 "function_item"
273 | "function_declaration"
274 | "function_definition"
275 | "method_declaration"
276 | "method_definition"
277 | "constructor_declaration"
278 | "variable_declarator" => ChunkKind::Function,
279
280 "struct_item"
281 | "struct_specifier"
282 | "struct_declaration"
283 | "enum_item"
284 | "enum_specifier"
285 | "enum_declaration"
286 | "trait_item"
287 | "interface_declaration"
288 | "type_alias_declaration"
289 | "type_spec" => ChunkKind::Struct,
290
291 "impl_item" => ChunkKind::Impl,
292
293 "class_declaration"
294 | "abstract_class_declaration"
295 | "class_specifier"
296 | "class_definition" => ChunkKind::Class,
297
298 "namespace_definition" | "namespace_declaration" => ChunkKind::Module,
299
300 _ => ChunkKind::Other,
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn extract_rust_chunks() {
310 let src = r#"use std::io;
311
312pub fn process(input: &str) -> String {
313 input.to_uppercase()
314}
315
316pub struct Config {
317 pub name: String,
318 pub port: u16,
319}
320
321impl Config {
322 pub fn new() -> Self {
323 Self { name: "default".into(), port: 8080 }
324 }
325}
326
327fn helper() -> bool {
328 true
329}
330"#;
331 let chunks = extract_chunks_ts("main.rs", src, "rs").unwrap();
332 assert!(
333 chunks.len() >= 4,
334 "expected >=4 chunks, got {}",
335 chunks.len()
336 );
337
338 let names: Vec<&str> = chunks.iter().map(|c| c.symbol_name.as_str()).collect();
339 assert!(names.contains(&"process"), "got {names:?}");
340 assert!(names.contains(&"Config"), "got {names:?}");
341 assert!(names.contains(&"helper"), "got {names:?}");
342
343 let process = chunks.iter().find(|c| c.symbol_name == "process").unwrap();
344 assert!(matches!(process.kind, ChunkKind::Function));
345 assert!(process.content.contains("to_uppercase"));
346 }
347
348 #[test]
349 fn extract_typescript_chunks() {
350 let src = r"
351export function greet(name: string): string {
352 return `Hello ${name}`;
353}
354
355export class UserService {
356 findUser(id: number): User {
357 return db.find(id);
358 }
359}
360
361const handler = async (req: Request): Promise<Response> => {
362 return new Response();
363};
364";
365 let chunks = extract_chunks_ts("app.ts", src, "ts").unwrap();
366 assert!(
367 chunks.len() >= 3,
368 "expected >=3 chunks, got {}",
369 chunks.len()
370 );
371
372 let names: Vec<&str> = chunks.iter().map(|c| c.symbol_name.as_str()).collect();
373 assert!(names.contains(&"greet"), "got {names:?}");
374 assert!(names.contains(&"UserService"), "got {names:?}");
375 }
376
377 #[test]
378 fn extract_python_chunks() {
379 let src = r"
380class AuthService:
381 def __init__(self, db):
382 self.db = db
383
384 def authenticate(self, email: str) -> bool:
385 user = self.db.find(email)
386 return user is not None
387
388def create_app():
389 return Flask(__name__)
390";
391 let chunks = extract_chunks_ts("app.py", src, "py").unwrap();
392 assert!(
393 chunks.len() >= 2,
394 "expected >=2 chunks, got {}",
395 chunks.len()
396 );
397
398 let names: Vec<&str> = chunks.iter().map(|c| c.symbol_name.as_str()).collect();
399 assert!(names.contains(&"AuthService"), "got {names:?}");
400 assert!(names.contains(&"create_app"), "got {names:?}");
401
402 let auth = chunks
403 .iter()
404 .find(|c| c.symbol_name == "AuthService")
405 .unwrap();
406 assert!(auth.content.contains("authenticate"));
407 }
408
409 #[test]
410 fn chunks_contain_full_body() {
411 let src = r#"
412pub fn complex(x: i32, y: i32) -> Result<String, Error> {
413 let sum = x + y;
414 let result = format!("Sum: {}", sum);
415 if sum > 100 {
416 return Err(Error::new("too large"));
417 }
418 Ok(result)
419}
420"#;
421 let chunks = extract_chunks_ts("math.rs", src, "rs").unwrap();
422 let complex = chunks.iter().find(|c| c.symbol_name == "complex").unwrap();
423 assert!(complex.content.contains("sum > 100"));
424 assert!(complex.content.contains("Ok(result)"));
425 }
426
427 #[test]
428 fn unsupported_language_returns_none() {
429 assert!(extract_chunks_ts("file.xyz", "content", "xyz").is_none());
430 }
431
432 #[test]
433 fn empty_file_returns_none() {
434 assert!(extract_chunks_ts("empty.rs", "", "rs").is_none());
435 }
436
437 #[test]
438 fn chunks_sorted_by_line() {
439 let src = r"
440fn b_func() {}
441fn a_func() {}
442";
443 let chunks = extract_chunks_ts("sort.rs", src, "rs").unwrap();
444 assert!(chunks[0].start_line <= chunks[1].start_line);
445 }
446}