use crate::index::symbol::Symbol;
use ast_grep_core::{Doc, Node};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeChunk {
pub index: usize,
pub text: String,
pub node_kind: String,
pub line_start: usize,
pub line_end: usize,
pub byte_start: usize,
pub byte_end: usize,
pub non_ws_chars: usize,
pub parent_symbol: Option<String>,
pub file_path: String,
}
#[derive(Debug, Clone)]
pub struct ChunkConfig {
pub max_chunk_size: usize,
pub min_chunk_size: usize,
pub overlap_lines: usize,
}
impl Default for ChunkConfig {
fn default() -> Self {
Self {
max_chunk_size: 1500,
min_chunk_size: 50,
overlap_lines: 0,
}
}
}
fn count_non_ws(s: &str) -> usize {
s.chars().filter(|c| !c.is_whitespace()).count()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SemanticCategory {
Import,
Declaration,
Comment,
Other,
}
fn classify_node(kind: &str) -> SemanticCategory {
match kind {
k if k.contains("import")
|| k == "use_declaration"
|| k == "use_item"
|| k == "extern_crate_declaration"
|| k == "include_directive"
|| k == "using_declaration"
|| k == "package_declaration" =>
{
SemanticCategory::Import
}
k if k.contains("comment")
|| k == "line_comment"
|| k == "block_comment"
|| k == "doc_comment" =>
{
SemanticCategory::Comment
}
k if k.contains("function")
|| k.contains("method")
|| k.contains("class")
|| k.contains("struct")
|| k.contains("enum")
|| k.contains("interface")
|| k.contains("trait")
|| k.contains("impl")
|| k == "const_item"
|| k == "static_item"
|| k == "type_alias"
|| k == "type_item"
|| k == "mod_item"
|| k == "module"
|| k == "lexical_declaration"
|| k == "variable_declaration"
|| k == "export_statement" =>
{
SemanticCategory::Declaration
}
_ => SemanticCategory::Other,
}
}
fn is_semantic_boundary(kind: &str) -> bool {
matches!(classify_node(kind), SemanticCategory::Declaration)
}
struct RawChunk {
text: String,
node_kind: String,
line_start: usize,
line_end: usize,
byte_start: usize,
byte_end: usize,
non_ws_chars: usize,
category: SemanticCategory,
}
pub fn chunk_file<D: Doc>(
root: &ast_grep_core::AstGrep<D>,
source: &str,
file_path: &str,
symbols: &[Symbol],
config: &ChunkConfig,
) -> Vec<CodeChunk>
where
D::Lang: ast_grep_core::Language,
{
if source.trim().is_empty() {
return Vec::new();
}
let root_node = root.root();
let mut raw_chunks = Vec::new();
collect_chunks(&root_node, config, &mut raw_chunks);
let merged = merge_small_chunks(raw_chunks, source, config);
let merged = if config.overlap_lines > 0 {
apply_overlap(merged, source, config.overlap_lines)
} else {
merged
};
let interval_index = SymbolIntervalIndex::build(symbols);
merged
.into_iter()
.enumerate()
.map(|(idx, raw)| {
let parent = interval_index.resolve(raw.line_start, raw.line_end);
let parent_symbol = parent.map(|s| s.qualified_name.clone());
let text = if let Some(sym) = parent {
if raw.line_start > sym.line_start && !sym.signature.is_empty() {
let sig = truncate_signature(&sym.signature, 120);
format!("[context: {sig}]\n{}", raw.text)
} else {
raw.text
}
} else {
raw.text
};
CodeChunk {
index: idx,
non_ws_chars: count_non_ws(&text),
text,
node_kind: raw.node_kind,
line_start: raw.line_start,
line_end: raw.line_end,
byte_start: raw.byte_start,
byte_end: raw.byte_end,
parent_symbol,
file_path: file_path.to_string(),
}
})
.collect()
}
fn collect_chunks<D: Doc>(node: &Node<'_, D>, config: &ChunkConfig, out: &mut Vec<RawChunk>)
where
D::Lang: ast_grep_core::Language,
{
let text = node.text();
let nws = count_non_ws(&text);
let kind = node.kind().to_string();
if nws <= config.max_chunk_size {
let range = node.range();
out.push(RawChunk {
text: text.to_string(),
category: classify_node(&kind),
node_kind: kind,
line_start: node.start_pos().line(),
line_end: node.end_pos().line(),
byte_start: range.start,
byte_end: range.end,
non_ws_chars: nws,
});
return;
}
let named_children: Vec<_> = node.children().filter(|c| c.is_named()).collect();
if named_children.is_empty() {
let range = node.range();
out.push(RawChunk {
text: text.to_string(),
category: classify_node(&kind),
node_kind: kind,
line_start: node.start_pos().line(),
line_end: node.end_pos().line(),
byte_start: range.start,
byte_end: range.end,
non_ws_chars: nws,
});
return;
}
let has_boundaries = named_children
.iter()
.any(|c| is_semantic_boundary(&c.kind()));
if has_boundaries {
let mut non_boundary_group: Vec<&Node<'_, D>> = Vec::new();
for child in &named_children {
if is_semantic_boundary(&child.kind()) {
if !non_boundary_group.is_empty() {
emit_group(&non_boundary_group, config, out);
non_boundary_group.clear();
}
collect_chunks(child, config, out);
} else {
non_boundary_group.push(child);
}
}
if !non_boundary_group.is_empty() {
emit_group(&non_boundary_group, config, out);
}
} else {
for child in &named_children {
collect_chunks(child, config, out);
}
}
}
fn emit_group<D: Doc>(nodes: &[&Node<'_, D>], config: &ChunkConfig, out: &mut Vec<RawChunk>)
where
D::Lang: ast_grep_core::Language,
{
if nodes.is_empty() {
return;
}
let total_nws: usize = nodes.iter().map(|n| count_non_ws(&n.text())).sum();
if total_nws <= config.max_chunk_size {
let first = nodes.first().unwrap();
let last = nodes.last().unwrap();
let text: String = nodes
.iter()
.map(|n| n.text().to_string())
.collect::<Vec<_>>()
.join("\n");
let first_kind = first.kind();
let kind = nodes
.iter()
.map(|n| n.kind().to_string())
.collect::<Vec<_>>()
.join(",");
let range_start = first.range().start;
let range_end = last.range().end;
out.push(RawChunk {
text,
category: classify_node(&first_kind),
node_kind: kind,
line_start: first.start_pos().line(),
line_end: last.end_pos().line(),
byte_start: range_start,
byte_end: range_end,
non_ws_chars: total_nws,
});
} else {
for node in nodes {
collect_chunks(node, config, out);
}
}
}
fn categories_mergeable(a: SemanticCategory, b: SemanticCategory) -> bool {
a == b || a == SemanticCategory::Comment || b == SemanticCategory::Comment
}
fn merge_small_chunks(chunks: Vec<RawChunk>, source: &str, config: &ChunkConfig) -> Vec<RawChunk> {
if chunks.is_empty() {
return Vec::new();
}
let mut result: Vec<RawChunk> = Vec::new();
for chunk in chunks {
if let Some(last) = result.last_mut() {
if (last.non_ws_chars < config.min_chunk_size
|| chunk.non_ws_chars < config.min_chunk_size)
&& categories_mergeable(last.category, chunk.category)
{
let merged_start = last.byte_start;
let merged_end = chunk.byte_end;
let merged_text = if merged_end <= source.len() {
source[merged_start..merged_end].to_string()
} else {
format!("{}\n{}", last.text, chunk.text)
};
let merged_nws = count_non_ws(&merged_text);
if merged_nws <= config.max_chunk_size {
last.text = merged_text;
if last.node_kind.contains(&chunk.node_kind) {
} else {
last.node_kind = format!("{},{}", last.node_kind, chunk.node_kind);
}
last.line_end = chunk.line_end;
last.byte_end = merged_end;
last.non_ws_chars = merged_nws;
if last.category == SemanticCategory::Comment {
last.category = chunk.category;
}
continue;
}
}
}
result.push(chunk);
}
result
}
fn apply_overlap(chunks: Vec<RawChunk>, source: &str, overlap_lines: usize) -> Vec<RawChunk> {
if chunks.len() <= 1 || overlap_lines == 0 {
return chunks;
}
let source_lines: Vec<&str> = source.lines().collect();
let mut result = Vec::with_capacity(chunks.len());
for (i, mut chunk) in chunks.into_iter().enumerate() {
if i > 0 && chunk.line_start > 0 {
let overlap_start = chunk.line_start.saturating_sub(overlap_lines);
if overlap_start < chunk.line_start && overlap_start < source_lines.len() {
let end = chunk.line_start.min(source_lines.len());
let prefix: String = source_lines[overlap_start..end].join("\n");
chunk.text = format!("{}\n{}", prefix, chunk.text);
chunk.line_start = overlap_start;
chunk.non_ws_chars = count_non_ws(&chunk.text);
}
}
result.push(chunk);
}
result
}
fn truncate_signature(sig: &str, max_len: usize) -> &str {
let first_line = sig.lines().next().unwrap_or(sig);
if first_line.len() <= max_len {
return first_line;
}
match first_line[..max_len].rfind(' ') {
Some(pos) => &first_line[..pos],
None => &first_line[..max_len],
}
}
struct SymbolIntervalIndex<'a> {
sorted: Vec<&'a Symbol>,
}
impl<'a> SymbolIntervalIndex<'a> {
fn build(symbols: &'a [Symbol]) -> Self {
let mut sorted: Vec<&Symbol> = symbols.iter().collect();
sorted.sort_by(|a, b| {
a.line_start
.cmp(&b.line_start)
.then_with(|| b.line_end.cmp(&a.line_end))
});
Self { sorted }
}
fn resolve(&self, line_start: usize, line_end: usize) -> Option<&'a Symbol> {
if self.sorted.is_empty() {
return None;
}
let idx = match self
.sorted
.binary_search_by(|s| s.line_start.cmp(&line_start))
{
Ok(i) => i,
Err(i) => {
if i == 0 {
return None;
}
i - 1
}
};
let mut best: Option<&Symbol> = None;
let mut best_span = usize::MAX;
for &sym in self.sorted[..=idx].iter().rev() {
if sym.line_start > line_start {
continue;
}
if best.is_some() && sym.line_end < line_end {
continue;
}
if sym.line_end >= line_end {
let span = sym.line_end - sym.line_start;
if span < best_span {
best_span = span;
best = Some(sym);
}
}
}
best
}
}
#[cfg(test)]
#[path = "tests/chunker_tests.rs"]
mod tests;