use crate::types::{MemoryError, MemoryResult, MIN_CHUNK_LENGTH};
use tiktoken_rs::cl100k_base;
#[derive(Debug, Clone)]
pub struct TextChunk {
pub content: String,
pub token_count: usize,
pub start_index: usize,
pub end_index: usize,
}
#[derive(Debug, Clone)]
pub struct ChunkingConfig {
pub chunk_size: usize,
pub chunk_overlap: usize,
pub separator: Option<String>,
}
impl Default for ChunkingConfig {
fn default() -> Self {
Self {
chunk_size: 512,
chunk_overlap: 64,
separator: None,
}
}
}
pub struct Tokenizer {
bpe: tiktoken_rs::CoreBPE,
}
impl Tokenizer {
pub fn new() -> MemoryResult<Self> {
let bpe = cl100k_base().map_err(|e| MemoryError::Tokenization(e.to_string()))?;
Ok(Self { bpe })
}
pub fn count_tokens(&self, text: &str) -> usize {
self.bpe.encode_with_special_tokens(text).len()
}
pub fn encode(&self, text: &str) -> Vec<u32> {
self.bpe.encode_with_special_tokens(text)
}
pub fn decode(&self, tokens: &[u32]) -> String {
self.bpe.decode(tokens.to_vec()).unwrap_or_default()
}
}
impl Default for Tokenizer {
fn default() -> Self {
Self::new().expect("Failed to initialize tokenizer")
}
}
pub fn chunk_text(text: &str, config: &ChunkingConfig) -> MemoryResult<Vec<TextChunk>> {
if text.is_empty() {
return Ok(Vec::new());
}
if text.len() < MIN_CHUNK_LENGTH {
let tokenizer = Tokenizer::new()?;
let token_count = tokenizer.count_tokens(text);
return Ok(vec![TextChunk {
content: text.to_string(),
token_count,
start_index: 0,
end_index: text.len(),
}]);
}
let tokenizer = Tokenizer::new()?;
let tokens = tokenizer.encode(text);
if tokens.len() <= config.chunk_size {
return Ok(vec![TextChunk {
content: text.to_string(),
token_count: tokens.len(),
start_index: 0,
end_index: text.len(),
}]);
}
let mut chunks = Vec::new();
let mut start = 0;
while start < tokens.len() {
let end = (start + config.chunk_size).min(tokens.len());
let chunk_tokens = &tokens[start..end];
let chunk_text = tokenizer.decode(chunk_tokens);
let start_char = if start == 0 {
0
} else {
let prev_tokens = &tokens[..start];
tokenizer.decode(prev_tokens).len()
};
let end_char = start_char + chunk_text.len();
chunks.push(TextChunk {
content: chunk_text,
token_count: chunk_tokens.len(),
start_index: start_char,
end_index: end_char.min(text.len()),
});
let step = config.chunk_size.saturating_sub(config.chunk_overlap);
if step == 0 {
start = end;
} else {
start += step;
}
if start >= tokens.len() {
break;
}
}
Ok(chunks)
}
pub fn chunk_text_semantic(text: &str, config: &ChunkingConfig) -> MemoryResult<Vec<TextChunk>> {
if text.is_empty() {
return Ok(Vec::new());
}
if text.len() < MIN_CHUNK_LENGTH {
let tokenizer = Tokenizer::new()?;
let token_count = tokenizer.count_tokens(text);
return Ok(vec![TextChunk {
content: text.to_string(),
token_count,
start_index: 0,
end_index: text.len(),
}]);
}
let tokenizer = Tokenizer::new()?;
let tokens = tokenizer.encode(text);
if tokens.len() <= config.chunk_size {
return Ok(vec![TextChunk {
content: text.to_string(),
token_count: tokens.len(),
start_index: 0,
end_index: text.len(),
}]);
}
let paragraphs: Vec<&str> = text
.split("\n\n")
.filter(|p| !p.trim().is_empty())
.collect();
let mut chunks = Vec::new();
let mut current_chunk = String::new();
let mut current_tokens = 0;
let mut chunk_start = 0;
let mut current_pos = 0;
for paragraph in paragraphs {
let para_tokens = tokenizer.count_tokens(paragraph);
if para_tokens > config.chunk_size {
if !current_chunk.is_empty() {
chunks.push(TextChunk {
content: current_chunk.clone(),
token_count: current_tokens,
start_index: chunk_start,
end_index: current_pos,
});
current_chunk.clear();
current_tokens = 0;
chunk_start = current_pos;
}
let sentences: Vec<&str> = paragraph
.split(['.', '!', '?'])
.filter(|s| !s.trim().is_empty())
.collect();
for sentence in sentences {
let sentence_with_punct = format!("{}.", sentence.trim());
let sent_tokens = tokenizer.count_tokens(&sentence_with_punct);
if current_tokens + sent_tokens > config.chunk_size && !current_chunk.is_empty() {
chunks.push(TextChunk {
content: current_chunk.clone(),
token_count: current_tokens,
start_index: chunk_start,
end_index: current_pos,
});
let overlap_tokens =
current_tokens.saturating_sub(config.chunk_size - config.chunk_overlap);
if overlap_tokens > 0 && overlap_tokens < current_tokens {
let overlap_text =
get_last_n_tokens(&tokenizer, ¤t_chunk, overlap_tokens);
current_chunk = overlap_text;
current_tokens = overlap_tokens;
chunk_start = current_pos - current_chunk.len();
} else {
current_chunk.clear();
current_tokens = 0;
chunk_start = current_pos;
}
}
current_chunk.push_str(&sentence_with_punct);
current_chunk.push(' ');
current_tokens += sent_tokens;
current_pos += sentence_with_punct.len() + 1;
}
} else if current_tokens + para_tokens > config.chunk_size {
if !current_chunk.is_empty() {
chunks.push(TextChunk {
content: current_chunk.clone(),
token_count: current_tokens,
start_index: chunk_start,
end_index: current_pos,
});
}
current_chunk = paragraph.to_string();
current_chunk.push('\n');
current_tokens = para_tokens;
chunk_start = current_pos;
current_pos += paragraph.len() + 1;
} else {
current_chunk.push_str(paragraph);
current_chunk.push('\n');
current_tokens += para_tokens;
current_pos += paragraph.len() + 1;
}
}
if !current_chunk.is_empty() {
chunks.push(TextChunk {
content: current_chunk.trim().to_string(),
token_count: current_tokens,
start_index: chunk_start,
end_index: text.len(),
});
}
Ok(chunks)
}
fn get_last_n_tokens(tokenizer: &Tokenizer, text: &str, n: usize) -> String {
let tokens = tokenizer.encode(text);
let start = tokens.len().saturating_sub(n);
let last_tokens = &tokens[start..];
tokenizer.decode(last_tokens)
}
pub fn estimate_token_count(text: &str) -> usize {
text.len() / 4
}
pub fn truncate_to_tokens(text: &str, max_tokens: usize) -> MemoryResult<String> {
let tokenizer = Tokenizer::new()?;
let tokens = tokenizer.encode(text);
if tokens.len() <= max_tokens {
Ok(text.to_string())
} else {
let truncated = &tokens[..max_tokens];
Ok(tokenizer.decode(truncated))
}
}
pub fn merge_small_chunks(chunks: Vec<TextChunk>, min_tokens: usize) -> Vec<TextChunk> {
if chunks.len() < 2 {
return chunks;
}
let mut merged = Vec::new();
let mut current = chunks[0].clone();
for chunk in chunks.into_iter().skip(1) {
if current.token_count < min_tokens {
current.content.push('\n');
current.content.push_str(&chunk.content);
current.token_count += chunk.token_count;
current.end_index = chunk.end_index;
} else {
merged.push(current);
current = chunk;
}
}
merged.push(current);
merged
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chunk_text_empty() {
let config = ChunkingConfig::default();
let chunks = chunk_text("", &config).unwrap();
assert!(chunks.is_empty());
}
#[test]
fn test_chunk_text_short() {
let config = ChunkingConfig::default();
let text = "This is a short text.";
let chunks = chunk_text(text, &config).unwrap();
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].content, text);
}
#[test]
fn test_chunk_text_long() {
let config = ChunkingConfig {
chunk_size: 10,
chunk_overlap: 2,
separator: None,
};
let text = "This is a much longer text that needs to be split into multiple chunks. It contains several sentences and should be broken up appropriately.";
let chunks = chunk_text(text, &config).unwrap();
assert!(chunks.len() > 1);
for i in 1..chunks.len() {
let prev_end = chunks[i - 1].end_index;
let curr_start = chunks[i].start_index;
assert!(curr_start < prev_end, "Chunks should overlap");
}
}
#[test]
fn test_tokenizer_count() {
let tokenizer = Tokenizer::new().unwrap();
let count = tokenizer.count_tokens("Hello world");
assert!(count > 0);
}
#[test]
fn test_estimate_token_count() {
let text = "This is a test sentence with approximately twelve tokens.";
let estimated = estimate_token_count(text);
let tokenizer = Tokenizer::new().unwrap();
let actual = tokenizer.count_tokens(text);
let diff = (estimated as i64 - actual as i64).abs();
assert!(diff < 5, "Estimate should be close to actual");
}
}