use tree_sitter::{Node, Parser};
use crate::language::{is_javascript_like, is_typescript_like, tree_sitter_language};
use crate::model::Chunk;
const DESIRED_CHUNK_LENGTH_CHARS: usize = 1500;
struct ChunkBoundary {
start: usize,
end: usize,
}
fn is_definition_node(language: &str, node: &Node) -> bool {
let kind = node.kind();
match language {
"rust" => matches!(
kind,
"function_item"
| "impl_item"
| "struct_item"
| "enum_item"
| "trait_item"
| "mod_item"
| "const_item"
| "static_item"
| "type_item"
| "macro_definition"
| "attribute_item"
),
"python" => matches!(
kind,
"function_definition" | "class_definition" | "decorated_definition"
),
language if is_javascript_like(language) => matches!(
kind,
"function_declaration"
| "class_declaration"
| "export_statement"
| "lexical_declaration"
| "variable_declaration"
),
language if is_typescript_like(language) => matches!(
kind,
"function_declaration"
| "class_declaration"
| "interface_declaration"
| "type_alias_declaration"
| "enum_declaration"
| "export_statement"
| "lexical_declaration"
| "variable_declaration"
),
"go" => matches!(
kind,
"function_declaration" | "method_declaration" | "type_declaration"
),
"java" => matches!(
kind,
"class_declaration"
| "method_declaration"
| "interface_declaration"
| "enum_declaration"
| "constructor_declaration"
| "record_declaration"
),
"c" => matches!(
kind,
"function_definition" | "struct_specifier" | "enum_specifier" | "declaration"
),
"cpp" => matches!(
kind,
"function_definition"
| "class_specifier"
| "struct_specifier"
| "enum_specifier"
| "declaration"
| "namespace_definition"
| "template_declaration"
),
_ => false,
}
}
fn chunk_with_tree_sitter(source: &str, language: &str) -> Option<Vec<ChunkBoundary>> {
let ts_lang = tree_sitter_language(language)?;
let mut parser = Parser::new();
parser.set_language(&ts_lang).ok()?;
let tree = parser.parse(source, None)?;
let root = tree.root_node();
let mut def_starts: Vec<usize> = Vec::new();
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
if is_definition_node(language, &child) {
def_starts.push(child.start_byte());
}
}
if def_starts.is_empty() {
return None;
}
let mut boundaries = Vec::new();
for (i, &start) in def_starts.iter().enumerate() {
let end = if i + 1 < def_starts.len() {
def_starts[i + 1]
} else {
source.len()
};
let actual_start = if i == 0 { 0 } else { start };
if actual_start < end {
boundaries.push(ChunkBoundary {
start: actual_start,
end,
});
}
}
if boundaries.is_empty() {
return None;
}
Some(merge_adjacent_chunks(
&boundaries,
DESIRED_CHUNK_LENGTH_CHARS,
))
}
fn is_markdown_heading(line: &str) -> bool {
let trimmed = line.trim_start();
let hashes = trimmed.chars().take_while(|&c| c == '#').count();
if hashes == 0 || hashes > 6 {
return false;
}
matches!(trimmed.as_bytes().get(hashes), Some(b' ' | b'\t'))
}
fn chunk_markdown(source: &str) -> Option<Vec<ChunkBoundary>> {
let mut heading_starts = Vec::new();
let mut index = 0;
for line in source.split_inclusive('\n') {
if is_markdown_heading(line) {
heading_starts.push(index);
}
index += line.len();
}
if heading_starts.is_empty() {
return None;
}
if heading_starts[0] != 0 {
heading_starts.insert(0, 0);
}
let mut sections = Vec::new();
for (i, &start) in heading_starts.iter().enumerate() {
let end = if i + 1 < heading_starts.len() {
heading_starts[i + 1]
} else {
source.len()
};
let section = &source[start..end];
if section.trim().is_empty() {
continue;
}
sections.push(ChunkBoundary { start, end });
}
let sections = merge_heading_only_sections(source, §ions);
let mut boundaries = Vec::new();
for section_boundary in sections {
let start = section_boundary.start;
let end = section_boundary.end;
let section = &source[start..end];
if end - start <= DESIRED_CHUNK_LENGTH_CHARS {
boundaries.push(ChunkBoundary { start, end });
continue;
}
for part in chunk_lines(section, DESIRED_CHUNK_LENGTH_CHARS) {
boundaries.push(ChunkBoundary {
start: start + part.start,
end: start + part.end,
});
}
}
if boundaries.is_empty() {
return None;
}
Some(boundaries)
}
fn merge_heading_only_sections(source: &str, sections: &[ChunkBoundary]) -> Vec<ChunkBoundary> {
let mut merged = Vec::new();
let mut pending_heading_start: Option<usize> = None;
for section in sections {
if is_markdown_heading_only(&source[section.start..section.end]) {
if let Some(start) = pending_heading_start.take() {
merged.push(ChunkBoundary {
start,
end: section.start,
});
}
pending_heading_start = Some(section.start);
continue;
}
let start = pending_heading_start.take().unwrap_or(section.start);
merged.push(ChunkBoundary {
start,
end: section.end,
});
}
if let Some(start) = pending_heading_start {
merged.push(ChunkBoundary {
start,
end: source.len(),
});
}
merged
}
fn is_markdown_heading_only(section: &str) -> bool {
let mut non_empty = section
.lines()
.map(str::trim)
.filter(|line| !line.is_empty());
let Some(first) = non_empty.next() else {
return false;
};
is_markdown_heading(first) && non_empty.next().is_none()
}
fn merge_adjacent_chunks(chunks: &[ChunkBoundary], desired_length: usize) -> Vec<ChunkBoundary> {
if chunks.is_empty() {
return Vec::new();
}
let mut merged = Vec::new();
let mut current_start = chunks[0].start;
let mut current_end = chunks[0].end;
let mut current_length = current_end - current_start;
for group in &chunks[1..] {
let length = group.end - group.start;
if current_length + length > desired_length {
merged.push(ChunkBoundary {
start: current_start,
end: current_end,
});
current_start = group.start;
current_end = group.end;
current_length = length;
continue;
}
current_end = group.end;
current_length += length;
}
merged.push(ChunkBoundary {
start: current_start,
end: current_end,
});
merged
}
fn chunk_lines(text: &str, desired_length: usize) -> Vec<ChunkBoundary> {
if text.trim().is_empty() {
return Vec::new();
}
let mut lines_as_groups = Vec::new();
let mut index = 0;
for line in text.split_inclusive('\n') {
lines_as_groups.push(ChunkBoundary {
start: index,
end: index + line.len(),
});
index += line.len();
}
if index < text.len() {
lines_as_groups.push(ChunkBoundary {
start: index,
end: text.len(),
});
}
merge_adjacent_chunks(&lines_as_groups, desired_length)
}
pub fn chunk_source(source: &str, file_path: &str, language: Option<&str>) -> Vec<Chunk> {
if source.trim().is_empty() {
return Vec::new();
}
let boundaries = match language {
Some("markdown") => chunk_markdown(source),
Some(lang) => chunk_with_tree_sitter(source, lang),
None => None,
}
.unwrap_or_else(|| chunk_lines(source, DESIRED_CHUNK_LENGTH_CHARS));
let mut chunks = Vec::new();
for boundary in &boundaries {
let text = &source[boundary.start..boundary.end];
let start_line = source[..boundary.start].matches('\n').count() + 1;
let end_line = source[..boundary.end].matches('\n').count() + 1;
chunks.push(Chunk::new(
text.to_string(),
file_path.to_string(),
start_line,
end_line,
language.map(String::from),
));
}
chunks
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rust_tree_sitter_chunking_small() {
let source = r#"
use std::collections::HashMap;
fn foo() {
println!("foo");
}
struct MyStruct {
field: i32,
}
"#;
let chunks = chunk_source(source, "test.rs", Some("rust"));
assert!(!chunks.is_empty());
let all_content: String = chunks.iter().map(|c| c.content.as_str()).collect();
assert!(all_content.contains("fn foo"));
assert!(all_content.contains("struct MyStruct"));
assert!(all_content.contains("use std::collections"));
}
#[test]
fn test_rust_tree_sitter_splits_large() {
let long_body = " let x = 1;\n".repeat(100);
let source = format!(
"fn foo() {{\n{long_body}}}\n\nfn bar() {{\n{long_body}}}\n\nfn baz() {{\n{long_body}}}\n"
);
let chunks = chunk_source(&source, "test.rs", Some("rust"));
assert!(
chunks.len() >= 2,
"large source should split: got {} chunks",
chunks.len()
);
}
#[test]
fn test_tsx_tree_sitter_chunking() {
let source = r#"
import React from 'react';
export function Button() {
return <button>Save</button>;
}
"#;
let chunks = chunk_source(source, "Button.tsx", Some("tsx"));
assert!(!chunks.is_empty());
let all_content: String = chunks.iter().map(|c| c.content.as_str()).collect();
assert!(all_content.contains("export function Button"));
}
#[test]
fn test_python_tree_sitter_chunking() {
let long_body = " x = 1\n".repeat(100);
let source =
format!("import os\n\nclass MyClass:\n{long_body}\ndef standalone():\n{long_body}\n");
let chunks = chunk_source(&source, "test.py", Some("python"));
assert!(
chunks.len() >= 2,
"large python source should split: got {} chunks",
chunks.len()
);
let all_content: String = chunks.iter().map(|c| c.content.as_str()).collect();
assert!(all_content.contains("class MyClass"));
assert!(all_content.contains("def standalone"));
}
#[test]
fn test_fallback_for_unknown_language() {
let source = "line1\nline2\nline3\n";
let chunks = chunk_source(source, "test.xyz", None);
assert!(!chunks.is_empty());
}
#[test]
fn test_markdown_heading_chunking_splits_large_sections() {
let long_body = "retrieval notes preserve wiki context\n".repeat(80);
let source = format!(
"# Project Wiki\nshort intro\n\n## Retrieval\n{long_body}\n## Verification\nfacts need sources\n"
);
let chunks = chunk_source(&source, "wiki/index.md", Some("markdown"));
assert!(
chunks.len() >= 2,
"large markdown source should split on headings: got {} chunks",
chunks.len()
);
assert!(
chunks
.iter()
.any(|chunk| chunk.content.starts_with("## Retrieval")),
"retrieval section should start at its heading"
);
}
#[test]
fn test_markdown_heading_chunking_keeps_small_sections_separate() {
let source = "# First\nalpha beta\n\n## Second\ngamma delta\n";
let chunks = chunk_source(source, "wiki/index.md", Some("markdown"));
assert_eq!(chunks.len(), 2, "content sections should not be merged");
assert!(chunks[0].content.starts_with("# First"));
assert!(chunks[1].content.starts_with("## Second"));
}
#[test]
fn test_markdown_heading_chunking_merges_empty_title_with_first_section() {
let source = "# Title\n\n## Claim\nbody text\n\n## Evidence\nmore text\n";
let chunks = chunk_source(source, "wiki/index.md", Some("markdown"));
assert_eq!(chunks.len(), 2);
assert!(chunks[0].content.starts_with("# Title"));
assert!(chunks[0].content.contains("## Claim"));
assert!(chunks[1].content.starts_with("## Evidence"));
}
#[test]
fn test_javascript_tree_sitter_chunking() {
let source = r#"
const x = require('something');
function hello() {
console.log("hello");
}
class Greeter {
greet() {
return "hi";
}
}
"#;
let chunks = chunk_source(source, "test.js", Some("javascript"));
assert!(!chunks.is_empty());
let all_content: String = chunks.iter().map(|c| c.content.as_str()).collect();
assert!(all_content.contains("function hello"));
assert!(all_content.contains("class Greeter"));
}
#[test]
fn test_go_tree_sitter_chunking() {
let source = r#"
package main
import "fmt"
func main() {
fmt.Println("hello")
}
func helper() int {
return 42
}
"#;
let chunks = chunk_source(source, "test.go", Some("go"));
assert!(!chunks.is_empty());
let all_content: String = chunks.iter().map(|c| c.content.as_str()).collect();
assert!(all_content.contains("func main"));
assert!(all_content.contains("func helper"));
}
}