cognis_rag/splitters/
character.rs1use std::sync::Arc;
19
20use crate::document::Document;
21
22use super::{child_doc, TextSplitter};
23
24pub type LengthFn = Arc<dyn Fn(&str) -> usize + Send + Sync>;
27
28pub struct CharacterSplitter {
30 chunk_size: usize,
31 chunk_overlap: usize,
32 separator: String,
33 keep_separator: bool,
34 length_fn: LengthFn,
35}
36
37impl Default for CharacterSplitter {
38 fn default() -> Self {
39 Self {
40 chunk_size: 1000,
41 chunk_overlap: 200,
42 separator: "\n\n".to_string(),
43 keep_separator: false,
44 length_fn: Arc::new(|s: &str| s.chars().count()),
45 }
46 }
47}
48
49impl std::fmt::Debug for CharacterSplitter {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 f.debug_struct("CharacterSplitter")
52 .field("chunk_size", &self.chunk_size)
53 .field("chunk_overlap", &self.chunk_overlap)
54 .field("separator", &self.separator)
55 .field("keep_separator", &self.keep_separator)
56 .finish()
57 }
58}
59
60impl CharacterSplitter {
61 pub fn new() -> Self {
63 Self::default()
64 }
65
66 pub fn with_chunk_size(mut self, n: usize) -> Self {
68 self.chunk_size = n;
69 self
70 }
71
72 pub fn with_overlap(mut self, n: usize) -> Self {
74 self.chunk_overlap = n;
75 self
76 }
77
78 pub fn with_separator(mut self, s: impl Into<String>) -> Self {
80 self.separator = s.into();
81 self
82 }
83
84 pub fn with_keep_separator(mut self, keep: bool) -> Self {
87 self.keep_separator = keep;
88 self
89 }
90
91 pub fn with_length_fn<F>(mut self, f: F) -> Self
93 where
94 F: Fn(&str) -> usize + Send + Sync + 'static,
95 {
96 self.length_fn = Arc::new(f);
97 self
98 }
99
100 fn pack_fragments(&self, fragments: Vec<String>) -> Vec<String> {
103 let len = |s: &str| (self.length_fn)(s);
104 let mut chunks: Vec<String> = Vec::new();
105 let mut current = String::new();
106
107 for frag in fragments.into_iter() {
108 if frag.is_empty() {
109 continue;
110 }
111 if len(&frag) > self.chunk_size {
114 if !current.is_empty() {
115 chunks.push(std::mem::take(&mut current));
116 }
117 let mut buf = String::new();
118 for ch in frag.chars() {
119 let cand_len = len(&buf) + len(&ch.to_string());
120 if cand_len > self.chunk_size && !buf.is_empty() {
121 chunks.push(std::mem::take(&mut buf));
122 }
123 buf.push(ch);
124 }
125 if !buf.is_empty() {
126 chunks.push(buf);
127 }
128 continue;
129 }
130 let separator_cost = if current.is_empty() {
132 0
133 } else {
134 len(&self.separator)
135 };
136 if len(¤t) + separator_cost + len(&frag) <= self.chunk_size {
137 if !current.is_empty() {
138 current.push_str(&self.separator);
139 }
140 current.push_str(&frag);
141 } else {
142 chunks.push(std::mem::take(&mut current));
143 current = frag;
144 }
145 }
146 if !current.is_empty() {
147 chunks.push(current);
148 }
149
150 if self.chunk_overlap > 0 && chunks.len() > 1 {
151 for i in 1..chunks.len() {
152 let prev = chunks[i - 1].clone();
153 let prev_chars: Vec<char> = prev.chars().collect();
154 let take = self.chunk_overlap.min(prev_chars.len());
155 let tail: String = prev_chars[prev_chars.len() - take..].iter().collect();
156 let mut new = tail;
157 new.push_str(&chunks[i]);
158 chunks[i] = new;
159 }
160 }
161 chunks
162 }
163}
164
165impl TextSplitter for CharacterSplitter {
166 fn split(&self, doc: &Document) -> Vec<Document> {
167 if doc.content.is_empty() {
168 return Vec::new();
169 }
170 let raw: Vec<&str> = doc.content.split(self.separator.as_str()).collect();
172 let fragments: Vec<String> = if self.keep_separator {
173 raw.iter()
175 .enumerate()
176 .map(|(i, p)| {
177 if i == 0 {
178 (*p).to_string()
179 } else {
180 format!("{}{}", self.separator, p)
181 }
182 })
183 .collect()
184 } else {
185 raw.iter().map(|p| (*p).to_string()).collect()
186 };
187 let chunks = self.pack_fragments(fragments);
188 chunks
189 .into_iter()
190 .enumerate()
191 .map(|(i, c)| child_doc(doc, c, i))
192 .collect()
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 fn doc(content: &str) -> Document {
201 Document::new(content)
202 }
203
204 #[test]
205 fn splits_on_default_double_newline() {
206 let s = CharacterSplitter::new().with_chunk_size(50).with_overlap(0);
207 let d = doc("para one\n\npara two\n\npara three");
208 let out = s.split(&d);
209 assert_eq!(out.len(), 1);
210 assert!(out[0].content.contains("para one"));
211 assert!(out[0].content.contains("para three"));
212 }
213
214 #[test]
215 fn packs_fragments_to_chunk_size() {
216 let s = CharacterSplitter::new()
217 .with_chunk_size(15)
218 .with_overlap(0)
219 .with_separator("|");
220 let d = doc("aaaaa|bbbbb|ccccc|ddddd");
221 let out = s.split(&d);
222 assert_eq!(out.len(), 2);
225 assert_eq!(out[0].content, "aaaaa|bbbbb");
226 assert_eq!(out[1].content, "ccccc|ddddd");
227 }
228
229 #[test]
230 fn applies_overlap_between_chunks() {
231 let s = CharacterSplitter::new()
232 .with_chunk_size(10)
233 .with_overlap(3)
234 .with_separator("|");
235 let d = doc("aaaa|bbbb|cccc");
236 let out = s.split(&d);
237 assert!(out.len() >= 2);
238 let prev = &out[0].content;
240 let prev_tail: String = prev
241 .chars()
242 .rev()
243 .take(3)
244 .collect::<Vec<_>>()
245 .into_iter()
246 .rev()
247 .collect();
248 assert!(out[1].content.starts_with(&prev_tail));
249 }
250
251 #[test]
252 fn keep_separator_preserves_delimiter() {
253 let s = CharacterSplitter::new()
254 .with_chunk_size(20)
255 .with_overlap(0)
256 .with_separator("\n\n")
257 .with_keep_separator(true);
258 let d = doc("one\n\ntwo\n\nthree");
259 let out = s.split(&d);
260 let joined = out
262 .iter()
263 .map(|c| c.content.clone())
264 .collect::<Vec<_>>()
265 .join("|");
266 assert!(joined.contains("\n\ntwo"));
267 }
268
269 #[test]
270 fn hard_splits_oversized_fragment() {
271 let s = CharacterSplitter::new()
272 .with_chunk_size(5)
273 .with_overlap(0)
274 .with_separator("|");
275 let d = doc("abcdefghij");
276 let out = s.split(&d);
277 assert_eq!(out.len(), 2);
280 assert_eq!(out[0].content.chars().count(), 5);
281 assert_eq!(out[1].content.chars().count(), 5);
282 }
283
284 #[test]
285 fn custom_length_fn_used() {
286 let s = CharacterSplitter::new()
289 .with_chunk_size(7) .with_overlap(0)
291 .with_separator(" ")
292 .with_length_fn(|s: &str| s.split_whitespace().count() * 5);
293 let d = doc("a b c");
294 let out = s.split(&d);
298 assert!(out.len() >= 2);
299 }
300
301 #[test]
302 fn empty_doc_returns_no_chunks() {
303 let s = CharacterSplitter::new();
304 let d = doc("");
305 let out = s.split(&d);
306 assert!(out.is_empty());
307 }
308
309 #[test]
310 fn metadata_propagates_to_children() {
311 let mut d = doc("aaa|bbb");
312 d.metadata.insert(
313 "source".into(),
314 serde_json::Value::String("file.txt".into()),
315 );
316 let s = CharacterSplitter::new()
317 .with_chunk_size(3)
318 .with_overlap(0)
319 .with_separator("|");
320 let out = s.split(&d);
321 assert!(out.iter().all(
322 |c| c.metadata.get("source") == Some(&serde_json::Value::String("file.txt".into()))
323 ));
324 assert!(out.iter().enumerate().all(|(i, c)| {
326 c.metadata.get("chunk_index").and_then(|v| v.as_u64()) == Some(i as u64)
327 }));
328 }
329}