use std::sync::Arc;
use crate::document::Document;
pub use cognis_core::tokenizer::{CharTokenizer, FnTokenizer, Tokenizer};
use super::{child_doc, recursive::RecursiveCharSplitter, TextSplitter};
pub struct TokenAwareSplitter {
tokenizer: Arc<dyn Tokenizer>,
max_tokens: usize,
overlap_tokens: usize,
inner: RecursiveCharSplitter,
}
impl TokenAwareSplitter {
pub fn new(tokenizer: Arc<dyn Tokenizer>, max_tokens: usize) -> Self {
Self {
tokenizer,
max_tokens,
overlap_tokens: 0,
inner: RecursiveCharSplitter::new()
.with_chunk_size(max_tokens.saturating_mul(4).max(1)),
}
}
pub fn with_overlap_tokens(mut self, n: usize) -> Self {
let cap = self.max_tokens.saturating_sub(1);
self.overlap_tokens = n.min(cap);
self
}
}
fn token_tail(s: &str, n_tokens: usize, tok: &dyn Tokenizer) -> String {
if n_tokens == 0 {
return String::new();
}
let chars: Vec<char> = s.chars().collect();
let mut tail = String::new();
for &c in chars.iter().rev() {
let mut candidate = String::with_capacity(tail.len() + c.len_utf8());
candidate.push(c);
candidate.push_str(&tail);
if tok.count(&candidate) > n_tokens {
break;
}
tail = candidate;
}
tail
}
impl TextSplitter for TokenAwareSplitter {
fn split(&self, doc: &Document) -> Vec<Document> {
let intermediate = self.inner.split(doc);
let mut out: Vec<Document> = Vec::new();
for d in intermediate {
if self.tokenizer.count(&d.content) <= self.max_tokens {
out.push(child_doc(doc, d.content, out.len()));
continue;
}
let mut buf = String::new();
for ch in d.content.chars() {
buf.push(ch);
if self.tokenizer.count(&buf) >= self.max_tokens {
out.push(child_doc(doc, std::mem::take(&mut buf), out.len()));
if self.overlap_tokens > 0 {
let last = &out.last().unwrap().content;
buf.push_str(&token_tail(
last,
self.overlap_tokens,
self.tokenizer.as_ref(),
));
}
}
}
if !buf.is_empty() {
out.push(child_doc(doc, buf, out.len()));
}
}
out
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn char_tokenizer_caps_chunk_size() {
let tok: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
let s = TokenAwareSplitter::new(tok, 10);
let doc = Document::new("a".repeat(50));
let chunks = s.split(&doc);
assert!(chunks.iter().all(|c| c.content.chars().count() <= 10));
assert!(!chunks.is_empty());
}
#[test]
fn fn_tokenizer_works() {
let tok: Arc<dyn Tokenizer> = Arc::new(FnTokenizer(|s: &str| s.split_whitespace().count()));
assert_eq!(tok.count("hello rust world"), 3);
}
#[test]
fn overlap_clamps_below_max_tokens() {
let tok: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
let s = TokenAwareSplitter::new(tok, 5).with_overlap_tokens(20);
assert_eq!(s.overlap_tokens, 4);
}
#[test]
fn overlap_uses_token_count_not_char_count() {
let tok: Arc<dyn Tokenizer> = Arc::new(FnTokenizer(|s: &str| s.split_whitespace().count()));
let tail = token_tail("alpha beta gamma delta", 2, tok.as_ref());
assert_eq!(tok.count(&tail), 2);
assert!(tail.ends_with("gamma delta"), "tail = {tail:?}");
}
#[test]
fn overlap_zero_tokens_returns_empty_tail() {
let tok: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
assert_eq!(token_tail("anything", 0, tok.as_ref()), "");
}
}