Skip to main content

entelix_rag/splitter/
token_count.rs

1//! `TokenCountSplitter` — token-budget splitter built on top of an
2//! operator-supplied [`TokenCounter`].
3//!
4//! Shares the recursive-merge core with
5//! [`RecursiveCharacterSplitter`](super::RecursiveCharacterSplitter)
6//! via [`splitter::common`](super::common); the two differ only in
7//! the size metric. Operators reach for this splitter when chunks
8//! must fit a vendor-token-count budget — typical embedding models
9//! cap context at 512 / 8192 tokens, and char-count approximation
10//! mis-estimates by 20-50% on multilingual or code-heavy corpora.
11//!
12//! ## Algorithm
13//!
14//! Same separator-priority recursion as the char splitter — the
15//! only differences are:
16//!
17//! 1. **Size metric** — [`TokenCounter::count`] replaces character
18//!    counting. Multilingual and CJK corpora get vendor-accurate
19//!    chunk sizing instead of a heuristic that under-counts CJK
20//!    tokens (one Korean syllable typically tokenises to 2-3 BPE
21//!    tokens, not the ~3 chars a char-based budget would assume).
22//! 2. **Tail extraction** — overlap seeding bisects on suffix length
23//!    until the largest suffix fitting `chunk_overlap` tokens is
24//!    found. Bisection cost is `O(log N)` `count` calls per chunk
25//!    seal, amortised across the chunk's tokens.
26//! 3. **Soft-cap discipline** — `chunk_size` is honoured strictly
27//!    when `chunk_overlap = 0`. With overlap engaged, the splitter
28//!    follows the same soft-cap contract as
29//!    [`RecursiveCharacterSplitter`](super::RecursiveCharacterSplitter):
30//!    the overlap-seeded prefix plus the next segment can briefly
31//!    land above `chunk_size` before the next split point, and BPE
32//!    seam effects can additionally shift the concatenated count by
33//!    a token or two at chunk boundaries. Operators wanting a strict
34//!    cap configure `chunk_overlap = 0`.
35//!
36//! ## Pairing with a `TokenCounter`
37//!
38//! Wire any [`TokenCounter`] impl — the
39//! [`entelix_core::ByteCountTokenCounter`] zero-dep default, the
40//! [`TiktokenCounter`](https://docs.rs/entelix-tokenizer-tiktoken)
41//! companion for OpenAI BPE accuracy, or any other vendor /
42//! locale-specific counter shipping over the same trait.
43
44use std::sync::Arc;
45
46use entelix_core::TokenCounter;
47
48use crate::document::{Document, Lineage};
49use crate::splitter::TextSplitter;
50use crate::splitter::common::{merge_with_overlap_metric, recurse_with_metric};
51use crate::splitter::recursive::DEFAULT_RECURSIVE_SEPARATORS;
52
53/// Default chunk size in tokens. `512` matches the typical embedding
54/// context window (`text-embedding-3-small` and `-large` both cap at
55/// 8191 tokens; chunking under 512 leaves headroom for query +
56/// instruction tokens at retrieval time).
57pub const DEFAULT_CHUNK_SIZE_TOKENS: usize = 512;
58
59/// Default overlap between consecutive chunks in tokens. ~12.5% of
60/// [`DEFAULT_CHUNK_SIZE_TOKENS`] preserves enough trailing context
61/// for retrieval grounding without bloating the index.
62pub const DEFAULT_CHUNK_OVERLAP_TOKENS: usize = 64;
63
64/// Stable identifier surfaced on every produced chunk's
65/// [`Lineage::splitter`](crate::Lineage::splitter) field.
66const SPLITTER_NAME: &str = "token-count";
67
68/// Recursive token-budget splitter.
69///
70/// Construct via [`Self::new`] with any `Arc<C>` where
71/// `C: TokenCounter + ?Sized + 'static`. Both concrete
72/// (`Arc<TiktokenCounter>`) and type-erased
73/// (`Arc<dyn TokenCounter>`) inputs are accepted — the splitter
74/// monomorphises per concrete counter for inlined hot-path
75/// dispatch, or falls through to dyn dispatch when the operator
76/// passes an erased Arc. Chain [`Self::with_chunk_size`] /
77/// [`Self::with_chunk_overlap`] / [`Self::with_separators`] for
78/// tuning. Cloning is cheap — the counter sits behind an [`Arc`]
79/// and the separator list is held by `Arc<[String]>`.
80#[derive(Clone)]
81pub struct TokenCountSplitter<C: TokenCounter + ?Sized + 'static = dyn TokenCounter> {
82    counter: Arc<C>,
83    chunk_size: usize,
84    chunk_overlap: usize,
85    separators: Arc<[String]>,
86}
87
88impl<C: TokenCounter + ?Sized + 'static> std::fmt::Debug for TokenCountSplitter<C> {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.debug_struct("TokenCountSplitter")
91            .field("counter", &self.counter.encoding_name())
92            .field("chunk_size", &self.chunk_size)
93            .field("chunk_overlap", &self.chunk_overlap)
94            .field("separators", &self.separators)
95            .finish()
96    }
97}
98
99impl<C: TokenCounter + ?Sized + 'static> TokenCountSplitter<C> {
100    /// Build with the supplied [`TokenCounter`] and the default
101    /// 512-token / 64-token shape.
102    #[must_use]
103    pub fn new(counter: Arc<C>) -> Self {
104        Self {
105            counter,
106            chunk_size: DEFAULT_CHUNK_SIZE_TOKENS,
107            chunk_overlap: DEFAULT_CHUNK_OVERLAP_TOKENS,
108            separators: DEFAULT_RECURSIVE_SEPARATORS
109                .iter()
110                .map(|s| (*s).to_owned())
111                .collect(),
112        }
113    }
114
115    /// Override the target chunk size in tokens.
116    #[must_use]
117    pub const fn with_chunk_size(mut self, chunk_size: usize) -> Self {
118        self.chunk_size = chunk_size;
119        self
120    }
121
122    /// Override the overlap (in tokens) between consecutive chunks.
123    /// Values at or above the chunk size silently clamp to
124    /// `chunk_size - 1` at split time so the recursion terminates.
125    #[must_use]
126    pub const fn with_chunk_overlap(mut self, chunk_overlap: usize) -> Self {
127        self.chunk_overlap = chunk_overlap;
128        self
129    }
130
131    /// Override the separator priority list. Defaults to
132    /// `["\n\n", "\n", " ", ""]` — paragraph → line → word → unit
133    /// fallback. Pipelines splitting source code or LaTeX ship
134    /// alternative priorities.
135    #[must_use]
136    pub fn with_separators<I, S>(mut self, separators: I) -> Self
137    where
138        I: IntoIterator<Item = S>,
139        S: Into<String>,
140    {
141        self.separators = separators.into_iter().map(Into::into).collect();
142        self
143    }
144
145    /// Effective chunk size in tokens.
146    #[must_use]
147    pub const fn chunk_size(&self) -> usize {
148        self.chunk_size
149    }
150
151    /// Effective chunk overlap in tokens.
152    #[must_use]
153    pub const fn chunk_overlap(&self) -> usize {
154        self.chunk_overlap
155    }
156
157    /// Borrow the wired token counter — surfaces
158    /// [`TokenCounter::encoding_name`] for OTel attribute emission
159    /// and operator diagnostics.
160    #[must_use]
161    pub const fn counter(&self) -> &Arc<C> {
162        &self.counter
163    }
164}
165
166impl<C: TokenCounter + ?Sized + 'static> TextSplitter for TokenCountSplitter<C> {
167    fn name(&self) -> &'static str {
168        SPLITTER_NAME
169    }
170
171    fn split(&self, document: &Document) -> Vec<Document> {
172        let chunk_size = self.chunk_size.max(1);
173        let chunk_overlap = self.chunk_overlap.min(chunk_size.saturating_sub(1));
174
175        let counter = Arc::clone(&self.counter);
176        let measure = move |text: &str| count_tokens(&*counter, text);
177        let counter_for_tail = Arc::clone(&self.counter);
178        let take_tail = move |text: &str, n: usize| take_tail_tokens(&*counter_for_tail, text, n);
179        let counter_for_fallback = Arc::clone(&self.counter);
180        let fallback = move |text: &str, n: usize| token_chunks(&*counter_for_fallback, text, n);
181
182        let segments = recurse_with_metric(
183            &document.content,
184            &self.separators,
185            chunk_size,
186            &measure,
187            &fallback,
188        );
189        let texts =
190            merge_with_overlap_metric(segments, chunk_size, chunk_overlap, &measure, &take_tail);
191
192        let total = texts.len();
193        if total == 0 {
194            return Vec::new();
195        }
196        #[allow(clippy::cast_possible_truncation)]
197        let total_u32 = total.min(u32::MAX as usize) as u32;
198        texts
199            .into_iter()
200            .enumerate()
201            .map(|(idx, text)| {
202                #[allow(clippy::cast_possible_truncation)]
203                let idx_u32 = idx.min(u32::MAX as usize) as u32;
204                let lineage =
205                    Lineage::from_split(document.id.clone(), idx_u32, total_u32, SPLITTER_NAME);
206                document.child(text, lineage)
207            })
208            .collect()
209    }
210}
211
212fn count_tokens<C: TokenCounter + ?Sized>(counter: &C, text: &str) -> usize {
213    usize::try_from(counter.count(text)).unwrap_or(usize::MAX)
214}
215
216/// Token-aware tail extraction: bisect on suffix char-length to find
217/// the largest suffix whose token count is `<= target`. Cost is
218/// `O(log L)` token counts where `L` is the input char-length.
219fn take_tail_tokens<C: TokenCounter + ?Sized>(counter: &C, text: &str, target: usize) -> String {
220    if text.is_empty() || target == 0 {
221        return String::new();
222    }
223    let total = count_tokens(counter, text);
224    if target >= total {
225        return text.to_owned();
226    }
227    let chars: Vec<char> = text.chars().collect();
228    let total_chars = chars.len();
229    let mut lo: usize = 0;
230    let mut hi: usize = total_chars;
231    while lo < hi {
232        let mid = lo + (hi - lo).div_ceil(2);
233        let suffix_start = total_chars.saturating_sub(mid);
234        let suffix: String = chars.iter().skip(suffix_start).collect();
235        if count_tokens(counter, &suffix) <= target {
236            lo = mid;
237        } else {
238            hi = mid - 1;
239        }
240    }
241    let suffix_start = total_chars.saturating_sub(lo);
242    chars.iter().skip(suffix_start).collect()
243}
244
245/// Always-fits token-budget fallback. Walks chars greedily,
246/// flushing a chunk every time the next char would push the chunk's
247/// token count over `chunk_size`. Cost is `O(C)` token counts where
248/// `C` is the input char-length — acceptable because this is the
249/// terminator for the recursion and runs only on segments that
250/// every separator already failed to split below the cap.
251fn token_chunks<C: TokenCounter + ?Sized>(
252    counter: &C,
253    text: &str,
254    chunk_size: usize,
255) -> Vec<String> {
256    if chunk_size == 0 || text.is_empty() {
257        return Vec::new();
258    }
259    let mut out = Vec::new();
260    let mut current = String::new();
261    for ch in text.chars() {
262        current.push(ch);
263        if count_tokens(counter, &current) > chunk_size {
264            // Roll back the last char into the next chunk.
265            current.pop();
266            if !current.is_empty() {
267                out.push(std::mem::take(&mut current));
268            }
269            current.push(ch);
270        }
271    }
272    if !current.is_empty() {
273        out.push(current);
274    }
275    out
276}
277
278#[cfg(test)]
279#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
280mod tests {
281    use super::*;
282    use crate::document::Source;
283    use entelix_core::ByteCountTokenCounter;
284    use entelix_memory::Namespace;
285
286    fn ns() -> Namespace {
287        Namespace::new(entelix_core::TenantId::new("acme"))
288    }
289
290    fn doc(content: &str) -> Document {
291        Document::root("doc", content, Source::now("test://", "test"), ns())
292    }
293
294    fn byte_counter() -> Arc<dyn TokenCounter> {
295        Arc::new(ByteCountTokenCounter::new())
296    }
297
298    #[test]
299    fn empty_input_produces_no_chunks() {
300        let chunks = TokenCountSplitter::new(byte_counter()).split(&doc(""));
301        assert!(chunks.is_empty());
302    }
303
304    #[test]
305    fn small_input_produces_single_chunk_with_lineage() {
306        let chunks = TokenCountSplitter::new(byte_counter()).split(&doc("short"));
307        assert_eq!(chunks.len(), 1);
308        let lineage = chunks[0].lineage.as_ref().unwrap();
309        assert_eq!(lineage.chunk_index, 0);
310        assert_eq!(lineage.total_chunks, 1);
311        assert_eq!(lineage.splitter, "token-count");
312        assert_eq!(lineage.parent_id.as_str(), "doc");
313    }
314
315    #[test]
316    fn paragraph_split_prefers_double_newline_boundary() {
317        // ByteCountTokenCounter counts as div_ceil(bytes / 4). Each
318        // 16-byte paragraph is 4 tokens; chunk_size=5 fits one but
319        // not two.
320        let text = "alpha paragraph\n\nbeta paragraph\n\ngamma paragraph";
321        let splitter = TokenCountSplitter::new(byte_counter())
322            .with_chunk_size(5)
323            .with_chunk_overlap(0);
324        let chunks = splitter.split(&doc(text));
325        assert_eq!(chunks.len(), 3);
326        assert!(chunks[0].content.contains("alpha"));
327        assert!(chunks[1].content.contains("beta"));
328        assert!(chunks[2].content.contains("gamma"));
329    }
330
331    #[test]
332    fn cap_enforced_on_every_chunk() {
333        let splitter = TokenCountSplitter::new(byte_counter())
334            .with_chunk_size(8)
335            .with_chunk_overlap(0);
336        let text = "alpha bravo charlie delta echo foxtrot golf hotel india juliet kilo lima mike november";
337        let chunks = splitter.split(&doc(text));
338        assert!(chunks.len() > 1);
339        for chunk in &chunks {
340            let count = byte_counter().count(&chunk.content);
341            assert!(
342                count <= 8,
343                "chunk over cap: {} tokens, content={:?}",
344                count,
345                chunk.content
346            );
347        }
348    }
349
350    #[test]
351    fn overlap_seeds_tail_into_next_chunk() {
352        let text = "0123456789 abcdefghij KLMNOPQRST uvwxyz0123";
353        let splitter = TokenCountSplitter::new(byte_counter())
354            .with_chunk_size(5)
355            .with_chunk_overlap(1);
356        let chunks = splitter.split(&doc(text));
357        assert!(chunks.len() >= 2);
358        for window in chunks.windows(2) {
359            let tail = take_tail_tokens(&byte_counter(), &window[0].content, 1);
360            // Empty tail (when chunk smaller than 1 token boundary)
361            // is the trivial case; only non-empty tails carry a
362            // semantic claim.
363            if !tail.is_empty() {
364                assert!(
365                    window[1].content.starts_with(&tail),
366                    "next chunk must begin with previous tail: tail={tail:?}, next={:?}",
367                    window[1].content
368                );
369            }
370        }
371    }
372
373    #[test]
374    fn unicode_input_split_preserves_grapheme_boundary() {
375        // Korean text: byte-counter hits the cap fast (each syllable
376        // is 3 UTF-8 bytes ~ 1 token). Verify the splitter never
377        // breaks mid-grapheme — Rust's String type guarantees valid
378        // UTF-8 so any panic here would surface as `chars()` decode
379        // failure.
380        let text = "안녕하세요반갑습니다오늘은좋은날이에요";
381        let splitter = TokenCountSplitter::new(byte_counter())
382            .with_chunk_size(2)
383            .with_chunk_overlap(0)
384            .with_separators(["", ""]);
385        let chunks = splitter.split(&doc(text));
386        for chunk in &chunks {
387            let chars: String = chunk.content.chars().collect();
388            assert_eq!(
389                chars, chunk.content,
390                "chunk must be valid UTF-8 with no mid-grapheme cut"
391            );
392        }
393        let joined: String = chunks.iter().map(|c| c.content.as_str()).collect();
394        assert_eq!(joined, text, "round-trip must reproduce input");
395    }
396
397    #[test]
398    fn child_id_carries_chunk_index_suffix() {
399        let chunks = TokenCountSplitter::new(byte_counter())
400            .with_chunk_size(2)
401            .with_chunk_overlap(0)
402            .split(&doc("alpha beta gamma delta"));
403        for (idx, chunk) in chunks.iter().enumerate() {
404            assert_eq!(chunk.id.as_str(), format!("doc:{idx}"));
405        }
406    }
407
408    #[test]
409    fn lineage_total_chunks_matches_emitted_count() {
410        let text = "para one.\n\npara two.\n\npara three.";
411        let chunks = TokenCountSplitter::new(byte_counter())
412            .with_chunk_size(4)
413            .with_chunk_overlap(0)
414            .split(&doc(text));
415        let total = chunks.len();
416        for (idx, chunk) in chunks.iter().enumerate() {
417            let lineage = chunk.lineage.as_ref().unwrap();
418            #[allow(clippy::cast_possible_truncation)]
419            let idx_u32 = idx as u32;
420            #[allow(clippy::cast_possible_truncation)]
421            let total_u32 = total as u32;
422            assert_eq!(lineage.chunk_index, idx_u32);
423            assert_eq!(lineage.total_chunks, total_u32);
424        }
425    }
426
427    #[test]
428    fn overlap_clamped_below_chunk_size_terminates() {
429        let splitter = TokenCountSplitter::new(byte_counter())
430            .with_chunk_size(3)
431            .with_chunk_overlap(100);
432        let chunks = splitter.split(&doc("0123456789 abcdefghij KLMNOP uvwxyz"));
433        assert!(
434            !chunks.is_empty() && chunks.len() < 1000,
435            "split terminated with bounded chunk count, got {}",
436            chunks.len()
437        );
438    }
439
440    #[test]
441    fn counter_accessor_exposes_encoding_name() {
442        let splitter = TokenCountSplitter::new(byte_counter());
443        assert_eq!(splitter.counter().encoding_name(), "byte-count-naive");
444    }
445
446    #[test]
447    fn debug_lists_encoding_not_arc_pointer() {
448        let splitter = TokenCountSplitter::new(byte_counter());
449        let debug = format!("{splitter:?}");
450        assert!(debug.contains("byte-count-naive"));
451        assert!(debug.contains("chunk_size"));
452    }
453
454    #[test]
455    fn take_tail_tokens_handles_empty_and_oversize_target() {
456        let counter = byte_counter();
457        assert_eq!(take_tail_tokens(&counter, "", 5), "");
458        assert_eq!(take_tail_tokens(&counter, "abc", 0), "");
459        assert_eq!(take_tail_tokens(&counter, "abc", 1000), "abc");
460    }
461
462    #[test]
463    fn take_tail_tokens_returns_largest_fitting_suffix() {
464        let counter = byte_counter();
465        // ByteCountTokenCounter: 4 bytes per token (rounds up).
466        // "abcdefgh" = 8 bytes = 2 tokens. Asking for 1-token tail
467        // should return the trailing 4-byte slice.
468        let tail = take_tail_tokens(&counter, "abcdefgh", 1);
469        assert_eq!(counter.count(&tail), 1);
470        assert!("abcdefgh".ends_with(&tail));
471    }
472}