mod chunk;
mod markdown;
pub(crate) mod recursive;
#[cfg(feature = "semantic")]
mod semantic;
pub use chunk::Chunk;
pub(crate) fn byte_offset_of(sub: &str, parent: &str) -> usize {
let sub_ptr = sub.as_ptr() as usize;
let parent_ptr = parent.as_ptr() as usize;
debug_assert!(
sub_ptr >= parent_ptr && sub_ptr <= parent_ptr + parent.len(),
"substring pointer is not within parent string bounds"
);
sub_ptr.saturating_sub(parent_ptr)
}
#[derive(Debug)]
#[non_exhaustive]
pub enum Error {
#[cfg(feature = "semantic")]
Embed(embedrs::Error),
}
impl std::fmt::Display for Error {
#[allow(unused_variables)]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match *self {
#[cfg(feature = "semantic")]
Error::Embed(ref e) => write!(f, "embedding error: {e}"),
#[cfg(not(feature = "semantic"))]
_ => unreachable!("Error is uninhabited without semantic feature"),
}
}
}
impl std::error::Error for Error {}
pub type Result<T> = std::result::Result<T, Error>;
pub fn chunk(text: &str) -> ChunkBuilder<'_> {
ChunkBuilder {
text,
max_tokens: 512,
overlap: 0,
model_name: None,
encoding_name: None,
strategy: Strategy::Recursive,
#[cfg(feature = "semantic")]
semantic_client: None,
#[cfg(feature = "semantic")]
semantic_threshold: 0.5,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Strategy {
Recursive,
Markdown,
#[cfg(feature = "semantic")]
Semantic,
}
pub struct ChunkBuilder<'a> {
text: &'a str,
max_tokens: usize,
overlap: usize,
model_name: Option<&'a str>,
encoding_name: Option<&'a str>,
strategy: Strategy,
#[cfg(feature = "semantic")]
semantic_client: Option<&'a embedrs::Client>,
#[cfg(feature = "semantic")]
semantic_threshold: f64,
}
impl<'a> ChunkBuilder<'a> {
pub fn max_tokens(mut self, n: usize) -> Self {
self.max_tokens = n.max(1);
self
}
pub fn overlap(mut self, tokens: usize) -> Self {
self.overlap = tokens;
self
}
pub fn model(mut self, model: &'a str) -> Self {
self.model_name = Some(model);
self
}
pub fn encoding(mut self, encoding: &'a str) -> Self {
self.encoding_name = Some(encoding);
self
}
pub fn markdown(mut self) -> Self {
self.strategy = Strategy::Markdown;
self
}
#[cfg(feature = "semantic")]
pub fn semantic(mut self, client: &'a embedrs::Client) -> Self {
self.strategy = Strategy::Semantic;
self.semantic_client = Some(client);
self
}
#[cfg(feature = "semantic")]
pub fn threshold(mut self, t: f64) -> Self {
self.semantic_threshold = t;
self
}
pub fn split(self) -> Vec<Chunk> {
let encoder = self.resolve_encoder();
match self.strategy {
Strategy::Recursive => recursive::split_recursive(
self.text,
0,
self.max_tokens,
self.overlap,
encoder,
&None,
),
Strategy::Markdown => {
markdown::split_markdown(self.text, self.max_tokens, self.overlap, encoder)
}
#[cfg(feature = "semantic")]
Strategy::Semantic => {
panic!(
"semantic strategy requires async: use .split_async().await instead of .split()"
)
}
}
}
#[cfg(feature = "semantic")]
pub async fn split_async(self) -> Result<Vec<Chunk>> {
let encoder = self.resolve_encoder();
match self.strategy {
Strategy::Semantic => {
let client = self
.semantic_client
.expect("semantic() must be called before split_async()");
semantic::split_semantic(
self.text,
self.max_tokens,
self.overlap,
encoder,
client,
self.semantic_threshold,
)
.await
}
_ => Ok(self.split()),
}
}
fn resolve_encoder(&self) -> &'static tiktoken::CoreBpe {
let default = || tiktoken::get_encoding("o200k_base").expect("o200k_base encoding");
if let Some(name) = self.encoding_name {
return tiktoken::get_encoding(name).unwrap_or_else(default);
}
if let Some(model) = self.model_name {
return tiktoken::encoding_for_model(model)
.or_else(|| tiktoken::get_encoding(model))
.unwrap_or_else(default);
}
default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chunk_short_text() {
let chunks = chunk("hello world").split();
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].content, "hello world");
assert_eq!(chunks[0].index, 0);
assert_eq!(chunks[0].start_byte, 0);
assert_eq!(chunks[0].end_byte, 11);
assert!(chunks[0].token_count > 0);
}
#[test]
fn chunk_empty_text() {
let chunks = chunk("").split();
assert!(chunks.is_empty());
}
#[test]
fn chunk_respects_max_tokens() {
let text = "The quick brown fox. ".repeat(100);
let chunks = chunk(&text).max_tokens(20).split();
for c in &chunks {
assert!(
c.token_count <= 20,
"chunk {} has {} tokens",
c.index,
c.token_count
);
}
}
#[test]
fn chunk_with_overlap() {
let text = "Sentence one. Sentence two. Sentence three. Sentence four. Sentence five. Sentence six.";
let chunks = chunk(text).max_tokens(10).overlap(3).split();
assert!(chunks.len() >= 2);
}
#[test]
fn chunk_max_tokens_minimum_one() {
let chunks = chunk("hello").max_tokens(0).split();
assert!(!chunks.is_empty());
}
#[test]
fn chunk_with_model() {
let chunks = chunk("hello world").model("gpt-4o").split();
assert_eq!(chunks.len(), 1);
}
#[test]
fn chunk_with_encoding() {
let chunks = chunk("hello world").encoding("cl100k_base").split();
assert_eq!(chunks.len(), 1);
}
#[test]
fn chunk_markdown_mode() {
let md = "# Title\n\nSome content.\n\n## Section\n\nMore content.\n";
let chunks = chunk(md).markdown().split();
assert!(chunks.len() >= 2);
assert_eq!(chunks[0].section.as_deref(), Some("# Title"));
}
#[test]
fn chunk_sequential_indices() {
let text = "Word. ".repeat(200);
let chunks = chunk(&text).max_tokens(10).split();
for (i, c) in chunks.iter().enumerate() {
assert_eq!(c.index, i);
}
}
#[test]
fn chunk_chinese_text() {
let text = "这是一段中文文本。它包含多个句子。每个句子都应该被正确分割。更多的内容在这里。还有更多。最后一句话。";
let chunks = chunk(text).max_tokens(10).split();
assert!(chunks.len() >= 2);
for c in &chunks {
assert!(c.token_count <= 10);
}
}
#[test]
fn chunk_japanese_text() {
let text =
"これは日本語のテキストです。複数の文が含まれています。正しく分割されるべきです。";
let chunks = chunk(text).max_tokens(10).split();
assert!(!chunks.is_empty());
for c in &chunks {
assert!(c.token_count <= 10);
}
}
#[test]
fn chunk_preserves_all_content() {
let text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph.";
let chunks = chunk(text).max_tokens(5).split();
let combined: String = chunks
.iter()
.map(|c| c.content.as_str())
.collect::<Vec<_>>()
.join("");
assert!(combined.contains("First"));
assert!(combined.contains("Second"));
assert!(combined.contains("Third"));
}
#[test]
fn chunk_large_document() {
let text = "Lorem ipsum dolor sit amet. ".repeat(1000);
let chunks = chunk(&text).max_tokens(100).split();
assert!(chunks.len() >= 10);
for c in &chunks {
assert!(c.token_count <= 100);
}
}
#[test]
fn chunk_single_token_max() {
let chunks = chunk("hello world foo bar").max_tokens(1).split();
assert!(chunks.len() >= 4);
for c in &chunks {
assert!(c.token_count <= 1);
}
}
#[test]
fn resolve_encoder_unknown_falls_back() {
let builder = chunk("test").model("nonexistent-model-xyz");
let enc = builder.resolve_encoder();
assert!(enc.count("hello") > 0);
}
#[test]
fn model_and_encoding_are_independent() {
let enc_cl100k = chunk("test")
.model("gpt-4o")
.encoding("cl100k_base")
.resolve_encoder();
let enc_o200k = chunk("test").model("gpt-4o").resolve_encoder();
let test_texts = [
"hello_world_123_test",
"foo::bar::baz::qux",
"αβγδεζηθ",
"1234567890",
];
let any_different = test_texts
.iter()
.any(|t| enc_cl100k.count(t) != enc_o200k.count(t));
assert!(
any_different,
"cl100k_base and o200k_base should produce different token counts for at least one test string"
);
}
#[test]
fn encoding_only_without_model() {
let builder = chunk("test").encoding("cl100k_base");
let enc = builder.resolve_encoder();
assert!(enc.count("hello") > 0);
}
#[test]
fn model_only_without_encoding() {
let builder = chunk("test").model("gpt-4o");
let enc = builder.resolve_encoder();
assert!(enc.count("hello") > 0);
}
}