use crate::tokenizer::{CharacterTokenizer, Tokenizer};
use crate::types::Chunk;
use rayon::prelude::*;
pub struct TokenChunker<T: Tokenizer> {
pub tokenizer: T,
pub chunk_size: usize,
pub chunk_overlap: usize,
}
impl<T: Tokenizer> TokenChunker<T> {
pub fn new(tokenizer: T, chunk_size: usize, chunk_overlap: usize) -> Self {
Self {
tokenizer,
chunk_size,
chunk_overlap,
}
}
pub fn chunk(&self, text: &str) -> Vec<Chunk> {
let tokens = self.tokenizer.encode(text);
let mut chunks = Vec::with_capacity(tokens.len().div_ceil(self.chunk_size));
let mut current_pos = 0;
while current_pos < tokens.len() {
let end = (current_pos + self.chunk_size).min(tokens.len());
let chunk_tokens = &tokens[current_pos..end];
let chunk_text = self.tokenizer.decode(chunk_tokens);
let start_index = if current_pos == 0 {
0
} else {
chunks.last().map_or(0, |c: &Chunk| c.end_index)
};
let end_index = start_index + chunk_text.len();
chunks.push(Chunk::new(
chunk_text,
start_index,
end_index,
chunk_tokens.len(),
));
current_pos += self.chunk_size - self.chunk_overlap;
}
chunks
}
pub fn chunk_batch(&self, texts: &Vec<String>) -> Vec<Vec<Chunk>> {
texts.par_iter().map(|text| self.chunk(text)).collect()
}
}
impl Default for TokenChunker<CharacterTokenizer> {
fn default() -> Self {
Self {
tokenizer: CharacterTokenizer::new(),
chunk_size: 512,
chunk_overlap: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizer::CharacterTokenizer;
#[test]
fn test_chunk() {
let tokenizer = CharacterTokenizer::new();
let chunker = TokenChunker::new(tokenizer, 12, 0);
let text = "Hello, world! This is a test.";
let chunks = chunker.chunk(&text.to_string());
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0].text, "Hello, world");
assert_eq!(chunks[0].start_index, 0);
assert_eq!(chunks[0].end_index, 12);
}
#[test]
fn test_chunk_batch() {
let tokenizer = CharacterTokenizer::new();
let chunker = TokenChunker::new(tokenizer, 12, 0);
let texts = vec![
"Hello, world! This is a test.".to_string(),
"This is another test.".to_string(),
];
let chunks = chunker.chunk_batch(&texts);
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0].len(), 3);
assert_eq!(chunks[1].len(), 2);
}
}