rag_toolchain/chunkers/
character_chunker.rs

1use crate::chunkers::Chunker;
2use crate::common::{Chunk, Chunks};
3use std::convert::Infallible;
4use std::num::NonZeroUsize;
5
6pub struct CharacterChunker {
7    /// chunk_size: the number of characters in each chunk
8    chunk_size: NonZeroUsize,
9    /// chunk_overlap: the number over characters
10    /// shared between neighbouring chunks
11    chunk_overlap: usize,
12}
13
14impl CharacterChunker {
15    /// [`TokenChunker::try_new`]
16    ///
17    /// # Arguements
18    /// * `chunk_size`: [`NonZeroUsize`] - The number of characters in each chunk
19    /// * `chunk_overlap`: [`usize`] - The number of characters shared between
20    ///                   neighbouring chunks
21    ///
22    /// # Errors
23    /// This function will error if you provide a chunk_overlap greater than or equal to
24    /// the chunk_size.
25    ///
26    /// # Returns
27    /// [`TokenChunker`]
28    pub fn try_new(chunk_size: NonZeroUsize, chunk_overlap: usize) -> Result<Self, String> {
29        if chunk_overlap >= chunk_size.into() {
30            return Err("chunk_overlap cannot be greater than or equal to chunk_size".into());
31        }
32
33        Ok(Self {
34            chunk_size,
35            chunk_overlap,
36        })
37    }
38}
39
40impl Chunker for CharacterChunker {
41    type ErrorType = Infallible;
42    fn generate_chunks(&self, raw_text: &str) -> Result<Chunks, Self::ErrorType> {
43        let mut chunks: Chunks = Vec::new();
44        let chunk_size: usize = self.chunk_size.into();
45
46        let mut i = 0;
47        while i < raw_text.len() {
48            let end = std::cmp::min(i + chunk_size, raw_text.len());
49            let chunk: Chunk = Chunk::new(&raw_text[i..end]);
50            chunks.push(chunk);
51            i += chunk_size - self.chunk_overlap;
52        }
53
54        Ok(chunks)
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61
62    #[test]
63    fn test_generate_chunks_with_valid_input() {
64        let raw_text: &str = "This is a test string";
65        let chunk_overlap: usize = 1;
66        let chunk_size: NonZeroUsize = NonZeroUsize::new(2).unwrap();
67        let chunker: CharacterChunker =
68            CharacterChunker::try_new(chunk_size, chunk_overlap).unwrap();
69        let chunks = chunker.generate_chunks(raw_text).unwrap();
70        let chunk_strings: Vec<String> = chunks
71            .into_iter()
72            .map(|chunk| chunk.content().to_string())
73            .collect();
74        assert_eq!(
75            chunk_strings,
76            vec![
77                "Th", "hi", "is", "s ", " i", "is", "s ", " a", "a ", " t", "te", "es", "st", "t ",
78                " s", "st", "tr", "ri", "in", "ng", "g"
79            ]
80        );
81    }
82
83    #[test]
84    fn test_generate_chunks_with_empty_string() {
85        let raw_text: &str = "";
86        let chunk_overlap: usize = 1;
87        let chunk_size: NonZeroUsize = NonZeroUsize::new(2).unwrap();
88        let chunker: CharacterChunker =
89            CharacterChunker::try_new(chunk_size, chunk_overlap).unwrap();
90        let chunks = chunker.generate_chunks(raw_text).unwrap();
91        let chunk_strings: Vec<String> = chunks
92            .into_iter()
93            .map(|chunk| chunk.content().to_string())
94            .collect();
95        assert_eq!(chunk_strings, Vec::<String>::new());
96    }
97
98    #[test]
99    fn test_try_new_with_invalid_arguments() {
100        let chunk_overlap: usize = 3;
101        let chunk_size: NonZeroUsize = NonZeroUsize::new(2).unwrap();
102        assert!(CharacterChunker::try_new(chunk_size, chunk_overlap).is_err());
103
104        let chunk_overlap: usize = 2;
105        let chunk_size: NonZeroUsize = NonZeroUsize::new(2).unwrap();
106        assert!(CharacterChunker::try_new(chunk_size, chunk_overlap).is_err())
107    }
108}