phago_embeddings/
chunker.rs1#[derive(Debug, Clone)]
8pub struct ChunkConfig {
9 pub max_size: usize,
11 pub overlap: usize,
13 pub min_size: usize,
15 pub respect_sentences: bool,
17}
18
19impl Default for ChunkConfig {
20 fn default() -> Self {
21 Self {
22 max_size: 512,
23 overlap: 64,
24 min_size: 50,
25 respect_sentences: true,
26 }
27 }
28}
29
30impl ChunkConfig {
31 pub fn short() -> Self {
33 Self {
34 max_size: 128,
35 overlap: 16,
36 min_size: 10,
37 respect_sentences: false,
38 }
39 }
40
41 pub fn medium() -> Self {
43 Self {
44 max_size: 512,
45 overlap: 64,
46 min_size: 50,
47 respect_sentences: true,
48 }
49 }
50
51 pub fn long() -> Self {
53 Self {
54 max_size: 2048,
55 overlap: 256,
56 min_size: 100,
57 respect_sentences: true,
58 }
59 }
60}
61
62pub struct Chunker {
64 config: ChunkConfig,
65}
66
67impl Chunker {
68 pub fn new(config: ChunkConfig) -> Self {
70 Self { config }
71 }
72
73 pub fn default_config() -> Self {
75 Self::new(ChunkConfig::default())
76 }
77
78 pub fn chunk(&self, text: &str) -> Vec<Chunk> {
80 if text.len() <= self.config.max_size {
81 return vec![Chunk {
82 text: text.to_string(),
83 start: 0,
84 end: text.len(),
85 index: 0,
86 }];
87 }
88
89 let mut chunks = Vec::new();
90 let mut start = 0;
91 let mut index = 0;
92
93 while start < text.len() {
94 let mut end = (start + self.config.max_size).min(text.len());
95
96 if self.config.respect_sentences && end < text.len() {
98 if let Some(break_point) = self.find_sentence_break(&text[start..end]) {
99 end = start + break_point;
100 }
101 }
102
103 if end - start < self.config.min_size && end < text.len() {
105 end = (start + self.config.min_size).min(text.len());
106 }
107
108 chunks.push(Chunk {
109 text: text[start..end].to_string(),
110 start,
111 end,
112 index,
113 });
114
115 start = if end >= text.len() {
117 end
118 } else {
119 (end - self.config.overlap).max(start + 1)
120 };
121 index += 1;
122 }
123
124 chunks
125 }
126
127 fn find_sentence_break(&self, text: &str) -> Option<usize> {
129 let search_start = text.len().saturating_sub(self.config.overlap * 2);
131
132 for (i, c) in text[search_start..].char_indices().rev() {
133 let pos = search_start + i;
134 if (c == '.' || c == '!' || c == '?') && pos > self.config.min_size {
135 let next_pos = pos + c.len_utf8();
137 if next_pos >= text.len()
138 || text[next_pos..].starts_with(char::is_whitespace)
139 {
140 return Some(next_pos);
141 }
142 }
143 }
144
145 for (i, c) in text[search_start..].char_indices().rev() {
147 if c == '\n' {
148 let pos = search_start + i;
149 if pos > self.config.min_size {
150 return Some(pos + 1);
151 }
152 }
153 }
154
155 None
156 }
157}
158
159impl Default for Chunker {
160 fn default() -> Self {
161 Self::default_config()
162 }
163}
164
165#[derive(Debug, Clone)]
167pub struct Chunk {
168 pub text: String,
170 pub start: usize,
172 pub end: usize,
174 pub index: usize,
176}
177
178impl Chunk {
179 pub fn len(&self) -> usize {
181 self.text.len()
182 }
183
184 pub fn is_empty(&self) -> bool {
186 self.text.is_empty()
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_short_text_no_chunking() {
196 let chunker = Chunker::default_config();
197 let chunks = chunker.chunk("Hello world");
198 assert_eq!(chunks.len(), 1);
199 assert_eq!(chunks[0].text, "Hello world");
200 }
201
202 #[test]
203 fn test_long_text_chunking() {
204 let chunker = Chunker::new(ChunkConfig {
205 max_size: 50,
206 overlap: 10,
207 min_size: 10,
208 respect_sentences: false,
209 });
210
211 let text = "This is a longer text that should be split into multiple chunks for processing.";
212 let chunks = chunker.chunk(text);
213
214 assert!(chunks.len() > 1);
215 for chunk in &chunks {
217 assert!(chunk.len() <= 50);
218 }
219 }
220
221 #[test]
222 fn test_sentence_boundary_respect() {
223 let chunker = Chunker::new(ChunkConfig {
224 max_size: 100,
225 overlap: 20,
226 min_size: 10,
227 respect_sentences: true,
228 });
229
230 let text = "First sentence here. Second sentence follows. Third sentence ends.";
231 let chunks = chunker.chunk(text);
232
233 for chunk in &chunks {
235 if chunk.end < text.len() {
237 let last_char = chunk.text.trim_end().chars().last();
238 assert!(
240 last_char == Some('.') || last_char == Some('!') || last_char == Some('?')
241 || chunk.text == text
242 );
243 }
244 }
245 }
246}