julienne 0.1.0

Range-preserving Rust text chunkers for retrieval and embedding pipelines
Documentation
use crate::character::validate_chunk_config;
use crate::chunk::{TextChunk, TextSpan};
use crate::error::ChunkError;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TokenSpan {
    pub start_byte: usize,
    pub end_byte: usize,
}

pub trait TokenBoundaryProvider: Clone + Send + Sync + 'static {
    fn token_spans(&self, input: &str) -> Result<Vec<TokenSpan>, ChunkError>;
}

#[derive(Debug, Clone)]
pub struct TokenChunker<P> {
    provider: P,
    chunk_size: usize,
    chunk_overlap: usize,
    strip_whitespace: bool,
}

impl<P> TokenChunker<P>
where
    P: TokenBoundaryProvider,
{
    pub fn new(provider: P, chunk_size: usize, chunk_overlap: usize) -> Result<Self, ChunkError> {
        validate_chunk_config(chunk_size, chunk_overlap)?;
        Ok(Self {
            provider,
            chunk_size,
            chunk_overlap,
            strip_whitespace: true,
        })
    }

    pub fn strip_whitespace(mut self, strip_whitespace: bool) -> Self {
        self.strip_whitespace = strip_whitespace;
        self
    }

    pub fn try_split_chunks<'a>(&self, input: &'a str) -> Result<Vec<TextChunk<'a>>, ChunkError> {
        let tokens = self.provider.token_spans(input)?;
        validate_token_spans(input, &tokens)?;
        if tokens.is_empty() {
            return Ok(Vec::new());
        }

        let mut chunks = Vec::new();
        let mut start = 0usize;
        while start < tokens.len() {
            let end = (start + self.chunk_size).min(tokens.len());
            let span = TextSpan::new(tokens[start].start_byte, tokens[end - 1].end_byte);
            if let Some(span) = if self.strip_whitespace {
                span.trim(input)
            } else {
                Some(span)
            } {
                chunks.push(TextChunk::from_byte_range(
                    input,
                    span.start,
                    span.end,
                    end - start,
                ));
            }
            if end == tokens.len() {
                break;
            }
            start = if self.chunk_overlap == 0 {
                end
            } else {
                end.saturating_sub(self.chunk_overlap)
            };
        }

        Ok(chunks)
    }

    pub fn try_chunks<'a>(
        &self,
        input: &'a str,
    ) -> Result<std::vec::IntoIter<TextChunk<'a>>, ChunkError> {
        Ok(self.try_split_chunks(input)?.into_iter())
    }

    pub fn try_split_text(&self, input: &str) -> Result<Vec<String>, ChunkError> {
        Ok(self
            .try_split_chunks(input)?
            .into_iter()
            .map(|chunk| chunk.text.to_string())
            .collect())
    }
}

fn validate_token_spans(input: &str, tokens: &[TokenSpan]) -> Result<(), ChunkError> {
    let mut previous_end = 0usize;
    for (index, token) in tokens.iter().enumerate() {
        if token.start_byte >= token.end_byte {
            return Err(ChunkError::invalid_configuration(format!(
                "token span at index {index} must be non-empty"
            )));
        }
        if token.end_byte > input.len() {
            return Err(ChunkError::invalid_configuration(format!(
                "token span at index {index} ends past input length"
            )));
        }
        if !input.is_char_boundary(token.start_byte) || !input.is_char_boundary(token.end_byte) {
            return Err(ChunkError::invalid_configuration(format!(
                "token span at index {index} is not on UTF-8 boundaries"
            )));
        }
        if token.start_byte < previous_end {
            return Err(ChunkError::invalid_configuration(format!(
                "token span at index {index} overlaps a previous token"
            )));
        }
        previous_end = token.end_byte;
    }
    Ok(())
}

#[cfg(feature = "tiktoken-rs")]
pub mod tiktoken {
    use std::sync::Arc;

    use super::{TokenBoundaryProvider, TokenSpan};
    use crate::error::ChunkError;

    #[derive(Clone)]
    pub struct TiktokenBoundaryProvider {
        bpe: Arc<tiktoken_rs::CoreBPE>,
    }

    impl TiktokenBoundaryProvider {
        pub fn new(bpe: tiktoken_rs::CoreBPE) -> Self {
            Self { bpe: Arc::new(bpe) }
        }
    }

    impl TokenBoundaryProvider for TiktokenBoundaryProvider {
        fn token_spans(&self, input: &str) -> Result<Vec<TokenSpan>, ChunkError> {
            let tokens = self.bpe.encode_with_special_tokens(input);
            let mut spans = Vec::with_capacity(tokens.len());
            let mut cursor = 0usize;
            for token in tokens {
                let token_text = self.bpe.decode(vec![token]).map_err(|err| {
                    ChunkError::invalid_configuration(format!("failed to decode token: {err}"))
                })?;
                let relative = input[cursor..].find(&token_text).ok_or_else(|| {
                    ChunkError::invalid_configuration("token text did not map to source input")
                })?;
                let start = cursor + relative;
                let end = start + token_text.len();
                spans.push(TokenSpan {
                    start_byte: start,
                    end_byte: end,
                });
                cursor = end;
            }
            Ok(spans)
        }
    }
}

#[cfg(feature = "tokenizers")]
pub mod huggingface {
    use std::sync::Arc;

    use super::{TokenBoundaryProvider, TokenSpan};
    use crate::error::ChunkError;

    #[derive(Clone)]
    pub struct HuggingFaceBoundaryProvider {
        tokenizer: Arc<tokenizers::Tokenizer>,
    }

    impl HuggingFaceBoundaryProvider {
        pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
            Self {
                tokenizer: Arc::new(tokenizer),
            }
        }
    }

    impl TokenBoundaryProvider for HuggingFaceBoundaryProvider {
        fn token_spans(&self, input: &str) -> Result<Vec<TokenSpan>, ChunkError> {
            let encoding = self
                .tokenizer
                .encode(input, false)
                .map_err(|err| ChunkError::invalid_configuration(err.to_string()))?;
            Ok(encoding
                .get_offsets()
                .iter()
                .map(|(start, end)| TokenSpan {
                    start_byte: *start,
                    end_byte: *end,
                })
                .collect())
        }
    }
}