Skip to main content

leann_core/chunking/
ast.rs

1use std::collections::HashMap;
2
3/// Language detection from file extension.
4pub fn detect_language(filename: &str) -> Option<&'static str> {
5    let ext = filename.rsplit('.').next()?.to_lowercase();
6    match ext.as_str() {
7        "py" => Some("python"),
8        "rs" => Some("rust"),
9        "js" | "jsx" => Some("javascript"),
10        "ts" | "tsx" => Some("typescript"),
11        "java" => Some("java"),
12        "go" => Some("go"),
13        "c" | "h" => Some("c"),
14        "cpp" | "cxx" | "cc" | "hpp" => Some("cpp"),
15        "rb" => Some("ruby"),
16        "sh" | "bash" => Some("bash"),
17        _ => None,
18    }
19}
20
21/// A code chunk extracted from AST analysis.
22#[derive(Debug, Clone)]
23pub struct CodeChunk {
24    pub text: String,
25    pub chunk_type: String, // "function", "class", "method", "module", etc.
26    pub name: Option<String>,
27    pub start_line: usize,
28    pub end_line: usize,
29    pub language: String,
30    pub metadata: HashMap<String, serde_json::Value>,
31}
32
33/// Chunk source code using AST analysis.
34///
35/// When tree-sitter features are enabled, uses grammar-based parsing for
36/// accurate AST boundaries. Falls back to heuristic-based parsing otherwise.
37pub fn chunk_code(source: &str, filename: &str, max_chunk_size: usize) -> Vec<CodeChunk> {
38    #[cfg(any(
39        feature = "tree-sitter-python",
40        feature = "tree-sitter-java",
41        feature = "tree-sitter-c-sharp",
42        feature = "tree-sitter-typescript",
43        feature = "tree-sitter-javascript",
44    ))]
45    if let Some(chunks) =
46        super::tree_sitter::chunk_code_tree_sitter(source, filename, max_chunk_size)
47        && !chunks.is_empty()
48    {
49        return chunks;
50    }
51
52    // Heuristic fallback
53    let language = detect_language(filename).unwrap_or("unknown");
54
55    match language {
56        "python" => chunk_python(source, filename, max_chunk_size),
57        "rust" => chunk_rust(source, filename, max_chunk_size),
58        "javascript" | "typescript" => chunk_js_ts(source, filename, max_chunk_size),
59        _ => chunk_generic(source, filename, language, max_chunk_size),
60    }
61}
62
63/// Chunk Python source code by detecting function and class definitions.
64fn chunk_python(source: &str, filename: &str, max_chunk_size: usize) -> Vec<CodeChunk> {
65    let lines: Vec<&str> = source.lines().collect();
66    let mut chunks = Vec::new();
67    let mut i = 0;
68
69    while i < lines.len() {
70        let line = lines[i];
71        let trimmed = line.trim();
72
73        // Detect function or class definition
74        if trimmed.starts_with("def ")
75            || trimmed.starts_with("class ")
76            || trimmed.starts_with("async def ")
77        {
78            let indent = line.len() - line.trim_start().len();
79            let chunk_type = if trimmed.starts_with("class ") {
80                "class"
81            } else {
82                "function"
83            };
84
85            let name = extract_name(trimmed);
86            let start_line = i;
87
88            // Find the end of this block (next line at same or lower indentation)
89            let mut end_line = i + 1;
90            while end_line < lines.len() {
91                let next = lines[end_line];
92                if next.trim().is_empty() {
93                    end_line += 1;
94                    continue;
95                }
96                let next_indent = next.len() - next.trim_start().len();
97                if next_indent <= indent && !next.trim().is_empty() {
98                    // Check if this is a decorator for the next function
99                    if next.trim().starts_with('@') {
100                        break;
101                    }
102                    // Same or lower indent means block ended
103                    break;
104                }
105                end_line += 1;
106            }
107
108            let text: String = lines[start_line..end_line].join("\n");
109            if text.len() <= max_chunk_size {
110                chunks.push(CodeChunk {
111                    text,
112                    chunk_type: chunk_type.to_string(),
113                    name: Some(name),
114                    start_line: start_line + 1,
115                    end_line,
116                    language: "python".to_string(),
117                    metadata: make_metadata(filename, start_line + 1, end_line),
118                });
119            } else {
120                // Split large blocks
121                let sub_chunks = split_large_block(&lines[start_line..end_line], max_chunk_size);
122                for sub in sub_chunks {
123                    chunks.push(CodeChunk {
124                        text: sub,
125                        chunk_type: format!("{}_part", chunk_type),
126                        name: None,
127                        start_line: start_line + 1,
128                        end_line,
129                        language: "python".to_string(),
130                        metadata: make_metadata(filename, start_line + 1, end_line),
131                    });
132                }
133            }
134
135            i = end_line;
136        } else {
137            i += 1;
138        }
139    }
140
141    // If no chunks were found, fall back to generic chunking
142    if chunks.is_empty() {
143        return chunk_generic(source, filename, "python", max_chunk_size);
144    }
145
146    chunks
147}
148
149/// Chunk Rust source code by detecting fn, struct, impl, enum blocks.
150fn chunk_rust(source: &str, filename: &str, max_chunk_size: usize) -> Vec<CodeChunk> {
151    let lines: Vec<&str> = source.lines().collect();
152    let mut chunks = Vec::new();
153    let mut i = 0;
154
155    while i < lines.len() {
156        let trimmed = lines[i].trim();
157
158        let is_block_start = trimmed.starts_with("pub fn ")
159            || trimmed.starts_with("fn ")
160            || trimmed.starts_with("pub struct ")
161            || trimmed.starts_with("struct ")
162            || trimmed.starts_with("pub enum ")
163            || trimmed.starts_with("enum ")
164            || trimmed.starts_with("impl ")
165            || trimmed.starts_with("pub impl ")
166            || trimmed.starts_with("pub trait ")
167            || trimmed.starts_with("trait ")
168            || trimmed.starts_with("pub mod ")
169            || trimmed.starts_with("mod ");
170
171        if is_block_start {
172            let chunk_type = if trimmed.contains("fn ") {
173                "function"
174            } else if trimmed.contains("struct ") {
175                "struct"
176            } else if trimmed.contains("enum ") {
177                "enum"
178            } else if trimmed.contains("impl ") {
179                "impl"
180            } else if trimmed.contains("trait ") {
181                "trait"
182            } else {
183                "module"
184            };
185
186            let name = extract_rust_name(trimmed);
187            let start_line = i;
188
189            // Find matching closing brace using brace counting
190            let mut brace_count = 0;
191            let mut end_line = i;
192            let mut found_open = false;
193
194            for (j, line) in lines.iter().enumerate().skip(i) {
195                for ch in line.chars() {
196                    if ch == '{' {
197                        brace_count += 1;
198                        found_open = true;
199                    } else if ch == '}' {
200                        brace_count -= 1;
201                    }
202                }
203                end_line = j + 1;
204                if found_open && brace_count == 0 {
205                    break;
206                }
207            }
208
209            let text: String = lines[start_line..end_line].join("\n");
210            if text.len() <= max_chunk_size {
211                chunks.push(CodeChunk {
212                    text,
213                    chunk_type: chunk_type.to_string(),
214                    name: Some(name),
215                    start_line: start_line + 1,
216                    end_line,
217                    language: "rust".to_string(),
218                    metadata: make_metadata(filename, start_line + 1, end_line),
219                });
220            } else {
221                let sub_chunks = split_large_block(&lines[start_line..end_line], max_chunk_size);
222                for sub in sub_chunks {
223                    chunks.push(CodeChunk {
224                        text: sub,
225                        chunk_type: format!("{}_part", chunk_type),
226                        name: None,
227                        start_line: start_line + 1,
228                        end_line,
229                        language: "rust".to_string(),
230                        metadata: make_metadata(filename, start_line + 1, end_line),
231                    });
232                }
233            }
234
235            i = end_line;
236        } else {
237            i += 1;
238        }
239    }
240
241    if chunks.is_empty() {
242        return chunk_generic(source, filename, "rust", max_chunk_size);
243    }
244
245    chunks
246}
247
248/// Chunk JavaScript/TypeScript source code.
249fn chunk_js_ts(source: &str, filename: &str, max_chunk_size: usize) -> Vec<CodeChunk> {
250    let lines: Vec<&str> = source.lines().collect();
251    let mut chunks = Vec::new();
252    let mut i = 0;
253    let language = detect_language(filename).unwrap_or("javascript");
254
255    while i < lines.len() {
256        let trimmed = lines[i].trim();
257
258        let is_block_start = trimmed.starts_with("function ")
259            || trimmed.starts_with("async function ")
260            || trimmed.starts_with("export function ")
261            || trimmed.starts_with("export async function ")
262            || trimmed.starts_with("export default function ")
263            || trimmed.starts_with("class ")
264            || trimmed.starts_with("export class ")
265            || trimmed.starts_with("export default class ")
266            || trimmed.contains("=> {");
267
268        if is_block_start {
269            let chunk_type = if trimmed.contains("class ") {
270                "class"
271            } else {
272                "function"
273            };
274
275            let start_line = i;
276            let mut brace_count = 0;
277            let mut end_line = i;
278            let mut found_open = false;
279
280            for (j, line) in lines.iter().enumerate().skip(i) {
281                for ch in line.chars() {
282                    if ch == '{' {
283                        brace_count += 1;
284                        found_open = true;
285                    } else if ch == '}' {
286                        brace_count -= 1;
287                    }
288                }
289                end_line = j + 1;
290                if found_open && brace_count == 0 {
291                    break;
292                }
293            }
294
295            let text: String = lines[start_line..end_line].join("\n");
296            if text.len() <= max_chunk_size {
297                chunks.push(CodeChunk {
298                    text,
299                    chunk_type: chunk_type.to_string(),
300                    name: None,
301                    start_line: start_line + 1,
302                    end_line,
303                    language: language.to_string(),
304                    metadata: make_metadata(filename, start_line + 1, end_line),
305                });
306            }
307
308            i = end_line;
309        } else {
310            i += 1;
311        }
312    }
313
314    if chunks.is_empty() {
315        return chunk_generic(source, filename, language, max_chunk_size);
316    }
317
318    chunks
319}
320
321/// Generic line-based chunking for unsupported languages.
322fn chunk_generic(
323    source: &str,
324    filename: &str,
325    language: &str,
326    max_chunk_size: usize,
327) -> Vec<CodeChunk> {
328    let lines: Vec<&str> = source.lines().collect();
329    let mut chunks = Vec::new();
330    let mut current = String::new();
331    let mut start_line = 0;
332
333    for (i, line) in lines.iter().enumerate() {
334        if current.len() + line.len() + 1 > max_chunk_size && !current.is_empty() {
335            chunks.push(CodeChunk {
336                text: std::mem::take(&mut current),
337                chunk_type: "block".to_string(),
338                name: None,
339                start_line: start_line + 1,
340                end_line: i,
341                language: language.to_string(),
342                metadata: make_metadata(filename, start_line + 1, i),
343            });
344            start_line = i;
345        }
346
347        if !current.is_empty() {
348            current.push('\n');
349        }
350        current.push_str(line);
351    }
352
353    if !current.trim().is_empty() {
354        chunks.push(CodeChunk {
355            text: current,
356            chunk_type: "block".to_string(),
357            name: None,
358            start_line: start_line + 1,
359            end_line: lines.len(),
360            language: language.to_string(),
361            metadata: make_metadata(filename, start_line + 1, lines.len()),
362        });
363    }
364
365    chunks
366}
367
368fn extract_name(definition_line: &str) -> String {
369    let trimmed = definition_line.trim();
370    // "def foo(..." or "class Foo:" or "async def bar(..."
371    let parts: Vec<&str> = trimmed.split_whitespace().collect();
372    for (i, &part) in parts.iter().enumerate() {
373        if (part == "def" || part == "class")
374            && let Some(name) = parts.get(i + 1)
375        {
376            return name.trim_end_matches('(').trim_end_matches(':').to_string();
377        }
378    }
379    "unknown".to_string()
380}
381
382fn extract_rust_name(definition_line: &str) -> String {
383    let trimmed = definition_line.trim();
384    let keywords = ["fn", "struct", "enum", "impl", "trait", "mod"];
385    let parts: Vec<&str> = trimmed.split_whitespace().collect();
386    for (i, &part) in parts.iter().enumerate() {
387        if keywords.contains(&part)
388            && let Some(name) = parts.get(i + 1)
389        {
390            return name
391                .trim_end_matches('{')
392                .trim_end_matches('<')
393                .trim_end_matches('(')
394                .to_string();
395        }
396    }
397    "unknown".to_string()
398}
399
400pub(crate) fn split_large_block(lines: &[&str], max_size: usize) -> Vec<String> {
401    let mut chunks = Vec::new();
402    let mut current = String::new();
403
404    for line in lines {
405        if current.len() + line.len() + 1 > max_size && !current.is_empty() {
406            chunks.push(std::mem::take(&mut current));
407        }
408
409        // If a single line exceeds max_size, split it by characters
410        if line.len() > max_size && current.is_empty() {
411            let mut offset = 0;
412            while offset < line.len() {
413                let end = (offset + max_size).min(line.len());
414                chunks.push(line[offset..end].to_string());
415                offset = end;
416            }
417            continue;
418        }
419
420        if !current.is_empty() {
421            current.push('\n');
422        }
423        current.push_str(line);
424    }
425
426    if !current.trim().is_empty() {
427        chunks.push(current);
428    }
429
430    chunks
431}
432
433pub(crate) fn make_metadata(
434    filename: &str,
435    start_line: usize,
436    end_line: usize,
437) -> HashMap<String, serde_json::Value> {
438    let mut m = HashMap::new();
439    m.insert("source".to_string(), serde_json::json!(filename));
440    m.insert("start_line".to_string(), serde_json::json!(start_line));
441    m.insert("end_line".to_string(), serde_json::json!(end_line));
442    m
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448
449    #[test]
450    fn test_detect_language() {
451        assert_eq!(detect_language("foo.py"), Some("python"));
452        assert_eq!(detect_language("bar.rs"), Some("rust"));
453        assert_eq!(detect_language("baz.js"), Some("javascript"));
454        assert_eq!(detect_language("qux.txt"), None);
455    }
456
457    #[test]
458    fn test_chunk_python() {
459        let source = r#"
460def hello():
461    print("hello")
462
463def world():
464    print("world")
465
466class Foo:
467    def bar(self):
468        pass
469"#;
470        let chunks = chunk_code(source, "test.py", 1000);
471        assert!(
472            chunks.len() >= 2,
473            "Expected at least 2 chunks, got {}",
474            chunks.len()
475        );
476    }
477
478    #[test]
479    fn test_chunk_rust() {
480        let source = r#"
481fn hello() {
482    println!("hello");
483}
484
485fn world() {
486    println!("world");
487}
488
489struct Foo {
490    bar: i32,
491}
492"#;
493        let chunks = chunk_code(source, "test.rs", 1000);
494        assert!(
495            chunks.len() >= 2,
496            "Expected at least 2 chunks, got {}",
497            chunks.len()
498        );
499    }
500
501    #[test]
502    fn test_chunk_generic() {
503        let source = "line 1\nline 2\nline 3\nline 4\nline 5";
504        let chunks = chunk_code(source, "test.txt", 20);
505        assert!(!chunks.is_empty());
506    }
507}