#![deny(unsafe_code)]
#![warn(missing_docs)]
#![warn(rust_2018_idioms)]
use rayon::prelude::*;
use regex::Regex;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tiktoken_rs::CoreBPE;
pub type Result<T> = std::result::Result<T, ChunkerError>;
#[derive(Error, Debug)]
pub enum ChunkerError {
#[error("unknown encoding: {0} (expected cl100k_base or o200k_base)")]
UnknownEncoding(String),
#[error("invalid config: {0}")]
InvalidConfig(String),
#[error("tiktoken-rs error: {0}")]
Tiktoken(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ChunkConfig {
pub max_tokens: usize,
pub overlap_tokens: usize,
pub min_tokens: usize,
pub encoding: String,
#[serde(default)]
pub preserve_paragraphs: bool,
}
impl Default for ChunkConfig {
fn default() -> Self {
Self {
max_tokens: 512,
overlap_tokens: 0,
min_tokens: 1,
encoding: "cl100k_base".to_string(),
preserve_paragraphs: false,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Chunk {
pub text: String,
pub start: usize,
pub end: usize,
pub token_count: usize,
}
pub struct Chunker {
bpe: CoreBPE,
cfg: ChunkConfig,
sentence_re: Regex,
}
impl Chunker {
pub fn new(cfg: ChunkConfig) -> Result<Self> {
if cfg.max_tokens == 0 {
return Err(ChunkerError::InvalidConfig("max_tokens must be > 0".into()));
}
if cfg.overlap_tokens >= cfg.max_tokens {
return Err(ChunkerError::InvalidConfig(format!(
"overlap_tokens ({}) must be < max_tokens ({})",
cfg.overlap_tokens, cfg.max_tokens
)));
}
if cfg.min_tokens > cfg.max_tokens {
return Err(ChunkerError::InvalidConfig(format!(
"min_tokens ({}) must be <= max_tokens ({})",
cfg.min_tokens, cfg.max_tokens
)));
}
let bpe = match cfg.encoding.as_str() {
"cl100k_base" => {
tiktoken_rs::cl100k_base().map_err(|e| ChunkerError::Tiktoken(e.to_string()))?
}
"o200k_base" => {
tiktoken_rs::o200k_base().map_err(|e| ChunkerError::Tiktoken(e.to_string()))?
}
other => return Err(ChunkerError::UnknownEncoding(other.to_string())),
};
let sentence_re = Regex::new(
r"(?P<term>[.!?])(?P<close>[\)\]\}\u{201d}\u{2019}\u{00bb}'\x22]?)\s+(?P<next>[A-Z\u{00c0}-\u{00de}\u{2018}\u{201c}\(\[\{])"
).expect("sentence regex compiles");
Ok(Self {
bpe,
cfg,
sentence_re,
})
}
pub fn split(&self, text: &str) -> Result<Vec<Chunk>> {
if self.cfg.preserve_paragraphs {
let mut all: Vec<Chunk> = Vec::new();
let paragraphs = split_paragraphs(text);
for (p_start, p_end) in paragraphs {
let para_text = &text[p_start..p_end];
let mut chunks = self.split_internal(para_text)?;
for c in &mut chunks {
c.start += p_start;
c.end += p_start;
}
all.extend(chunks);
}
return Ok(all);
}
self.split_internal(text)
}
fn split_internal(&self, text: &str) -> Result<Vec<Chunk>> {
let sentences = self.split_sentences(text);
if sentences.is_empty() {
return Ok(Vec::new());
}
let mut s_tokens: Vec<Vec<u32>> = Vec::with_capacity(sentences.len());
for &(start, end) in &sentences {
s_tokens.push(self.bpe.encode_ordinary(&text[start..end]));
}
let mut raw: Vec<(Vec<u32>, usize, usize)> = Vec::new();
let mut cur_tokens: Vec<u32> = Vec::new();
let mut cur_start: Option<usize> = None;
let mut cur_end: usize = 0;
for (i, &(s_start, s_end)) in sentences.iter().enumerate() {
let stoks = &s_tokens[i];
if stoks.len() > self.cfg.max_tokens {
if !cur_tokens.is_empty() {
raw.push((std::mem::take(&mut cur_tokens), cur_start.unwrap(), cur_end));
cur_start = None;
}
self.slice_long_sentence(stoks, s_start, s_end, &mut raw);
continue;
}
if cur_tokens.len() + stoks.len() > self.cfg.max_tokens && !cur_tokens.is_empty() {
raw.push((std::mem::take(&mut cur_tokens), cur_start.unwrap(), cur_end));
cur_start = None;
}
if cur_start.is_none() {
cur_start = Some(s_start);
}
cur_tokens.extend_from_slice(stoks);
cur_end = s_end;
}
if !cur_tokens.is_empty() {
raw.push((cur_tokens, cur_start.unwrap(), cur_end));
}
let mut out: Vec<Chunk> = Vec::with_capacity(raw.len());
let mut prev_tail: Vec<u32> = Vec::new();
for (toks, start, end) in raw {
let mut full = Vec::with_capacity(prev_tail.len() + toks.len());
full.extend_from_slice(&prev_tail);
full.extend_from_slice(&toks);
let text = self
.bpe
.decode(full.clone())
.map_err(|e| ChunkerError::Tiktoken(e.to_string()))?;
prev_tail = if self.cfg.overlap_tokens > 0 && toks.len() > self.cfg.overlap_tokens {
toks[toks.len() - self.cfg.overlap_tokens..].to_vec()
} else if self.cfg.overlap_tokens > 0 {
toks.clone()
} else {
Vec::new()
};
let token_count = full.len();
if token_count < self.cfg.min_tokens {
continue;
}
out.push(Chunk {
text,
start,
end,
token_count,
});
}
Ok(out)
}
pub fn split_many(&self, texts: &[&str], parallel: bool) -> Result<Vec<Vec<Chunk>>> {
if parallel {
texts.par_iter().map(|t| self.split(t)).collect()
} else {
texts.iter().map(|t| self.split(t)).collect()
}
}
fn split_sentences(&self, text: &str) -> Vec<(usize, usize)> {
if text.is_empty() {
return Vec::new();
}
let mut spans: Vec<(usize, usize)> = Vec::new();
let mut last = 0usize;
for caps in self.sentence_re.captures_iter(text) {
let m = caps.name("term").unwrap();
let cut = caps
.name("close")
.filter(|c| !c.as_str().is_empty())
.map(|c| c.end())
.unwrap_or_else(|| m.end());
if cut <= last {
continue;
}
if is_abbreviation(&text[..m.end()]) {
continue;
}
spans.push((last, cut));
let mut next_start = cut;
while next_start < text.len() && text.as_bytes()[next_start].is_ascii_whitespace() {
next_start += 1;
}
last = next_start;
}
if last < text.len() {
spans.push((last, text.len()));
}
spans.retain(|&(s, e)| s < e && !text[s..e].trim().is_empty());
spans
}
fn slice_long_sentence(
&self,
toks: &[u32],
s_start: usize,
s_end: usize,
out: &mut Vec<(Vec<u32>, usize, usize)>,
) {
let mut i = 0usize;
while i < toks.len() {
let end = (i + self.cfg.max_tokens).min(toks.len());
out.push((toks[i..end].to_vec(), s_start, s_end));
i = end;
}
}
}
fn split_paragraphs(text: &str) -> Vec<(usize, usize)> {
if text.is_empty() {
return Vec::new();
}
let mut spans: Vec<(usize, usize)> = Vec::new();
let mut start: Option<usize> = None;
let bytes = text.as_bytes();
let mut i = 0usize;
while i < bytes.len() {
if bytes[i] == b'\n' {
let mut nl_end = i;
while nl_end < bytes.len() && bytes[nl_end] == b'\n' {
nl_end += 1;
}
if nl_end - i >= 2 {
if let Some(s) = start.take() {
spans.push((s, i));
}
i = nl_end;
continue;
}
}
if start.is_none() {
start = Some(i);
}
i += 1;
}
if let Some(s) = start {
spans.push((s, text.len()));
}
spans
}
fn is_abbreviation(prefix: &str) -> bool {
const ABBREVS: &[&str] = &[
"mr.", "mrs.", "ms.", "dr.", "st.", "jr.", "sr.", "prof.", "rev.", "vs.", "etc.", "e.g.",
"i.e.", "fig.", "cf.", "no.", "vol.", "ch.", "sec.",
];
let lower_tail: String = prefix
.chars()
.rev()
.take(8)
.collect::<String>()
.chars()
.rev()
.collect::<String>()
.to_lowercase();
ABBREVS.iter().any(|a| lower_tail.ends_with(a))
}
#[cfg(test)]
mod tests {
use super::*;
fn cfg(max_tokens: usize) -> ChunkConfig {
ChunkConfig {
max_tokens,
overlap_tokens: 0,
min_tokens: 1,
encoding: "cl100k_base".to_string(),
preserve_paragraphs: false,
}
}
#[test]
fn empty_input_yields_no_chunks() {
let c = Chunker::new(cfg(100)).unwrap();
assert!(c.split("").unwrap().is_empty());
}
#[test]
fn short_text_one_chunk() {
let c = Chunker::new(cfg(100)).unwrap();
let r = c.split("hello world").unwrap();
assert_eq!(r.len(), 1);
assert_eq!(r[0].text, "hello world");
}
#[test]
fn splits_at_sentence_boundary_under_budget() {
let c = Chunker::new(cfg(8)).unwrap();
let text = "Alpha beta gamma. Delta epsilon zeta. Eta theta iota.";
let chunks = c.split(text).unwrap();
assert!(
chunks.len() >= 2,
"expected >=2 chunks, got {}",
chunks.len()
);
for ch in &chunks {
assert!(
ch.token_count <= 8,
"chunk over budget: {} tokens",
ch.token_count
);
}
}
#[test]
fn long_sentence_falls_back_to_token_slicing() {
let c = Chunker::new(cfg(5)).unwrap();
let text = "the quick brown fox jumps over the lazy dog and runs through fields";
let chunks = c.split(text).unwrap();
assert!(chunks.len() > 1);
for ch in &chunks {
assert!(ch.token_count <= 5);
}
}
#[test]
fn overlap_re_prepends_tail_tokens() {
let c = Chunker::new(ChunkConfig {
max_tokens: 6,
overlap_tokens: 2,
min_tokens: 1,
encoding: "cl100k_base".to_string(),
preserve_paragraphs: false,
})
.unwrap();
let text = "Alpha beta gamma. Delta epsilon zeta. Eta theta iota.";
let chunks = c.split(text).unwrap();
assert!(chunks.len() >= 2);
for ch in chunks.iter().skip(1) {
assert!(ch.token_count <= 6 + 2);
}
}
#[test]
fn min_tokens_drops_short_chunks() {
let c = Chunker::new(ChunkConfig {
max_tokens: 1000,
overlap_tokens: 0,
min_tokens: 50,
encoding: "cl100k_base".to_string(),
preserve_paragraphs: false,
})
.unwrap();
let text = "tiny.";
assert!(c.split(text).unwrap().is_empty());
}
#[test]
fn invalid_config_overlap_ge_max() {
let bad = ChunkConfig {
max_tokens: 10,
overlap_tokens: 10,
..Default::default()
};
assert!(Chunker::new(bad).is_err());
}
#[test]
fn invalid_config_zero_max() {
let bad = ChunkConfig {
max_tokens: 0,
..Default::default()
};
assert!(Chunker::new(bad).is_err());
}
#[test]
fn unknown_encoding_rejected() {
let bad = ChunkConfig {
encoding: "nope_base".to_string(),
..Default::default()
};
assert!(matches!(
Chunker::new(bad),
Err(ChunkerError::UnknownEncoding(_))
));
}
#[test]
fn abbreviation_does_not_split_sentence() {
let c = Chunker::new(cfg(1000)).unwrap();
let text = "Dr. Smith arrived. He said hello.";
let sentences = c.split_sentences(text);
assert_eq!(sentences.len(), 2, "got: {:?}", sentences);
}
#[test]
fn split_many_serial_and_parallel_match() {
let c = Chunker::new(cfg(10)).unwrap();
let texts = vec!["Alpha beta gamma.", "Delta. Epsilon. Zeta."];
let serial = c.split_many(&texts, false).unwrap();
let parallel = c.split_many(&texts, true).unwrap();
assert_eq!(serial, parallel);
}
#[test]
fn chunk_text_decodes_to_token_count() {
let c = Chunker::new(cfg(10)).unwrap();
let text = "The quick brown fox jumps over the lazy dog.";
let chunks = c.split(text).unwrap();
let bpe = tiktoken_rs::cl100k_base().unwrap();
for ch in &chunks {
let actual = bpe.encode_ordinary(&ch.text).len();
assert_eq!(actual, ch.token_count);
}
}
#[test]
fn unicode_input_handled() {
let c = Chunker::new(cfg(100)).unwrap();
let text = "你好世界. Hello world. 🌍 done.";
let r = c.split(text).unwrap();
assert!(!r.is_empty());
for ch in &r {
assert!(!ch.text.is_empty());
}
}
#[test]
fn min_tokens_filters_single_word_input() {
let c = Chunker::new(ChunkConfig {
max_tokens: 100,
overlap_tokens: 0,
min_tokens: 5,
encoding: "cl100k_base".to_string(),
preserve_paragraphs: false,
})
.unwrap();
let r = c.split("hi").unwrap();
assert!(r.is_empty());
}
#[test]
fn preserve_paragraphs_emits_per_paragraph_chunks() {
let c = Chunker::new(ChunkConfig {
max_tokens: 100,
overlap_tokens: 0,
min_tokens: 1,
encoding: "cl100k_base".to_string(),
preserve_paragraphs: true,
})
.unwrap();
let text = "First paragraph here.\n\nSecond paragraph here.";
let r = c.split(text).unwrap();
assert_eq!(r.len(), 2);
for ch in &r {
assert!(ch.end <= text.len());
assert!(text.get(ch.start..ch.end).is_some());
}
}
#[test]
fn preserve_paragraphs_respects_token_budget_per_paragraph() {
let c = Chunker::new(ChunkConfig {
max_tokens: 5,
overlap_tokens: 0,
min_tokens: 1,
encoding: "cl100k_base".to_string(),
preserve_paragraphs: true,
})
.unwrap();
let text = "alpha beta gamma delta epsilon zeta\n\nshort.";
let r = c.split(text).unwrap();
assert!(r.len() >= 2);
}
#[test]
fn split_paragraphs_helper_returns_disjoint_spans() {
let text = "para 1\n\n\npara 2\n\npara 3";
let spans = split_paragraphs(text);
assert_eq!(spans.len(), 3);
assert_eq!(&text[spans[0].0..spans[0].1], "para 1");
assert_eq!(&text[spans[1].0..spans[1].1], "para 2");
assert_eq!(&text[spans[2].0..spans[2].1], "para 3");
}
#[test]
fn preserve_paragraphs_default_off_keeps_existing_behavior() {
let c = Chunker::new(ChunkConfig {
max_tokens: 100,
overlap_tokens: 0,
min_tokens: 1,
encoding: "cl100k_base".to_string(),
preserve_paragraphs: false,
})
.unwrap();
let text = "First paragraph here.\n\nSecond paragraph here.";
let r = c.split(text).unwrap();
assert_eq!(r.len(), 1);
}
}