use super::*;
use std::collections::HashMap;
pub struct DfsTextChunker {
splits: VecDeque<TextSplit>,
config: Arc<ChunkerConfig>,
remaining_token_count: f32,
valid_split_indices_memo: HashMap<usize, Vec<usize>>,
}
impl DfsTextChunker {
pub fn run(config: &Arc<ChunkerConfig>) -> Option<Vec<Chunk>> {
let splits = config.initial_splits.clone();
if splits
.iter()
.any(|split: &TextSplit| split.token_count.unwrap() as f32 > config.length_max)
{
eprintln!(
"\nPure semantic chunking is impossible for separator: {:#?}.\nA splits token count is more than length_max: {:#?}.", config.initial_separator, config.length_max,
);
return None;
};
let mut chunker = DfsTextChunker {
splits,
config: Arc::clone(config),
remaining_token_count: 0.0,
valid_split_indices_memo: HashMap::new(),
};
chunker.remaining_token_count = chunker.config.estimate_splits_token_count(&chunker.splits);
let chunk_split_indexes = chunker.find_valid_chunk_combinations()?;
chunker.create_chunks(chunk_split_indexes)
}
fn find_valid_chunk_combinations(&mut self) -> Option<Vec<usize>> {
let chunk_split_indexes = self.recursive_chunk_tester(0);
if chunk_split_indexes.is_none() || chunk_split_indexes.as_ref().unwrap().len() == 1 {
None
} else {
chunk_split_indexes
}
}
fn recursive_chunk_tester(&mut self, start: usize) -> Option<Vec<usize>> {
if self.config.chunks_found.load(Ordering::Relaxed) {
return None;
}
if self.valid_split_indices_memo.contains_key(&{ start }) {
return None; }
let valid_split_indices = self.find_valid_split_indices_for_chunk(start)?;
for &end_split in &valid_split_indices {
if end_split + 1 == self.splits.len() {
return Some(vec![end_split]);
}
}
for &end_split in &valid_split_indices {
let result = self.recursive_chunk_tester(end_split + 1);
if let Some(mut result) = result {
result.insert(0, end_split);
return Some(result);
}
}
None
}
fn find_valid_split_indices_for_chunk(&mut self, start: usize) -> Option<Vec<usize>> {
let mut valid_split_indices = Vec::new();
let mut chunk = Chunk::new(&self.config);
for (index, split) in self.splits.iter().enumerate().skip(start) {
chunk.add_split(split.clone(), false);
if chunk.estimated_token_count >= self.config.absolute_length_min as f32 {
if chunk.estimated_token_count > self.config.length_max {
break;
}
valid_split_indices.push(index);
}
}
if valid_split_indices.is_empty() {
self.valid_split_indices_memo
.insert(start, valid_split_indices);
None
} else {
self.valid_split_indices_memo
.insert(start, valid_split_indices.clone());
Some(valid_split_indices)
}
}
fn create_chunks(&self, chunk_split_indexes: Vec<usize>) -> Option<Vec<Chunk>> {
let mut chunks = Vec::new();
let mut chunk_ranges: Vec<(usize, usize)> = Vec::new();
chunk_ranges.push((0, chunk_split_indexes[0]));
for (i, &split_index) in chunk_split_indexes.iter().enumerate() {
if i + 1 == chunk_split_indexes.len() {
break;
}
chunk_ranges.push((split_index + 1, chunk_split_indexes[i + 1]));
}
for (start_index, end_index) in chunk_ranges {
let mut chunk = Chunk::new(&self.config);
for i in start_index..=end_index {
chunk.add_split(self.splits[i].clone(), false);
}
chunks.push(chunk);
}
Some(chunks)
}
}
#[cfg(test)]
mod tests {
use super::*;
use llm_models::local_model::gguf::preset::LlmPreset;
fn runner(
tokenizer: &std::sync::Arc<LlmTokenizer>,
separator: Separator,
) -> Option<Vec<Chunk>> {
let chunks_found: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
let incoming_text =
"\n\nOne one one one.\n\nTwo two two two.\n\n\nThree three three three.\n\n";
let absolute_length_max = 5;
let config = Arc::new(ChunkerConfig::new(
&chunks_found,
separator.clone(),
incoming_text,
absolute_length_max,
None,
Some(0.0),
Arc::clone(tokenizer),
)?);
DfsTextChunker::run(&config)
}
#[test]
fn all() {
let test_cases = [
"One one one one.",
"Two two two two.",
"Three three three three.",
];
let separators = vec![
Separator::TwoPlusEoL,
Separator::SingleEol,
Separator::SentencesRuleBased,
Separator::SentencesUnicode,
];
let tokenizer: Arc<LlmTokenizer> =
Arc::new(LlmTokenizer::new_tiktoken(TOKENIZER_TIKTOKEN_DEFAULT).unwrap());
for separator in separators.clone() {
let mut chunks = runner(&tokenizer, separator).unwrap();
let chunks_string: Vec<String> = chunks.iter_mut().map(|chunk| chunk.text()).collect();
for (i, chunk) in chunks_string.into_iter().enumerate() {
assert_eq!(chunk, test_cases[i]);
}
}
let tokenizer = LlmPreset::Llama3_1_8bInstruct
.load()
.unwrap()
.model_base
.tokenizer;
for separator in separators {
let mut chunks = runner(&tokenizer, separator).unwrap();
let chunks_string: Vec<String> = chunks.iter_mut().map(|chunk| chunk.text()).collect();
for (i, chunk) in chunks_string.into_iter().enumerate() {
assert_eq!(chunk, test_cases[i]);
}
}
}
}