Skip to main content

cognis_rag/splitters/
token_aware.rs

1//! Token-aware splitter using `cognis_core`'s pluggable [`Tokenizer`].
2
3use std::sync::Arc;
4
5use crate::document::Document;
6
7// Re-export the canonical trait from cognis-core so users can write
8// `cognis_rag::Tokenizer` or `cognis_core::Tokenizer` interchangeably.
9pub use cognis_core::tokenizer::{CharTokenizer, FnTokenizer, Tokenizer};
10
11use super::{child_doc, recursive::RecursiveCharSplitter, TextSplitter};
12
13/// Splits text so each chunk's token count (per the supplied [`Tokenizer`])
14/// stays under `max_tokens`. Falls back to a recursive char splitter for
15/// the structural cuts; just adds a token-aware re-pack step on top.
16pub struct TokenAwareSplitter {
17    tokenizer: Arc<dyn Tokenizer>,
18    max_tokens: usize,
19    overlap_tokens: usize,
20    inner: RecursiveCharSplitter,
21}
22
23impl TokenAwareSplitter {
24    /// Build with a tokenizer + max-token cap.
25    pub fn new(tokenizer: Arc<dyn Tokenizer>, max_tokens: usize) -> Self {
26        Self {
27            tokenizer,
28            max_tokens,
29            overlap_tokens: 0,
30            // Approximate the char budget as 4× tokens — the recursive
31            // splitter is just a structural cutter; we re-bound by tokens.
32            inner: RecursiveCharSplitter::new()
33                .with_chunk_size(max_tokens.saturating_mul(4).max(1)),
34        }
35    }
36
37    /// Token-overlap between adjacent chunks. Clamped to `max_tokens - 1`
38    /// so the second-pass walker always makes progress; if the caller
39    /// passes a larger value, the chunk would just feed itself back in
40    /// as overlap and the splitter would loop or emit chunks > max.
41    pub fn with_overlap_tokens(mut self, n: usize) -> Self {
42        let cap = self.max_tokens.saturating_sub(1);
43        self.overlap_tokens = n.min(cap);
44        self
45    }
46}
47
48/// Take a suffix of `s` whose token count is exactly `n_tokens` (or
49/// the largest suffix below it). Counts via `tok` so the result is
50/// honest about the configured token unit.
51fn token_tail(s: &str, n_tokens: usize, tok: &dyn Tokenizer) -> String {
52    if n_tokens == 0 {
53        return String::new();
54    }
55    // Walk back char-by-char, growing the tail until the tokenizer
56    // reports n_tokens. Char-stepping is bounded by the chunk size so
57    // this is O(chunk_chars) — fine for typical RAG sizes.
58    let chars: Vec<char> = s.chars().collect();
59    let mut tail = String::new();
60    for &c in chars.iter().rev() {
61        let mut candidate = String::with_capacity(tail.len() + c.len_utf8());
62        candidate.push(c);
63        candidate.push_str(&tail);
64        if tok.count(&candidate) > n_tokens {
65            break;
66        }
67        tail = candidate;
68    }
69    tail
70}
71
72impl TextSplitter for TokenAwareSplitter {
73    fn split(&self, doc: &Document) -> Vec<Document> {
74        // First pass: structural cut.
75        let intermediate = self.inner.split(doc);
76        // Second pass: any chunk over budget gets char-trimmed greedily.
77        let mut out: Vec<Document> = Vec::new();
78        for d in intermediate {
79            if self.tokenizer.count(&d.content) <= self.max_tokens {
80                out.push(child_doc(doc, d.content, out.len()));
81                continue;
82            }
83            // Greedy char-walk until we hit max_tokens.
84            let mut buf = String::new();
85            for ch in d.content.chars() {
86                buf.push(ch);
87                if self.tokenizer.count(&buf) >= self.max_tokens {
88                    out.push(child_doc(doc, std::mem::take(&mut buf), out.len()));
89                    if self.overlap_tokens > 0 {
90                        let last = &out.last().unwrap().content;
91                        buf.push_str(&token_tail(
92                            last,
93                            self.overlap_tokens,
94                            self.tokenizer.as_ref(),
95                        ));
96                    }
97                }
98            }
99            if !buf.is_empty() {
100                out.push(child_doc(doc, buf, out.len()));
101            }
102        }
103        out
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn char_tokenizer_caps_chunk_size() {
113        let tok: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
114        let s = TokenAwareSplitter::new(tok, 10);
115        let doc = Document::new("a".repeat(50));
116        let chunks = s.split(&doc);
117        assert!(chunks.iter().all(|c| c.content.chars().count() <= 10));
118        assert!(!chunks.is_empty());
119    }
120
121    #[test]
122    fn fn_tokenizer_works() {
123        // Pretend each whitespace-separated word is one token.
124        let tok: Arc<dyn Tokenizer> = Arc::new(FnTokenizer(|s: &str| s.split_whitespace().count()));
125        assert_eq!(tok.count("hello rust world"), 3);
126    }
127
128    #[test]
129    fn overlap_clamps_below_max_tokens() {
130        let tok: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
131        let s = TokenAwareSplitter::new(tok, 5).with_overlap_tokens(20);
132        // Clamped to max - 1 = 4.
133        assert_eq!(s.overlap_tokens, 4);
134    }
135
136    #[test]
137    fn overlap_uses_token_count_not_char_count() {
138        // Word tokenizer: tokens = whitespace-separated word count.
139        let tok: Arc<dyn Tokenizer> = Arc::new(FnTokenizer(|s: &str| s.split_whitespace().count()));
140        let tail = token_tail("alpha beta gamma delta", 2, tok.as_ref());
141        // Last 2 word-tokens: "gamma delta" (length 11 chars), not the
142        // last 2 *characters* ("ta"). The tail-walker grows char-by-char,
143        // so leading whitespace from word boundaries can be included.
144        assert_eq!(tok.count(&tail), 2);
145        assert!(tail.ends_with("gamma delta"), "tail = {tail:?}");
146    }
147
148    #[test]
149    fn overlap_zero_tokens_returns_empty_tail() {
150        let tok: Arc<dyn Tokenizer> = Arc::new(CharTokenizer);
151        assert_eq!(token_tail("anything", 0, tok.as_ref()), "");
152    }
153}