use super::TextSplitter;
use cognis_core::utils::tokens::{estimate_token_count, get_model_context_window};
pub struct TokenAwareTextSplitter {
pub max_tokens: usize,
pub overlap_tokens: usize,
pub model_name: Option<String>,
pub separators: Vec<String>,
}
impl Default for TokenAwareTextSplitter {
fn default() -> Self {
Self {
max_tokens: 500,
overlap_tokens: 50,
model_name: None,
separators: vec!["\n\n".into(), "\n".into(), ". ".into(), " ".into()],
}
}
}
impl TokenAwareTextSplitter {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_overlap_tokens(mut self, overlap_tokens: usize) -> Self {
self.overlap_tokens = overlap_tokens;
self
}
pub fn with_model(mut self, model_name: impl Into<String>) -> Self {
self.model_name = Some(model_name.into());
self
}
pub fn with_separators(mut self, separators: Vec<String>) -> Self {
self.separators = separators;
self
}
pub fn from_model_context(model_name: &str, chunks_per_context: usize) -> Self {
let context_window = get_model_context_window(model_name).unwrap_or(2000);
let max_tokens = context_window
.checked_div(chunks_per_context)
.unwrap_or(context_window);
Self {
max_tokens,
overlap_tokens: 50,
model_name: Some(model_name.to_string()),
separators: vec!["\n\n".into(), "\n".into(), ". ".into(), " ".into()],
}
}
fn estimate_tokens(text: &str, _model: Option<&str>) -> usize {
estimate_token_count(text)
}
fn split_with_separators(&self, text: &str) -> Vec<String> {
if text.is_empty() {
return Vec::new();
}
let text_tokens = Self::estimate_tokens(text, self.model_name.as_deref());
if text_tokens <= self.max_tokens {
return vec![text.to_string()];
}
let separator = self
.separators
.iter()
.find(|sep| text.contains(sep.as_str()))
.cloned();
let pieces: Vec<&str> = match &separator {
Some(sep) => text.split(sep.as_str()).collect(),
None => {
text.split_whitespace().collect()
}
};
let pieces: Vec<&str> = pieces.iter().copied().filter(|p| !p.is_empty()).collect();
let mut chunks: Vec<String> = Vec::new();
let mut current = String::new();
let mut current_tokens: usize = 0;
for piece in &pieces {
let piece_tokens = Self::estimate_tokens(piece, self.model_name.as_deref());
if piece_tokens > self.max_tokens {
if !current.is_empty() {
chunks.push(current.trim().to_string());
current = String::new();
current_tokens = 0;
}
let sub_chunks = self.split_subsection(piece);
chunks.extend(sub_chunks);
continue;
}
let sep_str = separator.as_deref().unwrap_or(" ");
let would_be = if current.is_empty() {
piece_tokens
} else {
current_tokens
+ Self::estimate_tokens(sep_str, self.model_name.as_deref())
+ piece_tokens
};
if would_be > self.max_tokens && !current.is_empty() {
chunks.push(current.trim().to_string());
current = String::new();
}
if current.is_empty() {
current = piece.to_string();
current_tokens = piece_tokens;
} else {
current.push_str(separator.as_deref().unwrap_or(" "));
current.push_str(piece);
current_tokens = Self::estimate_tokens(¤t, self.model_name.as_deref());
}
}
if !current.is_empty() {
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
chunks.push(trimmed);
}
}
if self.overlap_tokens > 0 && chunks.len() > 1 {
chunks = self.apply_overlap(chunks);
}
chunks
}
fn split_subsection(&self, text: &str) -> Vec<String> {
for sep in &self.separators {
if text.contains(sep.as_str()) {
let pieces: Vec<&str> =
text.split(sep.as_str()).filter(|p| !p.is_empty()).collect();
if pieces.len() > 1 {
let mut sub_chunks = Vec::new();
let mut current = String::new();
let mut current_tokens: usize = 0;
for piece in &pieces {
let piece_tokens = Self::estimate_tokens(piece, self.model_name.as_deref());
let would_be = if current.is_empty() {
piece_tokens
} else {
current_tokens
+ Self::estimate_tokens(sep, self.model_name.as_deref())
+ piece_tokens
};
if would_be > self.max_tokens && !current.is_empty() {
sub_chunks.push(current.trim().to_string());
current = String::new();
}
if current.is_empty() {
current = piece.to_string();
current_tokens = piece_tokens;
} else {
current.push_str(sep);
current.push_str(piece);
current_tokens =
Self::estimate_tokens(¤t, self.model_name.as_deref());
}
}
if !current.is_empty() {
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
sub_chunks.push(trimmed);
}
}
return sub_chunks;
}
}
}
vec![text.to_string()]
}
fn apply_overlap(&self, chunks: Vec<String>) -> Vec<String> {
if chunks.len() <= 1 {
return chunks;
}
let mut result = Vec::with_capacity(chunks.len());
result.push(chunks[0].clone());
for i in 1..chunks.len() {
let prev = &chunks[i - 1];
let overlap_text = self.get_overlap_suffix(prev);
if overlap_text.is_empty() {
result.push(chunks[i].clone());
} else {
let merged = format!("{} {}", overlap_text.trim(), chunks[i].trim());
let merged_tokens = Self::estimate_tokens(&merged, self.model_name.as_deref());
if merged_tokens <= self.max_tokens {
result.push(merged);
} else {
result.push(chunks[i].clone());
}
}
}
result
}
fn get_overlap_suffix(&self, text: &str) -> String {
let words: Vec<&str> = text.split_whitespace().collect();
let mut suffix_words: Vec<&str> = Vec::new();
let mut token_count = 0;
for word in words.iter().rev() {
let word_tokens = Self::estimate_tokens(word, self.model_name.as_deref());
if token_count + word_tokens > self.overlap_tokens {
break;
}
token_count += word_tokens;
suffix_words.push(word);
}
suffix_words.reverse();
suffix_words.join(" ")
}
}
impl TextSplitter for TokenAwareTextSplitter {
fn split_text(&self, text: &str) -> Vec<String> {
self.split_with_separators(text)
}
fn chunk_size(&self) -> usize {
self.max_tokens
}
fn chunk_overlap(&self) -> usize {
self.overlap_tokens
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_short_text_returns_single_chunk() {
let splitter = TokenAwareTextSplitter::new().with_max_tokens(100);
let result = splitter.split_text("Hello world.");
assert_eq!(result.len(), 1);
assert_eq!(result[0], "Hello world.");
}
#[test]
fn test_long_text_splits_into_multiple_chunks() {
let splitter = TokenAwareTextSplitter::new()
.with_max_tokens(10)
.with_overlap_tokens(0);
let text = "The quick brown fox jumps over the lazy dog. \
The quick brown fox jumps over the lazy dog. \
The quick brown fox jumps over the lazy dog.";
let chunks = splitter.split_text(text);
assert!(
chunks.len() > 1,
"Expected multiple chunks, got {}",
chunks.len()
);
for chunk in &chunks {
let tokens = estimate_token_count(chunk);
assert!(
tokens <= splitter.max_tokens + 2,
"Chunk has {} tokens, max is {}: {:?}",
tokens,
splitter.max_tokens,
chunk
);
}
}
#[test]
fn test_overlap_between_chunks() {
let splitter = TokenAwareTextSplitter::new()
.with_max_tokens(15)
.with_overlap_tokens(5);
let text = "Alpha beta gamma delta. Epsilon zeta eta theta. \
Iota kappa lambda mu. Nu xi omicron pi.";
let chunks = splitter.split_text(text);
assert!(
chunks.len() > 1,
"Expected multiple chunks for overlap test"
);
let mut found_overlap = false;
for i in 1..chunks.len() {
let prev_words: Vec<&str> = chunks[i - 1].split_whitespace().collect();
let curr_words: Vec<&str> = chunks[i].split_whitespace().collect();
for word in &prev_words {
if curr_words.contains(word) && word.len() > 3 {
found_overlap = true;
break;
}
}
if found_overlap {
break;
}
}
assert!(found_overlap, "Expected overlap between consecutive chunks");
}
#[test]
fn test_custom_separators() {
let splitter = TokenAwareTextSplitter::new()
.with_max_tokens(10)
.with_overlap_tokens(0)
.with_separators(vec!["||".into()]);
let text = "chunk one text here||chunk two text here||chunk three text here";
let chunks = splitter.split_text(text);
assert!(
chunks.len() >= 2,
"Expected at least 2 chunks with custom separator, got {}",
chunks.len()
);
}
#[test]
fn test_from_model_context_factory() {
let splitter = TokenAwareTextSplitter::from_model_context("gpt-4o", 10);
assert_eq!(splitter.max_tokens, 12_800);
assert_eq!(splitter.model_name.as_deref(), Some("gpt-4o"));
let splitter_claude = TokenAwareTextSplitter::from_model_context("claude-3-opus", 20);
assert_eq!(splitter_claude.max_tokens, 10_000);
let splitter_unknown = TokenAwareTextSplitter::from_model_context("unknown-model", 4);
assert_eq!(splitter_unknown.max_tokens, 500);
}
#[test]
fn test_empty_text_returns_empty_vec() {
let splitter = TokenAwareTextSplitter::new();
let result = splitter.split_text("");
assert!(result.is_empty());
}
#[test]
fn test_chunk_size_and_overlap_trait_methods() {
let splitter = TokenAwareTextSplitter::new()
.with_max_tokens(256)
.with_overlap_tokens(32);
assert_eq!(splitter.chunk_size(), 256);
assert_eq!(splitter.chunk_overlap(), 32);
}
}