use std::{
borrow::Cow,
cell::{Ref, RefMut},
cmp::Ordering,
fmt,
ops::{Deref, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive},
rc::Rc,
sync::Arc,
};
use ahash::AHashMap;
use itertools::Itertools;
use thiserror::Error;
mod characters;
#[cfg(feature = "tokenizers")]
mod huggingface;
#[cfg(feature = "tiktoken-rs")]
mod tiktoken;
use crate::trim::Trim;
pub use characters::Characters;
#[derive(Error, Debug)]
#[error(transparent)]
pub struct ChunkCapacityError(#[from] ChunkCapacityErrorRepr);
#[derive(Error, Debug)]
enum ChunkCapacityErrorRepr {
#[error("Max chunk size must be greater than or equal to the desired chunk size")]
MaxLessThanDesired,
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct ChunkCapacity {
pub(crate) desired: usize,
pub(crate) max: usize,
}
impl ChunkCapacity {
#[must_use]
pub fn new(size: usize) -> Self {
Self {
desired: size,
max: size,
}
}
#[must_use]
pub fn desired(&self) -> usize {
self.desired
}
#[must_use]
pub fn max(&self) -> usize {
self.max
}
pub fn with_max(mut self, max: usize) -> Result<Self, ChunkCapacityError> {
if max < self.desired {
Err(ChunkCapacityError(
ChunkCapacityErrorRepr::MaxLessThanDesired,
))
} else {
self.max = max;
Ok(self)
}
}
#[must_use]
pub fn fits(&self, chunk_size: usize) -> Ordering {
if chunk_size < self.desired {
Ordering::Less
} else if chunk_size > self.max {
Ordering::Greater
} else {
Ordering::Equal
}
}
}
impl From<usize> for ChunkCapacity {
fn from(size: usize) -> Self {
ChunkCapacity::new(size)
}
}
impl From<Range<usize>> for ChunkCapacity {
fn from(range: Range<usize>) -> Self {
ChunkCapacity::new(range.start)
.with_max(range.end.saturating_sub(1).max(range.start))
.expect("invalid range")
}
}
impl From<RangeFrom<usize>> for ChunkCapacity {
fn from(range: RangeFrom<usize>) -> Self {
ChunkCapacity::new(range.start)
.with_max(usize::MAX)
.expect("invalid range")
}
}
impl From<RangeFull> for ChunkCapacity {
fn from(_: RangeFull) -> Self {
ChunkCapacity::new(usize::MIN)
.with_max(usize::MAX)
.expect("invalid range")
}
}
impl From<RangeInclusive<usize>> for ChunkCapacity {
fn from(range: RangeInclusive<usize>) -> Self {
ChunkCapacity::new(*range.start())
.with_max(*range.end())
.expect("invalid range")
}
}
impl From<RangeTo<usize>> for ChunkCapacity {
fn from(range: RangeTo<usize>) -> Self {
ChunkCapacity::new(usize::MIN)
.with_max(range.end.saturating_sub(1))
.expect("invalid range")
}
}
impl From<RangeToInclusive<usize>> for ChunkCapacity {
fn from(range: RangeToInclusive<usize>) -> Self {
ChunkCapacity::new(usize::MIN)
.with_max(range.end)
.expect("invalid range")
}
}
pub trait ChunkSizer {
fn size(&self, chunk: &str) -> usize;
}
impl<T> ChunkSizer for &T
where
T: ChunkSizer,
{
fn size(&self, chunk: &str) -> usize {
(*self).size(chunk)
}
}
impl<T> ChunkSizer for Ref<'_, T>
where
T: ChunkSizer,
{
fn size(&self, chunk: &str) -> usize {
self.deref().size(chunk)
}
}
impl<T> ChunkSizer for RefMut<'_, T>
where
T: ChunkSizer,
{
fn size(&self, chunk: &str) -> usize {
self.deref().size(chunk)
}
}
impl<T> ChunkSizer for Box<T>
where
T: ChunkSizer,
{
fn size(&self, chunk: &str) -> usize {
self.deref().size(chunk)
}
}
impl<T> ChunkSizer for Cow<'_, T>
where
T: ChunkSizer + ToOwned + ?Sized,
<T as ToOwned>::Owned: ChunkSizer,
{
fn size(&self, chunk: &str) -> usize {
self.as_ref().size(chunk)
}
}
impl<T> ChunkSizer for Rc<T>
where
T: ChunkSizer,
{
fn size(&self, chunk: &str) -> usize {
self.deref().size(chunk)
}
}
impl<T> ChunkSizer for Arc<T>
where
T: ChunkSizer,
{
fn size(&self, chunk: &str) -> usize {
self.as_ref().size(chunk)
}
}
#[derive(Error, Debug)]
#[error(transparent)]
pub struct ChunkConfigError(#[from] ChunkConfigErrorRepr);
#[derive(Error, Debug)]
enum ChunkConfigErrorRepr {
#[error("The overlap is larger than or equal to the desired chunk capacity")]
OverlapLargerThanCapacity,
}
#[derive(Debug)]
pub struct ChunkConfig<Sizer>
where
Sizer: ChunkSizer,
{
pub(crate) capacity: ChunkCapacity,
pub(crate) overlap: usize,
pub(crate) sizer: Sizer,
pub(crate) trim: bool,
}
impl ChunkConfig<Characters> {
#[must_use]
pub fn new(capacity: impl Into<ChunkCapacity>) -> Self {
Self {
capacity: capacity.into(),
overlap: 0,
sizer: Characters,
trim: true,
}
}
}
impl<Sizer> ChunkConfig<Sizer>
where
Sizer: ChunkSizer,
{
pub fn capacity(&self) -> &ChunkCapacity {
&self.capacity
}
pub fn overlap(&self) -> usize {
self.overlap
}
pub fn with_overlap(mut self, overlap: usize) -> Result<Self, ChunkConfigError> {
if overlap >= self.capacity.desired {
Err(ChunkConfigError(
ChunkConfigErrorRepr::OverlapLargerThanCapacity,
))
} else {
self.overlap = overlap;
Ok(self)
}
}
pub fn sizer(&self) -> &Sizer {
&self.sizer
}
#[must_use]
pub fn with_sizer<S: ChunkSizer>(self, sizer: S) -> ChunkConfig<S> {
ChunkConfig {
capacity: self.capacity,
overlap: self.overlap,
sizer,
trim: self.trim,
}
}
pub fn trim(&self) -> bool {
self.trim
}
#[must_use]
pub fn with_trim(mut self, trim: bool) -> Self {
self.trim = trim;
self
}
}
impl<T> From<T> for ChunkConfig<Characters>
where
T: Into<ChunkCapacity>,
{
fn from(capacity: T) -> Self {
Self::new(capacity)
}
}
#[derive(Debug)]
pub struct MemoizedChunkSizer<'sizer, Sizer>
where
Sizer: ChunkSizer,
{
size_cache: AHashMap<Range<usize>, usize>,
sizer: &'sizer Sizer,
}
impl<'sizer, Sizer> MemoizedChunkSizer<'sizer, Sizer>
where
Sizer: ChunkSizer,
{
pub fn new(sizer: &'sizer Sizer) -> Self {
Self {
size_cache: AHashMap::new(),
sizer,
}
}
pub fn chunk_size(&mut self, offset: usize, chunk: &str, trim: Trim) -> usize {
let (offset, chunk) = trim.trim(offset, chunk);
*self
.size_cache
.entry(offset..(offset + chunk.len()))
.or_insert_with(|| self.sizer.size(chunk))
}
pub fn find_correct_level<'text, L: fmt::Debug>(
&mut self,
offset: usize,
capacity: &ChunkCapacity,
levels_with_first_chunk: impl Iterator<Item = (L, &'text str)>,
trim: Trim,
) -> (Option<L>, Option<usize>) {
let mut semantic_level = None;
let mut max_offset = None;
let levels_with_first_chunk =
levels_with_first_chunk.coalesce(|(a_level, a_str), (b_level, b_str)| {
if a_str.len() >= b_str.len() {
Ok((b_level, b_str))
} else {
Err(((a_level, a_str), (b_level, b_str)))
}
});
for (level, str) in levels_with_first_chunk {
let len = str.len();
if len > capacity.max {
let chunk_size = self.chunk_size(offset, str, trim);
let fits = capacity.fits(chunk_size);
if fits.is_gt() {
max_offset = Some(offset + len);
break;
}
}
semantic_level = Some(level);
}
(semantic_level, max_offset)
}
pub fn clear_cache(&mut self) {
self.size_cache.clear();
}
}
#[cfg(test)]
mod tests {
use std::{
cell::RefCell,
sync::atomic::{self, AtomicUsize},
};
use crate::trim::Trim;
use super::*;
#[test]
fn check_chunk_capacity() {
let chunk = "12345";
assert_eq!(
ChunkCapacity::from(4).fits(Characters.size(chunk)),
Ordering::Greater
);
assert_eq!(
ChunkCapacity::from(5).fits(Characters.size(chunk)),
Ordering::Equal
);
assert_eq!(
ChunkCapacity::from(6).fits(Characters.size(chunk)),
Ordering::Less
);
}
#[test]
fn check_chunk_capacity_for_range() {
let chunk = "12345";
assert_eq!(
ChunkCapacity::from(0..0).fits(Characters.size(chunk)),
Ordering::Greater
);
assert_eq!(
ChunkCapacity::from(0..5).fits(Characters.size(chunk)),
Ordering::Greater
);
assert_eq!(
ChunkCapacity::from(5..6).fits(Characters.size(chunk)),
Ordering::Equal
);
assert_eq!(
ChunkCapacity::from(6..100).fits(Characters.size(chunk)),
Ordering::Less
);
}
#[test]
fn check_chunk_capacity_for_range_from() {
let chunk = "12345";
assert_eq!(
ChunkCapacity::from(0..).fits(Characters.size(chunk)),
Ordering::Equal
);
assert_eq!(
ChunkCapacity::from(5..).fits(Characters.size(chunk)),
Ordering::Equal
);
assert_eq!(
ChunkCapacity::from(6..).fits(Characters.size(chunk)),
Ordering::Less
);
}
#[test]
fn check_chunk_capacity_for_range_full() {
let chunk = "12345";
assert_eq!(
ChunkCapacity::from(..).fits(Characters.size(chunk)),
Ordering::Equal
);
}
#[test]
fn check_chunk_capacity_for_range_inclusive() {
let chunk = "12345";
assert_eq!(
ChunkCapacity::from(0..=4).fits(Characters.size(chunk)),
Ordering::Greater
);
assert_eq!(
ChunkCapacity::from(5..=6).fits(Characters.size(chunk)),
Ordering::Equal
);
assert_eq!(
ChunkCapacity::from(4..=5).fits(Characters.size(chunk)),
Ordering::Equal
);
assert_eq!(
ChunkCapacity::from(6..=100).fits(Characters.size(chunk)),
Ordering::Less
);
}
#[test]
fn check_chunk_capacity_for_range_to() {
let chunk = "12345";
assert_eq!(
ChunkCapacity::from(..0).fits(Characters.size(chunk)),
Ordering::Greater
);
assert_eq!(
ChunkCapacity::from(..5).fits(Characters.size(chunk)),
Ordering::Greater
);
assert_eq!(
ChunkCapacity::from(..6).fits(Characters.size(chunk)),
Ordering::Equal
);
}
#[test]
fn check_chunk_capacity_for_range_to_inclusive() {
let chunk = "12345";
assert_eq!(
ChunkCapacity::from(..=4).fits(Characters.size(chunk)),
Ordering::Greater
);
assert_eq!(
ChunkCapacity::from(..=5).fits(Characters.size(chunk)),
Ordering::Equal
);
assert_eq!(
ChunkCapacity::from(..=6).fits(Characters.size(chunk)),
Ordering::Equal
);
}
#[derive(Default)]
struct CountingSizer {
calls: AtomicUsize,
}
impl ChunkSizer for CountingSizer {
fn size(&self, chunk: &str) -> usize {
self.calls.fetch_add(1, atomic::Ordering::SeqCst);
Characters.size(chunk)
}
}
#[test]
fn memoized_sizer_only_calculates_once_per_text() {
let sizer = CountingSizer::default();
let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
let text = "1234567890";
for _ in 0..10 {
memoized_sizer.chunk_size(0, text, Trim::All);
}
assert_eq!(memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst), 1);
}
#[test]
fn memoized_sizer_calculates_once_per_different_text() {
let sizer = CountingSizer::default();
let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
let text = "1234567890";
for i in 0..10 {
memoized_sizer.chunk_size(0, text.get(0..i).unwrap(), Trim::All);
}
assert_eq!(
memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst),
10
);
}
#[test]
fn can_clear_cache_on_memoized_sizer() {
let sizer = CountingSizer::default();
let mut memoized_sizer = MemoizedChunkSizer::new(&sizer);
let text = "1234567890";
for _ in 0..10 {
memoized_sizer.chunk_size(0, text, Trim::All);
memoized_sizer.clear_cache();
}
assert_eq!(
memoized_sizer.sizer.calls.load(atomic::Ordering::SeqCst),
10
);
}
#[test]
fn basic_chunk_config() {
let config = ChunkConfig::new(10);
assert_eq!(config.capacity, 10.into());
assert_eq!(config.sizer, Characters);
assert!(config.trim());
}
#[test]
fn disable_trimming() {
let config = ChunkConfig::new(10).with_trim(false);
assert!(!config.trim());
}
#[test]
fn new_sizer() {
#[derive(Debug, PartialEq)]
struct BasicSizer;
impl ChunkSizer for BasicSizer {
fn size(&self, _chunk: &str) -> usize {
unimplemented!()
}
}
let config = ChunkConfig::new(10).with_sizer(BasicSizer);
assert_eq!(config.capacity, 10.into());
assert_eq!(config.sizer, BasicSizer);
assert!(config.trim());
}
#[test]
fn chunk_capacity_max_and_desired_equal() {
let capacity = ChunkCapacity::new(10);
assert_eq!(capacity.desired(), 10);
assert_eq!(capacity.max(), 10);
}
#[test]
fn chunk_capacity_can_adjust_max() {
let capacity = ChunkCapacity::new(10).with_max(20).unwrap();
assert_eq!(capacity.desired(), 10);
assert_eq!(capacity.max(), 20);
}
#[test]
fn chunk_capacity_max_cant_be_less_than_desired() {
let capacity = ChunkCapacity::new(10);
let err = capacity.with_max(5).unwrap_err();
assert_eq!(
err.to_string(),
"Max chunk size must be greater than or equal to the desired chunk size"
);
assert_eq!(capacity.desired(), 10);
assert_eq!(capacity.max(), 10);
}
#[test]
fn set_chunk_overlap() {
let config = ChunkConfig::new(10).with_overlap(5).unwrap();
assert_eq!(config.overlap(), 5);
}
#[test]
fn cant_set_overlap_larger_than_capacity() {
let chunk_config = ChunkConfig::new(5);
let err = chunk_config.with_overlap(10).unwrap_err();
assert_eq!(
err.to_string(),
"The overlap is larger than or equal to the desired chunk capacity"
);
}
#[test]
fn cant_set_overlap_larger_than_desired() {
let chunk_config = ChunkConfig::new(5..15);
let err = chunk_config.with_overlap(10).unwrap_err();
assert_eq!(
err.to_string(),
"The overlap is larger than or equal to the desired chunk capacity"
);
}
#[test]
fn chunk_size_reference() {
let config = ChunkConfig::new(1).with_sizer(&Characters);
config.sizer().size("chunk");
}
#[test]
fn chunk_size_cow() {
let sizer: Cow<'_, Characters> = Cow::Owned(Characters);
let config = ChunkConfig::new(1).with_sizer(sizer);
config.sizer().size("chunk");
let sizer = Cow::Borrowed(&Characters);
let config = ChunkConfig::new(1).with_sizer(sizer);
config.sizer().size("chunk");
}
#[test]
fn chunk_size_arc() {
let sizer = Arc::new(Characters);
let config = ChunkConfig::new(1).with_sizer(sizer);
config.sizer().size("chunk");
}
#[test]
fn chunk_size_ref() {
let sizer = RefCell::new(Characters);
let config = ChunkConfig::new(1).with_sizer(sizer.borrow());
config.sizer().size("chunk");
}
#[test]
fn chunk_size_ref_mut() {
let sizer = RefCell::new(Characters);
let config = ChunkConfig::new(1).with_sizer(sizer.borrow_mut());
config.sizer().size("chunk");
}
#[test]
fn chunk_size_box() {
let sizer = Box::new(Characters);
let config = ChunkConfig::new(1).with_sizer(sizer);
config.sizer().size("chunk");
}
#[test]
fn chunk_size_rc() {
let sizer = Rc::new(Characters);
let config = ChunkConfig::new(1).with_sizer(sizer);
config.sizer().size("chunk");
}
}