use crate::{
core::{ChunkId, ChunkingStrategy, DocumentId, TextChunk},
text::{HierarchicalChunker, SemanticChunker},
};
use std::sync::atomic::{AtomicU64, Ordering};
static CHUNK_COUNTER: AtomicU64 = AtomicU64::new(0);
pub struct HierarchicalChunkingStrategy {
inner: HierarchicalChunker,
chunk_size: usize,
overlap: usize,
document_id: DocumentId,
}
impl HierarchicalChunkingStrategy {
pub fn new(chunk_size: usize, overlap: usize, document_id: DocumentId) -> Self {
Self {
inner: HierarchicalChunker::new().with_min_size(50),
chunk_size,
overlap,
document_id,
}
}
pub fn with_min_size(mut self, min_size: usize) -> Self {
self.inner = self.inner.with_min_size(min_size);
self
}
}
impl ChunkingStrategy for HierarchicalChunkingStrategy {
fn chunk(&self, text: &str) -> Vec<TextChunk> {
let chunks_text = self.inner.chunk_text(text, self.chunk_size, self.overlap);
let mut chunks = Vec::new();
let mut current_pos = 0;
for chunk_content in chunks_text {
if !chunk_content.trim().is_empty() {
let chunk_id = ChunkId::new(format!(
"{}_{}",
self.document_id,
CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
));
let chunk_start = current_pos;
let chunk_end = chunk_start + chunk_content.len();
let chunk = TextChunk::new(
chunk_id,
self.document_id.clone(),
chunk_content.clone(),
chunk_start,
chunk_end,
);
chunks.push(chunk);
current_pos = chunk_end;
} else {
current_pos += chunk_content.len();
}
}
chunks
}
}
pub struct SemanticChunkingStrategy {
_inner: SemanticChunker,
document_id: DocumentId,
}
impl SemanticChunkingStrategy {
pub fn new(chunker: SemanticChunker, document_id: DocumentId) -> Self {
Self {
_inner: chunker,
document_id,
}
}
}
impl ChunkingStrategy for SemanticChunkingStrategy {
fn chunk(&self, text: &str) -> Vec<TextChunk> {
let sentences: Vec<&str> = text
.split(&['.', '!', '?'][..])
.filter(|s| !s.trim().is_empty())
.collect();
let mut chunks = Vec::new();
let mut current_pos = 0;
let chunk_size = 5; for chunk_sentences in sentences.chunks(chunk_size) {
let chunk_content = chunk_sentences.join(". ") + ".";
let chunk_id = ChunkId::new(format!(
"{}_{}",
self.document_id,
CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
));
let chunk_start = current_pos;
let chunk_end = chunk_start + chunk_content.len();
let chunk = TextChunk::new(
chunk_id,
self.document_id.clone(),
chunk_content,
chunk_start,
chunk_end,
);
chunks.push(chunk);
current_pos = chunk_end;
}
chunks
}
}
#[cfg(feature = "code-chunking")]
pub struct RustCodeChunkingStrategy {
min_chunk_size: usize,
document_id: DocumentId,
}
#[cfg(feature = "code-chunking")]
impl RustCodeChunkingStrategy {
pub fn new(min_chunk_size: usize, document_id: DocumentId) -> Self {
Self {
min_chunk_size,
document_id,
}
}
}
#[cfg(feature = "code-chunking")]
impl ChunkingStrategy for RustCodeChunkingStrategy {
fn chunk(&self, text: &str) -> Vec<TextChunk> {
use tree_sitter::Parser;
let mut parser = Parser::new();
let language = tree_sitter_rust::language();
parser
.set_language(&language)
.expect("Error loading Rust grammar");
let tree = parser.parse(text, None).expect("Error parsing Rust code");
let root_node = tree.root_node();
let mut chunks = Vec::new();
self.extract_chunks(&root_node, text, &mut chunks);
if chunks.is_empty() && !text.trim().is_empty() {
let chunk_id = ChunkId::new(format!(
"{}_{}",
self.document_id,
CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
));
let chunk = TextChunk::new(
chunk_id,
self.document_id.clone(),
text.to_string(),
0,
text.len(),
);
chunks.push(chunk);
}
chunks
}
}
#[cfg(feature = "code-chunking")]
impl RustCodeChunkingStrategy {
fn extract_chunks(&self, node: &tree_sitter::Node, source: &str, chunks: &mut Vec<TextChunk>) {
match node.kind() {
"function_item" | "impl_item" | "struct_item" | "enum_item" | "mod_item"
| "trait_item" => {
let start_byte = node.start_byte();
let end_byte = node.end_byte();
let start_pos = source.len() - source[start_byte..].len();
let end_pos = source.len() - source[end_byte..].len();
let chunk_content = &source[start_pos..end_pos];
if chunk_content.len() >= self.min_chunk_size {
let chunk_id = ChunkId::new(format!(
"{}_{}",
self.document_id,
CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
));
let chunk = TextChunk::new(
chunk_id,
self.document_id.clone(),
chunk_content.to_string(),
start_pos,
end_pos,
);
chunks.push(chunk);
}
},
"source_file" => {
let mut child = node.child(0);
while let Some(current) = child {
self.extract_chunks(¤t, source, chunks);
child = current.next_sibling();
}
},
_ => {
let mut child = node.child(0);
while let Some(current) = child {
self.extract_chunks(¤t, source, chunks);
child = current.next_sibling();
}
},
}
}
}
pub struct BoundaryAwareChunkingStrategy {
#[cfg_attr(not(feature = "async"), allow(dead_code))]
boundary_detector: crate::text::BoundaryDetector,
#[cfg_attr(not(feature = "async"), allow(dead_code))]
coherence_scorer: std::sync::Arc<crate::text::SemanticCoherenceScorer>,
max_chunk_chars: usize,
#[cfg_attr(not(feature = "async"), allow(dead_code))]
min_chunk_chars: usize,
document_id: DocumentId,
}
impl BoundaryAwareChunkingStrategy {
pub fn new(
boundary_config: crate::text::BoundaryDetectionConfig,
coherence_config: crate::text::CoherenceConfig,
embedding_provider: std::sync::Arc<dyn crate::embeddings::EmbeddingProvider>,
max_chunk_chars: usize,
min_chunk_chars: usize,
document_id: DocumentId,
) -> Self {
Self {
boundary_detector: crate::text::BoundaryDetector::with_config(boundary_config),
coherence_scorer: std::sync::Arc::new(crate::text::SemanticCoherenceScorer::new(
coherence_config,
embedding_provider,
)),
max_chunk_chars,
min_chunk_chars,
document_id,
}
}
pub fn with_defaults(
embedding_provider: std::sync::Arc<dyn crate::embeddings::EmbeddingProvider>,
document_id: DocumentId,
) -> Self {
Self::new(
crate::text::BoundaryDetectionConfig::default(),
crate::text::CoherenceConfig::default(),
embedding_provider,
2000, 200, document_id,
)
}
#[cfg(feature = "async")]
async fn chunk_async(&self, text: &str) -> Vec<TextChunk> {
let boundaries = self.boundary_detector.detect_boundaries(text);
let boundary_positions: Vec<usize> = boundaries
.iter()
.filter(|b| {
matches!(
b.boundary_type,
crate::text::BoundaryType::Paragraph
| crate::text::BoundaryType::Heading
| crate::text::BoundaryType::CodeBlock
)
})
.map(|b| b.position)
.collect();
let optimal_result = self
.coherence_scorer
.find_optimal_split(text, &boundary_positions)
.await;
let chunks = match optimal_result {
Ok(result) => {
self.create_text_chunks_from_scored(&result.chunks)
},
Err(_) => {
self.create_text_chunks_from_boundaries(text, &boundary_positions)
},
};
self.enforce_size_constraints(chunks)
}
#[cfg(feature = "async")]
fn create_text_chunks_from_scored(
&self,
scored_chunks: &[crate::text::ScoredChunk],
) -> Vec<TextChunk> {
scored_chunks
.iter()
.enumerate()
.map(|(i, sc)| {
let chunk_id = ChunkId::new(format!("{}_{}", self.document_id, i));
let mut chunk = TextChunk::new(
chunk_id,
self.document_id.clone(),
sc.text.clone(),
sc.start_pos,
sc.end_pos,
);
chunk.metadata.custom.insert(
"coherence_score".to_string(),
sc.coherence_score.to_string(),
);
chunk
.metadata
.custom
.insert("sentence_count".to_string(), sc.sentence_count.to_string());
chunk
})
.collect()
}
#[cfg(feature = "async")]
fn create_text_chunks_from_boundaries(
&self,
text: &str,
boundaries: &[usize],
) -> Vec<TextChunk> {
let mut chunks = Vec::new();
let mut prev_pos = 0;
for (i, &pos) in boundaries.iter().enumerate() {
if pos > prev_pos {
let chunk_id = ChunkId::new(format!("{}_{}", self.document_id, i));
let chunk = TextChunk::new(
chunk_id,
self.document_id.clone(),
text[prev_pos..pos].to_string(),
prev_pos,
pos,
);
chunks.push(chunk);
prev_pos = pos;
}
}
if prev_pos < text.len() {
let chunk_id = ChunkId::new(format!("{}_{}", self.document_id, chunks.len()));
let chunk = TextChunk::new(
chunk_id,
self.document_id.clone(),
text[prev_pos..].to_string(),
prev_pos,
text.len(),
);
chunks.push(chunk);
}
chunks
}
#[cfg(feature = "async")]
fn enforce_size_constraints(&self, mut chunks: Vec<TextChunk>) -> Vec<TextChunk> {
let mut result = Vec::new();
for chunk in chunks.drain(..) {
let chunk_len = chunk.content.len();
if chunk_len > self.max_chunk_chars {
result.extend(self.split_large_chunk(chunk));
} else if chunk_len < self.min_chunk_chars && !result.is_empty() {
if let Some(mut prev_chunk) = result.pop() {
prev_chunk.content.push(' ');
prev_chunk.content.push_str(&chunk.content);
prev_chunk.end_offset = chunk.end_offset;
result.push(prev_chunk);
} else {
result.push(chunk);
}
} else {
result.push(chunk);
}
}
result
}
#[cfg(feature = "async")]
fn split_large_chunk(&self, chunk: TextChunk) -> Vec<TextChunk> {
let sentences: Vec<&str> = chunk
.content
.split(&['.', '!', '?'][..])
.filter(|s| !s.trim().is_empty())
.collect();
let mut sub_chunks = Vec::new();
let mut current_text = String::new();
let mut current_start = chunk.start_offset;
for sentence in sentences {
if current_text.len() + sentence.len() > self.max_chunk_chars
&& !current_text.is_empty()
{
let chunk_id = ChunkId::new(format!(
"{}_{}",
self.document_id,
CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
));
let end = current_start + current_text.len();
sub_chunks.push(TextChunk::new(
chunk_id,
self.document_id.clone(),
current_text.clone(),
current_start,
end,
));
current_start = end;
current_text.clear();
}
current_text.push_str(sentence);
current_text.push('.');
}
if !current_text.is_empty() {
let chunk_id = ChunkId::new(format!(
"{}_{}",
self.document_id,
CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
));
sub_chunks.push(TextChunk::new(
chunk_id,
self.document_id.clone(),
current_text,
current_start,
chunk.end_offset,
));
}
sub_chunks
}
}
impl ChunkingStrategy for BoundaryAwareChunkingStrategy {
#[cfg(feature = "async")]
fn chunk(&self, text: &str) -> Vec<TextChunk> {
let runtime = tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime");
runtime.block_on(self.chunk_async(text))
}
#[cfg(not(feature = "async"))]
fn chunk(&self, text: &str) -> Vec<TextChunk> {
let sentences: Vec<&str> = text
.split(['.', '!', '?'])
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
let mut chunks = Vec::new();
let mut current = String::new();
let mut start_offset = 0;
for sentence in &sentences {
if current.len() + sentence.len() > self.max_chunk_chars && !current.is_empty() {
let chunk_id = ChunkId::new(format!(
"{}_{}",
self.document_id,
CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
));
let end = start_offset + current.len();
chunks.push(TextChunk::new(
chunk_id,
self.document_id.clone(),
current.clone(),
start_offset,
end,
));
start_offset = end;
current.clear();
}
if !current.is_empty() {
current.push(' ');
}
current.push_str(sentence);
}
if !current.is_empty() {
let chunk_id = ChunkId::new(format!(
"{}_{}",
self.document_id,
CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
));
let end = start_offset + current.len();
chunks.push(TextChunk::new(
chunk_id,
self.document_id.clone(),
current,
start_offset,
end,
));
}
chunks
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hierarchical_chunking_strategy() {
let document_id = DocumentId::new("test_doc".to_string());
let strategy = HierarchicalChunkingStrategy::new(100, 20, document_id);
let text = "This is paragraph one.\n\nThis is paragraph two with more content to test chunking behavior.";
let chunks = strategy.chunk(text);
assert!(!chunks.is_empty());
for chunk in &chunks {
assert!(!chunk.content.is_empty());
assert!(chunk.start_offset < chunk.end_offset);
}
}
#[test]
fn test_semantic_chunking_strategy() {
let _document_id = DocumentId::new("test_doc".to_string());
let _config = crate::text::semantic_chunking::SemanticChunkerConfig::default();
}
#[test]
#[cfg(feature = "code-chunking")]
fn test_rust_code_chunking_strategy() {
let document_id = DocumentId::new("rust_code".to_string());
let strategy = RustCodeChunkingStrategy::new(10, document_id);
let rust_code = r#"
fn main() {
println!("Hello, world!");
}
struct Point {
x: f64,
y: f64,
}
impl Point {
fn new(x: f64, y: f64) -> Self {
Point { x, y }
}
}
"#;
let chunks = strategy.chunk(rust_code);
assert!(!chunks.is_empty());
assert!(chunks.len() >= 2);
for chunk in &chunks {
assert!(!chunk.content.is_empty());
assert!(chunk.start_offset < chunk.end_offset);
}
}
}