1use crate::index::symbol::Symbol;
13use ast_grep_core::{Doc, Node};
14use serde::{Deserialize, Serialize};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct CodeChunk {
19 pub index: usize,
21 pub text: String,
23 pub node_kind: String,
25 pub line_start: usize,
27 pub line_end: usize,
29 pub byte_start: usize,
31 pub byte_end: usize,
33 pub non_ws_chars: usize,
35 pub parent_symbol: Option<String>,
37 pub file_path: String,
39}
40
41#[derive(Debug, Clone)]
43pub struct ChunkConfig {
44 pub max_chunk_size: usize,
46 pub min_chunk_size: usize,
48 pub overlap_lines: usize,
50}
51
52impl Default for ChunkConfig {
53 fn default() -> Self {
54 Self {
55 max_chunk_size: 1500,
56 min_chunk_size: 50,
57 overlap_lines: 0,
58 }
59 }
60}
61
62fn count_non_ws(s: &str) -> usize {
64 s.chars().filter(|c| !c.is_whitespace()).count()
65}
66
67struct RawChunk {
69 text: String,
70 node_kind: String,
71 line_start: usize,
72 line_end: usize,
73 byte_start: usize,
74 byte_end: usize,
75 non_ws_chars: usize,
76}
77
78pub fn chunk_file<D: Doc>(
86 root: &ast_grep_core::AstGrep<D>,
87 source: &str,
88 file_path: &str,
89 symbols: &[Symbol],
90 config: &ChunkConfig,
91) -> Vec<CodeChunk>
92where
93 D::Lang: ast_grep_core::Language,
94{
95 if source.trim().is_empty() {
96 return Vec::new();
97 }
98
99 let root_node = root.root();
100 let mut raw_chunks = Vec::new();
101 collect_chunks(&root_node, config, &mut raw_chunks);
102
103 let merged = merge_small_chunks(raw_chunks, source, config);
105
106 let merged = if config.overlap_lines > 0 {
108 apply_overlap(merged, source, config.overlap_lines)
109 } else {
110 merged
111 };
112
113 let interval_index = SymbolIntervalIndex::build(symbols);
115
116 merged
118 .into_iter()
119 .enumerate()
120 .map(|(idx, raw)| {
121 let parent_symbol = interval_index
122 .resolve(raw.line_start, raw.line_end)
123 .map(|s| s.qualified_name.clone());
124 CodeChunk {
125 index: idx,
126 text: raw.text,
127 node_kind: raw.node_kind,
128 line_start: raw.line_start,
129 line_end: raw.line_end,
130 byte_start: raw.byte_start,
131 byte_end: raw.byte_end,
132 non_ws_chars: raw.non_ws_chars,
133 parent_symbol,
134 file_path: file_path.to_string(),
135 }
136 })
137 .collect()
138}
139
140fn collect_chunks<D: Doc>(node: &Node<'_, D>, config: &ChunkConfig, out: &mut Vec<RawChunk>)
142where
143 D::Lang: ast_grep_core::Language,
144{
145 let text = node.text();
146 let nws = count_non_ws(&text);
147
148 if nws <= config.max_chunk_size {
150 let range = node.range();
151 out.push(RawChunk {
152 text: text.to_string(),
153 node_kind: node.kind().to_string(),
154 line_start: node.start_pos().line(),
155 line_end: node.end_pos().line(),
156 byte_start: range.start,
157 byte_end: range.end,
158 non_ws_chars: nws,
159 });
160 return;
161 }
162
163 let named_children: Vec<_> = node.children().filter(|c| c.is_named()).collect();
165 if named_children.is_empty() {
166 let range = node.range();
168 out.push(RawChunk {
169 text: text.to_string(),
170 node_kind: node.kind().to_string(),
171 line_start: node.start_pos().line(),
172 line_end: node.end_pos().line(),
173 byte_start: range.start,
174 byte_end: range.end,
175 non_ws_chars: nws,
176 });
177 } else {
178 for child in &named_children {
179 collect_chunks(child, config, out);
180 }
181 }
182}
183
184fn merge_small_chunks(chunks: Vec<RawChunk>, source: &str, config: &ChunkConfig) -> Vec<RawChunk> {
186 if chunks.is_empty() {
187 return Vec::new();
188 }
189
190 let mut result: Vec<RawChunk> = Vec::new();
191
192 for chunk in chunks {
193 if let Some(last) = result.last_mut() {
194 if last.non_ws_chars < config.min_chunk_size
196 || chunk.non_ws_chars < config.min_chunk_size
197 {
198 let merged_start = last.byte_start;
200 let merged_end = chunk.byte_end;
201 let merged_text = if merged_end <= source.len() {
202 source[merged_start..merged_end].to_string()
203 } else {
204 format!("{}\n{}", last.text, chunk.text)
205 };
206 let merged_nws = count_non_ws(&merged_text);
207
208 if merged_nws <= config.max_chunk_size {
209 last.text = merged_text;
210 if last.node_kind.contains(&chunk.node_kind) {
212 } else {
214 last.node_kind = format!("{},{}", last.node_kind, chunk.node_kind);
215 }
216 last.line_end = chunk.line_end;
217 last.byte_end = merged_end;
218 last.non_ws_chars = merged_nws;
219 continue;
220 }
221 }
222 }
223 result.push(chunk);
224 }
225
226 result
227}
228
229fn apply_overlap(chunks: Vec<RawChunk>, source: &str, overlap_lines: usize) -> Vec<RawChunk> {
232 if chunks.len() <= 1 || overlap_lines == 0 {
233 return chunks;
234 }
235
236 let source_lines: Vec<&str> = source.lines().collect();
237 let mut result = Vec::with_capacity(chunks.len());
238
239 for (i, mut chunk) in chunks.into_iter().enumerate() {
240 if i > 0 && chunk.line_start > 0 {
241 let overlap_start = chunk.line_start.saturating_sub(overlap_lines);
243 if overlap_start < chunk.line_start && overlap_start < source_lines.len() {
244 let end = chunk.line_start.min(source_lines.len());
245 let prefix: String = source_lines[overlap_start..end].join("\n");
246 chunk.text = format!("{}\n{}", prefix, chunk.text);
247 chunk.line_start = overlap_start;
248 chunk.non_ws_chars = count_non_ws(&chunk.text);
249 }
250 }
251 result.push(chunk);
252 }
253
254 result
255}
256
257struct SymbolIntervalIndex<'a> {
259 sorted: Vec<&'a Symbol>,
261}
262
263impl<'a> SymbolIntervalIndex<'a> {
264 fn build(symbols: &'a [Symbol]) -> Self {
265 let mut sorted: Vec<&Symbol> = symbols.iter().collect();
266 sorted.sort_by(|a, b| {
267 a.line_start
268 .cmp(&b.line_start)
269 .then_with(|| b.line_end.cmp(&a.line_end))
270 });
271 Self { sorted }
272 }
273
274 fn resolve(&self, line_start: usize, line_end: usize) -> Option<&'a Symbol> {
278 if self.sorted.is_empty() {
279 return None;
280 }
281
282 let idx = match self
284 .sorted
285 .binary_search_by(|s| s.line_start.cmp(&line_start))
286 {
287 Ok(i) => i,
288 Err(i) => {
289 if i == 0 {
290 return None;
291 }
292 i - 1
293 }
294 };
295
296 let mut best: Option<&Symbol> = None;
297 let mut best_span = usize::MAX;
298
299 for &sym in self.sorted[..=idx].iter().rev() {
301 if sym.line_start > line_start {
302 continue;
303 }
304 if best.is_some() && sym.line_end < line_end {
306 continue;
311 }
312 if sym.line_end >= line_end {
313 let span = sym.line_end - sym.line_start;
314 if span < best_span {
315 best_span = span;
316 best = Some(sym);
317 }
318 }
319 }
320
321 best
322 }
323}
324
325#[cfg(test)]
326#[path = "tests/chunker_tests.rs"]
327mod tests;