use std::ops::Range;
use memchr::memchr;
use tree_sitter::TreeCursor;
pub fn split_code(source: &str, tree: &tree_sitter::Tree, max_chunk_size: usize) -> Vec<(usize, usize)> {
if source.is_empty() || max_chunk_size == 0 {
return Vec::new();
}
if source.len() <= max_chunk_size {
return vec![(0, source.len())];
}
let node_ranges = collect_node_ranges(tree.walk());
let max_depth = node_ranges.iter().map(|nr| nr.depth).max().unwrap_or(0);
let mut split_points_by_depth: Vec<Vec<usize>> = vec![Vec::new(); max_depth + 1];
for nr in &node_ranges {
split_points_by_depth[nr.depth].push(nr.range.start);
}
for points in &mut split_points_by_depth {
points.push(source.len());
points.sort_unstable();
points.dedup();
}
let mut chunks: Vec<(usize, usize)> = Vec::new();
split_recursive(
source,
0,
source.len(),
max_chunk_size,
&split_points_by_depth,
0,
&mut chunks,
);
chunks
}
#[derive(Debug, Clone)]
struct NodeRange {
depth: usize,
range: Range<usize>,
}
fn collect_node_ranges(cursor: TreeCursor<'_>) -> Vec<NodeRange> {
let mut ranges = Vec::new();
let mut cursor = cursor;
if !cursor.goto_first_child() {
return ranges;
}
ranges.push(NodeRange {
depth: cursor.depth() as usize,
range: cursor.node().byte_range(),
});
loop {
if cursor.goto_first_child() {
ranges.push(NodeRange {
depth: cursor.depth() as usize,
range: cursor.node().byte_range(),
});
continue;
}
loop {
if cursor.goto_next_sibling() {
ranges.push(NodeRange {
depth: cursor.depth() as usize,
range: cursor.node().byte_range(),
});
break;
}
if !cursor.goto_parent() {
return ranges;
}
}
}
}
fn split_recursive(
source: &str,
region_start: usize,
region_end: usize,
max_chunk_size: usize,
split_points_by_depth: &[Vec<usize>],
current_depth: usize,
out: &mut Vec<(usize, usize)>,
) {
let region_size = region_end - region_start;
if region_size <= max_chunk_size {
if region_size > 0 {
out.push((region_start, region_end));
}
return;
}
if current_depth < split_points_by_depth.len() {
let points = &split_points_by_depth[current_depth];
let relevant: Vec<usize> = points
.iter()
.copied()
.filter(|&p| p > region_start && p < region_end)
.collect();
if !relevant.is_empty() {
let mut boundaries = Vec::with_capacity(relevant.len() + 2);
boundaries.push(region_start);
boundaries.extend_from_slice(&relevant);
boundaries.push(region_end);
let mut cursor = 0;
while cursor < boundaries.len() - 1 {
let chunk_start = boundaries[cursor];
let mut best_end_idx = cursor + 1;
for (j, &boundary) in boundaries.iter().enumerate().skip(cursor + 1) {
if boundary - chunk_start <= max_chunk_size {
best_end_idx = j;
} else {
break;
}
}
let chunk_end = boundaries[best_end_idx];
if chunk_end - chunk_start <= max_chunk_size {
if chunk_end > chunk_start {
out.push((chunk_start, chunk_end));
}
cursor = best_end_idx;
} else {
split_recursive(
source,
chunk_start,
chunk_end,
max_chunk_size,
split_points_by_depth,
current_depth + 1,
out,
);
cursor = best_end_idx;
}
}
return;
}
if current_depth + 1 < split_points_by_depth.len() {
split_recursive(
source,
region_start,
region_end,
max_chunk_size,
split_points_by_depth,
current_depth + 1,
out,
);
return;
}
}
split_at_lines(source, region_start, region_end, max_chunk_size, out);
}
fn split_at_lines(
source: &str,
region_start: usize,
region_end: usize,
max_chunk_size: usize,
out: &mut Vec<(usize, usize)>,
) {
let region = &source[region_start..region_end];
let mut line_ends: Vec<usize> = Vec::new();
let region_bytes = region.as_bytes();
let mut search_start = 0;
while let Some(rel_pos) = memchr(b'\n', ®ion_bytes[search_start..]) {
let abs_pos = region_start + search_start + rel_pos + 1;
line_ends.push(abs_pos);
search_start += rel_pos + 1;
}
if line_ends.last().copied() != Some(region_end) {
line_ends.push(region_end);
}
let mut chunk_start = region_start;
let mut prev_line_end = region_start;
for &line_end in &line_ends {
let candidate_size = line_end - chunk_start;
if candidate_size > max_chunk_size {
if prev_line_end > chunk_start {
out.push((chunk_start, prev_line_end));
chunk_start = prev_line_end;
}
if line_end - chunk_start > max_chunk_size {
split_at_bytes(source, chunk_start, line_end, max_chunk_size, out);
chunk_start = line_end;
}
}
prev_line_end = line_end;
}
if chunk_start < region_end {
out.push((chunk_start, region_end));
}
}
fn split_at_bytes(
source: &str,
region_start: usize,
region_end: usize,
max_chunk_size: usize,
out: &mut Vec<(usize, usize)>,
) {
let mut pos = region_start;
while pos < region_end {
let remaining = region_end - pos;
if remaining <= max_chunk_size {
out.push((pos, region_end));
return;
}
let mut end = pos + max_chunk_size;
while end > pos && !source.is_char_boundary(end) {
end -= 1;
}
if end == pos {
match source[pos..region_end].chars().next() {
Some(ch) => end = pos + ch.len_utf8(),
None => return, }
}
out.push((pos, end));
pos = end;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_parser() -> Option<tree_sitter::Parser> {
let langs = crate::available_languages();
let lang_name = langs.first()?;
let language = crate::get_language(lang_name).ok()?;
let mut parser = tree_sitter::Parser::new();
parser.set_language(&language).ok()?;
Some(parser)
}
fn parse_or_skip(source: &str) -> Option<tree_sitter::Tree> {
let mut parser = test_parser()?;
parser.parse(source, None)
}
#[test]
fn empty_source_returns_empty_vec() {
if let Some(tree) = parse_or_skip("x") {
let result = split_code("", &tree, 100);
assert!(result.is_empty());
}
}
#[test]
fn zero_max_chunk_size_returns_empty_vec() {
if let Some(tree) = parse_or_skip("x") {
let result = split_code("x", &tree, 0);
assert!(result.is_empty());
}
}
#[test]
fn source_fits_in_one_chunk() {
let source = "let x = 1;";
if let Some(tree) = parse_or_skip(source) {
let result = split_code(source, &tree, 1000);
assert_eq!(result, vec![(0, source.len())]);
}
}
#[test]
fn chunks_cover_entire_source() {
let source = "fn foo() {}\nfn bar() {}\nfn baz() {}\n";
if let Some(tree) = parse_or_skip(source) {
let chunks = split_code(source, &tree, 15);
assert!(!chunks.is_empty());
assert_eq!(chunks.first().unwrap().0, 0);
assert_eq!(chunks.last().unwrap().1, source.len());
for window in chunks.windows(2) {
assert_eq!(window[0].1, window[1].0, "chunks must be contiguous");
}
}
}
#[test]
fn line_fallback_when_no_ast_boundaries() {
let source = "aaaa\nbbbb\ncccc\ndddd\n";
if let Some(tree) = parse_or_skip(source) {
let chunks = split_code(source, &tree, 10);
assert!(chunks.len() > 1);
for &(s, e) in &chunks {
assert!(e - s <= 10);
}
}
}
#[test]
fn byte_fallback_on_long_line() {
let source = "abcdefghijklmnopqrstuvwxyz";
if let Some(tree) = parse_or_skip(source) {
let chunks = split_code(source, &tree, 10);
let joined: String = chunks.iter().map(|&(s, e)| &source[s..e]).collect();
assert_eq!(joined, source);
for &(s, e) in &chunks {
assert!(e - s <= 10);
}
}
}
#[test]
fn utf8_safety_in_byte_fallback() {
let source = "aaaa\u{1F600}\u{1F600}\u{1F600}\u{1F600}";
if let Some(tree) = parse_or_skip(source) {
let chunks = split_code(source, &tree, 6);
let joined: String = chunks.iter().map(|&(s, e)| &source[s..e]).collect();
assert_eq!(joined, source);
for &(s, e) in &chunks {
assert!(source.is_char_boundary(s));
assert!(source.is_char_boundary(e));
}
}
}
#[test]
fn collect_node_ranges_depth_first() {
let source = "fn main() {\n let x = 5;\n}";
if let Some(tree) = parse_or_skip(source) {
let ranges = collect_node_ranges(tree.walk());
for nr in &ranges {
assert!(nr.range.start <= source.len());
assert!(nr.range.end <= source.len());
assert!(nr.range.start <= nr.range.end);
assert!(nr.depth >= 1);
}
}
}
#[test]
fn split_at_lines_basic() {
let source = "line1\nline2\nline3\n";
let mut out = Vec::new();
split_at_lines(source, 0, source.len(), 7, &mut out);
assert!(!out.is_empty());
let joined: String = out.iter().map(|&(s, e)| &source[s..e]).collect();
assert_eq!(joined, source);
}
#[test]
fn split_at_bytes_basic() {
let source = "abcdefghij";
let mut out = Vec::new();
split_at_bytes(source, 0, source.len(), 4, &mut out);
let joined: String = out.iter().map(|&(s, e)| &source[s..e]).collect();
assert_eq!(joined, source);
for &(s, e) in &out {
assert!(e - s <= 4);
}
}
#[test]
fn split_at_bytes_utf8() {
let source = "\u{1F600}\u{1F600}\u{1F600}"; let mut out = Vec::new();
split_at_bytes(source, 0, source.len(), 5, &mut out);
let joined: String = out.iter().map(|&(s, e)| &source[s..e]).collect();
assert_eq!(joined, source);
for &(s, e) in &out {
assert!(source.is_char_boundary(s));
assert!(source.is_char_boundary(e));
}
}
}