use std::sync::Arc;
use std::{path::Path, pin::Pin};
mod code;
mod docx;
mod html;
mod markdown;
mod pdf;
mod text;
use async_trait::async_trait;
pub use code::CodeChunker;
pub use docx::DocxChunker;
pub use html::HtmlChunker;
pub use markdown::MarkdownChunker;
pub use pdf::PdfChunker;
pub use text::TextChunker;
use text_splitter::ChunkSizer;
use tokio::io::AsyncRead;
use crate::{ChunkerConfig, Tokenizer, languages::*, types::*};
#[derive(Clone)]
pub enum ConcreteSizer {
Characters(text_splitter::Characters),
Tiktoken(Arc<tiktoken_rs::CoreBPE>),
HuggingFace(Arc<tokenizers::tokenizer::Tokenizer>),
}
impl ChunkSizer for ConcreteSizer {
fn size(&self, chunk: &str) -> usize {
match self {
ConcreteSizer::Characters(sizer) => sizer.size(chunk),
ConcreteSizer::Tiktoken(sizer) => sizer.size(chunk),
ConcreteSizer::HuggingFace(sizer) => sizer.size(chunk),
}
}
}
impl TryFrom<Tokenizer> for ConcreteSizer {
type Error = ChunkError;
fn try_from(value: Tokenizer) -> Result<Self, Self::Error> {
match value {
Tokenizer::Characters => Ok(ConcreteSizer::Characters(text_splitter::Characters)),
Tokenizer::Tiktoken(encoding) => {
let tiktoken = match encoding.as_str() {
"cl100k_base" => tiktoken_rs::cl100k_base(),
"p50k_base" => tiktoken_rs::p50k_base(),
"p50k_edit" => tiktoken_rs::p50k_edit(),
"r50k_base" => tiktoken_rs::r50k_base(),
"o200k_base" => tiktoken_rs::o200k_base(),
_ => {
return Err(ChunkError::ParseError(format!(
"Unknown tiktoken encoding: {}",
encoding
)));
}
}
.map_err(|e| ChunkError::ParseError(format!("Failed to create tiktoken: {}", e)))?;
Ok(ConcreteSizer::Tiktoken(Arc::new(tiktoken)))
}
Tokenizer::PreloadedTiktoken(tiktoken) => Ok(ConcreteSizer::Tiktoken(tiktoken)),
Tokenizer::HuggingFace(model) => {
let tokenizer = tokenizers::tokenizer::Tokenizer::from_pretrained(&model, None)
.map_err(|e| ChunkError::ParseError(format!("Failed to load HF tokenizer: {}", e)))?;
Ok(ConcreteSizer::HuggingFace(Arc::new(tokenizer)))
}
Tokenizer::PreloadedHuggingFace(tokenizer) => Ok(ConcreteSizer::HuggingFace(tokenizer)),
}
}
}
pub type ChunkStream = Pin<Box<dyn futures::Stream<Item = Result<Chunk, ChunkError>> + Send>>;
#[async_trait]
pub trait Chunker: Send + Sync {
async fn applies(
&self,
path: &Path,
reader: PeekableReader<Box<dyn AsyncRead + Unpin + Send>>,
) -> Result<PeekableReader<Box<dyn AsyncRead + Unpin + Send>>, PeekableReader<Box<dyn AsyncRead + Unpin + Send>>>;
async fn chunk(&self, path: &Path, reader: Box<dyn AsyncRead + Unpin + Send>) -> ChunkStream;
}
#[derive(Default)]
pub struct ChunkerOverrides {
pub code_chunker: Option<Box<dyn Chunker>>,
}
#[cfg(test)]
pub fn memory_async_reader(bytes: Vec<u8>) -> Box<dyn AsyncRead + Unpin + Send> {
Box::new(std::io::Cursor::new(bytes))
}
struct ChunkerChainNode {
chunker: Box<dyn Chunker>,
next: Option<Box<ChunkerChainNode>>,
}
impl ChunkerChainNode {
fn new(chunker: Box<dyn Chunker>) -> Self {
Self { chunker, next: None }
}
fn prepend(self: Box<Self>, chunker: Box<dyn Chunker>) -> Box<Self> {
Box::new(Self {
chunker,
next: Some(self),
})
}
}
fn compute_overlap(config: &ChunkerConfig) -> Result<usize, ChunkError> {
if !(0.0..=1.0).contains(&config.overlap_percentage) {
return Err(ChunkError::ParseError(format!(
"Invalid overlap percentage: {}",
config.overlap_percentage
)));
}
let overlap = ((config.max_chunk_size as f32) * config.overlap_percentage).round() as usize;
if config.max_chunk_size > 0 {
Ok(overlap.min(config.max_chunk_size.saturating_sub(1)))
} else {
Ok(overlap)
}
}
fn build_chunker_chain(config: &ChunkerConfig) -> Result<Box<ChunkerChainNode>, ChunkError> {
let overlap = compute_overlap(config)?;
let chain = Box::new(ChunkerChainNode::new(Box::new(TextChunker::new(
config.max_chunk_size,
config.tokenizer.clone(),
overlap,
)?)));
let chain = chain.prepend(Box::new(CodeChunker::new(
config.max_chunk_size,
config.tokenizer.clone(),
overlap,
)?));
let chain = chain.prepend(Box::new(MarkdownChunker::new(
config.max_chunk_size,
config.tokenizer.clone(),
overlap,
)?));
let chain = chain.prepend(Box::new(HtmlChunker::new(
config.max_chunk_size,
config.tokenizer.clone(),
overlap,
)?));
let chain = chain.prepend(Box::new(DocxChunker::new(
config.max_chunk_size,
config.tokenizer.clone(),
overlap,
)?));
let chain = chain.prepend(Box::new(PdfChunker::new(
config.max_chunk_size,
config.tokenizer.clone(),
overlap,
)?));
Ok(chain)
}
fn build_chunker_chain_with_overrides(
config: &ChunkerConfig,
overrides: ChunkerOverrides,
) -> Result<Box<ChunkerChainNode>, ChunkError> {
let overlap = compute_overlap(config)?;
let code_chunker = match overrides.code_chunker {
Some(chunker) => chunker,
None => Box::new(CodeChunker::new(
config.max_chunk_size,
config.tokenizer.clone(),
overlap,
)?),
};
let chain = Box::new(ChunkerChainNode::new(Box::new(TextChunker::new(
config.max_chunk_size,
config.tokenizer.clone(),
overlap,
)?)));
let chain = chain.prepend(code_chunker);
let chain = chain.prepend(Box::new(MarkdownChunker::new(
config.max_chunk_size,
config.tokenizer.clone(),
overlap,
)?));
let chain = chain.prepend(Box::new(HtmlChunker::new(
config.max_chunk_size,
config.tokenizer.clone(),
overlap,
)?));
let chain = chain.prepend(Box::new(DocxChunker::new(
config.max_chunk_size,
config.tokenizer.clone(),
overlap,
)?));
let chain = chain.prepend(Box::new(PdfChunker::new(
config.max_chunk_size,
config.tokenizer.clone(),
overlap,
)?));
Ok(chain)
}
async fn select_chunker(
chain: Box<ChunkerChainNode>,
path: &Path,
mut peekable: PeekableReader<Box<dyn AsyncRead + Unpin + Send>>,
) -> Result<(Box<dyn Chunker>, PeekableReader<Box<dyn AsyncRead + Unpin + Send>>), ChunkError> {
let mut current = chain;
loop {
match current.chunker.applies(path, peekable).await {
Ok(returned) => return Ok((current.chunker, returned)),
Err(returned) => {
peekable = returned;
if let Some(next) = current.next {
current = next;
} else {
return Err(ChunkError::UnsupportedFileType(path.to_string_lossy().to_string()));
}
}
}
}
}
pub async fn get_chunker<P, R>(
path: P,
reader: R,
config: ChunkerConfig,
) -> Result<(Box<dyn Chunker>, impl AsyncRead + Unpin + Send + 'static), ChunkError>
where
P: AsRef<Path>,
R: AsyncRead + Unpin + Send + 'static,
{
let path = path.as_ref().to_owned();
let peekable: PeekableReader<Box<dyn AsyncRead + Unpin + Send>> = PeekableReader::new(Box::new(reader), 51200);
let chain = build_chunker_chain(&config)?;
let (selected_chunker, peekable) = select_chunker(chain, &path, peekable).await?;
let combined = peekable.into_async_read();
Ok((selected_chunker, combined))
}
pub async fn get_chunker_with_overrides<P, R>(
path: P,
reader: R,
config: ChunkerConfig,
overrides: ChunkerOverrides,
) -> Result<(Box<dyn Chunker>, impl AsyncRead + Unpin + Send + 'static), ChunkError>
where
P: AsRef<Path>,
R: AsyncRead + Unpin + Send + 'static,
{
let path = path.as_ref().to_owned();
let peekable: PeekableReader<Box<dyn AsyncRead + Unpin + Send>> = PeekableReader::new(Box::new(reader), 51200);
let chain = build_chunker_chain_with_overrides(&config, overrides)?;
let (selected_chunker, peekable) = select_chunker(chain, &path, peekable).await?;
let combined = peekable.into_async_read();
Ok((selected_chunker, combined))
}