use tree_sitter::{Language, Node, Parser, TreeCursor};
pub const DEFAULT_DESIRED_CHUNK_CHARS: usize = 1500;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ChunkBoundary {
pub start_byte: usize,
pub end_byte: usize,
pub start_line: usize,
pub end_line: usize,
}
impl ChunkBoundary {
#[must_use]
pub fn content<'a>(&self, source: &'a str) -> &'a str {
let end = clamp_to_char_boundary(source, self.end_byte.min(source.len()), true);
let start = clamp_to_char_boundary(source, self.start_byte.min(end), false);
&source[start..end]
}
}
fn clamp_to_char_boundary(source: &str, mut byte: usize, forward: bool) -> usize {
byte = byte.min(source.len());
for _ in 0..4 {
if source.is_char_boundary(byte) {
return byte;
}
if forward {
byte = (byte + 1).min(source.len());
} else {
byte = byte.saturating_sub(1);
}
}
byte
}
#[must_use]
pub fn chunk_source(
source: &str,
language: Option<&Language>,
desired_length: usize,
) -> Vec<ChunkBoundary> {
if source.trim().is_empty() {
return Vec::new();
}
if let Some(lang) = language
&& let Some(tree) = parse_source(source, lang)
{
return chunk_tree(source, tree.root_node(), desired_length);
}
chunk_lines(source, desired_length)
}
#[must_use]
pub fn chunk_lines(source: &str, desired_length: usize) -> Vec<ChunkBoundary> {
if source.trim().is_empty() {
return Vec::new();
}
let mut line_boundaries: Vec<ChunkBoundary> = Vec::new();
let mut start = 0;
let mut line_no = 1;
for line in source.split_inclusive('\n') {
let end = start + line.len();
let line_count = line.bytes().filter(|&b| b == b'\n').count().max(1);
line_boundaries.push(ChunkBoundary {
start_byte: start,
end_byte: end,
start_line: line_no,
end_line: line_no + line_count - 1,
});
start = end;
line_no += line_count;
}
merge_adjacent_chunks(line_boundaries, desired_length, source)
}
fn parse_source(source: &str, language: &Language) -> Option<tree_sitter::Tree> {
let mut parser = Parser::new();
parser.set_language(language).ok()?;
parser.parse(source, None)
}
fn chunk_tree(source: &str, root: Node<'_>, desired_length: usize) -> Vec<ChunkBoundary> {
let raw = merge_node_inner(root, desired_length);
let with_lines = raw
.into_iter()
.map(|(start, end)| ChunkBoundary {
start_byte: start,
end_byte: end,
start_line: line_at_byte(source, start),
end_line: line_at_byte(source, end.saturating_sub(1).max(start)),
})
.collect();
merge_adjacent_chunks(with_lines, desired_length, source)
}
fn line_at_byte(source: &str, byte: usize) -> usize {
let clamped = byte.min(source.len());
1 + bytecount::count(&source.as_bytes()[..clamped], b'\n')
}
fn merge_node_inner(node: Node<'_>, desired_length: usize) -> Vec<(usize, usize)> {
let mut cursor = node.walk();
if !cursor.goto_first_child() {
return vec![(node.start_byte(), node.end_byte())];
}
let mut groups: Vec<(usize, usize)> = Vec::new();
loop {
let child = cursor.node();
let start = child.start_byte();
let mut end = child.end_byte();
let mut length = end - start;
if length > desired_length {
groups.extend(merge_node_inner(child, desired_length));
if !cursor.goto_next_sibling() {
break;
}
continue;
}
while let Some(next) = peek_next_sibling(&cursor) {
let next_length = next.end_byte() - next.start_byte();
if length + next_length > desired_length {
break;
}
cursor.goto_next_sibling();
end = cursor.node().end_byte();
length += next_length;
}
groups.push((start, end));
if !cursor.goto_next_sibling() {
break;
}
}
groups
}
fn peek_next_sibling<'a>(cursor: &TreeCursor<'a>) -> Option<Node<'a>> {
let mut probe = cursor.clone();
if probe.goto_next_sibling() {
Some(probe.node())
} else {
None
}
}
fn merge_adjacent_chunks(
chunks: Vec<ChunkBoundary>,
desired_length: usize,
source: &str,
) -> Vec<ChunkBoundary> {
if chunks.is_empty() {
return chunks;
}
let mut merged: Vec<ChunkBoundary> = Vec::new();
let mut current = chunks[0].clone();
let mut current_length = current.end_byte - current.start_byte;
for next in chunks.into_iter().skip(1) {
let nlen = next.end_byte - next.start_byte;
if current_length + nlen > desired_length {
merged.push(current);
current = next;
current_length = nlen;
} else {
current.end_byte = next.end_byte;
current.end_line = next.end_line;
current_length += nlen;
}
}
merged.push(current);
for b in &mut merged {
b.end_line = line_at_byte(source, b.end_byte.saturating_sub(1).max(b.start_byte));
}
merged
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chunker_line_fallback() {
let source = "aaa\nbbbbb\nccc\nddd\n";
let chunks = chunk_lines(source, 8);
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0].content(source), "aaa\n");
assert_eq!(chunks[1].content(source), "bbbbb\n");
assert_eq!(chunks[2].content(source), "ccc\nddd\n");
}
#[test]
fn empty_input_returns_empty() {
assert_eq!(chunk_lines("", 100), Vec::<ChunkBoundary>::new());
assert_eq!(chunk_lines(" \n ", 100), Vec::<ChunkBoundary>::new());
assert_eq!(chunk_source("", None, 100), Vec::<ChunkBoundary>::new());
}
#[test]
fn chunker_merge_adjacent() {
let lang: Language = tree_sitter_python::LANGUAGE.into();
let source = "def a():\n return 1\n\ndef b():\n return 2\n\ndef c():\n return 3\n";
let combined = chunk_source(source, Some(&lang), 200);
let split = chunk_source(source, Some(&lang), 50);
assert!(
combined.len() <= split.len(),
"larger budget should produce fewer-or-equal chunks; got combined={} split={}",
combined.len(),
split.len()
);
let combined_chars: usize = combined.iter().map(|c| c.end_byte - c.start_byte).sum();
assert!(
combined_chars >= source.len() - 4,
"chunks should cover ~all of source ({combined_chars} of {})",
source.len()
);
}
#[test]
fn chunker_recurses_oversized_nodes() {
let lang: Language = tree_sitter_python::LANGUAGE.into();
let mut body = String::new();
for i in 0..40 {
use std::fmt::Write;
let _ = writeln!(&mut body, " x{i} = {i}");
}
let source = format!("def big_function():\n{body}");
let chunks = chunk_source(&source, Some(&lang), 100);
assert!(
chunks.len() > 1,
"oversized function should split into multiple chunks; got {}",
chunks.len()
);
}
#[test]
fn property_chunker_parity_python() {
let lang: Language = tree_sitter_python::LANGUAGE.into();
let source = "def a():\n pass\n\ndef b():\n pass\n\nclass C:\n def m(self):\n pass\n";
let chunks = chunk_source(source, Some(&lang), 1500);
let mut prev_end = 0;
for c in &chunks {
assert!(
c.start_byte >= prev_end,
"chunk overlap: prev_end={prev_end} c.start_byte={}",
c.start_byte
);
prev_end = c.end_byte;
}
assert!(prev_end <= source.len(), "chunks extend past source");
}
}