1use std::{cmp::Ordering, fmt, iter::once, ops::Range};
2
3use either::Either;
4use itertools::Itertools;
5use strum::IntoEnumIterator;
6
7use self::fallback::FallbackLevel;
8use crate::{chunk_size::MemoizedChunkSizer, trim::Trim, ChunkCapacity, ChunkConfig, ChunkSizer};
9
10#[cfg(feature = "code")]
11mod code;
12mod fallback;
13#[cfg(feature = "markdown")]
14mod markdown;
15mod text;
16
17#[cfg(feature = "code")]
18pub use code::{CodeSplitter, CodeSplitterError};
19#[cfg(feature = "markdown")]
20pub use markdown::MarkdownSplitter;
21pub use text::TextSplitter;
22
23trait Splitter<Sizer>
26where
27 Sizer: ChunkSizer,
28{
29 type Level: SemanticLevel;
30
31 const TRIM: Trim = Trim::All;
33
34 fn chunk_config(&self) -> &ChunkConfig<Sizer>;
36
37 fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)>;
39
40 fn chunk_indices<'splitter, 'text: 'splitter>(
43 &'splitter self,
44 text: &'text str,
45 ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter
46 where
47 Sizer: 'splitter,
48 {
49 TextChunks::<Sizer, Self::Level>::new(
50 self.chunk_config(),
51 text,
52 self.parse(text),
53 Self::TRIM,
54 )
55 }
56
57 fn chunk_char_indices<'splitter, 'text: 'splitter>(
64 &'splitter self,
65 text: &'text str,
66 ) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter
67 where
68 Sizer: 'splitter,
69 {
70 TextChunksWithCharIndices::<Sizer, Self::Level>::new(
71 self.chunk_config(),
72 text,
73 self.parse(text),
74 Self::TRIM,
75 )
76 }
77
78 fn chunks<'splitter, 'text: 'splitter>(
81 &'splitter self,
82 text: &'text str,
83 ) -> impl Iterator<Item = &'text str> + 'splitter
84 where
85 Sizer: 'splitter,
86 {
87 self.chunk_indices(text).map(|(_, t)| t)
88 }
89}
90
91trait SemanticLevel: Copy + fmt::Debug + Ord + PartialOrd + 'static {
93 fn sections(
98 text: &str,
99 level_ranges: impl Iterator<Item = (Self, Range<usize>)>,
100 ) -> impl Iterator<Item = (usize, &str)> {
101 let mut cursor = 0;
102 let mut final_match = false;
103 level_ranges
104 .batching(move |it| {
105 loop {
106 match it.next() {
107 None if final_match => return None,
109 None => {
111 final_match = true;
112 return text.get(cursor..).map(|t| Either::Left(once((cursor, t))));
113 }
114 Some((_, range)) => {
116 if range.start < cursor {
117 continue;
118 }
119 let offset = cursor;
120 let prev_section = text
121 .get(offset..range.start)
122 .expect("invalid character sequence");
123 let separator = text
124 .get(range.start..range.end)
125 .expect("invalid character sequence");
126 cursor = range.end;
127 return Some(Either::Right(
128 [(offset, prev_section), (range.start, separator)].into_iter(),
129 ));
130 }
131 }
132 }
133 })
134 .flatten()
135 .filter(|(_, s)| !s.is_empty())
136 }
137}
138
139#[derive(Debug)]
142struct SemanticSplitRanges<Level>
143where
144 Level: SemanticLevel,
145{
146 cursor: usize,
149 ranges: Vec<(Level, Range<usize>)>,
151}
152
153impl<Level> SemanticSplitRanges<Level>
154where
155 Level: SemanticLevel,
156{
157 fn new(mut ranges: Vec<(Level, Range<usize>)>) -> Self {
158 ranges.sort_unstable_by(|(_, a), (_, b)| {
160 a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end))
161 });
162 Self { cursor: 0, ranges }
163 }
164
165 fn ranges_after_offset(
167 &self,
168 offset: usize,
169 ) -> impl Iterator<Item = (Level, Range<usize>)> + '_ {
170 self.ranges[self.cursor..]
171 .iter()
172 .filter(move |(_, sep)| sep.start >= offset)
173 .map(|(l, r)| (*l, r.start..r.end))
174 }
175 fn level_ranges_after_offset(
177 &self,
178 offset: usize,
179 level: Level,
180 ) -> impl Iterator<Item = (Level, Range<usize>)> + '_ {
181 let first_item = self
184 .ranges_after_offset(offset)
185 .position(|(l, _)| l == level)
186 .and_then(|i| {
187 self.ranges_after_offset(offset)
188 .skip(i)
189 .coalesce(|(a_level, a_range), (b_level, b_range)| {
190 if a_level == b_level && a_range.start == b_range.start && i == 0 {
192 Ok((b_level, b_range))
193 } else {
194 Err(((a_level, a_range), (b_level, b_range)))
195 }
196 })
197 .next()
199 });
200 self.ranges_after_offset(offset)
202 .filter(move |(l, _)| l >= &level)
203 .skip_while(move |(l, r)| {
204 first_item.as_ref().is_some_and(|(_, fir)| {
205 (l > &level && r.contains(&fir.start))
206 || (l == &level && r.start == fir.start && r.end > fir.end)
207 })
208 })
209 }
210
211 fn levels_in_remaining_text(&self, offset: usize) -> impl Iterator<Item = Level> + '_ {
214 self.ranges_after_offset(offset)
215 .map(|(l, _)| l)
216 .sorted()
217 .dedup()
218 }
219
220 fn semantic_chunks<'splitter, 'text: 'splitter>(
222 &'splitter self,
223 offset: usize,
224 text: &'text str,
225 semantic_level: Level,
226 ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
227 Level::sections(
228 text,
229 self.level_ranges_after_offset(offset, semantic_level)
230 .map(move |(l, sep)| (l, sep.start - offset..sep.end - offset)),
231 )
232 .map(move |(i, str)| (offset + i, str))
233 }
234
235 fn update_cursor(&mut self, cursor: usize) {
237 self.cursor += self.ranges[self.cursor..]
238 .iter()
239 .position(|(_, range)| range.start >= cursor)
240 .unwrap_or_else(|| self.ranges.len() - self.cursor);
241 }
242}
243
244#[derive(Debug)]
246struct TextChunks<'text, 'sizer, Sizer, Level>
247where
248 Sizer: ChunkSizer,
249 Level: SemanticLevel,
250{
251 capacity: ChunkCapacity,
253 chunk_sizer: MemoizedChunkSizer<'sizer, Sizer>,
255 chunk_stats: ChunkStats,
257 cursor: usize,
259 next_sections: Vec<(usize, &'text str)>,
261 overlap: ChunkCapacity,
263 prev_item_end: usize,
265 semantic_split: SemanticSplitRanges<Level>,
267 text: &'text str,
269 trim: Trim,
271}
272
273impl<'sizer, 'text: 'sizer, Sizer, Level> TextChunks<'text, 'sizer, Sizer, Level>
274where
275 Sizer: ChunkSizer,
276 Level: SemanticLevel,
277{
278 fn new(
281 chunk_config: &'sizer ChunkConfig<Sizer>,
282 text: &'text str,
283 offsets: Vec<(Level, Range<usize>)>,
284 trim: Trim,
285 ) -> Self {
286 let ChunkConfig {
287 capacity,
288 overlap,
289 sizer,
290 trim: trim_enabled,
291 } = chunk_config;
292 Self {
293 capacity: *capacity,
294 chunk_sizer: MemoizedChunkSizer::new(sizer),
295 chunk_stats: ChunkStats::new(),
296 cursor: 0,
297 next_sections: Vec::new(),
298 overlap: (*overlap).into(),
299 prev_item_end: 0,
300 semantic_split: SemanticSplitRanges::new(offsets),
301 text,
302 trim: if *trim_enabled { trim } else { Trim::None },
303 }
304 }
305
306 fn next_chunk(&mut self) -> Option<(usize, &'text str)> {
310 self.semantic_split.update_cursor(self.cursor);
311 let low = self.update_next_sections();
312 let (start, end) = self.binary_search_next_chunk(low)?;
313 let chunk = self.text.get(start..end)?;
314 self.chunk_stats.update_max_chunk_size(end - start);
315
316 self.chunk_sizer.clear_cache();
318 self.update_cursor(end);
320
321 Some(self.trim.trim(start, chunk))
323 }
324
325 fn binary_search_next_chunk(&mut self, mut low: usize) -> Option<(usize, usize)> {
327 let start = self.cursor;
328 let mut end = self.cursor;
329 let mut equals_found = false;
330 let mut high = self.next_sections.len().saturating_sub(1);
331 let mut successful_index = None;
332 let mut successful_chunk_size = None;
333
334 while low <= high {
335 let mid = low + (high - low) / 2;
336 let (offset, str) = self.next_sections[mid];
337 let text_end = offset + str.len();
338 let chunk = self.text.get(start..text_end)?;
339 let chunk_size = self.chunk_sizer.chunk_size(start, chunk, self.trim);
340 let fits = self.capacity.fits(chunk_size);
341
342 match fits {
343 Ordering::Less => {
344 if text_end > end {
346 end = text_end;
347 successful_index = Some(mid);
348 successful_chunk_size = Some(chunk_size);
349 }
350 }
351 Ordering::Equal => {
352 if text_end < end || !equals_found {
354 end = text_end;
355 successful_index = Some(mid);
356 successful_chunk_size = Some(chunk_size);
357 }
358 equals_found = true;
359 }
360 Ordering::Greater => {
361 if mid == 0 && start == end {
363 end = text_end;
364 successful_index = Some(mid);
365 successful_chunk_size = Some(chunk_size);
366 }
367 }
368 }
369
370 if fits.is_lt() {
372 low = mid + 1;
373 } else if mid > 0 {
374 high = mid - 1;
375 } else {
376 break;
378 }
379 }
380
381 if let (Some(successful_index), Some(chunk_size)) =
382 (successful_index, successful_chunk_size)
383 {
384 let mut range = successful_index..self.next_sections.len();
385 range.next();
387
388 for index in range {
389 let (offset, str) = self.next_sections[index];
390 let text_end = offset + str.len();
391 let chunk = self.text.get(start..text_end)?;
392 let size = self.chunk_sizer.chunk_size(start, chunk, self.trim);
393 if size <= chunk_size {
394 if text_end > end {
395 end = text_end;
396 }
397 } else {
398 break;
399 }
400 }
401 }
402
403 Some((start, end))
404 }
405
406 fn update_cursor(&mut self, end: usize) {
409 if self.overlap.max == 0 {
410 self.cursor = end;
411 return;
412 }
413
414 let mut start = end;
416 let mut low = 0;
417 let mut high = match self
419 .next_sections
420 .binary_search_by_key(&end, |(offset, str)| offset + str.len())
421 {
422 Ok(i) | Err(i) => i,
423 };
424
425 while low <= high {
426 let mid = low + (high - low) / 2;
427 let (offset, _) = self.next_sections[mid];
428 let chunk_size = self.chunk_sizer.chunk_size(
429 offset,
430 self.text.get(offset..end).expect("Invalid range"),
431 self.trim,
432 );
433 let fits = self.overlap.fits(chunk_size);
434
435 if fits.is_le() && offset < start && offset > self.cursor {
437 start = offset;
438 }
439
440 if fits.is_lt() && mid > 0 {
442 high = mid - 1;
443 } else {
444 low = mid + 1;
445 }
446 }
447
448 self.cursor = start;
449 }
450
451 #[expect(clippy::too_many_lines)]
455 fn update_next_sections(&mut self) -> usize {
456 self.next_sections.clear();
458
459 let remaining_text = self.text.get(self.cursor..).unwrap();
460
461 let (semantic_level, mut max_offset) = self.chunk_sizer.find_correct_level(
462 self.cursor,
463 &self.capacity,
464 self.semantic_split
465 .levels_in_remaining_text(self.cursor)
466 .filter_map(|level| {
467 self.semantic_split
468 .semantic_chunks(self.cursor, remaining_text, level)
469 .next()
470 .map(|(_, str)| (level, str))
471 }),
472 self.trim,
473 );
474
475 let sections = if let Some(semantic_level) = semantic_level {
476 Either::Left(self.semantic_split.semantic_chunks(
477 self.cursor,
478 remaining_text,
479 semantic_level,
480 ))
481 } else {
482 let (semantic_level, fallback_max_offset) = self.chunk_sizer.find_correct_level(
483 self.cursor,
484 &self.capacity,
485 FallbackLevel::iter().filter_map(|level| {
486 level
487 .sections(remaining_text)
488 .next()
489 .map(|(_, str)| (level, str))
490 }),
491 self.trim,
492 );
493
494 max_offset = match (fallback_max_offset, max_offset) {
495 (Some(fallback), Some(max)) => Some(fallback.min(max)),
496 (fallback, max) => fallback.or(max),
497 };
498
499 let fallback_level = semantic_level.unwrap_or(FallbackLevel::Char);
500
501 Either::Right(
502 fallback_level
503 .sections(remaining_text)
504 .map(|(offset, text)| (self.cursor + offset, text)),
505 )
506 };
507
508 let mut sections = sections
509 .take_while(move |(offset, _)| max_offset.is_none_or(|max| *offset <= max))
510 .filter(|(_, str)| !str.is_empty());
511
512 let mut low = 0;
515 let mut prev_equals: Option<usize> = None;
516 let max = self.capacity.max;
517 let mut target_offset = self.chunk_stats.max_chunk_size.unwrap_or(max);
518
519 loop {
520 let prev_num = self.next_sections.len();
521 for (offset, str) in sections.by_ref() {
522 self.next_sections.push((offset, str));
523 if offset + str.len() > (self.cursor.saturating_add(target_offset)) {
524 break;
525 }
526 }
527 let new_num = self.next_sections.len();
528 if new_num - prev_num == 0 {
530 break;
531 }
532
533 if let Some(&(offset, str)) = self.next_sections.last() {
535 let text_end = offset + str.len();
536 if (text_end - self.cursor) < target_offset {
537 break;
538 }
539 let chunk_size = self.chunk_sizer.chunk_size(
540 offset,
541 self.text.get(self.cursor..text_end).expect("Invalid range"),
542 self.trim,
543 );
544 let fits = self.capacity.fits(chunk_size);
545
546 if fits.is_le() {
547 let final_offset = offset + str.len() - self.cursor;
548 let size = chunk_size.max(1);
549 let diff = (max - size).max(1);
550 let avg_size = final_offset.div_ceil(size);
551
552 target_offset = final_offset
553 .saturating_add(diff.saturating_mul(avg_size))
554 .saturating_add(final_offset.div_ceil(10));
555 }
556
557 match fits {
558 Ordering::Less => {
559 low = new_num.saturating_sub(1);
561 }
562 Ordering::Equal => {
563 if let Some(prev) = prev_equals {
566 if prev < chunk_size {
567 break;
568 }
569 }
570 prev_equals = Some(chunk_size);
571 }
572 Ordering::Greater => {
573 break;
574 }
575 }
576 }
577 }
578
579 low
580 }
581}
582
583impl<'sizer, 'text: 'sizer, Sizer, Level> Iterator for TextChunks<'text, 'sizer, Sizer, Level>
584where
585 Sizer: ChunkSizer,
586 Level: SemanticLevel,
587{
588 type Item = (usize, &'text str);
589
590 fn next(&mut self) -> Option<Self::Item> {
591 loop {
592 if self.cursor >= self.text.len() {
594 return None;
595 }
596
597 match self.next_chunk()? {
598 (_, "") => {}
601 c => {
602 let item_end = c.0 + c.1.len();
603 if item_end <= self.prev_item_end {
605 continue;
606 }
607 self.prev_item_end = item_end;
608 return Some(c);
609 }
610 }
611 }
612 }
613}
614
615#[derive(Debug, Clone, Copy, PartialEq, Eq)]
617pub struct ChunkCharIndex<'text> {
618 pub chunk: &'text str,
620 pub byte_offset: usize,
622 pub char_offset: usize,
624}
625
626#[derive(Debug)]
628struct TextChunksWithCharIndices<'text, 'sizer, Sizer, Level>
629where
630 Sizer: ChunkSizer,
631 Level: SemanticLevel,
632{
633 text: &'text str,
635 text_chunks: TextChunks<'text, 'sizer, Sizer, Level>,
637 byte_offset: usize,
639 char_offset: usize,
641}
642
643impl<'sizer, 'text: 'sizer, Sizer, Level> TextChunksWithCharIndices<'text, 'sizer, Sizer, Level>
644where
645 Sizer: ChunkSizer,
646 Level: SemanticLevel,
647{
648 fn new(
651 chunk_config: &'sizer ChunkConfig<Sizer>,
652 text: &'text str,
653 offsets: Vec<(Level, Range<usize>)>,
654 trim: Trim,
655 ) -> Self {
656 Self {
657 text,
658 text_chunks: TextChunks::new(chunk_config, text, offsets, trim),
659 byte_offset: 0,
660 char_offset: 0,
661 }
662 }
663}
664
665impl<'sizer, 'text: 'sizer, Sizer, Level> Iterator
666 for TextChunksWithCharIndices<'text, 'sizer, Sizer, Level>
667where
668 Sizer: ChunkSizer,
669 Level: SemanticLevel,
670{
671 type Item = ChunkCharIndex<'text>;
672
673 fn next(&mut self) -> Option<Self::Item> {
674 let (byte_offset, chunk) = self.text_chunks.next()?;
675 let preceding_text = self
676 .text
677 .get(self.byte_offset..byte_offset)
678 .expect("Invalid byte sequence");
679 self.byte_offset = byte_offset;
680 self.char_offset += preceding_text.chars().count();
681 Some(ChunkCharIndex {
682 chunk,
683 byte_offset,
684 char_offset: self.char_offset,
685 })
686 }
687}
688
689#[derive(Debug, Default)]
691struct ChunkStats {
692 max_chunk_size: Option<usize>,
694}
695
696impl ChunkStats {
697 fn new() -> Self {
698 Self::default()
699 }
700
701 fn update_max_chunk_size(&mut self, size: usize) {
703 self.max_chunk_size = self.max_chunk_size.map(|s| s.max(size)).or(Some(size));
704 }
705}
706
707#[cfg(test)]
708mod tests {
709 use super::*;
710
711 #[test]
712 fn chunk_stats_empty() {
713 let stats = ChunkStats::new();
714 assert_eq!(stats.max_chunk_size, None);
715 }
716
717 #[test]
718 fn chunk_stats_one() {
719 let mut stats = ChunkStats::new();
720 stats.update_max_chunk_size(10);
721 assert_eq!(stats.max_chunk_size, Some(10));
722 }
723
724 #[test]
725 fn chunk_stats_multiple() {
726 let mut stats = ChunkStats::new();
727 stats.update_max_chunk_size(10);
728 stats.update_max_chunk_size(20);
729 stats.update_max_chunk_size(30);
730 assert_eq!(stats.max_chunk_size, Some(30));
731 }
732
733 impl SemanticLevel for usize {}
734
735 #[test]
736 fn semantic_ranges_are_sorted() {
737 let ranges = SemanticSplitRanges::new(vec![(0, 0..1), (1, 0..2), (0, 1..2), (2, 0..4)]);
738
739 assert_eq!(
740 ranges.ranges,
741 vec![(2, 0..4), (1, 0..2), (0, 0..1), (0, 1..2)]
742 );
743 }
744
745 #[test]
746 fn semantic_ranges_skip_previous_ranges() {
747 let mut ranges = SemanticSplitRanges::new(vec![(0, 0..1), (1, 0..2), (0, 1..2), (2, 0..4)]);
748
749 ranges.update_cursor(1);
750
751 assert_eq!(
752 ranges.ranges_after_offset(0).collect::<Vec<_>>(),
753 vec![(0, 1..2)]
754 );
755 }
756}