ragkit_ai/chunk/
simple.rs

1use super::{
2  Chunk,
3  Chunker,
4  SimpleChunk,
5};
6use crate::{
7  error::Error,
8  loc::Loc,
9};
10use derive_builder::Builder;
11
12/// Simple chunking algorithm. Splits a string along character boundaries
13/// according to the `chunk_size``. This should not be used on its own. It
14/// serves as a building block for more advanced chunking algorithms.
15#[derive(Default, Builder, Debug)]
16#[builder(setter(into))]
17#[builder(build_fn(error = "crate::error::Error"))]
18pub struct SimpleChunker {
19  /// How large each chunk should be.
20  chunk_size: u32,
21
22  /// An offset to use when generating chunk `Loc`s. Useful when this chunker
23  /// is used within other chunker implementations.
24  #[builder(default = "0")]
25  loc_offset: usize,
26}
27
28impl<'a> Chunker<'a> for SimpleChunker {
29  type Input = &'a str;
30
31  fn chunk(&self, input: Self::Input) -> Result<Vec<Chunk<'a>>, Error> {
32    let chunk_size = self.chunk_size as usize;
33    if chunk_size == 0 {
34      return Err(Error::InvalidChunkSize(chunk_size as u32));
35    }
36
37    let estimated_chunks = input.len() / chunk_size + 1;
38    let mut chunks: Vec<Chunk<'a>> = Vec::with_capacity(estimated_chunks);
39
40    // This always corresponds to the first byte in a valid UTF-8 code point
41    // sequence.
42    let mut start = 0;
43    // This might temporarily point to the midle of a UTF-8 code point sequence.
44    let mut end = 0;
45
46    while start < input.len() {
47      end = std::cmp::min(input.len(), end + chunk_size);
48      // Naively incrementing by `chunk_size` could put us in the middle of a
49      // UTF-8 code point sequence. We have to adjust `end` accordingly.
50      end = next_boundary(input, end);
51      chunks.push(Chunk::Simple(SimpleChunk {
52        content: &input[start..end],
53        loc: Loc {
54          start: start + self.loc_offset,
55          end: end + self.loc_offset,
56        },
57        tags: Default::default(),
58      }));
59      start = end;
60    }
61
62    Ok(chunks)
63  }
64}
65
66// This finds the next valid character boundary in `string` that is >= `index`.
67// Note: it may return `string.len()` which is always considered a valid
68// character boundary.
69fn next_boundary(string: &str, index: usize) -> usize {
70  let mut res = index;
71  while !string.is_char_boundary(res) {
72    res += 1;
73    if res >= string.len() {
74      break;
75    }
76  }
77  std::cmp::min(string.len(), res)
78}
79
80#[cfg(test)]
81mod tests {
82  use super::*;
83
84  #[test]
85  fn basic() {
86    let chunker = SimpleChunkerBuilder::default()
87      .chunk_size(5u32)
88      .build()
89      .unwrap();
90
91    // Indices:                 01234567890123
92    let chunks = chunker.chunk("this is a test").unwrap();
93    let content = chunks.iter().map(|c| c.content()).collect::<Vec<_>>();
94    assert_eq!(vec!["this ", "is a ", "test"], content);
95
96    let locs = chunks
97      .iter()
98      .map(|c| c.loc().as_tuple())
99      .collect::<Vec<_>>();
100    assert_eq!(vec![(0, 5), (5, 10), (10, 14)], locs);
101  }
102
103  #[test]
104  fn chunk_size_0() {
105    let chunker = SimpleChunkerBuilder::default()
106      .chunk_size(0u32)
107      .build()
108      .unwrap();
109
110    let chunks = chunker.chunk("test");
111    assert!(chunks.is_err());
112  }
113
114  #[test]
115  fn chunk_size_1() {
116    let chunker = SimpleChunkerBuilder::default()
117      .chunk_size(1u32)
118      .build()
119      .unwrap();
120
121    let chunks = chunker.chunk("test").unwrap();
122    let content = chunks.iter().map(|c| c.content()).collect::<Vec<_>>();
123    assert_eq!(vec!["t", "e", "s", "t"], content);
124  }
125}