1#[cfg(feature = "tree-sitter")]
10use tree_sitter::{Language, Node, Parser, Query, QueryCursor, StreamingIterator};
11
12use super::bm25_index::{ChunkKind, CodeChunk};
13
14#[cfg(feature = "tree-sitter")]
15const CHUNK_QUERY_RUST: &str = r"
16(function_item name: (identifier) @name) @chunk
17(struct_item name: (type_identifier) @name) @chunk
18(enum_item name: (type_identifier) @name) @chunk
19(trait_item name: (type_identifier) @name) @chunk
20(impl_item type: (type_identifier) @name) @chunk
21(const_item name: (identifier) @name) @chunk
22";
23
24#[cfg(feature = "tree-sitter")]
25const CHUNK_QUERY_TYPESCRIPT: &str = r"
26(function_declaration name: (identifier) @name) @chunk
27(class_declaration name: (type_identifier) @name) @chunk
28(abstract_class_declaration name: (type_identifier) @name) @chunk
29(interface_declaration name: (type_identifier) @name) @chunk
30(type_alias_declaration name: (type_identifier) @name) @chunk
31(method_definition name: (property_identifier) @name) @chunk
32(variable_declarator name: (identifier) @name value: (arrow_function)) @chunk
33";
34
35#[cfg(feature = "tree-sitter")]
36const CHUNK_QUERY_JAVASCRIPT: &str = r"
37(function_declaration name: (identifier) @name) @chunk
38(class_declaration name: (identifier) @name) @chunk
39(method_definition name: (property_identifier) @name) @chunk
40(variable_declarator name: (identifier) @name value: (arrow_function)) @chunk
41";
42
43#[cfg(feature = "tree-sitter")]
44const CHUNK_QUERY_PYTHON: &str = r"
45(function_definition name: (identifier) @name) @chunk
46(class_definition name: (identifier) @name) @chunk
47";
48
49#[cfg(feature = "tree-sitter")]
50const CHUNK_QUERY_GO: &str = r"
51(function_declaration name: (identifier) @name) @chunk
52(method_declaration name: (field_identifier) @name) @chunk
53(type_spec name: (type_identifier) @name) @chunk
54";
55
56#[cfg(feature = "tree-sitter")]
57const CHUNK_QUERY_JAVA: &str = r"
58(method_declaration name: (identifier) @name) @chunk
59(class_declaration name: (identifier) @name) @chunk
60(interface_declaration name: (identifier) @name) @chunk
61(enum_declaration name: (identifier) @name) @chunk
62(constructor_declaration name: (identifier) @name) @chunk
63";
64
65#[cfg(feature = "tree-sitter")]
66const CHUNK_QUERY_C: &str = r"
67(function_definition
68 declarator: (function_declarator
69 declarator: (identifier) @name)) @chunk
70(struct_specifier name: (type_identifier) @name) @chunk
71(enum_specifier name: (type_identifier) @name) @chunk
72";
73
74#[cfg(feature = "tree-sitter")]
75const CHUNK_QUERY_CPP: &str = r"
76(function_definition
77 declarator: (function_declarator
78 declarator: (_) @name)) @chunk
79(struct_specifier name: (type_identifier) @name) @chunk
80(class_specifier name: (type_identifier) @name) @chunk
81(enum_specifier name: (type_identifier) @name) @chunk
82(namespace_definition name: (identifier) @name) @chunk
83";
84
85#[cfg(feature = "tree-sitter")]
90pub fn extract_chunks_ts(file_path: &str, content: &str, file_ext: &str) -> Option<Vec<CodeChunk>> {
91 let language = get_language(file_ext)?;
92 let query_src = get_chunk_query(file_ext)?;
93
94 thread_local! {
95 static PARSER: std::cell::RefCell<Parser> = std::cell::RefCell::new(Parser::new());
96 }
97
98 let tree = PARSER.with(|p| {
99 let mut parser = p.borrow_mut();
100 let _ = parser.set_language(&language);
101 parser.parse(content, None)
102 })?;
103
104 let query = Query::new(&language, query_src).ok()?;
105 let chunk_idx = find_capture_index(&query, "chunk")?;
106 let name_idx = find_capture_index(&query, "name")?;
107
108 let source = content.as_bytes();
109 let lines: Vec<&str> = content.lines().collect();
110 let mut chunks = Vec::new();
111 let mut cursor = QueryCursor::new();
112 let mut matches = cursor.matches(&query, tree.root_node(), source);
113 let mut seen_ranges = Vec::new();
114
115 while let Some(m) = matches.next() {
116 let mut chunk_node: Option<Node> = None;
117 let mut name_text = String::new();
118
119 for cap in m.captures {
120 if cap.index == chunk_idx {
121 chunk_node = Some(cap.node);
122 } else if cap.index == name_idx {
123 if let Ok(text) = cap.node.utf8_text(source) {
124 name_text = text.to_string();
125 }
126 }
127 }
128
129 if let Some(node) = chunk_node {
130 if name_text.is_empty() {
131 continue;
132 }
133
134 let start_line = node.start_position().row;
135 let end_line = node.end_position().row;
136
137 let range = (start_line, end_line);
138 if seen_ranges
139 .iter()
140 .any(|&(s, e)| s <= start_line && end_line <= e && range != (s, e))
141 {
142 continue;
143 }
144 seen_ranges.push(range);
145
146 let block: String = lines[start_line..=end_line.min(lines.len() - 1)]
147 .to_vec()
148 .join("\n");
149
150 let kind = node_kind_to_chunk_kind(node.kind());
151 let token_count = super::bm25_index::tokenize_for_index(&block).len();
152
153 chunks.push(CodeChunk {
154 file_path: file_path.to_string(),
155 symbol_name: name_text,
156 kind,
157 start_line: start_line + 1,
158 end_line: end_line + 1,
159 content: block,
160 tokens: Vec::new(),
161 token_count,
162 });
163 }
164 }
165
166 if chunks.is_empty() {
167 return None;
168 }
169
170 chunks.sort_by_key(|c| c.start_line);
171 Some(chunks)
172}
173
174#[cfg(not(feature = "tree-sitter"))]
175pub fn extract_chunks_ts(
176 _file_path: &str,
177 _content: &str,
178 _file_ext: &str,
179) -> Option<Vec<CodeChunk>> {
180 None
181}
182
183#[cfg(feature = "tree-sitter")]
184fn get_language(ext: &str) -> Option<Language> {
185 Some(match ext {
186 "rs" => tree_sitter_rust::LANGUAGE.into(),
187 "ts" => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
188 "tsx" => tree_sitter_typescript::LANGUAGE_TSX.into(),
189 "js" | "jsx" => tree_sitter_javascript::LANGUAGE.into(),
190 "py" => tree_sitter_python::LANGUAGE.into(),
191 "go" => tree_sitter_go::LANGUAGE.into(),
192 "java" => tree_sitter_java::LANGUAGE.into(),
193 "c" | "h" => tree_sitter_c::LANGUAGE.into(),
194 "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "hh" => tree_sitter_cpp::LANGUAGE.into(),
195 _ => return None,
196 })
197}
198
199#[cfg(feature = "tree-sitter")]
200fn get_chunk_query(ext: &str) -> Option<&'static str> {
201 Some(match ext {
202 "rs" => CHUNK_QUERY_RUST,
203 "ts" | "tsx" => CHUNK_QUERY_TYPESCRIPT,
204 "js" | "jsx" => CHUNK_QUERY_JAVASCRIPT,
205 "py" => CHUNK_QUERY_PYTHON,
206 "go" => CHUNK_QUERY_GO,
207 "java" => CHUNK_QUERY_JAVA,
208 "c" | "h" => CHUNK_QUERY_C,
209 "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "hh" => CHUNK_QUERY_CPP,
210 _ => return None,
211 })
212}
213
214#[cfg(feature = "tree-sitter")]
215fn find_capture_index(query: &Query, name: &str) -> Option<u32> {
216 query
217 .capture_names()
218 .iter()
219 .position(|n| *n == name)
220 .map(|i| i as u32)
221}
222
223fn node_kind_to_chunk_kind(kind: &str) -> ChunkKind {
224 match kind {
225 "function_item"
226 | "function_declaration"
227 | "function_definition"
228 | "method_declaration"
229 | "method_definition"
230 | "constructor_declaration"
231 | "variable_declarator" => ChunkKind::Function,
232
233 "struct_item"
234 | "struct_specifier"
235 | "struct_declaration"
236 | "enum_item"
237 | "enum_specifier"
238 | "enum_declaration"
239 | "trait_item"
240 | "interface_declaration"
241 | "type_alias_declaration"
242 | "type_spec" => ChunkKind::Struct,
243
244 "impl_item" => ChunkKind::Impl,
245
246 "class_declaration"
247 | "abstract_class_declaration"
248 | "class_specifier"
249 | "class_definition" => ChunkKind::Class,
250
251 "namespace_definition" | "namespace_declaration" => ChunkKind::Module,
252
253 _ => ChunkKind::Other,
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 #[test]
262 fn extract_rust_chunks() {
263 let src = r#"use std::io;
264
265pub fn process(input: &str) -> String {
266 input.to_uppercase()
267}
268
269pub struct Config {
270 pub name: String,
271 pub port: u16,
272}
273
274impl Config {
275 pub fn new() -> Self {
276 Self { name: "default".into(), port: 8080 }
277 }
278}
279
280fn helper() -> bool {
281 true
282}
283"#;
284 let chunks = extract_chunks_ts("main.rs", src, "rs").unwrap();
285 assert!(
286 chunks.len() >= 4,
287 "expected >=4 chunks, got {}",
288 chunks.len()
289 );
290
291 let names: Vec<&str> = chunks.iter().map(|c| c.symbol_name.as_str()).collect();
292 assert!(names.contains(&"process"), "got {names:?}");
293 assert!(names.contains(&"Config"), "got {names:?}");
294 assert!(names.contains(&"helper"), "got {names:?}");
295
296 let process = chunks.iter().find(|c| c.symbol_name == "process").unwrap();
297 assert!(matches!(process.kind, ChunkKind::Function));
298 assert!(process.content.contains("to_uppercase"));
299 }
300
301 #[test]
302 fn extract_typescript_chunks() {
303 let src = r"
304export function greet(name: string): string {
305 return `Hello ${name}`;
306}
307
308export class UserService {
309 findUser(id: number): User {
310 return db.find(id);
311 }
312}
313
314const handler = async (req: Request): Promise<Response> => {
315 return new Response();
316};
317";
318 let chunks = extract_chunks_ts("app.ts", src, "ts").unwrap();
319 assert!(
320 chunks.len() >= 3,
321 "expected >=3 chunks, got {}",
322 chunks.len()
323 );
324
325 let names: Vec<&str> = chunks.iter().map(|c| c.symbol_name.as_str()).collect();
326 assert!(names.contains(&"greet"), "got {names:?}");
327 assert!(names.contains(&"UserService"), "got {names:?}");
328 }
329
330 #[test]
331 fn extract_python_chunks() {
332 let src = r"
333class AuthService:
334 def __init__(self, db):
335 self.db = db
336
337 def authenticate(self, email: str) -> bool:
338 user = self.db.find(email)
339 return user is not None
340
341def create_app():
342 return Flask(__name__)
343";
344 let chunks = extract_chunks_ts("app.py", src, "py").unwrap();
345 assert!(
346 chunks.len() >= 2,
347 "expected >=2 chunks, got {}",
348 chunks.len()
349 );
350
351 let names: Vec<&str> = chunks.iter().map(|c| c.symbol_name.as_str()).collect();
352 assert!(names.contains(&"AuthService"), "got {names:?}");
353 assert!(names.contains(&"create_app"), "got {names:?}");
354
355 let auth = chunks
356 .iter()
357 .find(|c| c.symbol_name == "AuthService")
358 .unwrap();
359 assert!(auth.content.contains("authenticate"));
360 }
361
362 #[test]
363 fn chunks_contain_full_body() {
364 let src = r#"
365pub fn complex(x: i32, y: i32) -> Result<String, Error> {
366 let sum = x + y;
367 let result = format!("Sum: {}", sum);
368 if sum > 100 {
369 return Err(Error::new("too large"));
370 }
371 Ok(result)
372}
373"#;
374 let chunks = extract_chunks_ts("math.rs", src, "rs").unwrap();
375 let complex = chunks.iter().find(|c| c.symbol_name == "complex").unwrap();
376 assert!(complex.content.contains("sum > 100"));
377 assert!(complex.content.contains("Ok(result)"));
378 }
379
380 #[test]
381 fn unsupported_language_returns_none() {
382 assert!(extract_chunks_ts("file.xyz", "content", "xyz").is_none());
383 }
384
385 #[test]
386 fn empty_file_returns_none() {
387 assert!(extract_chunks_ts("empty.rs", "", "rs").is_none());
388 }
389
390 #[test]
391 fn chunks_sorted_by_line() {
392 let src = r"
393fn b_func() {}
394fn a_func() {}
395";
396 let chunks = extract_chunks_ts("sort.rs", src, "rs").unwrap();
397 assert!(chunks[0].start_line <= chunks[1].start_line);
398 }
399}