use crate::error::{Result, VecStoreError};
pub trait TextSplitter {
fn split_text(&self, text: &str) -> Result<Vec<String>>;
fn split_with_metadata(&self, text: &str) -> Result<Vec<TextChunk>> {
let chunks = self.split_text(text)?;
Ok(chunks
.into_iter()
.enumerate()
.map(|(i, content)| TextChunk {
index: i,
content,
char_start: 0, char_end: 0,
})
.collect())
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TextChunk {
pub index: usize,
pub content: String,
pub char_start: usize,
pub char_end: usize,
}
pub struct RecursiveCharacterTextSplitter {
chunk_size: usize,
chunk_overlap: usize,
separators: Vec<String>,
}
impl RecursiveCharacterTextSplitter {
pub fn new(chunk_size: usize, chunk_overlap: usize) -> Self {
Self {
chunk_size,
chunk_overlap,
separators: vec![
"\n\n".to_string(), "\n".to_string(), ". ".to_string(), "! ".to_string(),
"? ".to_string(),
" ".to_string(), "".to_string(), ],
}
}
pub fn with_separators(mut self, separators: Vec<String>) -> Self {
self.separators = separators;
self
}
fn split_recursive(&self, text: &str, separators: &[String]) -> Vec<String> {
if text.len() <= self.chunk_size {
return vec![text.to_string()];
}
if separators.is_empty() {
return self.split_by_chars(text);
}
let sep = &separators[0];
let remaining_seps = &separators[1..];
if sep.is_empty() {
return self.split_by_chars(text);
}
let parts: Vec<&str> = text.split(sep).collect();
let mut chunks = Vec::new();
let mut current_chunk = String::new();
for (i, part) in parts.iter().enumerate() {
let part_with_sep = if i < parts.len() - 1 {
format!("{}{}", part, sep)
} else {
part.to_string()
};
if part_with_sep.len() > self.chunk_size {
if !current_chunk.is_empty() {
chunks.push(current_chunk.clone());
current_chunk.clear();
}
let sub_chunks = self.split_recursive(&part_with_sep, remaining_seps);
chunks.extend(sub_chunks);
continue;
}
if current_chunk.len() + part_with_sep.len() <= self.chunk_size {
current_chunk.push_str(&part_with_sep);
} else {
if !current_chunk.is_empty() {
chunks.push(current_chunk.clone());
}
current_chunk = part_with_sep;
}
}
if !current_chunk.is_empty() {
chunks.push(current_chunk);
}
self.add_overlap(chunks)
}
fn split_by_chars(&self, text: &str) -> Vec<String> {
let chars: Vec<char> = text.chars().collect();
let mut chunks = Vec::new();
let mut i = 0;
while i < chars.len() {
let end = (i + self.chunk_size).min(chars.len());
let chunk: String = chars[i..end].iter().collect();
chunks.push(chunk);
if end >= chars.len() {
break;
}
i += self.chunk_size - self.chunk_overlap;
}
chunks
}
fn add_overlap(&self, chunks: Vec<String>) -> Vec<String> {
if self.chunk_overlap == 0 || chunks.len() <= 1 {
return chunks;
}
let mut result = Vec::new();
for (i, chunk) in chunks.iter().enumerate() {
if i == 0 {
result.push(chunk.clone());
continue;
}
let prev_chunk = &chunks[i - 1];
let overlap_chars: Vec<char> = prev_chunk.chars().collect();
let overlap_start = overlap_chars.len().saturating_sub(self.chunk_overlap);
let overlap: String = overlap_chars[overlap_start..].iter().collect();
let new_chunk = format!("{}{}", overlap, chunk);
result.push(new_chunk);
}
result
}
}
impl TextSplitter for RecursiveCharacterTextSplitter {
fn split_text(&self, text: &str) -> Result<Vec<String>> {
if text.is_empty() {
return Ok(vec![]);
}
if self.chunk_size == 0 {
return Err(VecStoreError::invalid_parameter(
"chunk_size",
"must be greater than 0",
));
}
if self.chunk_overlap >= self.chunk_size {
return Err(VecStoreError::invalid_parameter(
"chunk_overlap",
"must be less than chunk_size",
));
}
Ok(self.split_recursive(text, &self.separators))
}
}
pub struct TokenTextSplitter {
max_tokens: usize,
token_overlap: usize,
chars_per_token: usize,
}
impl TokenTextSplitter {
pub fn new(max_tokens: usize, token_overlap: usize) -> Self {
Self {
max_tokens,
token_overlap,
chars_per_token: 4, }
}
pub fn with_chars_per_token(mut self, chars_per_token: usize) -> Self {
self.chars_per_token = chars_per_token;
self
}
}
impl TextSplitter for TokenTextSplitter {
fn split_text(&self, text: &str) -> Result<Vec<String>> {
if text.is_empty() {
return Ok(vec![]);
}
let chunk_size = self.max_tokens * self.chars_per_token;
let chunk_overlap = self.token_overlap * self.chars_per_token;
let char_splitter = RecursiveCharacterTextSplitter::new(chunk_size, chunk_overlap);
char_splitter.split_text(text)
}
}
pub struct MarkdownTextSplitter {
chunk_size: usize,
chunk_overlap: usize,
preserve_headers: bool,
}
impl MarkdownTextSplitter {
pub fn new(chunk_size: usize, chunk_overlap: usize) -> Self {
Self {
chunk_size,
chunk_overlap,
preserve_headers: false, }
}
pub fn with_preserve_headers(mut self, preserve: bool) -> Self {
self.preserve_headers = preserve;
self
}
fn parse_sections(&self, text: &str) -> Vec<MarkdownSection> {
let mut sections = Vec::new();
let mut current_section = MarkdownSection {
level: 0,
header: String::new(),
content: String::new(),
header_chain: Vec::new(),
};
let mut header_stack: Vec<(usize, String)> = Vec::new();
for line in text.lines() {
if let Some(level) = self.parse_header_level(line) {
if !current_section.content.is_empty() || !current_section.header.is_empty() {
sections.push(current_section.clone());
}
let header_text = line.trim_start_matches('#').trim().to_string();
header_stack.retain(|(l, _)| *l < level);
header_stack.push((level, header_text.clone()));
current_section = MarkdownSection {
level,
header: header_text,
content: String::new(),
header_chain: header_stack.iter().map(|(_, h)| h.clone()).collect(),
};
} else {
if !current_section.content.is_empty() {
current_section.content.push('\n');
}
current_section.content.push_str(line);
}
}
if !current_section.content.is_empty() || !current_section.header.is_empty() {
sections.push(current_section);
}
sections
}
fn parse_header_level(&self, line: &str) -> Option<usize> {
let trimmed = line.trim_start();
if !trimmed.starts_with('#') {
return None;
}
let level = trimmed.chars().take_while(|&c| c == '#').count();
if level > 0 && level <= 6 {
Some(level)
} else {
None
}
}
}
#[derive(Debug, Clone)]
struct MarkdownSection {
level: usize,
header: String,
content: String,
header_chain: Vec<String>, }
impl TextSplitter for MarkdownTextSplitter {
fn split_text(&self, text: &str) -> Result<Vec<String>> {
if text.is_empty() {
return Ok(vec![]);
}
if self.chunk_size == 0 {
return Err(VecStoreError::invalid_parameter(
"chunk_size",
"must be greater than 0",
));
}
let sections = self.parse_sections(text);
let mut chunks = Vec::new();
let mut current_chunk = String::new();
let mut current_header_context = String::new();
for section in sections {
if self.preserve_headers && !section.header_chain.is_empty() {
current_header_context = section
.header_chain
.iter()
.enumerate()
.map(|(i, h)| format!("{} {}", "#".repeat(i + 1), h))
.collect::<Vec<_>>()
.join("\n");
current_header_context.push_str("\n\n");
}
let section_text = if section.header.is_empty() {
section.content.clone()
} else {
format!(
"{} {}\n\n{}",
"#".repeat(section.level),
section.header,
section.content
)
};
let chunk_with_section = if self.preserve_headers {
format!(
"{}{}{}",
current_chunk, current_header_context, section_text
)
} else {
format!("{}{}", current_chunk, section_text)
};
if chunk_with_section.len() <= self.chunk_size {
current_chunk = chunk_with_section;
} else {
if !current_chunk.is_empty() {
chunks.push(current_chunk.trim().to_string());
}
if section_text.len() > self.chunk_size {
let splitter = RecursiveCharacterTextSplitter::new(
self.chunk_size.saturating_sub(current_header_context.len()),
self.chunk_overlap,
);
let sub_chunks = splitter.split_text(§ion_text)?;
for sub_chunk in sub_chunks {
if self.preserve_headers && !current_header_context.is_empty() {
chunks.push(format!("{}{}", current_header_context, sub_chunk));
} else {
chunks.push(sub_chunk);
}
}
current_chunk = String::new();
} else {
current_chunk = if self.preserve_headers {
format!("{}{}", current_header_context, section_text)
} else {
section_text
};
}
}
}
if !current_chunk.is_empty() {
chunks.push(current_chunk.trim().to_string());
}
Ok(chunks)
}
}
pub struct CodeTextSplitter {
chunk_size: usize,
chunk_overlap: usize,
language: Option<String>,
}
impl CodeTextSplitter {
pub fn new(chunk_size: usize, chunk_overlap: usize) -> Self {
Self {
chunk_size,
chunk_overlap,
language: None, }
}
pub fn with_language(mut self, language: impl Into<String>) -> Self {
self.language = Some(language.into());
self
}
fn is_code_block_start(&self, line: &str) -> bool {
let trimmed = line.trim_start();
match self.language.as_deref() {
Some("rust") => {
trimmed.starts_with("fn ")
|| trimmed.starts_with("pub fn ")
|| trimmed.starts_with("struct ")
|| trimmed.starts_with("pub struct ")
|| trimmed.starts_with("enum ")
|| trimmed.starts_with("pub enum ")
|| trimmed.starts_with("impl ")
|| trimmed.starts_with("trait ")
}
Some("python") => {
trimmed.starts_with("def ")
|| trimmed.starts_with("class ")
|| trimmed.starts_with("async def ")
}
Some("javascript") | Some("typescript") => {
trimmed.starts_with("function ")
|| trimmed.starts_with("class ")
|| trimmed.starts_with("const ")
|| trimmed.starts_with("let ")
|| trimmed.starts_with("async function ")
|| trimmed.starts_with("export ")
}
Some("java") | Some("c") | Some("cpp") => {
(trimmed.contains('(')
&& trimmed.contains(')')
&& (trimmed.contains("public")
|| trimmed.contains("private")
|| trimmed.contains("void")
|| trimmed.contains("int")))
|| trimmed.starts_with("class ")
}
Some("go") => {
trimmed.starts_with("func ")
|| trimmed.starts_with("type ")
|| trimmed.starts_with("struct ")
}
_ => {
trimmed.starts_with("fn ")
|| trimmed.starts_with("function ")
|| trimmed.starts_with("def ")
|| trimmed.starts_with("class ")
}
}
}
fn get_separators(&self) -> Vec<String> {
vec![
"\n\n".to_string(), "\n}\n".to_string(), "\n\n".to_string(), "\n".to_string(), "; ".to_string(), " ".to_string(), "".to_string(), ]
}
}
impl TextSplitter for CodeTextSplitter {
fn split_text(&self, text: &str) -> Result<Vec<String>> {
if text.is_empty() {
return Ok(vec![]);
}
if self.chunk_size == 0 {
return Err(VecStoreError::invalid_parameter(
"chunk_size",
"must be greater than 0",
));
}
let separators = self.get_separators();
let splitter = RecursiveCharacterTextSplitter::new(self.chunk_size, self.chunk_overlap)
.with_separators(separators);
if self.language.is_some() {
let mut chunks = Vec::new();
let mut current_chunk = String::new();
let mut current_block = String::new();
for line in text.lines() {
let line_with_newline = format!("{}\n", line);
if self.is_code_block_start(line) && !current_block.is_empty() {
if current_chunk.len() + current_block.len() <= self.chunk_size {
current_chunk.push_str(¤t_block);
current_block.clear();
} else {
if !current_chunk.is_empty() {
chunks.push(current_chunk.clone());
}
current_chunk = current_block.clone();
current_block.clear();
}
}
current_block.push_str(&line_with_newline);
if current_block.len() > self.chunk_size {
if !current_chunk.is_empty() {
chunks.push(current_chunk.clone());
current_chunk.clear();
}
let sub_chunks = splitter.split_text(¤t_block)?;
chunks.extend(sub_chunks);
current_block.clear();
}
}
if !current_block.is_empty() {
current_chunk.push_str(¤t_block);
}
if !current_chunk.is_empty() {
chunks.push(current_chunk);
}
return Ok(chunks);
}
splitter.split_text(text)
}
}
pub trait Embedder {
fn embed(&self, text: &str) -> Result<Vec<f32>>;
}
pub struct SemanticTextSplitter {
embedder: Box<dyn Embedder>,
max_chunk_size: usize,
min_chunk_size: usize,
similarity_threshold: f32,
}
impl SemanticTextSplitter {
pub fn new(embedder: Box<dyn Embedder>, max_chunk_size: usize, min_chunk_size: usize) -> Self {
Self {
embedder,
max_chunk_size,
min_chunk_size,
similarity_threshold: 0.7, }
}
pub fn with_similarity_threshold(mut self, threshold: f32) -> Self {
self.similarity_threshold = threshold.clamp(0.0, 1.0);
self
}
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot_product / (norm_a * norm_b)
}
fn split_sentences(&self, text: &str) -> Vec<String> {
text.split(&['.', '!', '?'][..])
.filter(|s| !s.trim().is_empty())
.map(|s| s.trim().to_string())
.collect()
}
}
impl TextSplitter for SemanticTextSplitter {
fn split_text(&self, text: &str) -> Result<Vec<String>> {
if text.is_empty() {
return Ok(vec![]);
}
if self.max_chunk_size == 0 {
return Err(VecStoreError::invalid_parameter(
"max_chunk_size",
"must be greater than 0",
));
}
let sentences = self.split_sentences(text);
if sentences.is_empty() {
return Ok(vec![]);
}
let mut sentence_embeddings = Vec::new();
for sentence in &sentences {
let embedding = self.embedder.embed(sentence)?;
sentence_embeddings.push(embedding);
}
let mut chunks = Vec::new();
let mut current_chunk = String::new();
let mut current_embedding: Option<Vec<f32>> = None;
for (i, sentence) in sentences.iter().enumerate() {
let sentence_with_space = if current_chunk.is_empty() {
sentence.clone()
} else {
format!(" {}", sentence)
};
if current_chunk.len() + sentence_with_space.len() > self.max_chunk_size {
if current_chunk.len() >= self.min_chunk_size {
chunks.push(current_chunk.clone());
current_chunk.clear();
current_embedding = None;
}
}
let should_add = if let Some(ref chunk_emb) = current_embedding {
let similarity = self.cosine_similarity(chunk_emb, &sentence_embeddings[i]);
similarity >= self.similarity_threshold
} else {
true };
if should_add || current_chunk.is_empty() {
current_chunk.push_str(&sentence_with_space);
if let Some(ref mut chunk_emb) = current_embedding {
for (j, val) in sentence_embeddings[i].iter().enumerate() {
chunk_emb[j] = (chunk_emb[j] + val) / 2.0;
}
} else {
current_embedding = Some(sentence_embeddings[i].clone());
}
} else {
if current_chunk.len() >= self.min_chunk_size {
chunks.push(current_chunk.clone());
}
current_chunk = sentence.clone();
current_embedding = Some(sentence_embeddings[i].clone());
}
}
if !current_chunk.is_empty() && current_chunk.len() >= self.min_chunk_size {
chunks.push(current_chunk);
}
if chunks.is_empty() {
let fallback =
RecursiveCharacterTextSplitter::new(self.max_chunk_size, self.min_chunk_size / 2);
return fallback.split_text(text);
}
Ok(chunks)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_recursive_splitter_basic() {
let splitter = RecursiveCharacterTextSplitter::new(20, 0);
let text = "Short text.";
let chunks = splitter.split_text(text).unwrap();
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0], text);
}
#[test]
fn test_recursive_splitter_paragraphs() {
let splitter = RecursiveCharacterTextSplitter::new(50, 0);
let text = "First paragraph.\n\nSecond paragraph.";
let chunks = splitter.split_text(text).unwrap();
assert!(chunks.len() >= 1);
}
#[test]
fn test_recursive_splitter_overlap() {
let splitter = RecursiveCharacterTextSplitter::new(20, 5);
let text = "This is a longer text that should be split into multiple chunks.";
let chunks = splitter.split_text(text).unwrap();
assert!(chunks.len() > 1);
}
#[test]
fn test_token_splitter() {
let splitter = TokenTextSplitter::new(10, 2); let text = "This is a test. This text should be split based on token count.";
let chunks = splitter.split_text(text).unwrap();
assert!(chunks.len() > 0);
}
#[test]
fn test_empty_text() {
let splitter = RecursiveCharacterTextSplitter::new(100, 10);
let chunks = splitter.split_text("").unwrap();
assert_eq!(chunks.len(), 0);
}
#[test]
fn test_invalid_chunk_size() {
let splitter = RecursiveCharacterTextSplitter::new(0, 0);
let result = splitter.split_text("test");
assert!(result.is_err());
}
#[test]
fn test_invalid_overlap() {
let splitter = RecursiveCharacterTextSplitter::new(100, 100);
let result = splitter.split_text("test");
assert!(result.is_err());
}
#[test]
fn test_markdown_splitter_basic() {
let splitter = MarkdownTextSplitter::new(200, 20);
let text = "# Header 1\n\nSome content here.\n\n## Header 2\n\nMore content.";
let chunks = splitter.split_text(text).unwrap();
assert!(chunks.len() >= 1);
}
#[test]
fn test_markdown_splitter_preserve_headers() {
let splitter = MarkdownTextSplitter::new(200, 20).with_preserve_headers(true);
let text = "# Main\n\nContent 1\n\n## Section\n\nContent 2";
let chunks = splitter.split_text(text).unwrap();
assert!(chunks.len() >= 1);
}
#[test]
fn test_markdown_header_parsing() {
let splitter = MarkdownTextSplitter::new(100, 10);
assert_eq!(splitter.parse_header_level("# H1"), Some(1));
assert_eq!(splitter.parse_header_level("## H2"), Some(2));
assert_eq!(splitter.parse_header_level("### H3"), Some(3));
assert_eq!(splitter.parse_header_level("Not a header"), None);
assert_eq!(splitter.parse_header_level("####### Too many"), None);
}
#[test]
fn test_markdown_simple_by_default() {
let splitter = MarkdownTextSplitter::new(500, 50);
assert!(!splitter.preserve_headers);
}
#[test]
fn test_code_splitter_basic() {
let splitter = CodeTextSplitter::new(200, 20);
let code = "fn main() {\n println!(\"Hello\");\n}\n\nfn test() {\n // test\n}";
let chunks = splitter.split_text(code).unwrap();
assert!(chunks.len() >= 1);
}
#[test]
fn test_code_splitter_with_language() {
let splitter = CodeTextSplitter::new(300, 30).with_language("rust");
let code =
"fn main() {\n println!(\"Hello\");\n}\n\nfn test() {\n println!(\"Test\");\n}";
let chunks = splitter.split_text(code).unwrap();
assert!(chunks.len() >= 1);
}
#[test]
fn test_code_block_detection() {
let splitter = CodeTextSplitter::new(100, 10).with_language("rust");
assert!(splitter.is_code_block_start("fn main() {"));
assert!(splitter.is_code_block_start("pub fn test() {"));
assert!(splitter.is_code_block_start("struct Foo {"));
assert!(!splitter.is_code_block_start(" let x = 5;"));
}
#[test]
fn test_code_splitter_simple_by_default() {
let splitter = CodeTextSplitter::new(500, 50);
assert!(splitter.language.is_none());
}
struct MockEmbedder;
impl Embedder for MockEmbedder {
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let len = text.len() as f32;
Ok(vec![len / 100.0, len / 50.0, len / 25.0])
}
}
#[test]
fn test_semantic_splitter_basic() {
let embedder = Box::new(MockEmbedder);
let splitter = SemanticTextSplitter::new(embedder, 200, 20);
let text =
"First sentence. Second sentence here. Third one is different. Fourth continues.";
let chunks = splitter.split_text(text).unwrap();
assert!(chunks.len() >= 1);
}
#[test]
fn test_semantic_splitter_with_threshold() {
let embedder = Box::new(MockEmbedder);
let splitter = SemanticTextSplitter::new(embedder, 300, 30).with_similarity_threshold(0.8);
let text = "Sentence one. Sentence two. Sentence three.";
let chunks = splitter.split_text(text).unwrap();
assert!(chunks.len() >= 1);
}
#[test]
fn test_semantic_splitter_cosine_similarity() {
let embedder = Box::new(MockEmbedder);
let splitter = SemanticTextSplitter::new(embedder, 100, 10);
let v1 = vec![1.0, 0.0, 0.0];
let v2 = vec![1.0, 0.0, 0.0];
let v3 = vec![0.0, 1.0, 0.0];
let sim1 = splitter.cosine_similarity(&v1, &v2);
assert!((sim1 - 1.0).abs() < 0.01);
let sim2 = splitter.cosine_similarity(&v1, &v3);
assert!(sim2.abs() < 0.01);
}
#[test]
fn test_embedder_trait_composable() {
struct CustomEmbedder;
impl Embedder for CustomEmbedder {
fn embed(&self, _text: &str) -> Result<Vec<f32>> {
Ok(vec![1.0, 2.0, 3.0])
}
}
let embedder = Box::new(CustomEmbedder);
let splitter = SemanticTextSplitter::new(embedder, 500, 50);
let text = "Test text.";
let result = splitter.split_text(text);
assert!(result.is_ok());
}
}