1use crate::index::symbol::Symbol;
17use ast_grep_core::{Doc, Node};
18use serde::{Deserialize, Serialize};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct CodeChunk {
23 pub index: usize,
25 pub text: String,
27 pub node_kind: String,
29 pub line_start: usize,
31 pub line_end: usize,
33 pub byte_start: usize,
35 pub byte_end: usize,
37 pub non_ws_chars: usize,
39 pub parent_symbol: Option<String>,
41 pub file_path: String,
43}
44
45#[derive(Debug, Clone)]
47pub struct ChunkConfig {
48 pub max_chunk_size: usize,
50 pub min_chunk_size: usize,
52 pub overlap_lines: usize,
54}
55
56impl Default for ChunkConfig {
57 fn default() -> Self {
58 Self {
59 max_chunk_size: 1500,
60 min_chunk_size: 50,
61 overlap_lines: 0,
62 }
63 }
64}
65
66fn count_non_ws(s: &str) -> usize {
68 s.chars().filter(|c| !c.is_whitespace()).count()
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73enum SemanticCategory {
74 Import,
76 Declaration,
78 Comment,
80 Other,
82}
83
84fn classify_node(kind: &str) -> SemanticCategory {
86 match kind {
87 k if k.contains("import")
89 || k == "use_declaration"
90 || k == "use_item"
91 || k == "extern_crate_declaration"
92 || k == "include_directive"
93 || k == "using_declaration"
94 || k == "package_declaration" =>
95 {
96 SemanticCategory::Import
97 }
98
99 k if k.contains("comment")
101 || k == "line_comment"
102 || k == "block_comment"
103 || k == "doc_comment" =>
104 {
105 SemanticCategory::Comment
106 }
107
108 k if k.contains("function")
110 || k.contains("method")
111 || k.contains("class")
112 || k.contains("struct")
113 || k.contains("enum")
114 || k.contains("interface")
115 || k.contains("trait")
116 || k.contains("impl")
117 || k == "const_item"
118 || k == "static_item"
119 || k == "type_alias"
120 || k == "type_item"
121 || k == "mod_item"
122 || k == "module"
123 || k == "lexical_declaration"
124 || k == "variable_declaration"
125 || k == "export_statement" =>
126 {
127 SemanticCategory::Declaration
128 }
129
130 _ => SemanticCategory::Other,
131 }
132}
133
134fn is_semantic_boundary(kind: &str) -> bool {
137 matches!(classify_node(kind), SemanticCategory::Declaration)
138}
139
140struct RawChunk {
142 text: String,
143 node_kind: String,
144 line_start: usize,
145 line_end: usize,
146 byte_start: usize,
147 byte_end: usize,
148 non_ws_chars: usize,
149 category: SemanticCategory,
151}
152
153pub fn chunk_file<D: Doc>(
161 root: &ast_grep_core::AstGrep<D>,
162 source: &str,
163 file_path: &str,
164 symbols: &[Symbol],
165 config: &ChunkConfig,
166) -> Vec<CodeChunk>
167where
168 D::Lang: ast_grep_core::Language,
169{
170 if source.trim().is_empty() {
171 return Vec::new();
172 }
173
174 let root_node = root.root();
175 let mut raw_chunks = Vec::new();
176 collect_chunks(&root_node, config, &mut raw_chunks);
177
178 let merged = merge_small_chunks(raw_chunks, source, config);
180
181 let merged = if config.overlap_lines > 0 {
183 apply_overlap(merged, source, config.overlap_lines)
184 } else {
185 merged
186 };
187
188 let interval_index = SymbolIntervalIndex::build(symbols);
190
191 merged
193 .into_iter()
194 .enumerate()
195 .map(|(idx, raw)| {
196 let parent = interval_index.resolve(raw.line_start, raw.line_end);
197 let parent_symbol = parent.map(|s| s.qualified_name.clone());
198
199 let text = if let Some(sym) = parent {
203 if raw.line_start > sym.line_start && !sym.signature.is_empty() {
204 let sig = truncate_signature(&sym.signature, 120);
205 format!("[context: {sig}]\n{}", raw.text)
206 } else {
207 raw.text
208 }
209 } else {
210 raw.text
211 };
212
213 CodeChunk {
214 index: idx,
215 non_ws_chars: count_non_ws(&text),
216 text,
217 node_kind: raw.node_kind,
218 line_start: raw.line_start,
219 line_end: raw.line_end,
220 byte_start: raw.byte_start,
221 byte_end: raw.byte_end,
222 parent_symbol,
223 file_path: file_path.to_string(),
224 }
225 })
226 .collect()
227}
228
229fn collect_chunks<D: Doc>(node: &Node<'_, D>, config: &ChunkConfig, out: &mut Vec<RawChunk>)
231where
232 D::Lang: ast_grep_core::Language,
233{
234 let text = node.text();
235 let nws = count_non_ws(&text);
236 let kind = node.kind().to_string();
237
238 if nws <= config.max_chunk_size {
240 let range = node.range();
241 out.push(RawChunk {
242 text: text.to_string(),
243 category: classify_node(&kind),
244 node_kind: kind,
245 line_start: node.start_pos().line(),
246 line_end: node.end_pos().line(),
247 byte_start: range.start,
248 byte_end: range.end,
249 non_ws_chars: nws,
250 });
251 return;
252 }
253
254 let named_children: Vec<_> = node.children().filter(|c| c.is_named()).collect();
256 if named_children.is_empty() {
257 let range = node.range();
259 out.push(RawChunk {
260 text: text.to_string(),
261 category: classify_node(&kind),
262 node_kind: kind,
263 line_start: node.start_pos().line(),
264 line_end: node.end_pos().line(),
265 byte_start: range.start,
266 byte_end: range.end,
267 non_ws_chars: nws,
268 });
269 return;
270 }
271
272 let has_boundaries = named_children
276 .iter()
277 .any(|c| is_semantic_boundary(&c.kind()));
278
279 if has_boundaries {
280 let mut non_boundary_group: Vec<&Node<'_, D>> = Vec::new();
283 for child in &named_children {
284 if is_semantic_boundary(&child.kind()) {
285 if !non_boundary_group.is_empty() {
287 emit_group(&non_boundary_group, config, out);
288 non_boundary_group.clear();
289 }
290 collect_chunks(child, config, out);
292 } else {
293 non_boundary_group.push(child);
294 }
295 }
296 if !non_boundary_group.is_empty() {
298 emit_group(&non_boundary_group, config, out);
299 }
300 } else {
301 for child in &named_children {
302 collect_chunks(child, config, out);
303 }
304 }
305}
306
307fn emit_group<D: Doc>(nodes: &[&Node<'_, D>], config: &ChunkConfig, out: &mut Vec<RawChunk>)
310where
311 D::Lang: ast_grep_core::Language,
312{
313 if nodes.is_empty() {
314 return;
315 }
316
317 let total_nws: usize = nodes.iter().map(|n| count_non_ws(&n.text())).sum();
319 if total_nws <= config.max_chunk_size {
320 let first = nodes.first().unwrap();
322 let last = nodes.last().unwrap();
323 let text: String = nodes
324 .iter()
325 .map(|n| n.text().to_string())
326 .collect::<Vec<_>>()
327 .join("\n");
328 let first_kind = first.kind();
329 let kind = nodes
330 .iter()
331 .map(|n| n.kind().to_string())
332 .collect::<Vec<_>>()
333 .join(",");
334 let range_start = first.range().start;
335 let range_end = last.range().end;
336 out.push(RawChunk {
337 text,
338 category: classify_node(&first_kind),
339 node_kind: kind,
340 line_start: first.start_pos().line(),
341 line_end: last.end_pos().line(),
342 byte_start: range_start,
343 byte_end: range_end,
344 non_ws_chars: total_nws,
345 });
346 } else {
347 for node in nodes {
349 collect_chunks(node, config, out);
350 }
351 }
352}
353
354fn categories_mergeable(a: SemanticCategory, b: SemanticCategory) -> bool {
358 a == b || a == SemanticCategory::Comment || b == SemanticCategory::Comment
359}
360
361fn merge_small_chunks(chunks: Vec<RawChunk>, source: &str, config: &ChunkConfig) -> Vec<RawChunk> {
363 if chunks.is_empty() {
364 return Vec::new();
365 }
366
367 let mut result: Vec<RawChunk> = Vec::new();
368
369 for chunk in chunks {
370 if let Some(last) = result.last_mut() {
371 if (last.non_ws_chars < config.min_chunk_size
373 || chunk.non_ws_chars < config.min_chunk_size)
374 && categories_mergeable(last.category, chunk.category)
375 {
376 let merged_start = last.byte_start;
378 let merged_end = chunk.byte_end;
379 let merged_text = if merged_end <= source.len() {
380 source[merged_start..merged_end].to_string()
381 } else {
382 format!("{}\n{}", last.text, chunk.text)
383 };
384 let merged_nws = count_non_ws(&merged_text);
385
386 if merged_nws <= config.max_chunk_size {
387 last.text = merged_text;
388 if last.node_kind.contains(&chunk.node_kind) {
390 } else {
392 last.node_kind = format!("{},{}", last.node_kind, chunk.node_kind);
393 }
394 last.line_end = chunk.line_end;
395 last.byte_end = merged_end;
396 last.non_ws_chars = merged_nws;
397 if last.category == SemanticCategory::Comment {
399 last.category = chunk.category;
400 }
401 continue;
402 }
403 }
404 }
405 result.push(chunk);
406 }
407
408 result
409}
410
411fn apply_overlap(chunks: Vec<RawChunk>, source: &str, overlap_lines: usize) -> Vec<RawChunk> {
414 if chunks.len() <= 1 || overlap_lines == 0 {
415 return chunks;
416 }
417
418 let source_lines: Vec<&str> = source.lines().collect();
419 let mut result = Vec::with_capacity(chunks.len());
420
421 for (i, mut chunk) in chunks.into_iter().enumerate() {
422 if i > 0 && chunk.line_start > 0 {
423 let overlap_start = chunk.line_start.saturating_sub(overlap_lines);
425 if overlap_start < chunk.line_start && overlap_start < source_lines.len() {
426 let end = chunk.line_start.min(source_lines.len());
427 let prefix: String = source_lines[overlap_start..end].join("\n");
428 chunk.text = format!("{}\n{}", prefix, chunk.text);
429 chunk.line_start = overlap_start;
430 chunk.non_ws_chars = count_non_ws(&chunk.text);
431 }
432 }
433 result.push(chunk);
434 }
435
436 result
437}
438
439fn truncate_signature(sig: &str, max_len: usize) -> &str {
441 let first_line = sig.lines().next().unwrap_or(sig);
443 if first_line.len() <= max_len {
444 return first_line;
445 }
446 match first_line[..max_len].rfind(' ') {
448 Some(pos) => &first_line[..pos],
449 None => &first_line[..max_len],
450 }
451}
452
453struct SymbolIntervalIndex<'a> {
455 sorted: Vec<&'a Symbol>,
457}
458
459impl<'a> SymbolIntervalIndex<'a> {
460 fn build(symbols: &'a [Symbol]) -> Self {
461 let mut sorted: Vec<&Symbol> = symbols.iter().collect();
462 sorted.sort_by(|a, b| {
463 a.line_start
464 .cmp(&b.line_start)
465 .then_with(|| b.line_end.cmp(&a.line_end))
466 });
467 Self { sorted }
468 }
469
470 fn resolve(&self, line_start: usize, line_end: usize) -> Option<&'a Symbol> {
474 if self.sorted.is_empty() {
475 return None;
476 }
477
478 let idx = match self
480 .sorted
481 .binary_search_by(|s| s.line_start.cmp(&line_start))
482 {
483 Ok(i) => i,
484 Err(i) => {
485 if i == 0 {
486 return None;
487 }
488 i - 1
489 }
490 };
491
492 let mut best: Option<&Symbol> = None;
493 let mut best_span = usize::MAX;
494
495 for &sym in self.sorted[..=idx].iter().rev() {
497 if sym.line_start > line_start {
498 continue;
499 }
500 if best.is_some() && sym.line_end < line_end {
502 continue;
507 }
508 if sym.line_end >= line_end {
509 let span = sym.line_end - sym.line_start;
510 if span < best_span {
511 best_span = span;
512 best = Some(sym);
513 }
514 }
515 }
516
517 best
518 }
519}
520
521#[cfg(test)]
522#[path = "tests/chunker_tests.rs"]
523mod tests;