1use std::sync::Arc;
45
46use entelix_core::TokenCounter;
47
48use crate::document::{Document, Lineage};
49use crate::splitter::TextSplitter;
50use crate::splitter::common::{merge_with_overlap_metric, recurse_with_metric};
51use crate::splitter::recursive::DEFAULT_RECURSIVE_SEPARATORS;
52
53pub const DEFAULT_CHUNK_SIZE_TOKENS: usize = 512;
58
59pub const DEFAULT_CHUNK_OVERLAP_TOKENS: usize = 64;
63
64const SPLITTER_NAME: &str = "token-count";
67
68#[derive(Clone)]
81pub struct TokenCountSplitter<C: TokenCounter + ?Sized + 'static = dyn TokenCounter> {
82 counter: Arc<C>,
83 chunk_size: usize,
84 chunk_overlap: usize,
85 separators: Arc<[String]>,
86}
87
88impl<C: TokenCounter + ?Sized + 'static> std::fmt::Debug for TokenCountSplitter<C> {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("TokenCountSplitter")
91 .field("counter", &self.counter.encoding_name())
92 .field("chunk_size", &self.chunk_size)
93 .field("chunk_overlap", &self.chunk_overlap)
94 .field("separators", &self.separators)
95 .finish()
96 }
97}
98
99impl<C: TokenCounter + ?Sized + 'static> TokenCountSplitter<C> {
100 #[must_use]
103 pub fn new(counter: Arc<C>) -> Self {
104 Self {
105 counter,
106 chunk_size: DEFAULT_CHUNK_SIZE_TOKENS,
107 chunk_overlap: DEFAULT_CHUNK_OVERLAP_TOKENS,
108 separators: DEFAULT_RECURSIVE_SEPARATORS
109 .iter()
110 .map(|s| (*s).to_owned())
111 .collect(),
112 }
113 }
114
115 #[must_use]
117 pub const fn with_chunk_size(mut self, chunk_size: usize) -> Self {
118 self.chunk_size = chunk_size;
119 self
120 }
121
122 #[must_use]
126 pub const fn with_chunk_overlap(mut self, chunk_overlap: usize) -> Self {
127 self.chunk_overlap = chunk_overlap;
128 self
129 }
130
131 #[must_use]
136 pub fn with_separators<I, S>(mut self, separators: I) -> Self
137 where
138 I: IntoIterator<Item = S>,
139 S: Into<String>,
140 {
141 self.separators = separators.into_iter().map(Into::into).collect();
142 self
143 }
144
145 #[must_use]
147 pub const fn chunk_size(&self) -> usize {
148 self.chunk_size
149 }
150
151 #[must_use]
153 pub const fn chunk_overlap(&self) -> usize {
154 self.chunk_overlap
155 }
156
157 #[must_use]
161 pub const fn counter(&self) -> &Arc<C> {
162 &self.counter
163 }
164}
165
166impl<C: TokenCounter + ?Sized + 'static> TextSplitter for TokenCountSplitter<C> {
167 fn name(&self) -> &'static str {
168 SPLITTER_NAME
169 }
170
171 fn split(&self, document: &Document) -> Vec<Document> {
172 let chunk_size = self.chunk_size.max(1);
173 let chunk_overlap = self.chunk_overlap.min(chunk_size.saturating_sub(1));
174
175 let counter = Arc::clone(&self.counter);
176 let measure = move |text: &str| count_tokens(&*counter, text);
177 let counter_for_tail = Arc::clone(&self.counter);
178 let take_tail = move |text: &str, n: usize| take_tail_tokens(&*counter_for_tail, text, n);
179 let counter_for_fallback = Arc::clone(&self.counter);
180 let fallback = move |text: &str, n: usize| token_chunks(&*counter_for_fallback, text, n);
181
182 let segments = recurse_with_metric(
183 &document.content,
184 &self.separators,
185 chunk_size,
186 &measure,
187 &fallback,
188 );
189 let texts =
190 merge_with_overlap_metric(segments, chunk_size, chunk_overlap, &measure, &take_tail);
191
192 let total = texts.len();
193 if total == 0 {
194 return Vec::new();
195 }
196 #[allow(clippy::cast_possible_truncation)]
197 let total_u32 = total.min(u32::MAX as usize) as u32;
198 texts
199 .into_iter()
200 .enumerate()
201 .map(|(idx, text)| {
202 #[allow(clippy::cast_possible_truncation)]
203 let idx_u32 = idx.min(u32::MAX as usize) as u32;
204 let lineage =
205 Lineage::from_split(document.id.clone(), idx_u32, total_u32, SPLITTER_NAME);
206 document.child(text, lineage)
207 })
208 .collect()
209 }
210}
211
212fn count_tokens<C: TokenCounter + ?Sized>(counter: &C, text: &str) -> usize {
213 usize::try_from(counter.count(text)).unwrap_or(usize::MAX)
214}
215
216fn take_tail_tokens<C: TokenCounter + ?Sized>(counter: &C, text: &str, target: usize) -> String {
220 if text.is_empty() || target == 0 {
221 return String::new();
222 }
223 let total = count_tokens(counter, text);
224 if target >= total {
225 return text.to_owned();
226 }
227 let chars: Vec<char> = text.chars().collect();
228 let total_chars = chars.len();
229 let mut lo: usize = 0;
230 let mut hi: usize = total_chars;
231 while lo < hi {
232 let mid = lo + (hi - lo).div_ceil(2);
233 let suffix_start = total_chars.saturating_sub(mid);
234 let suffix: String = chars.iter().skip(suffix_start).collect();
235 if count_tokens(counter, &suffix) <= target {
236 lo = mid;
237 } else {
238 hi = mid - 1;
239 }
240 }
241 let suffix_start = total_chars.saturating_sub(lo);
242 chars.iter().skip(suffix_start).collect()
243}
244
245fn token_chunks<C: TokenCounter + ?Sized>(
252 counter: &C,
253 text: &str,
254 chunk_size: usize,
255) -> Vec<String> {
256 if chunk_size == 0 || text.is_empty() {
257 return Vec::new();
258 }
259 let mut out = Vec::new();
260 let mut current = String::new();
261 for ch in text.chars() {
262 current.push(ch);
263 if count_tokens(counter, ¤t) > chunk_size {
264 current.pop();
266 if !current.is_empty() {
267 out.push(std::mem::take(&mut current));
268 }
269 current.push(ch);
270 }
271 }
272 if !current.is_empty() {
273 out.push(current);
274 }
275 out
276}
277
278#[cfg(test)]
279#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
280mod tests {
281 use super::*;
282 use crate::document::Source;
283 use entelix_core::ByteCountTokenCounter;
284 use entelix_memory::Namespace;
285
286 fn ns() -> Namespace {
287 Namespace::new(entelix_core::TenantId::new("acme"))
288 }
289
290 fn doc(content: &str) -> Document {
291 Document::root("doc", content, Source::now("test://", "test"), ns())
292 }
293
294 fn byte_counter() -> Arc<dyn TokenCounter> {
295 Arc::new(ByteCountTokenCounter::new())
296 }
297
298 #[test]
299 fn empty_input_produces_no_chunks() {
300 let chunks = TokenCountSplitter::new(byte_counter()).split(&doc(""));
301 assert!(chunks.is_empty());
302 }
303
304 #[test]
305 fn small_input_produces_single_chunk_with_lineage() {
306 let chunks = TokenCountSplitter::new(byte_counter()).split(&doc("short"));
307 assert_eq!(chunks.len(), 1);
308 let lineage = chunks[0].lineage.as_ref().unwrap();
309 assert_eq!(lineage.chunk_index, 0);
310 assert_eq!(lineage.total_chunks, 1);
311 assert_eq!(lineage.splitter, "token-count");
312 assert_eq!(lineage.parent_id.as_str(), "doc");
313 }
314
315 #[test]
316 fn paragraph_split_prefers_double_newline_boundary() {
317 let text = "alpha paragraph\n\nbeta paragraph\n\ngamma paragraph";
321 let splitter = TokenCountSplitter::new(byte_counter())
322 .with_chunk_size(5)
323 .with_chunk_overlap(0);
324 let chunks = splitter.split(&doc(text));
325 assert_eq!(chunks.len(), 3);
326 assert!(chunks[0].content.contains("alpha"));
327 assert!(chunks[1].content.contains("beta"));
328 assert!(chunks[2].content.contains("gamma"));
329 }
330
331 #[test]
332 fn cap_enforced_on_every_chunk() {
333 let splitter = TokenCountSplitter::new(byte_counter())
334 .with_chunk_size(8)
335 .with_chunk_overlap(0);
336 let text = "alpha bravo charlie delta echo foxtrot golf hotel india juliet kilo lima mike november";
337 let chunks = splitter.split(&doc(text));
338 assert!(chunks.len() > 1);
339 for chunk in &chunks {
340 let count = byte_counter().count(&chunk.content);
341 assert!(
342 count <= 8,
343 "chunk over cap: {} tokens, content={:?}",
344 count,
345 chunk.content
346 );
347 }
348 }
349
350 #[test]
351 fn overlap_seeds_tail_into_next_chunk() {
352 let text = "0123456789 abcdefghij KLMNOPQRST uvwxyz0123";
353 let splitter = TokenCountSplitter::new(byte_counter())
354 .with_chunk_size(5)
355 .with_chunk_overlap(1);
356 let chunks = splitter.split(&doc(text));
357 assert!(chunks.len() >= 2);
358 for window in chunks.windows(2) {
359 let tail = take_tail_tokens(&byte_counter(), &window[0].content, 1);
360 if !tail.is_empty() {
364 assert!(
365 window[1].content.starts_with(&tail),
366 "next chunk must begin with previous tail: tail={tail:?}, next={:?}",
367 window[1].content
368 );
369 }
370 }
371 }
372
373 #[test]
374 fn unicode_input_split_preserves_grapheme_boundary() {
375 let text = "안녕하세요반갑습니다오늘은좋은날이에요";
381 let splitter = TokenCountSplitter::new(byte_counter())
382 .with_chunk_size(2)
383 .with_chunk_overlap(0)
384 .with_separators(["", ""]);
385 let chunks = splitter.split(&doc(text));
386 for chunk in &chunks {
387 let chars: String = chunk.content.chars().collect();
388 assert_eq!(
389 chars, chunk.content,
390 "chunk must be valid UTF-8 with no mid-grapheme cut"
391 );
392 }
393 let joined: String = chunks.iter().map(|c| c.content.as_str()).collect();
394 assert_eq!(joined, text, "round-trip must reproduce input");
395 }
396
397 #[test]
398 fn child_id_carries_chunk_index_suffix() {
399 let chunks = TokenCountSplitter::new(byte_counter())
400 .with_chunk_size(2)
401 .with_chunk_overlap(0)
402 .split(&doc("alpha beta gamma delta"));
403 for (idx, chunk) in chunks.iter().enumerate() {
404 assert_eq!(chunk.id.as_str(), format!("doc:{idx}"));
405 }
406 }
407
408 #[test]
409 fn lineage_total_chunks_matches_emitted_count() {
410 let text = "para one.\n\npara two.\n\npara three.";
411 let chunks = TokenCountSplitter::new(byte_counter())
412 .with_chunk_size(4)
413 .with_chunk_overlap(0)
414 .split(&doc(text));
415 let total = chunks.len();
416 for (idx, chunk) in chunks.iter().enumerate() {
417 let lineage = chunk.lineage.as_ref().unwrap();
418 #[allow(clippy::cast_possible_truncation)]
419 let idx_u32 = idx as u32;
420 #[allow(clippy::cast_possible_truncation)]
421 let total_u32 = total as u32;
422 assert_eq!(lineage.chunk_index, idx_u32);
423 assert_eq!(lineage.total_chunks, total_u32);
424 }
425 }
426
427 #[test]
428 fn overlap_clamped_below_chunk_size_terminates() {
429 let splitter = TokenCountSplitter::new(byte_counter())
430 .with_chunk_size(3)
431 .with_chunk_overlap(100);
432 let chunks = splitter.split(&doc("0123456789 abcdefghij KLMNOP uvwxyz"));
433 assert!(
434 !chunks.is_empty() && chunks.len() < 1000,
435 "split terminated with bounded chunk count, got {}",
436 chunks.len()
437 );
438 }
439
440 #[test]
441 fn counter_accessor_exposes_encoding_name() {
442 let splitter = TokenCountSplitter::new(byte_counter());
443 assert_eq!(splitter.counter().encoding_name(), "byte-count-naive");
444 }
445
446 #[test]
447 fn debug_lists_encoding_not_arc_pointer() {
448 let splitter = TokenCountSplitter::new(byte_counter());
449 let debug = format!("{splitter:?}");
450 assert!(debug.contains("byte-count-naive"));
451 assert!(debug.contains("chunk_size"));
452 }
453
454 #[test]
455 fn take_tail_tokens_handles_empty_and_oversize_target() {
456 let counter = byte_counter();
457 assert_eq!(take_tail_tokens(&counter, "", 5), "");
458 assert_eq!(take_tail_tokens(&counter, "abc", 0), "");
459 assert_eq!(take_tail_tokens(&counter, "abc", 1000), "abc");
460 }
461
462 #[test]
463 fn take_tail_tokens_returns_largest_fitting_suffix() {
464 let counter = byte_counter();
465 let tail = take_tail_tokens(&counter, "abcdefgh", 1);
469 assert_eq!(counter.count(&tail), 1);
470 assert!("abcdefgh".ends_with(&tail));
471 }
472}