rlm_rs/chunking/
parallel.rs1use crate::chunking::traits::{ChunkMetadata, Chunker};
7use crate::core::Chunk;
8use crate::error::Result;
9use rayon::prelude::*;
10
11#[derive(Debug, Clone)]
26pub struct ParallelChunker<C: Chunker + Clone> {
27 inner: C,
29 min_parallel_size: usize,
31 num_segments: usize,
33}
34
35impl<C: Chunker + Clone> ParallelChunker<C> {
36 #[must_use]
42 pub fn new(inner: C) -> Self {
43 Self {
44 inner,
45 min_parallel_size: 100_000, num_segments: num_cpus::get().max(2),
47 }
48 }
49
50 #[must_use]
54 pub const fn min_parallel_size(mut self, size: usize) -> Self {
55 self.min_parallel_size = size;
56 self
57 }
58
59 #[must_use]
61 pub fn num_segments(mut self, n: usize) -> Self {
62 self.num_segments = n.max(1);
63 self
64 }
65
66 fn split_into_segments<'a>(&self, text: &'a str, n: usize) -> Vec<(usize, &'a str)> {
68 if n <= 1 || text.len() < self.min_parallel_size {
69 return vec![(0, text)];
70 }
71
72 let segment_size = text.len() / n;
73 let mut segments = Vec::with_capacity(n);
74 let mut start = 0;
75
76 for i in 0..n {
77 let target_end = if i == n - 1 {
78 text.len()
79 } else {
80 start + segment_size
81 };
82
83 let end = Self::find_segment_boundary(text, target_end);
84 let end = end.max(start + 1).min(text.len());
85
86 if start < text.len() {
87 segments.push((start, &text[start..end]));
88 }
89
90 start = end;
91 if start >= text.len() {
92 break;
93 }
94 }
95
96 segments
97 }
98
99 fn find_segment_boundary(text: &str, target: usize) -> usize {
101 if target >= text.len() {
102 return text.len();
103 }
104
105 let search_start = target.saturating_sub(1000);
107 let search_region = &text[search_start..target.min(text.len())];
108
109 if let Some(pos) = search_region.rfind("\n\n") {
110 return search_start + pos + 2;
111 }
112
113 if let Some(pos) = search_region.rfind('\n') {
115 return search_start + pos + 1;
116 }
117
118 if let Some(pos) = search_region.rfind(' ') {
120 return search_start + pos + 1;
121 }
122
123 let mut pos = target;
125 while !text.is_char_boundary(pos) && pos > 0 {
126 pos -= 1;
127 }
128 pos
129 }
130
131 fn merge_chunks(segment_chunks: Vec<Vec<Chunk>>, buffer_id: i64) -> Vec<Chunk> {
133 let mut all_chunks: Vec<Chunk> = Vec::new();
134 let mut index = 0;
135
136 for chunks in segment_chunks {
137 for mut chunk in chunks {
138 chunk.index = index;
139 chunk.buffer_id = buffer_id;
140 all_chunks.push(chunk);
141 index += 1;
142 }
143 }
144
145 all_chunks
146 }
147}
148
149impl<C: Chunker + Clone + Send + Sync> Chunker for ParallelChunker<C> {
150 fn chunk(
151 &self,
152 buffer_id: i64,
153 text: &str,
154 metadata: Option<&ChunkMetadata>,
155 ) -> Result<Vec<Chunk>> {
156 if text.len() < self.min_parallel_size {
158 return self.inner.chunk(buffer_id, text, metadata);
159 }
160
161 let segments = self.split_into_segments(text, self.num_segments);
163
164 if segments.len() <= 1 {
165 return self.inner.chunk(buffer_id, text, metadata);
166 }
167
168 let results: Vec<Result<Vec<Chunk>>> = segments
170 .par_iter()
171 .map(|(offset, segment)| {
172 let mut chunks = self.inner.chunk(buffer_id, segment, metadata)?;
173
174 for chunk in &mut chunks {
176 chunk.byte_range =
177 (chunk.byte_range.start + offset)..(chunk.byte_range.end + offset);
178 }
179
180 Ok(chunks)
181 })
182 .collect();
183
184 let mut all_segment_chunks = Vec::with_capacity(results.len());
186 for result in results {
187 all_segment_chunks.push(result?);
188 }
189
190 Ok(Self::merge_chunks(all_segment_chunks, buffer_id))
192 }
193
194 fn name(&self) -> &'static str {
195 "parallel"
196 }
197
198 fn supports_parallel(&self) -> bool {
199 true
200 }
201
202 fn description(&self) -> &'static str {
203 "Parallel chunking using rayon for multi-threaded processing"
204 }
205}
206
207mod num_cpus {
209 pub fn get() -> usize {
210 std::thread::available_parallelism()
211 .map(std::num::NonZeroUsize::get)
212 .unwrap_or(4)
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::chunking::SemanticChunker;
220
221 #[test]
222 fn test_parallel_chunker_small_text() {
223 let chunker = ParallelChunker::new(SemanticChunker::with_size(50));
224 let text = "Hello, world!";
225 let chunks = chunker.chunk(1, text, None).unwrap();
226 assert_eq!(chunks.len(), 1);
227 assert_eq!(chunks[0].content, text);
228 }
229
230 #[test]
231 fn test_parallel_chunker_large_text() {
232 let chunker = ParallelChunker::new(SemanticChunker::with_size(1000))
233 .min_parallel_size(1000)
234 .num_segments(4);
235
236 let text = "Hello, world! This is a test sentence. ".repeat(500);
238
239 let chunks = chunker.chunk(1, &text, None).unwrap();
240
241 for chunk in &chunks {
243 assert!(!chunk.content.is_empty());
244 assert_eq!(&text[chunk.byte_range.clone()], chunk.content);
245 }
246
247 for (i, chunk) in chunks.iter().enumerate() {
249 assert_eq!(chunk.index, i);
250 }
251 }
252
253 #[test]
254 fn test_parallel_chunker_preserves_content() {
255 let chunker = ParallelChunker::new(SemanticChunker::with_size(500))
256 .min_parallel_size(500)
257 .num_segments(2);
258
259 let text = "Paragraph one. Sentence two.\n\nParagraph two. More text here.\n\n".repeat(50);
260
261 let chunks = chunker.chunk(1, &text, None).unwrap();
262
263 let mut reconstructed = String::new();
265 let mut last_end = 0;
266
267 for chunk in &chunks {
268 use std::cmp::Ordering;
269 match chunk.byte_range.start.cmp(&last_end) {
270 Ordering::Greater => {
271 }
273 Ordering::Less => {
274 let skip = last_end - chunk.byte_range.start;
276 if skip < chunk.content.len() {
277 reconstructed.push_str(&chunk.content[skip..]);
278 }
279 }
280 Ordering::Equal => {
281 reconstructed.push_str(&chunk.content);
282 }
283 }
284 last_end = chunk.byte_range.end;
285 }
286
287 assert!(!chunks.is_empty());
289 assert!(!reconstructed.is_empty());
290 }
291
292 #[test]
293 fn test_parallel_chunker_strategy_name() {
294 let chunker = ParallelChunker::new(SemanticChunker::new());
295 assert_eq!(chunker.name(), "parallel");
296 assert!(chunker.supports_parallel());
297 }
298
299 #[test]
300 fn test_split_into_segments() {
301 let chunker = ParallelChunker::new(SemanticChunker::new())
302 .min_parallel_size(10)
303 .num_segments(3);
304
305 let text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph.";
306 let segments = chunker.split_into_segments(text, 3);
307
308 assert!(!segments.is_empty());
310
311 for (_, segment) in &segments {
313 assert!(!segment.is_empty());
314 }
315 }
316
317 #[test]
318 fn test_parallel_chunker_empty_text() {
319 let chunker = ParallelChunker::new(SemanticChunker::new());
320 let chunks = chunker.chunk(1, "", None).unwrap();
321 assert!(chunks.is_empty());
322 }
323
324 #[test]
325 fn test_split_into_segments_single_segment() {
326 let chunker = ParallelChunker::new(SemanticChunker::new())
328 .min_parallel_size(10)
329 .num_segments(1);
330
331 let text = "This is some test content";
332 let segments = chunker.split_into_segments(text, 1);
333 assert_eq!(segments.len(), 1);
334 assert_eq!(segments[0].1, text);
335 }
336
337 #[test]
338 fn test_split_into_segments_text_too_small() {
339 let chunker = ParallelChunker::new(SemanticChunker::new())
341 .min_parallel_size(1000)
342 .num_segments(4);
343
344 let text = "Short text";
345 let segments = chunker.split_into_segments(text, 4);
346 assert_eq!(segments.len(), 1);
347 assert_eq!(segments[0].1, text);
348 }
349
350 #[test]
351 fn test_parallel_chunker_segments_collapse_to_one() {
352 let chunker = ParallelChunker::new(SemanticChunker::with_size(100))
354 .min_parallel_size(10)
355 .num_segments(10);
356
357 let text = "A short text that won't split well.";
359 let chunks = chunker.chunk(1, text, None).unwrap();
360 assert!(!chunks.is_empty());
361 }
362
363 #[test]
364 fn test_parallel_chunker_description() {
365 let chunker = ParallelChunker::new(SemanticChunker::new());
367 let desc = chunker.description();
368 assert!(desc.contains("Parallel"));
369 assert!(!desc.is_empty());
370 }
371
372 #[test]
373 fn test_find_segment_boundary_no_good_boundary() {
374 let text = "AAAAAAAAAAAAAAAAAAAA"; let boundary = ParallelChunker::<SemanticChunker>::find_segment_boundary(text, 10);
377 assert!(boundary <= text.len());
379 }
380
381 #[test]
382 fn test_find_segment_boundary_at_end() {
383 let text = "Short";
385 let boundary = ParallelChunker::<SemanticChunker>::find_segment_boundary(text, 100);
386 assert_eq!(boundary, text.len());
387 }
388
389 #[test]
390 fn test_find_segment_boundary_finds_space() {
391 let text = "word1 word2 word3 word4";
393 let boundary = ParallelChunker::<SemanticChunker>::find_segment_boundary(text, 15);
394 assert!(boundary > 0 && boundary <= text.len());
396 }
397}