use crate::character::validate_chunk_config;
use crate::chunk::{chunks_from_spans, TextChunk, TextSpan};
use crate::error::ChunkError;
#[derive(Debug, Clone)]
pub struct MarkdownChunker {
chunk_size: usize,
chunk_overlap: usize,
}
impl MarkdownChunker {
pub fn new(chunk_size: usize, chunk_overlap: usize) -> Result<Self, ChunkError> {
validate_chunk_config(chunk_size, chunk_overlap)?;
Ok(Self {
chunk_size,
chunk_overlap,
})
}
pub fn split_chunks<'a>(&self, input: &'a str) -> Vec<TextChunk<'a>> {
let blocks = markdown_blocks(input);
let spans = merge_structure_spans(input, &blocks, self.chunk_size, self.chunk_overlap);
chunks_from_spans(input, spans, &crate::char_len)
}
pub fn split_text(&self, input: &str) -> Vec<String> {
self.split_chunks(input)
.into_iter()
.map(|chunk| chunk.text.to_string())
.collect()
}
}
#[derive(Debug, Clone)]
pub struct HtmlChunker {
chunk_size: usize,
chunk_overlap: usize,
}
impl HtmlChunker {
pub fn new(chunk_size: usize, chunk_overlap: usize) -> Result<Self, ChunkError> {
validate_chunk_config(chunk_size, chunk_overlap)?;
Ok(Self {
chunk_size,
chunk_overlap,
})
}
pub fn split_chunks<'a>(&self, input: &'a str) -> Vec<TextChunk<'a>> {
let blocks = html_blocks(input);
let spans = merge_structure_spans(input, &blocks, self.chunk_size, self.chunk_overlap);
chunks_from_spans(input, spans, &crate::char_len)
}
pub fn split_text(&self, input: &str) -> Vec<String> {
self.split_chunks(input)
.into_iter()
.map(|chunk| chunk.text.to_string())
.collect()
}
}
pub type XmlChunker = HtmlChunker;
fn markdown_blocks(input: &str) -> Vec<TextSpan> {
let mut blocks = Vec::new();
let mut block_start: Option<usize> = None;
let mut in_fence = false;
for (line_start, line) in lines_with_offsets(input) {
let line_end = line_start + line.len();
let trimmed = line.trim_start();
let is_fence = trimmed.starts_with("```") || trimmed.starts_with("~~~");
if is_fence {
if block_start.is_none() {
block_start = Some(line_start);
}
in_fence = !in_fence;
if !in_fence {
blocks.push(TextSpan::new(
block_start.take().unwrap_or(line_start),
line_end,
));
}
continue;
}
if in_fence {
continue;
}
let is_blank = line.trim().is_empty();
let starts_block = trimmed.starts_with('#')
|| trimmed.starts_with("- ")
|| trimmed.starts_with("* ")
|| trimmed.starts_with("+ ")
|| ordered_list_marker(trimmed);
if is_blank {
if let Some(start) = block_start.take() {
if let Some(span) = TextSpan::new(start, line_start).trim(input) {
blocks.push(span);
}
}
continue;
}
if starts_block {
if let Some(start) = block_start.take() {
if start < line_start {
if let Some(span) = TextSpan::new(start, line_start).trim(input) {
blocks.push(span);
}
}
}
block_start = Some(line_start);
} else if block_start.is_none() {
block_start = Some(line_start);
}
}
if let Some(start) = block_start {
if let Some(span) = TextSpan::new(start, input.len()).trim(input) {
blocks.push(span);
}
}
blocks
}
fn html_blocks(input: &str) -> Vec<TextSpan> {
const TAGS: &[&str] = &[
"section",
"article",
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
"p",
"li",
"pre",
"code",
"table",
"tr",
"blockquote",
"div",
];
let mut blocks = Vec::new();
let lower = input.to_ascii_lowercase();
let mut cursor = 0usize;
while cursor < input.len() {
let Some((tag, start)) = find_next_open_tag(&lower, cursor, TAGS) else {
break;
};
let open_end = lower[start..].find('>').map(|pos| start + pos + 1);
let Some(open_end) = open_end else {
break;
};
let close = format!("</{tag}>");
let end = lower[open_end..]
.find(&close)
.map(|pos| open_end + pos + close.len())
.unwrap_or(open_end);
if let Some(span) = TextSpan::new(start, end).trim(input) {
blocks.push(span);
}
cursor = end.max(open_end);
}
if blocks.is_empty() {
TextSpan::new(0, input.len())
.trim(input)
.into_iter()
.collect()
} else {
blocks
}
}
fn merge_structure_spans(
input: &str,
blocks: &[TextSpan],
chunk_size: usize,
chunk_overlap: usize,
) -> Vec<TextSpan> {
let mut chunks = Vec::new();
let mut current_start: Option<usize> = None;
let mut current_end = 0usize;
let mut previous_block: Option<TextSpan> = None;
for block in blocks {
let candidate = TextSpan::new(current_start.unwrap_or(block.start), block.end);
if crate::char_len(candidate.text(input)) > chunk_size && current_start.is_some() {
let chunk = TextSpan::new(current_start.unwrap(), current_end);
if let Some(chunk) = chunk.trim(input) {
chunks.push(chunk);
}
current_start = previous_block
.filter(|prev| crate::char_len(prev.text(input)) <= chunk_overlap)
.map(|prev| prev.start)
.or(Some(block.start));
} else if current_start.is_none() {
current_start = Some(block.start);
}
current_end = block.end;
previous_block = Some(*block);
}
if let Some(start) = current_start {
if let Some(chunk) = TextSpan::new(start, current_end).trim(input) {
chunks.push(chunk);
}
}
chunks
}
fn lines_with_offsets(input: &str) -> impl Iterator<Item = (usize, &str)> {
let mut offset = 0usize;
input.split_inclusive('\n').map(move |line| {
let start = offset;
offset += line.len();
(start, line)
})
}
fn ordered_list_marker(line: &str) -> bool {
let Some((digits, rest)) = line.split_once('.') else {
return false;
};
!digits.is_empty() && digits.chars().all(|ch| ch.is_ascii_digit()) && rest.starts_with(' ')
}
fn find_next_open_tag<'a>(
lower: &'a str,
cursor: usize,
tags: &'a [&'a str],
) -> Option<(&'a str, usize)> {
tags.iter()
.filter_map(|tag| {
let needle = format!("<{tag}");
lower[cursor..]
.find(&needle)
.map(|pos| (*tag, cursor + pos))
})
.min_by_key(|(_, pos)| *pos)
}
#[cfg(feature = "code")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CodeLanguage {
Rust,
Python,
Other(&'static str),
}
#[cfg(feature = "code")]
#[derive(Debug, Clone)]
pub struct CodeChunker {
language: CodeLanguage,
chunk_size: usize,
chunk_overlap: usize,
}
#[cfg(feature = "code")]
impl CodeChunker {
pub fn new(
language: CodeLanguage,
chunk_size: usize,
chunk_overlap: usize,
) -> Result<Self, ChunkError> {
validate_chunk_config(chunk_size, chunk_overlap)?;
Ok(Self {
language,
chunk_size,
chunk_overlap,
})
}
pub fn try_split_chunks<'a>(&self, input: &'a str) -> Result<Vec<TextChunk<'a>>, ChunkError> {
let spans = code_spans(input, self.language, self.chunk_size)?;
let spans = merge_structure_spans(input, &spans, self.chunk_size, self.chunk_overlap);
Ok(chunks_from_spans(input, spans, &crate::char_len))
}
pub fn try_chunks<'a>(
&self,
input: &'a str,
) -> Result<std::vec::IntoIter<TextChunk<'a>>, ChunkError> {
Ok(self.try_split_chunks(input)?.into_iter())
}
pub fn try_split_text(&self, input: &str) -> Result<Vec<String>, ChunkError> {
Ok(self
.try_split_chunks(input)?
.into_iter()
.map(|chunk| chunk.text.to_string())
.collect())
}
}
#[cfg(feature = "code")]
fn code_spans(
input: &str,
language: CodeLanguage,
chunk_size: usize,
) -> Result<Vec<TextSpan>, ChunkError> {
use tree_sitter::Parser;
let mut parser = Parser::new();
let language_fn = match language {
CodeLanguage::Rust => tree_sitter_rust::LANGUAGE,
CodeLanguage::Python => tree_sitter_python::LANGUAGE,
CodeLanguage::Other(language) => return Err(ChunkError::unsupported_language(language)),
};
parser
.set_language(&language_fn.into())
.map_err(|err| ChunkError::ParseFailure {
message: err.to_string(),
})?;
let tree = parser
.parse(input, None)
.ok_or_else(|| ChunkError::ParseFailure {
message: "tree-sitter returned no parse tree".to_string(),
})?;
if tree.root_node().has_error() {
return Err(ChunkError::ParseFailure {
message: "tree-sitter parse contains syntax errors".to_string(),
});
}
let mut cursor = tree.root_node().walk();
let mut spans = Vec::new();
collect_code_spans(input, language, &mut cursor, chunk_size, &mut spans)?;
if spans.is_empty() {
return Err(ChunkError::OversizedSemanticUnit {
measured: crate::char_len(input),
limit: chunk_size,
});
}
Ok(spans)
}
#[cfg(feature = "code")]
fn collect_code_spans(
input: &str,
language: CodeLanguage,
cursor: &mut tree_sitter::TreeCursor<'_>,
chunk_size: usize,
spans: &mut Vec<TextSpan>,
) -> Result<(), ChunkError> {
loop {
let node = cursor.node();
if is_semantic_code_node(language, node.kind()) {
let span =
with_attached_comments(input, TextSpan::new(node.start_byte(), node.end_byte()));
if crate::char_len(span.text(input)) <= chunk_size {
spans.push(span);
} else if cursor.goto_first_child() {
collect_code_spans(input, language, cursor, chunk_size, spans)?;
cursor.goto_parent();
} else {
return Err(ChunkError::OversizedSemanticUnit {
measured: crate::char_len(span.text(input)),
limit: chunk_size,
});
}
} else if cursor.goto_first_child() {
collect_code_spans(input, language, cursor, chunk_size, spans)?;
cursor.goto_parent();
}
if !cursor.goto_next_sibling() {
break;
}
}
Ok(())
}
#[cfg(feature = "code")]
fn is_semantic_code_node(language: CodeLanguage, kind: &str) -> bool {
match language {
CodeLanguage::Rust => matches!(
kind,
"mod_item" | "struct_item" | "enum_item" | "impl_item" | "function_item"
),
CodeLanguage::Python => matches!(kind, "class_definition" | "function_definition"),
CodeLanguage::Other(_) => false,
}
}
#[cfg(feature = "code")]
fn with_attached_comments(input: &str, span: TextSpan) -> TextSpan {
let mut start = span.start;
while start > 0 {
let line_start = input[..start - 1]
.rfind('\n')
.map(|idx| idx + 1)
.unwrap_or(0);
let line = input[line_start..start].trim();
if line.is_empty()
|| line.starts_with("///")
|| line.starts_with("//!")
|| line.starts_with("//")
|| line.starts_with('#')
{
start = line_start;
} else {
break;
}
}
TextSpan::new(start, span.end)
}