Skip to main content

codemem_engine/index/
chunker.rs

1//! CST-aware code chunking.
2//!
3//! Splits source files into semantically meaningful chunks using the
4//! concrete syntax tree (CST) produced by ast-grep/tree-sitter. The algorithm:
5//!
6//! 1. If a CST node fits within `max_chunk_size` (non-whitespace chars) -> emit it as a chunk.
7//! 2. If too large -> recurse into named children, preferring semantic boundaries
8//!    (function/class/impl definitions) as split points.
9//! 3. Adjacent small siblings are merged greedily, but only when they share the same
10//!    semantic category (e.g., imports with imports, declarations with declarations).
11//! 4. When a chunk comes from inside a function/class, a truncated signature header
12//!    is prepended so the chunk is self-contextualizing for embeddings.
13//!
14//! Each chunk records its parent symbol (resolved by line-range containment).
15
16use crate::index::symbol::Symbol;
17use ast_grep_core::{Doc, Node};
18use serde::{Deserialize, Serialize};
19
20/// A code chunk produced by the CST-aware chunker.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct CodeChunk {
23    /// 0-based index of this chunk within the file.
24    pub index: usize,
25    /// The source text of this chunk.
26    pub text: String,
27    /// The tree-sitter node kind (e.g., "function_item", "impl_item").
28    pub node_kind: String,
29    /// 0-based starting line.
30    pub line_start: usize,
31    /// 0-based ending line.
32    pub line_end: usize,
33    /// Byte offset start.
34    pub byte_start: usize,
35    /// Byte offset end.
36    pub byte_end: usize,
37    /// Count of non-whitespace characters.
38    pub non_ws_chars: usize,
39    /// Qualified name of the innermost containing symbol, if any.
40    pub parent_symbol: Option<String>,
41    /// Path of the source file.
42    pub file_path: String,
43}
44
45/// Configuration for the chunker.
46#[derive(Debug, Clone)]
47pub struct ChunkConfig {
48    /// Maximum chunk size in non-whitespace characters.
49    pub max_chunk_size: usize,
50    /// Minimum chunk size in non-whitespace characters.
51    pub min_chunk_size: usize,
52    /// Number of lines to overlap between adjacent chunks (0 = no overlap).
53    pub overlap_lines: usize,
54}
55
56impl Default for ChunkConfig {
57    fn default() -> Self {
58        Self {
59            max_chunk_size: 1500,
60            min_chunk_size: 50,
61            overlap_lines: 0,
62        }
63    }
64}
65
66/// Count non-whitespace characters in a string.
67fn count_non_ws(s: &str) -> usize {
68    s.chars().filter(|c| !c.is_whitespace()).count()
69}
70
71/// Semantic category of a CST node, used to decide merge compatibility.
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73enum SemanticCategory {
74    /// Import/use statements
75    Import,
76    /// Function, method, class, struct, enum, impl, trait, interface definitions
77    Declaration,
78    /// Comments and doc comments
79    Comment,
80    /// Everything else (expressions, statements, etc.)
81    Other,
82}
83
84/// Classify a tree-sitter node kind into a semantic category.
85fn classify_node(kind: &str) -> SemanticCategory {
86    match kind {
87        // Imports / use
88        k if k.contains("import")
89            || k == "use_declaration"
90            || k == "use_item"
91            || k == "extern_crate_declaration"
92            || k == "include_directive"
93            || k == "using_declaration"
94            || k == "package_declaration" =>
95        {
96            SemanticCategory::Import
97        }
98
99        // Comments
100        k if k.contains("comment")
101            || k == "line_comment"
102            || k == "block_comment"
103            || k == "doc_comment" =>
104        {
105            SemanticCategory::Comment
106        }
107
108        // Declarations — functions, classes, structs, etc.
109        k if k.contains("function")
110            || k.contains("method")
111            || k.contains("class")
112            || k.contains("struct")
113            || k.contains("enum")
114            || k.contains("interface")
115            || k.contains("trait")
116            || k.contains("impl")
117            || k == "const_item"
118            || k == "static_item"
119            || k == "type_alias"
120            || k == "type_item"
121            || k == "mod_item"
122            || k == "module"
123            || k == "lexical_declaration"
124            || k == "variable_declaration"
125            || k == "export_statement" =>
126        {
127            SemanticCategory::Declaration
128        }
129
130        _ => SemanticCategory::Other,
131    }
132}
133
134/// Returns true if a node kind represents a semantic boundary — a top-level
135/// declaration that should be kept whole when possible.
136fn is_semantic_boundary(kind: &str) -> bool {
137    matches!(classify_node(kind), SemanticCategory::Declaration)
138}
139
140/// Intermediate chunk before index assignment and parent resolution.
141struct RawChunk {
142    text: String,
143    node_kind: String,
144    line_start: usize,
145    line_end: usize,
146    byte_start: usize,
147    byte_end: usize,
148    non_ws_chars: usize,
149    /// Semantic category for merge compatibility.
150    category: SemanticCategory,
151}
152
153/// Chunk a file using its CST tree.
154///
155/// - `root`: the ast-grep parsed root (AstGrep instance).
156/// - `source`: the raw source string.
157/// - `file_path`: path to the source file (stored in each chunk).
158/// - `symbols`: symbols extracted from this file (used for parent resolution).
159/// - `config`: chunking parameters.
160pub fn chunk_file<D: Doc>(
161    root: &ast_grep_core::AstGrep<D>,
162    source: &str,
163    file_path: &str,
164    symbols: &[Symbol],
165    config: &ChunkConfig,
166) -> Vec<CodeChunk>
167where
168    D::Lang: ast_grep_core::Language,
169{
170    if source.trim().is_empty() {
171        return Vec::new();
172    }
173
174    let root_node = root.root();
175    let mut raw_chunks = Vec::new();
176    collect_chunks(&root_node, config, &mut raw_chunks);
177
178    // Merge adjacent small chunks greedily
179    let merged = merge_small_chunks(raw_chunks, source, config);
180
181    // C3: Apply overlap — prepend trailing lines from previous chunk
182    let merged = if config.overlap_lines > 0 {
183        apply_overlap(merged, source, config.overlap_lines)
184    } else {
185        merged
186    };
187
188    // C2: Build interval index once for O(log n) parent resolution per chunk
189    let interval_index = SymbolIntervalIndex::build(symbols);
190
191    // Assign indices, resolve parent symbols, and inject signature context
192    merged
193        .into_iter()
194        .enumerate()
195        .map(|(idx, raw)| {
196            let parent = interval_index.resolve(raw.line_start, raw.line_end);
197            let parent_symbol = parent.map(|s| s.qualified_name.clone());
198
199            // Signature context injection: if this chunk is strictly inside a
200            // symbol (doesn't start at the symbol's first line), prepend a
201            // truncated signature so the chunk is self-contextualizing.
202            let text = if let Some(sym) = parent {
203                if raw.line_start > sym.line_start && !sym.signature.is_empty() {
204                    let sig = truncate_signature(&sym.signature, 120);
205                    format!("[context: {sig}]\n{}", raw.text)
206                } else {
207                    raw.text
208                }
209            } else {
210                raw.text
211            };
212
213            CodeChunk {
214                index: idx,
215                non_ws_chars: count_non_ws(&text),
216                text,
217                node_kind: raw.node_kind,
218                line_start: raw.line_start,
219                line_end: raw.line_end,
220                byte_start: raw.byte_start,
221                byte_end: raw.byte_end,
222                parent_symbol,
223                file_path: file_path.to_string(),
224            }
225        })
226        .collect()
227}
228
229/// Recursively collect chunks from a CST node.
230fn collect_chunks<D: Doc>(node: &Node<'_, D>, config: &ChunkConfig, out: &mut Vec<RawChunk>)
231where
232    D::Lang: ast_grep_core::Language,
233{
234    let text = node.text();
235    let nws = count_non_ws(&text);
236    let kind = node.kind().to_string();
237
238    // If the node fits, emit it as a single chunk
239    if nws <= config.max_chunk_size {
240        let range = node.range();
241        out.push(RawChunk {
242            text: text.to_string(),
243            category: classify_node(&kind),
244            node_kind: kind,
245            line_start: node.start_pos().line(),
246            line_end: node.end_pos().line(),
247            byte_start: range.start,
248            byte_end: range.end,
249            non_ws_chars: nws,
250        });
251        return;
252    }
253
254    // Too large: recurse into named children
255    let named_children: Vec<_> = node.children().filter(|c| c.is_named()).collect();
256    if named_children.is_empty() {
257        // No named children (e.g., a very large string literal), emit as-is
258        let range = node.range();
259        out.push(RawChunk {
260            text: text.to_string(),
261            category: classify_node(&kind),
262            node_kind: kind,
263            line_start: node.start_pos().line(),
264            line_end: node.end_pos().line(),
265            byte_start: range.start,
266            byte_end: range.end,
267            non_ws_chars: nws,
268        });
269        return;
270    }
271
272    // Semantic-boundary-aware splitting: if this node contains semantic boundaries
273    // (e.g., an impl block with methods), split at those boundaries. Group
274    // non-boundary children between boundaries together.
275    let has_boundaries = named_children
276        .iter()
277        .any(|c| is_semantic_boundary(&c.kind()));
278
279    if has_boundaries {
280        // Collect runs: non-boundary children are grouped, boundary children
281        // are recursed individually.
282        let mut non_boundary_group: Vec<&Node<'_, D>> = Vec::new();
283        for child in &named_children {
284            if is_semantic_boundary(&child.kind()) {
285                // Flush any accumulated non-boundary nodes as a merged chunk
286                if !non_boundary_group.is_empty() {
287                    emit_group(&non_boundary_group, config, out);
288                    non_boundary_group.clear();
289                }
290                // Recurse the boundary child on its own
291                collect_chunks(child, config, out);
292            } else {
293                non_boundary_group.push(child);
294            }
295        }
296        // Flush trailing non-boundary nodes
297        if !non_boundary_group.is_empty() {
298            emit_group(&non_boundary_group, config, out);
299        }
300    } else {
301        for child in &named_children {
302            collect_chunks(child, config, out);
303        }
304    }
305}
306
307/// Emit a group of non-boundary sibling nodes. If they fit together, emit as one
308/// chunk; otherwise recurse each individually.
309fn emit_group<D: Doc>(nodes: &[&Node<'_, D>], config: &ChunkConfig, out: &mut Vec<RawChunk>)
310where
311    D::Lang: ast_grep_core::Language,
312{
313    if nodes.is_empty() {
314        return;
315    }
316
317    // Check total size of the group
318    let total_nws: usize = nodes.iter().map(|n| count_non_ws(&n.text())).sum();
319    if total_nws <= config.max_chunk_size {
320        // Emit as a single merged chunk
321        let first = nodes.first().unwrap();
322        let last = nodes.last().unwrap();
323        let text: String = nodes
324            .iter()
325            .map(|n| n.text().to_string())
326            .collect::<Vec<_>>()
327            .join("\n");
328        let first_kind = first.kind();
329        let kind = nodes
330            .iter()
331            .map(|n| n.kind().to_string())
332            .collect::<Vec<_>>()
333            .join(",");
334        let range_start = first.range().start;
335        let range_end = last.range().end;
336        out.push(RawChunk {
337            text,
338            category: classify_node(&first_kind),
339            node_kind: kind,
340            line_start: first.start_pos().line(),
341            line_end: last.end_pos().line(),
342            byte_start: range_start,
343            byte_end: range_end,
344            non_ws_chars: total_nws,
345        });
346    } else {
347        // Too large together — recurse each individually
348        for node in nodes {
349            collect_chunks(node, config, out);
350        }
351    }
352}
353
354/// Returns true if two semantic categories are compatible for merging.
355/// Only merges chunks of the same category, treating Comment as mergeable
356/// with anything (comments often annotate adjacent code).
357fn categories_mergeable(a: SemanticCategory, b: SemanticCategory) -> bool {
358    a == b || a == SemanticCategory::Comment || b == SemanticCategory::Comment
359}
360
361/// Merge adjacent small chunks greedily, respecting semantic categories.
362fn merge_small_chunks(chunks: Vec<RawChunk>, source: &str, config: &ChunkConfig) -> Vec<RawChunk> {
363    if chunks.is_empty() {
364        return Vec::new();
365    }
366
367    let mut result: Vec<RawChunk> = Vec::new();
368
369    for chunk in chunks {
370        if let Some(last) = result.last_mut() {
371            // Only merge if at least one is below min_chunk_size AND categories are compatible
372            if (last.non_ws_chars < config.min_chunk_size
373                || chunk.non_ws_chars < config.min_chunk_size)
374                && categories_mergeable(last.category, chunk.category)
375            {
376                // Compute actual merged non-whitespace count before deciding to merge
377                let merged_start = last.byte_start;
378                let merged_end = chunk.byte_end;
379                let merged_text = if merged_end <= source.len() {
380                    source[merged_start..merged_end].to_string()
381                } else {
382                    format!("{}\n{}", last.text, chunk.text)
383                };
384                let merged_nws = count_non_ws(&merged_text);
385
386                if merged_nws <= config.max_chunk_size {
387                    last.text = merged_text;
388                    // C4: Preserve individual node_kinds as comma-separated
389                    if last.node_kind.contains(&chunk.node_kind) {
390                        // Already contains this kind, no-op
391                    } else {
392                        last.node_kind = format!("{},{}", last.node_kind, chunk.node_kind);
393                    }
394                    last.line_end = chunk.line_end;
395                    last.byte_end = merged_end;
396                    last.non_ws_chars = merged_nws;
397                    // Keep the more specific category (prefer non-Comment)
398                    if last.category == SemanticCategory::Comment {
399                        last.category = chunk.category;
400                    }
401                    continue;
402                }
403            }
404        }
405        result.push(chunk);
406    }
407
408    result
409}
410
411/// C3: Apply overlap between adjacent chunks by prepending trailing lines
412/// from the previous chunk to the current one.
413fn apply_overlap(chunks: Vec<RawChunk>, source: &str, overlap_lines: usize) -> Vec<RawChunk> {
414    if chunks.len() <= 1 || overlap_lines == 0 {
415        return chunks;
416    }
417
418    let source_lines: Vec<&str> = source.lines().collect();
419    let mut result = Vec::with_capacity(chunks.len());
420
421    for (i, mut chunk) in chunks.into_iter().enumerate() {
422        if i > 0 && chunk.line_start > 0 {
423            // Prepend `overlap_lines` lines from before this chunk's start
424            let overlap_start = chunk.line_start.saturating_sub(overlap_lines);
425            if overlap_start < chunk.line_start && overlap_start < source_lines.len() {
426                let end = chunk.line_start.min(source_lines.len());
427                let prefix: String = source_lines[overlap_start..end].join("\n");
428                chunk.text = format!("{}\n{}", prefix, chunk.text);
429                chunk.line_start = overlap_start;
430                chunk.non_ws_chars = count_non_ws(&chunk.text);
431            }
432        }
433        result.push(chunk);
434    }
435
436    result
437}
438
439/// Truncate a signature to at most `max_len` chars, cutting at a word boundary.
440fn truncate_signature(sig: &str, max_len: usize) -> &str {
441    // Take only the first line of multi-line signatures
442    let first_line = sig.lines().next().unwrap_or(sig);
443    if first_line.len() <= max_len {
444        return first_line;
445    }
446    // Find last space before max_len
447    match first_line[..max_len].rfind(' ') {
448        Some(pos) => &first_line[..pos],
449        None => &first_line[..max_len],
450    }
451}
452
453/// Pre-sorted symbol index for O(log n) parent resolution via binary search.
454struct SymbolIntervalIndex<'a> {
455    /// Symbols sorted by (line_start ASC, line_end DESC) — outermost first at each start.
456    sorted: Vec<&'a Symbol>,
457}
458
459impl<'a> SymbolIntervalIndex<'a> {
460    fn build(symbols: &'a [Symbol]) -> Self {
461        let mut sorted: Vec<&Symbol> = symbols.iter().collect();
462        sorted.sort_by(|a, b| {
463            a.line_start
464                .cmp(&b.line_start)
465                .then_with(|| b.line_end.cmp(&a.line_end))
466        });
467        Self { sorted }
468    }
469
470    /// Find the innermost symbol containing [line_start, line_end].
471    /// Uses binary search to find candidates starting at or before line_start,
472    /// then scans forward for the tightest containment.
473    fn resolve(&self, line_start: usize, line_end: usize) -> Option<&'a Symbol> {
474        if self.sorted.is_empty() {
475            return None;
476        }
477
478        // Binary search: find the rightmost symbol whose line_start <= line_start
479        let idx = match self
480            .sorted
481            .binary_search_by(|s| s.line_start.cmp(&line_start))
482        {
483            Ok(i) => i,
484            Err(i) => {
485                if i == 0 {
486                    return None;
487                }
488                i - 1
489            }
490        };
491
492        let mut best: Option<&Symbol> = None;
493        let mut best_span = usize::MAX;
494
495        // Scan backwards from idx (all candidates have line_start <= line_start)
496        for &sym in self.sorted[..=idx].iter().rev() {
497            if sym.line_start > line_start {
498                continue;
499            }
500            // Once we pass symbols that start too early and are too short, stop
501            if best.is_some() && sym.line_end < line_end {
502                // Symbols are sorted with largest span first at each start position,
503                // so once we see one that doesn't contain us and we already have
504                // a best, earlier symbols with the same start won't either.
505                // But symbols with smaller line_start may still contain us.
506                continue;
507            }
508            if sym.line_end >= line_end {
509                let span = sym.line_end - sym.line_start;
510                if span < best_span {
511                    best_span = span;
512                    best = Some(sym);
513                }
514            }
515        }
516
517        best
518    }
519}
520
521#[cfg(test)]
522#[path = "tests/chunker_tests.rs"]
523mod tests;