cognis_rag/splitters/
recursive.rs1use crate::document::Document;
5
6use super::{child_doc, TextSplitter};
7
8pub struct RecursiveCharSplitter {
14 chunk_size: usize,
15 chunk_overlap: usize,
16 separators: Vec<String>,
17}
18
19impl Default for RecursiveCharSplitter {
20 fn default() -> Self {
21 Self {
22 chunk_size: 1000,
23 chunk_overlap: 200,
24 separators: vec![
25 "\n\n".to_string(),
26 "\n".to_string(),
27 ". ".to_string(),
28 " ".to_string(),
29 "".to_string(),
30 ],
31 }
32 }
33}
34
35impl RecursiveCharSplitter {
36 pub fn new() -> Self {
38 Self::default()
39 }
40
41 pub fn with_chunk_size(mut self, n: usize) -> Self {
43 self.chunk_size = n;
44 self
45 }
46
47 pub fn with_overlap(mut self, n: usize) -> Self {
49 self.chunk_overlap = n;
50 self
51 }
52
53 pub fn with_separators<I, S>(mut self, seps: I) -> Self
55 where
56 I: IntoIterator<Item = S>,
57 S: Into<String>,
58 {
59 self.separators = seps.into_iter().map(Into::into).collect();
60 self
61 }
62
63 fn split_text(&self, text: &str) -> Vec<String> {
64 let pieces = self.recurse(text, 0);
65 merge_with_overlap(pieces, self.chunk_size, self.chunk_overlap)
67 }
68
69 fn recurse(&self, text: &str, sep_idx: usize) -> Vec<String> {
70 if text.chars().count() <= self.chunk_size {
71 return if text.is_empty() {
72 vec![]
73 } else {
74 vec![text.to_string()]
75 };
76 }
77 let separator = match self.separators.get(sep_idx) {
78 Some(s) if !s.is_empty() => s.clone(),
79 _ => return hard_split(text, self.chunk_size),
80 };
81
82 let mut pieces = Vec::new();
83 for piece in text.split(&separator) {
84 if piece.chars().count() <= self.chunk_size {
85 if !piece.is_empty() {
86 pieces.push(piece.to_string());
87 }
88 } else {
89 pieces.extend(self.recurse(piece, sep_idx + 1));
90 }
91 }
92 pieces
93 }
94}
95
96fn merge_with_overlap(pieces: Vec<String>, chunk_size: usize, overlap: usize) -> Vec<String> {
97 let mut out: Vec<String> = Vec::new();
98 let mut buf = String::new();
99 for piece in pieces {
100 if buf.chars().count() + piece.chars().count() < chunk_size {
101 if !buf.is_empty() {
102 buf.push(' ');
103 }
104 buf.push_str(&piece);
105 } else {
106 if !buf.is_empty() {
107 out.push(buf.clone());
108 }
109 buf = if overlap > 0 {
110 let tail: String = out
111 .last()
112 .map(|s| {
113 let n = s.chars().count();
114 let start = n.saturating_sub(overlap);
115 s.chars().skip(start).collect()
116 })
117 .unwrap_or_default();
118 if tail.is_empty() {
119 piece
120 } else {
121 format!("{tail} {piece}")
122 }
123 } else {
124 piece
125 };
126 }
127 }
128 if !buf.is_empty() {
129 out.push(buf);
130 }
131 out
132}
133
134fn hard_split(text: &str, chunk_size: usize) -> Vec<String> {
135 let chars: Vec<char> = text.chars().collect();
136 chars
137 .chunks(chunk_size.max(1))
138 .map(|c| c.iter().collect::<String>())
139 .collect()
140}
141
142impl TextSplitter for RecursiveCharSplitter {
143 fn split(&self, doc: &Document) -> Vec<Document> {
144 self.split_text(&doc.content)
145 .into_iter()
146 .enumerate()
147 .map(|(i, c)| child_doc(doc, c, i))
148 .collect()
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn splits_paragraphs_first() {
158 let text = "p1.\n\np2.\n\np3.";
159 let s = RecursiveCharSplitter::new()
160 .with_chunk_size(4)
161 .with_overlap(0);
162 let chunks = s.split_text(text);
163 assert!(chunks.iter().all(|c| c.chars().count() <= 4));
164 assert!(chunks.iter().any(|c| c.contains("p1")));
165 assert!(chunks.iter().any(|c| c.contains("p3")));
166 }
167
168 #[test]
169 fn falls_through_to_chars_for_long_run() {
170 let s = RecursiveCharSplitter::new()
171 .with_chunk_size(3)
172 .with_overlap(0);
173 let chunks = s.split_text("abcdefghij");
174 assert!(chunks.iter().all(|c| c.chars().count() <= 3));
175 assert_eq!(chunks.concat().replace(' ', ""), "abcdefghij");
176 }
177
178 #[test]
179 fn small_text_returns_one_chunk() {
180 let s = RecursiveCharSplitter::new().with_chunk_size(100);
181 let chunks = s.split_text("hi");
182 assert_eq!(chunks, vec!["hi".to_string()]);
183 }
184
185 #[test]
186 fn split_doc_propagates_metadata() {
187 let doc = Document::new("a b c d e f g h i j k").with_metadata("source", "f.txt");
188 let s = RecursiveCharSplitter::new()
189 .with_chunk_size(5)
190 .with_overlap(0);
191 let chunks = s.split(&doc);
192 assert!(!chunks.is_empty());
193 for (i, c) in chunks.iter().enumerate() {
194 assert_eq!(c.metadata["source"], "f.txt");
195 assert_eq!(c.metadata["chunk_index"], serde_json::json!(i));
196 }
197 }
198
199 #[test]
200 fn empty_input_yields_no_chunks() {
201 let s = RecursiveCharSplitter::new().with_chunk_size(10);
202 assert!(s.split_text("").is_empty());
203 }
204}