use std::marker::PhantomData;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ChunkConfig<S = CharSizer> {
pub chunk_size: usize,
pub chunk_overlap: usize,
pub sizer: S,
}
impl<S> ChunkConfig<S> {
pub fn new(chunk_size: usize, chunk_overlap: usize, sizer: S) -> Self {
Self {
chunk_size,
chunk_overlap,
sizer,
}
}
}
impl Default for ChunkConfig<CharSizer> {
fn default() -> Self {
Self {
chunk_size: 1000,
chunk_overlap: 200,
sizer: CharSizer,
}
}
}
pub trait ChunkSizer: Clone + Send + Sync + 'static {
fn size(&self, text: &str) -> usize;
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct CharSizer;
impl ChunkSizer for CharSizer {
fn size(&self, text: &str) -> usize {
text.chars().count()
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct ByteSizer;
impl ChunkSizer for ByteSizer {
fn size(&self, text: &str) -> usize {
text.len()
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct WordSizer;
impl ChunkSizer for WordSizer {
fn size(&self, text: &str) -> usize {
text.split_whitespace().count()
}
}
#[derive(Clone)]
pub struct FunctionSizer {
length_fn: crate::LengthFn,
}
impl FunctionSizer {
pub fn new(length_fn: crate::LengthFn) -> Self {
Self { length_fn }
}
}
impl<F> From<F> for FunctionSizer
where
F: Fn(&str) -> usize + Send + Sync + 'static,
{
fn from(value: F) -> Self {
Self {
length_fn: Arc::new(value),
}
}
}
impl ChunkSizer for FunctionSizer {
fn size(&self, text: &str) -> usize {
(self.length_fn)(text)
}
}
#[cfg(feature = "unicode-segmentation")]
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct GraphemeSizer;
#[cfg(feature = "unicode-segmentation")]
impl ChunkSizer for GraphemeSizer {
fn size(&self, text: &str) -> usize {
unicode_segmentation::UnicodeSegmentation::graphemes(text, true).count()
}
}
#[cfg(feature = "unicode-segmentation")]
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct UnicodeWordSizer;
#[cfg(feature = "unicode-segmentation")]
impl ChunkSizer for UnicodeWordSizer {
fn size(&self, text: &str) -> usize {
unicode_segmentation::UnicodeSegmentation::unicode_words(text).count()
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct SizerKind<S> {
_marker: PhantomData<S>,
}