rlm_rs/chunking/
parallel.rs1use crate::chunking::traits::{ChunkMetadata, Chunker};
7use crate::core::Chunk;
8use crate::error::Result;
9use rayon::prelude::*;
10
11#[derive(Debug, Clone)]
26pub struct ParallelChunker<C: Chunker + Clone> {
27 inner: C,
29 min_parallel_size: usize,
31 num_segments: usize,
33}
34
35impl<C: Chunker + Clone> ParallelChunker<C> {
36 #[must_use]
42 pub fn new(inner: C) -> Self {
43 Self {
44 inner,
45 min_parallel_size: 100_000, num_segments: num_cpus::get().max(2),
47 }
48 }
49
50 #[must_use]
54 pub const fn min_parallel_size(mut self, size: usize) -> Self {
55 self.min_parallel_size = size;
56 self
57 }
58
59 #[must_use]
61 pub fn num_segments(mut self, n: usize) -> Self {
62 self.num_segments = n.max(1);
63 self
64 }
65
66 fn split_into_segments<'a>(&self, text: &'a str, n: usize) -> Vec<(usize, &'a str)> {
68 if n <= 1 || text.len() < self.min_parallel_size {
69 return vec![(0, text)];
70 }
71
72 let segment_size = text.len() / n;
73 let mut segments = Vec::with_capacity(n);
74 let mut start = 0;
75
76 for i in 0..n {
77 let target_end = if i == n - 1 {
78 text.len()
79 } else {
80 start + segment_size
81 };
82
83 let end = Self::find_segment_boundary(text, target_end);
84 let end = end.max(start + 1).min(text.len());
85
86 if start < text.len() {
87 segments.push((start, &text[start..end]));
88 }
89
90 start = end;
91 if start >= text.len() {
92 break;
93 }
94 }
95
96 segments
97 }
98
99 fn find_segment_boundary(text: &str, target: usize) -> usize {
101 if target >= text.len() {
102 return text.len();
103 }
104
105 let search_start = target.saturating_sub(1000);
107 let search_region = &text[search_start..target.min(text.len())];
108
109 if let Some(pos) = search_region.rfind("\n\n") {
110 return search_start + pos + 2;
111 }
112
113 if let Some(pos) = search_region.rfind('\n') {
115 return search_start + pos + 1;
116 }
117
118 if let Some(pos) = search_region.rfind(' ') {
120 return search_start + pos + 1;
121 }
122
123 let mut pos = target;
125 while !text.is_char_boundary(pos) && pos > 0 {
126 pos -= 1;
127 }
128 pos
129 }
130
131 fn merge_chunks(segment_chunks: Vec<Vec<Chunk>>, buffer_id: i64) -> Vec<Chunk> {
133 let mut all_chunks: Vec<Chunk> = Vec::new();
134 let mut index = 0;
135
136 for chunks in segment_chunks {
137 for mut chunk in chunks {
138 chunk.index = index;
139 chunk.buffer_id = buffer_id;
140 all_chunks.push(chunk);
141 index += 1;
142 }
143 }
144
145 all_chunks
146 }
147}
148
149impl<C: Chunker + Clone + Send + Sync> Chunker for ParallelChunker<C> {
150 fn chunk(
151 &self,
152 buffer_id: i64,
153 text: &str,
154 metadata: Option<&ChunkMetadata>,
155 ) -> Result<Vec<Chunk>> {
156 if text.len() < self.min_parallel_size {
158 return self.inner.chunk(buffer_id, text, metadata);
159 }
160
161 let segments = self.split_into_segments(text, self.num_segments);
163
164 if segments.len() <= 1 {
165 return self.inner.chunk(buffer_id, text, metadata);
166 }
167
168 let results: Vec<Result<Vec<Chunk>>> = segments
170 .par_iter()
171 .map(|(offset, segment)| {
172 let mut chunks = self.inner.chunk(buffer_id, segment, metadata)?;
173
174 for chunk in &mut chunks {
176 chunk.byte_range =
177 (chunk.byte_range.start + offset)..(chunk.byte_range.end + offset);
178 }
179
180 Ok(chunks)
181 })
182 .collect();
183
184 let mut all_segment_chunks = Vec::with_capacity(results.len());
186 for result in results {
187 all_segment_chunks.push(result?);
188 }
189
190 Ok(Self::merge_chunks(all_segment_chunks, buffer_id))
192 }
193
194 fn name(&self) -> &'static str {
195 "parallel"
196 }
197
198 fn supports_parallel(&self) -> bool {
199 true
200 }
201
202 fn description(&self) -> &'static str {
203 "Parallel chunking using rayon for multi-threaded processing"
204 }
205}
206
207mod num_cpus {
209 pub fn get() -> usize {
210 std::thread::available_parallelism().map_or(4, std::num::NonZeroUsize::get)
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use crate::chunking::SemanticChunker;
218
219 #[test]
220 fn test_parallel_chunker_small_text() {
221 let chunker = ParallelChunker::new(SemanticChunker::with_size(50));
222 let text = "Hello, world!";
223 let chunks = chunker.chunk(1, text, None).unwrap();
224 assert_eq!(chunks.len(), 1);
225 assert_eq!(chunks[0].content, text);
226 }
227
228 #[test]
229 fn test_parallel_chunker_large_text() {
230 let chunker = ParallelChunker::new(SemanticChunker::with_size(1000))
231 .min_parallel_size(1000)
232 .num_segments(4);
233
234 let text = "Hello, world! This is a test sentence. ".repeat(500);
236
237 let chunks = chunker.chunk(1, &text, None).unwrap();
238
239 for chunk in &chunks {
241 assert!(!chunk.content.is_empty());
242 assert_eq!(&text[chunk.byte_range.clone()], chunk.content);
243 }
244
245 for (i, chunk) in chunks.iter().enumerate() {
247 assert_eq!(chunk.index, i);
248 }
249 }
250
251 #[test]
252 fn test_parallel_chunker_preserves_content() {
253 let chunker = ParallelChunker::new(SemanticChunker::with_size(500))
254 .min_parallel_size(500)
255 .num_segments(2);
256
257 let text = "Paragraph one. Sentence two.\n\nParagraph two. More text here.\n\n".repeat(50);
258
259 let chunks = chunker.chunk(1, &text, None).unwrap();
260
261 let mut reconstructed = String::new();
263 let mut last_end = 0;
264
265 for chunk in &chunks {
266 use std::cmp::Ordering;
267 match chunk.byte_range.start.cmp(&last_end) {
268 Ordering::Greater => {
269 }
271 Ordering::Less => {
272 let skip = last_end - chunk.byte_range.start;
274 if skip < chunk.content.len() {
275 reconstructed.push_str(&chunk.content[skip..]);
276 }
277 }
278 Ordering::Equal => {
279 reconstructed.push_str(&chunk.content);
280 }
281 }
282 last_end = chunk.byte_range.end;
283 }
284
285 assert!(!chunks.is_empty());
287 assert!(!reconstructed.is_empty());
288 }
289
290 #[test]
291 fn test_parallel_chunker_strategy_name() {
292 let chunker = ParallelChunker::new(SemanticChunker::new());
293 assert_eq!(chunker.name(), "parallel");
294 assert!(chunker.supports_parallel());
295 }
296
297 #[test]
298 fn test_split_into_segments() {
299 let chunker = ParallelChunker::new(SemanticChunker::new())
300 .min_parallel_size(10)
301 .num_segments(3);
302
303 let text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph.";
304 let segments = chunker.split_into_segments(text, 3);
305
306 assert!(!segments.is_empty());
308
309 for (_, segment) in &segments {
311 assert!(!segment.is_empty());
312 }
313 }
314
315 #[test]
316 fn test_parallel_chunker_empty_text() {
317 let chunker = ParallelChunker::new(SemanticChunker::new());
318 let chunks = chunker.chunk(1, "", None).unwrap();
319 assert!(chunks.is_empty());
320 }
321
322 #[test]
323 fn test_split_into_segments_single_segment() {
324 let chunker = ParallelChunker::new(SemanticChunker::new())
326 .min_parallel_size(10)
327 .num_segments(1);
328
329 let text = "This is some test content";
330 let segments = chunker.split_into_segments(text, 1);
331 assert_eq!(segments.len(), 1);
332 assert_eq!(segments[0].1, text);
333 }
334
335 #[test]
336 fn test_split_into_segments_text_too_small() {
337 let chunker = ParallelChunker::new(SemanticChunker::new())
339 .min_parallel_size(1000)
340 .num_segments(4);
341
342 let text = "Short text";
343 let segments = chunker.split_into_segments(text, 4);
344 assert_eq!(segments.len(), 1);
345 assert_eq!(segments[0].1, text);
346 }
347
348 #[test]
349 fn test_parallel_chunker_segments_collapse_to_one() {
350 let chunker = ParallelChunker::new(SemanticChunker::with_size(100))
352 .min_parallel_size(10)
353 .num_segments(10);
354
355 let text = "A short text that won't split well.";
357 let chunks = chunker.chunk(1, text, None).unwrap();
358 assert!(!chunks.is_empty());
359 }
360
361 #[test]
362 fn test_parallel_chunker_description() {
363 let chunker = ParallelChunker::new(SemanticChunker::new());
365 let desc = chunker.description();
366 assert!(desc.contains("Parallel"));
367 assert!(!desc.is_empty());
368 }
369
370 #[test]
371 fn test_find_segment_boundary_no_good_boundary() {
372 let text = "AAAAAAAAAAAAAAAAAAAA"; let boundary = ParallelChunker::<SemanticChunker>::find_segment_boundary(text, 10);
375 assert!(boundary <= text.len());
377 }
378
379 #[test]
380 fn test_find_segment_boundary_at_end() {
381 let text = "Short";
383 let boundary = ParallelChunker::<SemanticChunker>::find_segment_boundary(text, 100);
384 assert_eq!(boundary, text.len());
385 }
386
387 #[test]
388 fn test_find_segment_boundary_finds_space() {
389 let text = "word1 word2 word3 word4";
391 let boundary = ParallelChunker::<SemanticChunker>::find_segment_boundary(text, 15);
392 assert!(boundary > 0 && boundary <= text.len());
394 }
395}