Skip to main content

codemem_engine/index/
chunker.rs

1//! CST-aware code chunking.
2//!
3//! Splits source files into semantically meaningful chunks using the
4//! concrete syntax tree (CST) produced by ast-grep/tree-sitter. The algorithm:
5//!
6//! 1. If a CST node fits within `max_chunk_size` (non-whitespace chars) -> emit it as a chunk.
7//! 2. If too large -> recurse into named children.
8//! 3. Adjacent small siblings are merged greedily until the merged size would exceed `max_chunk_size`.
9//!
10//! Each chunk records its parent symbol (resolved by line-range containment).
11
12use crate::index::symbol::Symbol;
13use ast_grep_core::{Doc, Node};
14use serde::{Deserialize, Serialize};
15
16/// A code chunk produced by the CST-aware chunker.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct CodeChunk {
19    /// 0-based index of this chunk within the file.
20    pub index: usize,
21    /// The source text of this chunk.
22    pub text: String,
23    /// The tree-sitter node kind (e.g., "function_item", "impl_item").
24    pub node_kind: String,
25    /// 0-based starting line.
26    pub line_start: usize,
27    /// 0-based ending line.
28    pub line_end: usize,
29    /// Byte offset start.
30    pub byte_start: usize,
31    /// Byte offset end.
32    pub byte_end: usize,
33    /// Count of non-whitespace characters.
34    pub non_ws_chars: usize,
35    /// Qualified name of the innermost containing symbol, if any.
36    pub parent_symbol: Option<String>,
37    /// Path of the source file.
38    pub file_path: String,
39}
40
41/// Configuration for the chunker.
42#[derive(Debug, Clone)]
43pub struct ChunkConfig {
44    /// Maximum chunk size in non-whitespace characters.
45    pub max_chunk_size: usize,
46    /// Minimum chunk size in non-whitespace characters.
47    pub min_chunk_size: usize,
48    /// Number of lines to overlap between adjacent chunks (0 = no overlap).
49    pub overlap_lines: usize,
50}
51
52impl Default for ChunkConfig {
53    fn default() -> Self {
54        Self {
55            max_chunk_size: 1500,
56            min_chunk_size: 50,
57            overlap_lines: 0,
58        }
59    }
60}
61
62/// Count non-whitespace characters in a string.
63fn count_non_ws(s: &str) -> usize {
64    s.chars().filter(|c| !c.is_whitespace()).count()
65}
66
67/// Intermediate chunk before index assignment and parent resolution.
68struct RawChunk {
69    text: String,
70    node_kind: String,
71    line_start: usize,
72    line_end: usize,
73    byte_start: usize,
74    byte_end: usize,
75    non_ws_chars: usize,
76}
77
78/// Chunk a file using its CST tree.
79///
80/// - `root`: the ast-grep parsed root (AstGrep instance).
81/// - `source`: the raw source string.
82/// - `file_path`: path to the source file (stored in each chunk).
83/// - `symbols`: symbols extracted from this file (used for parent resolution).
84/// - `config`: chunking parameters.
85pub fn chunk_file<D: Doc>(
86    root: &ast_grep_core::AstGrep<D>,
87    source: &str,
88    file_path: &str,
89    symbols: &[Symbol],
90    config: &ChunkConfig,
91) -> Vec<CodeChunk>
92where
93    D::Lang: ast_grep_core::Language,
94{
95    if source.trim().is_empty() {
96        return Vec::new();
97    }
98
99    let root_node = root.root();
100    let mut raw_chunks = Vec::new();
101    collect_chunks(&root_node, config, &mut raw_chunks);
102
103    // Merge adjacent small chunks greedily
104    let merged = merge_small_chunks(raw_chunks, source, config);
105
106    // C3: Apply overlap — prepend trailing lines from previous chunk
107    let merged = if config.overlap_lines > 0 {
108        apply_overlap(merged, source, config.overlap_lines)
109    } else {
110        merged
111    };
112
113    // C2: Build interval index once for O(log n) parent resolution per chunk
114    let interval_index = SymbolIntervalIndex::build(symbols);
115
116    // Assign indices and resolve parent symbols
117    merged
118        .into_iter()
119        .enumerate()
120        .map(|(idx, raw)| {
121            let parent_symbol = interval_index
122                .resolve(raw.line_start, raw.line_end)
123                .map(|s| s.qualified_name.clone());
124            CodeChunk {
125                index: idx,
126                text: raw.text,
127                node_kind: raw.node_kind,
128                line_start: raw.line_start,
129                line_end: raw.line_end,
130                byte_start: raw.byte_start,
131                byte_end: raw.byte_end,
132                non_ws_chars: raw.non_ws_chars,
133                parent_symbol,
134                file_path: file_path.to_string(),
135            }
136        })
137        .collect()
138}
139
140/// Recursively collect chunks from a CST node.
141fn collect_chunks<D: Doc>(node: &Node<'_, D>, config: &ChunkConfig, out: &mut Vec<RawChunk>)
142where
143    D::Lang: ast_grep_core::Language,
144{
145    let text = node.text();
146    let nws = count_non_ws(&text);
147
148    // If the node fits, emit it as a single chunk
149    if nws <= config.max_chunk_size {
150        let range = node.range();
151        out.push(RawChunk {
152            text: text.to_string(),
153            node_kind: node.kind().to_string(),
154            line_start: node.start_pos().line(),
155            line_end: node.end_pos().line(),
156            byte_start: range.start,
157            byte_end: range.end,
158            non_ws_chars: nws,
159        });
160        return;
161    }
162
163    // Too large: recurse into named children
164    let named_children: Vec<_> = node.children().filter(|c| c.is_named()).collect();
165    if named_children.is_empty() {
166        // No named children (e.g., a very large string literal), emit as-is
167        let range = node.range();
168        out.push(RawChunk {
169            text: text.to_string(),
170            node_kind: node.kind().to_string(),
171            line_start: node.start_pos().line(),
172            line_end: node.end_pos().line(),
173            byte_start: range.start,
174            byte_end: range.end,
175            non_ws_chars: nws,
176        });
177    } else {
178        for child in &named_children {
179            collect_chunks(child, config, out);
180        }
181    }
182}
183
184/// Merge adjacent small chunks greedily.
185fn merge_small_chunks(chunks: Vec<RawChunk>, source: &str, config: &ChunkConfig) -> Vec<RawChunk> {
186    if chunks.is_empty() {
187        return Vec::new();
188    }
189
190    let mut result: Vec<RawChunk> = Vec::new();
191
192    for chunk in chunks {
193        if let Some(last) = result.last_mut() {
194            // If both the current accumulator and new chunk are small, try merging
195            if last.non_ws_chars < config.min_chunk_size
196                || chunk.non_ws_chars < config.min_chunk_size
197            {
198                // Compute actual merged non-whitespace count before deciding to merge
199                let merged_start = last.byte_start;
200                let merged_end = chunk.byte_end;
201                let merged_text = if merged_end <= source.len() {
202                    source[merged_start..merged_end].to_string()
203                } else {
204                    format!("{}\n{}", last.text, chunk.text)
205                };
206                let merged_nws = count_non_ws(&merged_text);
207
208                if merged_nws <= config.max_chunk_size {
209                    last.text = merged_text;
210                    // C4: Preserve individual node_kinds as comma-separated
211                    if last.node_kind.contains(&chunk.node_kind) {
212                        // Already contains this kind, no-op
213                    } else {
214                        last.node_kind = format!("{},{}", last.node_kind, chunk.node_kind);
215                    }
216                    last.line_end = chunk.line_end;
217                    last.byte_end = merged_end;
218                    last.non_ws_chars = merged_nws;
219                    continue;
220                }
221            }
222        }
223        result.push(chunk);
224    }
225
226    result
227}
228
229/// C3: Apply overlap between adjacent chunks by prepending trailing lines
230/// from the previous chunk to the current one.
231fn apply_overlap(chunks: Vec<RawChunk>, source: &str, overlap_lines: usize) -> Vec<RawChunk> {
232    if chunks.len() <= 1 || overlap_lines == 0 {
233        return chunks;
234    }
235
236    let source_lines: Vec<&str> = source.lines().collect();
237    let mut result = Vec::with_capacity(chunks.len());
238
239    for (i, mut chunk) in chunks.into_iter().enumerate() {
240        if i > 0 && chunk.line_start > 0 {
241            // Prepend `overlap_lines` lines from before this chunk's start
242            let overlap_start = chunk.line_start.saturating_sub(overlap_lines);
243            if overlap_start < chunk.line_start && overlap_start < source_lines.len() {
244                let end = chunk.line_start.min(source_lines.len());
245                let prefix: String = source_lines[overlap_start..end].join("\n");
246                chunk.text = format!("{}\n{}", prefix, chunk.text);
247                chunk.line_start = overlap_start;
248                chunk.non_ws_chars = count_non_ws(&chunk.text);
249            }
250        }
251        result.push(chunk);
252    }
253
254    result
255}
256
257/// Pre-sorted symbol index for O(log n) parent resolution via binary search.
258struct SymbolIntervalIndex<'a> {
259    /// Symbols sorted by (line_start ASC, line_end DESC) — outermost first at each start.
260    sorted: Vec<&'a Symbol>,
261}
262
263impl<'a> SymbolIntervalIndex<'a> {
264    fn build(symbols: &'a [Symbol]) -> Self {
265        let mut sorted: Vec<&Symbol> = symbols.iter().collect();
266        sorted.sort_by(|a, b| {
267            a.line_start
268                .cmp(&b.line_start)
269                .then_with(|| b.line_end.cmp(&a.line_end))
270        });
271        Self { sorted }
272    }
273
274    /// Find the innermost symbol containing [line_start, line_end].
275    /// Uses binary search to find candidates starting at or before line_start,
276    /// then scans forward for the tightest containment.
277    fn resolve(&self, line_start: usize, line_end: usize) -> Option<&'a Symbol> {
278        if self.sorted.is_empty() {
279            return None;
280        }
281
282        // Binary search: find the rightmost symbol whose line_start <= line_start
283        let idx = match self
284            .sorted
285            .binary_search_by(|s| s.line_start.cmp(&line_start))
286        {
287            Ok(i) => i,
288            Err(i) => {
289                if i == 0 {
290                    return None;
291                }
292                i - 1
293            }
294        };
295
296        let mut best: Option<&Symbol> = None;
297        let mut best_span = usize::MAX;
298
299        // Scan backwards from idx (all candidates have line_start <= line_start)
300        for &sym in self.sorted[..=idx].iter().rev() {
301            if sym.line_start > line_start {
302                continue;
303            }
304            // Once we pass symbols that start too early and are too short, stop
305            if best.is_some() && sym.line_end < line_end {
306                // Symbols are sorted with largest span first at each start position,
307                // so once we see one that doesn't contain us and we already have
308                // a best, earlier symbols with the same start won't either.
309                // But symbols with smaller line_start may still contain us.
310                continue;
311            }
312            if sym.line_end >= line_end {
313                let span = sym.line_end - sym.line_start;
314                if span < best_span {
315                    best_span = span;
316                    best = Some(sym);
317                }
318            }
319        }
320
321        best
322    }
323}
324
325#[cfg(test)]
326#[path = "tests/chunker_tests.rs"]
327mod tests;