1use crate::{
7 core::{ChunkId, ChunkingStrategy, DocumentId, TextChunk},
8 text::{HierarchicalChunker, SemanticChunker},
9};
10
11use std::sync::atomic::{AtomicU64, Ordering};
12
13static CHUNK_COUNTER: AtomicU64 = AtomicU64::new(0);
15
16pub struct HierarchicalChunkingStrategy {
21 inner: HierarchicalChunker,
22 chunk_size: usize,
23 overlap: usize,
24 document_id: DocumentId,
25}
26
27impl HierarchicalChunkingStrategy {
28 pub fn new(chunk_size: usize, overlap: usize, document_id: DocumentId) -> Self {
30 Self {
31 inner: HierarchicalChunker::new().with_min_size(50),
32 chunk_size,
33 overlap,
34 document_id,
35 }
36 }
37
38 pub fn with_min_size(mut self, min_size: usize) -> Self {
40 self.inner = self.inner.with_min_size(min_size);
41 self
42 }
43}
44
45impl ChunkingStrategy for HierarchicalChunkingStrategy {
46 fn chunk(&self, text: &str) -> Vec<TextChunk> {
47 let chunks_text = self.inner.chunk_text(text, self.chunk_size, self.overlap);
48 let mut chunks = Vec::new();
49 let mut current_pos = 0;
50
51 for chunk_content in chunks_text {
52 if !chunk_content.trim().is_empty() {
53 let chunk_id = ChunkId::new(format!(
54 "{}_{}",
55 self.document_id,
56 CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
57 ));
58 let chunk_start = current_pos;
59 let chunk_end = chunk_start + chunk_content.len();
60
61 let chunk = TextChunk::new(
62 chunk_id,
63 self.document_id.clone(),
64 chunk_content.clone(),
65 chunk_start,
66 chunk_end,
67 );
68 chunks.push(chunk);
69 current_pos = chunk_end;
70 } else {
71 current_pos += chunk_content.len();
72 }
73 }
74
75 chunks
76 }
77}
78
79pub struct SemanticChunkingStrategy {
84 _inner: SemanticChunker,
85 document_id: DocumentId,
86}
87
88impl SemanticChunkingStrategy {
89 pub fn new(chunker: SemanticChunker, document_id: DocumentId) -> Self {
91 Self {
92 _inner: chunker,
93 document_id,
94 }
95 }
96}
97
98impl ChunkingStrategy for SemanticChunkingStrategy {
99 fn chunk(&self, text: &str) -> Vec<TextChunk> {
100 let sentences: Vec<&str> = text
106 .split(&['.', '!', '?'][..])
107 .filter(|s| !s.trim().is_empty())
108 .collect();
109
110 let mut chunks = Vec::new();
111 let mut current_pos = 0;
112
113 let chunk_size = 5; for chunk_sentences in sentences.chunks(chunk_size) {
116 let chunk_content = chunk_sentences.join(". ") + ".";
117 let chunk_id = ChunkId::new(format!(
118 "{}_{}",
119 self.document_id,
120 CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
121 ));
122 let chunk_start = current_pos;
123 let chunk_end = chunk_start + chunk_content.len();
124
125 let chunk = TextChunk::new(
126 chunk_id,
127 self.document_id.clone(),
128 chunk_content,
129 chunk_start,
130 chunk_end,
131 );
132 chunks.push(chunk);
133 current_pos = chunk_end;
134 }
135
136 chunks
137 }
138}
139
140#[cfg(feature = "code-chunking")]
145pub struct RustCodeChunkingStrategy {
146 min_chunk_size: usize,
147 document_id: DocumentId,
148}
149
150#[cfg(feature = "code-chunking")]
151impl RustCodeChunkingStrategy {
152 pub fn new(min_chunk_size: usize, document_id: DocumentId) -> Self {
154 Self {
155 min_chunk_size,
156 document_id,
157 }
158 }
159}
160
161#[cfg(feature = "code-chunking")]
162impl ChunkingStrategy for RustCodeChunkingStrategy {
163 fn chunk(&self, text: &str) -> Vec<TextChunk> {
164 use tree_sitter::Parser;
165
166 let mut parser = Parser::new();
167 let language = tree_sitter_rust::language();
168 parser
169 .set_language(&language)
170 .expect("Error loading Rust grammar");
171
172 let tree = parser.parse(text, None).expect("Error parsing Rust code");
173 let root_node = tree.root_node();
174
175 let mut chunks = Vec::new();
176
177 self.extract_chunks(&root_node, text, &mut chunks);
179
180 if chunks.is_empty() && !text.trim().is_empty() {
182 let chunk_id = ChunkId::new(format!(
183 "{}_{}",
184 self.document_id,
185 CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
186 ));
187 let chunk = TextChunk::new(
188 chunk_id,
189 self.document_id.clone(),
190 text.to_string(),
191 0,
192 text.len(),
193 );
194 chunks.push(chunk);
195 }
196
197 chunks
198 }
199}
200
201#[cfg(feature = "code-chunking")]
202impl RustCodeChunkingStrategy {
203 fn extract_chunks(&self, node: &tree_sitter::Node, source: &str, chunks: &mut Vec<TextChunk>) {
205 match node.kind() {
206 "function_item" | "impl_item" | "struct_item" | "enum_item" | "mod_item"
208 | "trait_item" => {
209 let start_byte = node.start_byte();
210 let end_byte = node.end_byte();
211
212 let start_pos = source.len() - source[start_byte..].len();
214 let end_pos = source.len() - source[end_byte..].len();
215
216 let chunk_content = &source[start_pos..end_pos];
217
218 if chunk_content.len() >= self.min_chunk_size {
219 let chunk_id = ChunkId::new(format!(
220 "{}_{}",
221 self.document_id,
222 CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
223 ));
224
225 let chunk = TextChunk::new(
226 chunk_id,
227 self.document_id.clone(),
228 chunk_content.to_string(),
229 start_pos,
230 end_pos,
231 );
232 chunks.push(chunk);
233 }
234 },
235
236 "source_file" => {
238 let mut child = node.child(0);
239 while let Some(current) = child {
240 self.extract_chunks(¤t, source, chunks);
241 child = current.next_sibling();
242 }
243 },
244
245 _ => {
247 let mut child = node.child(0);
248 while let Some(current) = child {
249 self.extract_chunks(¤t, source, chunks);
250 child = current.next_sibling();
251 }
252 },
253 }
254 }
255}
256
257pub struct BoundaryAwareChunkingStrategy {
267 boundary_detector: crate::text::BoundaryDetector,
268 coherence_scorer: std::sync::Arc<crate::text::SemanticCoherenceScorer>,
269 max_chunk_chars: usize,
270 min_chunk_chars: usize,
271 document_id: DocumentId,
272}
273
274impl BoundaryAwareChunkingStrategy {
275 pub fn new(
285 boundary_config: crate::text::BoundaryDetectionConfig,
286 coherence_config: crate::text::CoherenceConfig,
287 embedding_provider: std::sync::Arc<dyn crate::embeddings::EmbeddingProvider>,
288 max_chunk_chars: usize,
289 min_chunk_chars: usize,
290 document_id: DocumentId,
291 ) -> Self {
292 Self {
293 boundary_detector: crate::text::BoundaryDetector::with_config(boundary_config),
294 coherence_scorer: std::sync::Arc::new(crate::text::SemanticCoherenceScorer::new(
295 coherence_config,
296 embedding_provider,
297 )),
298 max_chunk_chars,
299 min_chunk_chars,
300 document_id,
301 }
302 }
303
304 pub fn with_defaults(
306 embedding_provider: std::sync::Arc<dyn crate::embeddings::EmbeddingProvider>,
307 document_id: DocumentId,
308 ) -> Self {
309 Self::new(
310 crate::text::BoundaryDetectionConfig::default(),
311 crate::text::CoherenceConfig::default(),
312 embedding_provider,
313 2000, 200, document_id,
316 )
317 }
318
319 async fn chunk_async(&self, text: &str) -> Vec<TextChunk> {
321 let boundaries = self.boundary_detector.detect_boundaries(text);
323
324 let boundary_positions: Vec<usize> = boundaries
326 .iter()
327 .filter(|b| {
328 matches!(
330 b.boundary_type,
331 crate::text::BoundaryType::Paragraph
332 | crate::text::BoundaryType::Heading
333 | crate::text::BoundaryType::CodeBlock
334 )
335 })
336 .map(|b| b.position)
337 .collect();
338
339 let optimal_result = self
341 .coherence_scorer
342 .find_optimal_split(text, &boundary_positions)
343 .await;
344
345 let chunks = match optimal_result {
346 Ok(result) => {
347 self.create_text_chunks_from_scored(&result.chunks)
349 },
350 Err(_) => {
351 self.create_text_chunks_from_boundaries(text, &boundary_positions)
353 },
354 };
355
356 self.enforce_size_constraints(chunks)
358 }
359
360 fn create_text_chunks_from_scored(
362 &self,
363 scored_chunks: &[crate::text::ScoredChunk],
364 ) -> Vec<TextChunk> {
365 scored_chunks
366 .iter()
367 .enumerate()
368 .map(|(i, sc)| {
369 let chunk_id = ChunkId::new(format!("{}_{}", self.document_id, i));
370 let mut chunk = TextChunk::new(
371 chunk_id,
372 self.document_id.clone(),
373 sc.text.clone(),
374 sc.start_pos,
375 sc.end_pos,
376 );
377
378 chunk.metadata.custom.insert(
380 "coherence_score".to_string(),
381 sc.coherence_score.to_string(),
382 );
383 chunk
384 .metadata
385 .custom
386 .insert("sentence_count".to_string(), sc.sentence_count.to_string());
387
388 chunk
389 })
390 .collect()
391 }
392
393 fn create_text_chunks_from_boundaries(
395 &self,
396 text: &str,
397 boundaries: &[usize],
398 ) -> Vec<TextChunk> {
399 let mut chunks = Vec::new();
400 let mut prev_pos = 0;
401
402 for (i, &pos) in boundaries.iter().enumerate() {
403 if pos > prev_pos {
404 let chunk_id = ChunkId::new(format!("{}_{}", self.document_id, i));
405 let chunk = TextChunk::new(
406 chunk_id,
407 self.document_id.clone(),
408 text[prev_pos..pos].to_string(),
409 prev_pos,
410 pos,
411 );
412 chunks.push(chunk);
413 prev_pos = pos;
414 }
415 }
416
417 if prev_pos < text.len() {
419 let chunk_id = ChunkId::new(format!("{}_{}", self.document_id, chunks.len()));
420 let chunk = TextChunk::new(
421 chunk_id,
422 self.document_id.clone(),
423 text[prev_pos..].to_string(),
424 prev_pos,
425 text.len(),
426 );
427 chunks.push(chunk);
428 }
429
430 chunks
431 }
432
433 fn enforce_size_constraints(&self, mut chunks: Vec<TextChunk>) -> Vec<TextChunk> {
435 let mut result = Vec::new();
436
437 for chunk in chunks.drain(..) {
438 let chunk_len = chunk.content.len();
439
440 if chunk_len > self.max_chunk_chars {
441 result.extend(self.split_large_chunk(chunk));
443 } else if chunk_len < self.min_chunk_chars && !result.is_empty() {
444 if let Some(mut prev_chunk) = result.pop() {
446 prev_chunk.content.push(' ');
447 prev_chunk.content.push_str(&chunk.content);
448 prev_chunk.end_offset = chunk.end_offset;
449 result.push(prev_chunk);
450 } else {
451 result.push(chunk);
452 }
453 } else {
454 result.push(chunk);
455 }
456 }
457
458 result
459 }
460
461 fn split_large_chunk(&self, chunk: TextChunk) -> Vec<TextChunk> {
463 let sentences: Vec<&str> = chunk
465 .content
466 .split(&['.', '!', '?'][..])
467 .filter(|s| !s.trim().is_empty())
468 .collect();
469
470 let mut sub_chunks = Vec::new();
471 let mut current_text = String::new();
472 let mut current_start = chunk.start_offset;
473
474 for sentence in sentences {
475 if current_text.len() + sentence.len() > self.max_chunk_chars
476 && !current_text.is_empty()
477 {
478 let chunk_id = ChunkId::new(format!(
480 "{}_{}",
481 self.document_id,
482 CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
483 ));
484 let end = current_start + current_text.len();
485 sub_chunks.push(TextChunk::new(
486 chunk_id,
487 self.document_id.clone(),
488 current_text.clone(),
489 current_start,
490 end,
491 ));
492
493 current_start = end;
494 current_text.clear();
495 }
496
497 current_text.push_str(sentence);
498 current_text.push('.');
499 }
500
501 if !current_text.is_empty() {
503 let chunk_id = ChunkId::new(format!(
504 "{}_{}",
505 self.document_id,
506 CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
507 ));
508 sub_chunks.push(TextChunk::new(
509 chunk_id,
510 self.document_id.clone(),
511 current_text,
512 current_start,
513 chunk.end_offset,
514 ));
515 }
516
517 sub_chunks
518 }
519}
520
521impl ChunkingStrategy for BoundaryAwareChunkingStrategy {
522 fn chunk(&self, text: &str) -> Vec<TextChunk> {
523 let runtime = tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime");
526
527 runtime.block_on(self.chunk_async(text))
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534
535 #[test]
536 fn test_hierarchical_chunking_strategy() {
537 let document_id = DocumentId::new("test_doc".to_string());
538 let strategy = HierarchicalChunkingStrategy::new(100, 20, document_id);
539
540 let text = "This is paragraph one.\n\nThis is paragraph two with more content to test chunking behavior.";
541 let chunks = strategy.chunk(text);
542
543 assert!(!chunks.is_empty());
544 for chunk in &chunks {
545 assert!(!chunk.content.is_empty());
546 assert!(chunk.start_offset < chunk.end_offset);
547 }
548 }
549
550 #[test]
551 fn test_semantic_chunking_strategy() {
552 let document_id = DocumentId::new("test_doc".to_string());
553 let config = crate::text::semantic_chunking::SemanticChunkerConfig::default();
556 }
569
570 #[test]
571 #[cfg(feature = "code-chunking")]
572 fn test_rust_code_chunking_strategy() {
573 let document_id = DocumentId::new("rust_code".to_string());
574 let strategy = RustCodeChunkingStrategy::new(10, document_id);
575
576 let rust_code = r#"
577fn main() {
578 println!("Hello, world!");
579}
580
581struct Point {
582 x: f64,
583 y: f64,
584}
585
586impl Point {
587 fn new(x: f64, y: f64) -> Self {
588 Point { x, y }
589 }
590}
591"#;
592
593 let chunks = strategy.chunk(rust_code);
594
595 assert!(!chunks.is_empty());
596 assert!(chunks.len() >= 2);
598
599 for chunk in &chunks {
600 assert!(!chunk.content.is_empty());
601 assert!(chunk.start_offset < chunk.end_offset);
602 }
603 }
604}