cognis_rag/splitters/
recursive.rs1use crate::document::Document;
5
6use super::{child_doc, TextSplitter};
7
8pub struct RecursiveCharSplitter {
15 chunk_size: usize,
16 chunk_overlap: usize,
17 separators: Vec<String>,
18}
19
20impl Default for RecursiveCharSplitter {
21 fn default() -> Self {
22 Self {
23 chunk_size: 1000,
24 chunk_overlap: 200,
25 separators: vec![
26 "\n\n".to_string(),
27 "\n".to_string(),
28 ". ".to_string(),
29 " ".to_string(),
30 "".to_string(),
31 ],
32 }
33 }
34}
35
36impl RecursiveCharSplitter {
37 pub fn new() -> Self {
39 Self::default()
40 }
41
42 pub fn with_chunk_size(mut self, n: usize) -> Self {
44 self.chunk_size = n;
45 self
46 }
47
48 pub fn with_overlap(mut self, n: usize) -> Self {
50 self.chunk_overlap = n;
51 self
52 }
53
54 pub fn with_separators<I, S>(mut self, seps: I) -> Self
56 where
57 I: IntoIterator<Item = S>,
58 S: Into<String>,
59 {
60 self.separators = seps.into_iter().map(Into::into).collect();
61 self
62 }
63
64 fn split_text(&self, text: &str) -> Vec<String> {
65 let pieces = self.recurse(text, 0);
66 merge_with_overlap(pieces, self.chunk_size, self.chunk_overlap)
68 }
69
70 fn recurse(&self, text: &str, sep_idx: usize) -> Vec<String> {
71 if text.chars().count() <= self.chunk_size {
72 return if text.is_empty() {
73 vec![]
74 } else {
75 vec![text.to_string()]
76 };
77 }
78 let separator = match self.separators.get(sep_idx) {
79 Some(s) if !s.is_empty() => s.clone(),
80 _ => return hard_split(text, self.chunk_size),
81 };
82
83 let mut pieces = Vec::new();
84 for piece in text.split(&separator) {
85 if piece.chars().count() <= self.chunk_size {
86 if !piece.is_empty() {
87 pieces.push(piece.to_string());
88 }
89 } else {
90 pieces.extend(self.recurse(piece, sep_idx + 1));
91 }
92 }
93 pieces
94 }
95}
96
97fn merge_with_overlap(pieces: Vec<String>, chunk_size: usize, overlap: usize) -> Vec<String> {
98 let mut out: Vec<String> = Vec::new();
99 let mut buf = String::new();
100 for piece in pieces {
101 if buf.chars().count() + piece.chars().count() < chunk_size {
102 if !buf.is_empty() {
103 buf.push(' ');
104 }
105 buf.push_str(&piece);
106 } else {
107 if !buf.is_empty() {
108 out.push(buf.clone());
109 }
110 buf = if overlap > 0 {
111 let tail: String = out
112 .last()
113 .map(|s| {
114 let n = s.chars().count();
115 let start = n.saturating_sub(overlap);
116 s.chars().skip(start).collect()
117 })
118 .unwrap_or_default();
119 if tail.is_empty() {
120 piece
121 } else {
122 format!("{tail} {piece}")
123 }
124 } else {
125 piece
126 };
127 }
128 }
129 if !buf.is_empty() {
130 out.push(buf);
131 }
132 out
133}
134
135fn hard_split(text: &str, chunk_size: usize) -> Vec<String> {
136 let chars: Vec<char> = text.chars().collect();
137 chars
138 .chunks(chunk_size.max(1))
139 .map(|c| c.iter().collect::<String>())
140 .collect()
141}
142
143impl TextSplitter for RecursiveCharSplitter {
144 fn split(&self, doc: &Document) -> Vec<Document> {
145 self.split_text(&doc.content)
146 .into_iter()
147 .enumerate()
148 .map(|(i, c)| child_doc(doc, c, i))
149 .collect()
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn splits_paragraphs_first() {
159 let text = "p1.\n\np2.\n\np3.";
160 let s = RecursiveCharSplitter::new()
161 .with_chunk_size(4)
162 .with_overlap(0);
163 let chunks = s.split_text(text);
164 assert!(chunks.iter().all(|c| c.chars().count() <= 4));
165 assert!(chunks.iter().any(|c| c.contains("p1")));
166 assert!(chunks.iter().any(|c| c.contains("p3")));
167 }
168
169 #[test]
170 fn falls_through_to_chars_for_long_run() {
171 let s = RecursiveCharSplitter::new()
172 .with_chunk_size(3)
173 .with_overlap(0);
174 let chunks = s.split_text("abcdefghij");
175 assert!(chunks.iter().all(|c| c.chars().count() <= 3));
176 assert_eq!(chunks.concat().replace(' ', ""), "abcdefghij");
177 }
178
179 #[test]
180 fn small_text_returns_one_chunk() {
181 let s = RecursiveCharSplitter::new().with_chunk_size(100);
182 let chunks = s.split_text("hi");
183 assert_eq!(chunks, vec!["hi".to_string()]);
184 }
185
186 #[test]
187 fn split_doc_propagates_metadata() {
188 let doc = Document::new("a b c d e f g h i j k").with_metadata("source", "f.txt");
189 let s = RecursiveCharSplitter::new()
190 .with_chunk_size(5)
191 .with_overlap(0);
192 let chunks = s.split(&doc);
193 assert!(!chunks.is_empty());
194 for (i, c) in chunks.iter().enumerate() {
195 assert_eq!(c.metadata["source"], "f.txt");
196 assert_eq!(c.metadata["chunk_index"], serde_json::json!(i));
197 }
198 }
199
200 #[test]
201 fn empty_input_yields_no_chunks() {
202 let s = RecursiveCharSplitter::new().with_chunk_size(10);
203 assert!(s.split_text("").is_empty());
204 }
205}