use crate::common::{Chunk, Chunks, EmbeddingModel, EmbeddingModelMetadata, TokenizerWrapper};
use std::num::NonZeroUsize;
use thiserror::Error;
use super::traits::Chunker;
pub struct TokenChunker {
chunk_size: NonZeroUsize,
chunk_overlap: usize,
tokenizer: Box<dyn TokenizerWrapper>,
}
impl TokenChunker {
pub fn try_new(
chunk_size: NonZeroUsize,
chunk_overlap: usize,
embedding_model: impl EmbeddingModel,
) -> Result<Self, TokenChunkingError> {
let metadata: EmbeddingModelMetadata = embedding_model.metadata();
Self::validate_arguments(chunk_size.into(), chunk_overlap, metadata.max_tokens)?;
let chunker = TokenChunker {
chunk_size,
chunk_overlap,
tokenizer: metadata.tokenizer,
};
Ok(chunker)
}
fn validate_arguments(
chunk_size: usize,
chunk_overlap: usize,
max_chunk_size: usize,
) -> Result<(), TokenChunkingError> {
if chunk_size > max_chunk_size {
Err(TokenChunkingError::InvalidChunkSize(format!(
"Chunk size must be smaller than {}",
max_chunk_size
)))?
}
if chunk_overlap >= chunk_size {
Err(TokenChunkingError::ChunkOverlapTooLarge(
"Window size must be smaller than chunk size".to_string(),
))?
}
Ok(())
}
}
impl Chunker for TokenChunker {
type ErrorType = TokenChunkingError;
fn generate_chunks(&self, raw_text: &str) -> Result<Chunks, Self::ErrorType> {
let tokens: Vec<String> = self.tokenizer.tokenize(raw_text).ok_or_else(|| {
TokenChunkingError::TokenizationError("Unable to tokenize text".to_string())
})?;
let chunk_size: usize = self.chunk_size.into();
let mut chunks: Chunks = Chunks::new();
let mut i = 0;
while i < tokens.len() {
let end = std::cmp::min(i + chunk_size, tokens.len());
let chunk: Chunk = Chunk::new(tokens[i..end].to_vec().join("").trim());
chunks.push(chunk);
i += chunk_size - self.chunk_overlap;
}
Ok(Chunks::from(chunks))
}
}
#[derive(Error, Debug, PartialEq, Eq)]
pub enum TokenChunkingError {
#[error("{0}")]
ChunkOverlapTooLarge(String),
#[error("{0}")]
TokenizationError(String),
#[error("{0}")]
InvalidChunkSize(String),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::OpenAIEmbeddingModel::TextEmbeddingAda002;
#[test]
fn test_generate_chunks_with_valid_input() {
let raw_text: &str = "This is a test string";
let window_size: usize = 1;
let chunk_size: NonZeroUsize = NonZeroUsize::new(2).unwrap();
let chunker: TokenChunker =
TokenChunker::try_new(chunk_size, window_size, TextEmbeddingAda002).unwrap();
let chunks: Chunks = chunker.generate_chunks(raw_text).unwrap();
let chunks: Vec<String> = chunks
.into_iter()
.map(|chunk| chunk.content().to_string())
.collect();
assert_eq!(chunks.len(), 5);
assert_eq!(
chunks,
vec!["This is", "is a", "a test", "test string", "string"]
);
}
#[test]
fn test_generate_chunks_with_empty_string() {
let raw_text: &str = "";
let window_size: usize = 1;
let chunk_size: NonZeroUsize = NonZeroUsize::new(2).unwrap();
let chunker: TokenChunker =
TokenChunker::try_new(chunk_size, window_size, TextEmbeddingAda002).unwrap();
let chunks: Chunks = chunker.generate_chunks(raw_text).unwrap();
let chunks: Vec<String> = chunks
.into_iter()
.map(|chunk| chunk.content().to_string())
.collect();
assert_eq!(chunks.len(), 0);
assert_eq!(chunks, Vec::<String>::new());
}
#[test]
fn test_generate_chunks_with_invalid_window_size() {
let window_size: usize = 3;
let chunk_size: NonZeroUsize = NonZeroUsize::new(2).unwrap();
assert!(TokenChunker::try_new(chunk_size, window_size, TextEmbeddingAda002).is_err());
}
#[test]
fn test_generate_chunks_with_invalid_chunk_size() {
let window_size: usize = 3;
let chunk_size: NonZeroUsize = NonZeroUsize::new(20000).unwrap();
assert!(TokenChunker::try_new(chunk_size, window_size, TextEmbeddingAda002).is_err());
}
}