cognis_rag/splitters/
token_aware.rs1use std::sync::Arc;
4
5use crate::document::Document;
6
7pub use cognis_core::tokenizer::{CharTokenizer, FnTokenizer, Tokenizer};
10
11use super::{child_doc, recursive::RecursiveCharSplitter, TextSplitter};
12
13pub struct TokenAwareSplitter {
17 tokenizer: Arc<dyn Tokenizer>,
18 max_tokens: usize,
19 overlap_tokens: usize,
20 inner: RecursiveCharSplitter,
21}
22
23impl TokenAwareSplitter {
24 pub fn new(tokenizer: Arc<dyn Tokenizer>, max_tokens: usize) -> Self {
26 Self {
27 tokenizer,
28 max_tokens,
29 overlap_tokens: 0,
30 inner: RecursiveCharSplitter::new()
33 .with_chunk_size(max_tokens.saturating_mul(4).max(1)),
34 }
35 }
36
37 pub fn with_overlap_tokens(mut self, n: usize) -> Self {
42 let cap = self.max_tokens.saturating_sub(1);
43 self.overlap_tokens = n.min(cap);
44 self
45 }
46}
47
48fn token_tail(s: &str, n_tokens: usize, tok: &dyn Tokenizer) -> String {
52 if n_tokens == 0 {
53 return String::new();
54 }
55 let chars: Vec<char> = s.chars().collect();
59 let mut tail = String::new();
60 for &c in chars.iter().rev() {
61 let mut candidate = String::with_capacity(tail.len() + c.len_utf8());
62 candidate.push(c);
63 candidate.push_str(&tail);
64 if tok.count(&candidate) > n_tokens {
65 break;
66 }
67 tail = candidate;
68 }
69 tail
70}
71
72impl TextSplitter for TokenAwareSplitter {
73 fn split(&self, doc: &Document) -> Vec<Document> {
74 let intermediate = self.inner.split(doc);
76 let mut out: Vec<Document> = Vec::new();
78 for d in intermediate {
79 if self.tokenizer.count(&d.content) <= self.max_tokens {
80 out.push(child_doc(doc, d.content, out.len()));
81 continue;
82 }
83 let mut buf = String::new();
85 for ch in d.content.chars() {
86 buf.push(ch);
87 if self.tokenizer.count(&buf) >= self.max_tokens {
88 out.push(child_doc(doc, std::mem::take(&mut buf), out.len()));
89 if self.overlap_tokens > 0 {
90 let last = &out.last().unwrap().content;
91 buf.push_str(&token_tail(
92 last,
93 self.overlap_tokens,
94 self.tokenizer.as_ref(),
95 ));
96 }
97 }
98 }
99 if !buf.is_empty() {
100 out.push(child_doc(doc, buf, out.len()));
101 }
102 }
103 out
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[test]
112 fn char_tokenizer_caps_chunk_size() {
113 let tok: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
114 let s = TokenAwareSplitter::new(tok, 10);
115 let doc = Document::new("a".repeat(50));
116 let chunks = s.split(&doc);
117 assert!(chunks.iter().all(|c| c.content.chars().count() <= 10));
118 assert!(!chunks.is_empty());
119 }
120
121 #[test]
122 fn fn_tokenizer_works() {
123 let tok: Arc<dyn Tokenizer> = Arc::new(FnTokenizer(|s: &str| s.split_whitespace().count()));
125 assert_eq!(tok.count("hello rust world"), 3);
126 }
127
128 #[test]
129 fn overlap_clamps_below_max_tokens() {
130 let tok: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
131 let s = TokenAwareSplitter::new(tok, 5).with_overlap_tokens(20);
132 assert_eq!(s.overlap_tokens, 4);
134 }
135
136 #[test]
137 fn overlap_uses_token_count_not_char_count() {
138 let tok: Arc<dyn Tokenizer> = Arc::new(FnTokenizer(|s: &str| s.split_whitespace().count()));
140 let tail = token_tail("alpha beta gamma delta", 2, tok.as_ref());
141 assert_eq!(tok.count(&tail), 2);
145 assert!(tail.ends_with("gamma delta"), "tail = {tail:?}");
146 }
147
148 #[test]
149 fn overlap_zero_tokens_returns_empty_tail() {
150 let tok: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
151 assert_eq!(token_tail("anything", 0, tok.as_ref()), "");
152 }
153}