use crate::character::validate_chunk_config;
use crate::chunk::{TextChunk, TextSpan};
use crate::error::ChunkError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TokenSpan {
pub start_byte: usize,
pub end_byte: usize,
}
pub trait TokenBoundaryProvider: Clone + Send + Sync + 'static {
fn token_spans(&self, input: &str) -> Result<Vec<TokenSpan>, ChunkError>;
}
#[derive(Debug, Clone)]
pub struct TokenChunker<P> {
provider: P,
chunk_size: usize,
chunk_overlap: usize,
strip_whitespace: bool,
}
impl<P> TokenChunker<P>
where
P: TokenBoundaryProvider,
{
pub fn new(provider: P, chunk_size: usize, chunk_overlap: usize) -> Result<Self, ChunkError> {
validate_chunk_config(chunk_size, chunk_overlap)?;
Ok(Self {
provider,
chunk_size,
chunk_overlap,
strip_whitespace: true,
})
}
pub fn strip_whitespace(mut self, strip_whitespace: bool) -> Self {
self.strip_whitespace = strip_whitespace;
self
}
pub fn try_split_chunks<'a>(&self, input: &'a str) -> Result<Vec<TextChunk<'a>>, ChunkError> {
let tokens = self.provider.token_spans(input)?;
validate_token_spans(input, &tokens)?;
if tokens.is_empty() {
return Ok(Vec::new());
}
let mut chunks = Vec::new();
let mut start = 0usize;
while start < tokens.len() {
let end = (start + self.chunk_size).min(tokens.len());
let span = TextSpan::new(tokens[start].start_byte, tokens[end - 1].end_byte);
if let Some(span) = if self.strip_whitespace {
span.trim(input)
} else {
Some(span)
} {
chunks.push(TextChunk::from_byte_range(
input,
span.start,
span.end,
end - start,
));
}
if end == tokens.len() {
break;
}
start = if self.chunk_overlap == 0 {
end
} else {
end.saturating_sub(self.chunk_overlap)
};
}
Ok(chunks)
}
pub fn try_chunks<'a>(
&self,
input: &'a str,
) -> Result<std::vec::IntoIter<TextChunk<'a>>, ChunkError> {
Ok(self.try_split_chunks(input)?.into_iter())
}
pub fn try_split_text(&self, input: &str) -> Result<Vec<String>, ChunkError> {
Ok(self
.try_split_chunks(input)?
.into_iter()
.map(|chunk| chunk.text.to_string())
.collect())
}
}
fn validate_token_spans(input: &str, tokens: &[TokenSpan]) -> Result<(), ChunkError> {
let mut previous_end = 0usize;
for (index, token) in tokens.iter().enumerate() {
if token.start_byte >= token.end_byte {
return Err(ChunkError::invalid_configuration(format!(
"token span at index {index} must be non-empty"
)));
}
if token.end_byte > input.len() {
return Err(ChunkError::invalid_configuration(format!(
"token span at index {index} ends past input length"
)));
}
if !input.is_char_boundary(token.start_byte) || !input.is_char_boundary(token.end_byte) {
return Err(ChunkError::invalid_configuration(format!(
"token span at index {index} is not on UTF-8 boundaries"
)));
}
if token.start_byte < previous_end {
return Err(ChunkError::invalid_configuration(format!(
"token span at index {index} overlaps a previous token"
)));
}
previous_end = token.end_byte;
}
Ok(())
}
#[cfg(feature = "tiktoken-rs")]
pub mod tiktoken {
use std::sync::Arc;
use super::{TokenBoundaryProvider, TokenSpan};
use crate::error::ChunkError;
#[derive(Clone)]
pub struct TiktokenBoundaryProvider {
bpe: Arc<tiktoken_rs::CoreBPE>,
}
impl TiktokenBoundaryProvider {
pub fn new(bpe: tiktoken_rs::CoreBPE) -> Self {
Self { bpe: Arc::new(bpe) }
}
}
impl TokenBoundaryProvider for TiktokenBoundaryProvider {
fn token_spans(&self, input: &str) -> Result<Vec<TokenSpan>, ChunkError> {
let tokens = self.bpe.encode_with_special_tokens(input);
let mut spans = Vec::with_capacity(tokens.len());
let mut cursor = 0usize;
for token in tokens {
let token_text = self.bpe.decode(vec![token]).map_err(|err| {
ChunkError::invalid_configuration(format!("failed to decode token: {err}"))
})?;
let relative = input[cursor..].find(&token_text).ok_or_else(|| {
ChunkError::invalid_configuration("token text did not map to source input")
})?;
let start = cursor + relative;
let end = start + token_text.len();
spans.push(TokenSpan {
start_byte: start,
end_byte: end,
});
cursor = end;
}
Ok(spans)
}
}
}
#[cfg(feature = "tokenizers")]
pub mod huggingface {
use std::sync::Arc;
use super::{TokenBoundaryProvider, TokenSpan};
use crate::error::ChunkError;
#[derive(Clone)]
pub struct HuggingFaceBoundaryProvider {
tokenizer: Arc<tokenizers::Tokenizer>,
}
impl HuggingFaceBoundaryProvider {
pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
Self {
tokenizer: Arc::new(tokenizer),
}
}
}
impl TokenBoundaryProvider for HuggingFaceBoundaryProvider {
fn token_spans(&self, input: &str) -> Result<Vec<TokenSpan>, ChunkError> {
let encoding = self
.tokenizer
.encode(input, false)
.map_err(|err| ChunkError::invalid_configuration(err.to_string()))?;
Ok(encoding
.get_offsets()
.iter()
.map(|(start, end)| TokenSpan {
start_byte: *start,
end_byte: *end,
})
.collect())
}
}
}