use crate::chunking::traits::{ChunkMetadata, Chunker};
use crate::core::Chunk;
use crate::error::Result;
use rayon::prelude::*;
#[derive(Debug, Clone)]
pub struct ParallelChunker<C: Chunker + Clone> {
inner: C,
min_parallel_size: usize,
num_segments: usize,
}
impl<C: Chunker + Clone> ParallelChunker<C> {
#[must_use]
pub fn new(inner: C) -> Self {
Self {
inner,
min_parallel_size: 100_000, num_segments: num_cpus::get().max(2),
}
}
#[must_use]
pub const fn min_parallel_size(mut self, size: usize) -> Self {
self.min_parallel_size = size;
self
}
#[must_use]
pub fn num_segments(mut self, n: usize) -> Self {
self.num_segments = n.max(1);
self
}
fn split_into_segments<'a>(&self, text: &'a str, n: usize) -> Vec<(usize, &'a str)> {
if n <= 1 || text.len() < self.min_parallel_size {
return vec![(0, text)];
}
let segment_size = text.len() / n;
let mut segments = Vec::with_capacity(n);
let mut start = 0;
for i in 0..n {
let target_end = if i == n - 1 {
text.len()
} else {
start + segment_size
};
let end = Self::find_segment_boundary(text, target_end);
let end = end.max(start + 1).min(text.len());
if start < text.len() {
segments.push((start, &text[start..end]));
}
start = end;
if start >= text.len() {
break;
}
}
segments
}
fn find_segment_boundary(text: &str, target: usize) -> usize {
if target >= text.len() {
return text.len();
}
let search_start = target.saturating_sub(1000);
let search_region = &text[search_start..target.min(text.len())];
if let Some(pos) = search_region.rfind("\n\n") {
return search_start + pos + 2;
}
if let Some(pos) = search_region.rfind('\n') {
return search_start + pos + 1;
}
if let Some(pos) = search_region.rfind(' ') {
return search_start + pos + 1;
}
let mut pos = target;
while !text.is_char_boundary(pos) && pos > 0 {
pos -= 1;
}
pos
}
fn merge_chunks(segment_chunks: Vec<Vec<Chunk>>, buffer_id: i64) -> Vec<Chunk> {
let mut all_chunks: Vec<Chunk> = Vec::new();
let mut index = 0;
for chunks in segment_chunks {
for mut chunk in chunks {
chunk.index = index;
chunk.buffer_id = buffer_id;
all_chunks.push(chunk);
index += 1;
}
}
all_chunks
}
}
impl<C: Chunker + Clone + Send + Sync> Chunker for ParallelChunker<C> {
fn chunk(
&self,
buffer_id: i64,
text: &str,
metadata: Option<&ChunkMetadata>,
) -> Result<Vec<Chunk>> {
if text.len() < self.min_parallel_size {
return self.inner.chunk(buffer_id, text, metadata);
}
let segments = self.split_into_segments(text, self.num_segments);
if segments.len() <= 1 {
return self.inner.chunk(buffer_id, text, metadata);
}
let results: Vec<Result<Vec<Chunk>>> = segments
.par_iter()
.map(|(offset, segment)| {
let mut chunks = self.inner.chunk(buffer_id, segment, metadata)?;
for chunk in &mut chunks {
chunk.byte_range =
(chunk.byte_range.start + offset)..(chunk.byte_range.end + offset);
}
Ok(chunks)
})
.collect();
let mut all_segment_chunks = Vec::with_capacity(results.len());
for result in results {
all_segment_chunks.push(result?);
}
Ok(Self::merge_chunks(all_segment_chunks, buffer_id))
}
fn name(&self) -> &'static str {
"parallel"
}
fn supports_parallel(&self) -> bool {
true
}
fn description(&self) -> &'static str {
"Parallel chunking using rayon for multi-threaded processing"
}
}
mod num_cpus {
pub fn get() -> usize {
std::thread::available_parallelism()
.map(std::num::NonZeroUsize::get)
.unwrap_or(4)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::chunking::SemanticChunker;
#[test]
fn test_parallel_chunker_small_text() {
let chunker = ParallelChunker::new(SemanticChunker::with_size(50));
let text = "Hello, world!";
let chunks = chunker.chunk(1, text, None).unwrap();
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].content, text);
}
#[test]
fn test_parallel_chunker_large_text() {
let chunker = ParallelChunker::new(SemanticChunker::with_size(1000))
.min_parallel_size(1000)
.num_segments(4);
let text = "Hello, world! This is a test sentence. ".repeat(500);
let chunks = chunker.chunk(1, &text, None).unwrap();
for chunk in &chunks {
assert!(!chunk.content.is_empty());
assert_eq!(&text[chunk.byte_range.clone()], chunk.content);
}
for (i, chunk) in chunks.iter().enumerate() {
assert_eq!(chunk.index, i);
}
}
#[test]
fn test_parallel_chunker_preserves_content() {
let chunker = ParallelChunker::new(SemanticChunker::with_size(500))
.min_parallel_size(500)
.num_segments(2);
let text = "Paragraph one. Sentence two.\n\nParagraph two. More text here.\n\n".repeat(50);
let chunks = chunker.chunk(1, &text, None).unwrap();
let mut reconstructed = String::new();
let mut last_end = 0;
for chunk in &chunks {
use std::cmp::Ordering;
match chunk.byte_range.start.cmp(&last_end) {
Ordering::Greater => {
}
Ordering::Less => {
let skip = last_end - chunk.byte_range.start;
if skip < chunk.content.len() {
reconstructed.push_str(&chunk.content[skip..]);
}
}
Ordering::Equal => {
reconstructed.push_str(&chunk.content);
}
}
last_end = chunk.byte_range.end;
}
assert!(!chunks.is_empty());
assert!(!reconstructed.is_empty());
}
#[test]
fn test_parallel_chunker_strategy_name() {
let chunker = ParallelChunker::new(SemanticChunker::new());
assert_eq!(chunker.name(), "parallel");
assert!(chunker.supports_parallel());
}
#[test]
fn test_split_into_segments() {
let chunker = ParallelChunker::new(SemanticChunker::new())
.min_parallel_size(10)
.num_segments(3);
let text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph.";
let segments = chunker.split_into_segments(text, 3);
assert!(!segments.is_empty());
for (_, segment) in &segments {
assert!(!segment.is_empty());
}
}
#[test]
fn test_parallel_chunker_empty_text() {
let chunker = ParallelChunker::new(SemanticChunker::new());
let chunks = chunker.chunk(1, "", None).unwrap();
assert!(chunks.is_empty());
}
#[test]
fn test_split_into_segments_single_segment() {
let chunker = ParallelChunker::new(SemanticChunker::new())
.min_parallel_size(10)
.num_segments(1);
let text = "This is some test content";
let segments = chunker.split_into_segments(text, 1);
assert_eq!(segments.len(), 1);
assert_eq!(segments[0].1, text);
}
#[test]
fn test_split_into_segments_text_too_small() {
let chunker = ParallelChunker::new(SemanticChunker::new())
.min_parallel_size(1000)
.num_segments(4);
let text = "Short text";
let segments = chunker.split_into_segments(text, 4);
assert_eq!(segments.len(), 1);
assert_eq!(segments[0].1, text);
}
#[test]
fn test_parallel_chunker_segments_collapse_to_one() {
let chunker = ParallelChunker::new(SemanticChunker::with_size(100))
.min_parallel_size(10)
.num_segments(10);
let text = "A short text that won't split well.";
let chunks = chunker.chunk(1, text, None).unwrap();
assert!(!chunks.is_empty());
}
#[test]
fn test_parallel_chunker_description() {
let chunker = ParallelChunker::new(SemanticChunker::new());
let desc = chunker.description();
assert!(desc.contains("Parallel"));
assert!(!desc.is_empty());
}
#[test]
fn test_find_segment_boundary_no_good_boundary() {
let text = "AAAAAAAAAAAAAAAAAAAA"; let boundary = ParallelChunker::<SemanticChunker>::find_segment_boundary(text, 10);
assert!(boundary <= text.len());
}
#[test]
fn test_find_segment_boundary_at_end() {
let text = "Short";
let boundary = ParallelChunker::<SemanticChunker>::find_segment_boundary(text, 100);
assert_eq!(boundary, text.len());
}
#[test]
fn test_find_segment_boundary_finds_space() {
let text = "word1 word2 word3 word4";
let boundary = ParallelChunker::<SemanticChunker>::find_segment_boundary(text, 15);
assert!(boundary > 0 && boundary <= text.len());
}
}