graphrag_core/text/
chunking_strategies.rs1use crate::{
7 core::{ChunkId, DocumentId, TextChunk, ChunkingStrategy},
8 text::{HierarchicalChunker, SemanticChunker},
9};
10
11use std::sync::atomic::{AtomicU64, Ordering};
12
13static CHUNK_COUNTER: AtomicU64 = AtomicU64::new(0);
15
16pub struct HierarchicalChunkingStrategy {
21 inner: HierarchicalChunker,
22 chunk_size: usize,
23 overlap: usize,
24 document_id: DocumentId,
25}
26
27impl HierarchicalChunkingStrategy {
28 pub fn new(chunk_size: usize, overlap: usize, document_id: DocumentId) -> Self {
30 Self {
31 inner: HierarchicalChunker::new().with_min_size(50),
32 chunk_size,
33 overlap,
34 document_id,
35 }
36 }
37
38 pub fn with_min_size(mut self, min_size: usize) -> Self {
40 self.inner = self.inner.with_min_size(min_size);
41 self
42 }
43}
44
45impl ChunkingStrategy for HierarchicalChunkingStrategy {
46 fn chunk(&self, text: &str) -> Vec<TextChunk> {
47 let chunks_text = self.inner.chunk_text(text, self.chunk_size, self.overlap);
48 let mut chunks = Vec::new();
49 let mut current_pos = 0;
50
51 for chunk_content in chunks_text {
52 if !chunk_content.trim().is_empty() {
53 let chunk_id = ChunkId::new(format!("{}_{}", self.document_id,
54 CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)));
55 let chunk_start = current_pos;
56 let chunk_end = chunk_start + chunk_content.len();
57
58 let chunk = TextChunk::new(
59 chunk_id,
60 self.document_id.clone(),
61 chunk_content.clone(),
62 chunk_start,
63 chunk_end,
64 );
65 chunks.push(chunk);
66 current_pos = chunk_end;
67 } else {
68 current_pos += chunk_content.len();
69 }
70 }
71
72 chunks
73 }
74}
75
76pub struct SemanticChunkingStrategy {
81 inner: SemanticChunker,
82 document_id: DocumentId,
83}
84
85impl SemanticChunkingStrategy {
86 pub fn new(chunker: SemanticChunker, document_id: DocumentId) -> Self {
88 Self {
89 inner: chunker,
90 document_id,
91 }
92 }
93}
94
95impl ChunkingStrategy for SemanticChunkingStrategy {
96 fn chunk(&self, text: &str) -> Vec<TextChunk> {
97 let sentences: Vec<&str> = text.split(&['.', '!', '?'][..])
103 .filter(|s| !s.trim().is_empty())
104 .collect();
105
106 let mut chunks = Vec::new();
107 let mut current_pos = 0;
108
109 let chunk_size = 5; for chunk_sentences in sentences.chunks(chunk_size) {
112 let chunk_content = chunk_sentences.join(". ") + ".";
113 let chunk_id = ChunkId::new(format!("{}_{}", self.document_id,
114 CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)));
115 let chunk_start = current_pos;
116 let chunk_end = chunk_start + chunk_content.len();
117
118 let chunk = TextChunk::new(
119 chunk_id,
120 self.document_id.clone(),
121 chunk_content,
122 chunk_start,
123 chunk_end,
124 );
125 chunks.push(chunk);
126 current_pos = chunk_end;
127 }
128
129 chunks
130 }
131}
132
133#[cfg(feature = "code-chunking")]
138pub struct RustCodeChunkingStrategy {
139 min_chunk_size: usize,
140 document_id: DocumentId,
141}
142
143#[cfg(feature = "code-chunking")]
144impl RustCodeChunkingStrategy {
145 pub fn new(min_chunk_size: usize, document_id: DocumentId) -> Self {
147 Self {
148 min_chunk_size,
149 document_id,
150 }
151 }
152}
153
154#[cfg(feature = "code-chunking")]
155impl ChunkingStrategy for RustCodeChunkingStrategy {
156 fn chunk(&self, text: &str) -> Vec<TextChunk> {
157 use tree_sitter::Parser;
158
159 let mut parser = Parser::new();
160 let language = tree_sitter_rust::language();
161 parser.set_language(&language).expect("Error loading Rust grammar");
162
163 let tree = parser.parse(text, None).expect("Error parsing Rust code");
164 let root_node = tree.root_node();
165
166 let mut chunks = Vec::new();
167
168 self.extract_chunks(&root_node, text, &mut chunks);
170
171 if chunks.is_empty() && !text.trim().is_empty() {
173 let chunk_id = ChunkId::new(format!("{}_{}", self.document_id,
174 CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)));
175 let chunk = TextChunk::new(
176 chunk_id,
177 self.document_id.clone(),
178 text.to_string(),
179 0,
180 text.len(),
181 );
182 chunks.push(chunk);
183 }
184
185 chunks
186 }
187}
188
189#[cfg(feature = "code-chunking")]
190impl RustCodeChunkingStrategy {
191 fn extract_chunks(&self, node: &tree_sitter::Node, source: &str, chunks: &mut Vec<TextChunk>) {
193 match node.kind() {
194 "function_item" | "impl_item" | "struct_item" | "enum_item" | "mod_item" | "trait_item" => {
196 let start_byte = node.start_byte();
197 let end_byte = node.end_byte();
198
199 let start_pos = source.len() - source[start_byte..].len();
201 let end_pos = source.len() - source[end_byte..].len();
202
203 let chunk_content = &source[start_pos..end_pos];
204
205 if chunk_content.len() >= self.min_chunk_size {
206 let chunk_id = ChunkId::new(format!("{}_{}", self.document_id,
207 CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)));
208
209 let chunk = TextChunk::new(
210 chunk_id,
211 self.document_id.clone(),
212 chunk_content.to_string(),
213 start_pos,
214 end_pos,
215 );
216 chunks.push(chunk);
217 }
218 }
219
220 "source_file" => {
222 let mut child = node.child(0);
223 while let Some(current) = child {
224 self.extract_chunks(¤t, source, chunks);
225 child = current.next_sibling();
226 }
227 }
228
229 _ => {
231 let mut child = node.child(0);
232 while let Some(current) = child {
233 self.extract_chunks(¤t, source, chunks);
234 child = current.next_sibling();
235 }
236 }
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_hierarchical_chunking_strategy() {
247 let document_id = DocumentId::new("test_doc".to_string());
248 let strategy = HierarchicalChunkingStrategy::new(100, 20, document_id);
249
250 let text = "This is paragraph one.\n\nThis is paragraph two with more content to test chunking behavior.";
251 let chunks = strategy.chunk(text);
252
253 assert!(!chunks.is_empty());
254 for chunk in &chunks {
255 assert!(!chunk.content.is_empty());
256 assert!(chunk.start_offset < chunk.end_offset);
257 }
258 }
259
260 #[test]
261 fn test_semantic_chunking_strategy() {
262 let document_id = DocumentId::new("test_doc".to_string());
263 let config = crate::text::semantic_chunking::SemanticChunkerConfig::default();
266 }
279
280 #[test]
281 #[cfg(feature = "code-chunking")]
282 fn test_rust_code_chunking_strategy() {
283 let document_id = DocumentId::new("rust_code".to_string());
284 let strategy = RustCodeChunkingStrategy::new(10, document_id);
285
286 let rust_code = r#"
287fn main() {
288 println!("Hello, world!");
289}
290
291struct Point {
292 x: f64,
293 y: f64,
294}
295
296impl Point {
297 fn new(x: f64, y: f64) -> Self {
298 Point { x, y }
299 }
300}
301"#;
302
303 let chunks = strategy.chunk(rust_code);
304
305 assert!(!chunks.is_empty());
306 assert!(chunks.len() >= 2);
308
309 for chunk in &chunks {
310 assert!(!chunk.content.is_empty());
311 assert!(chunk.start_offset < chunk.end_offset);
312 }
313 }
314}