Skip to main content

cognis_rag/splitters/
recursive.rs

1//! Recursive character splitter — tries successively coarser separators
2//! (paragraph → line → sentence → word) until each chunk fits the target size.
3
4use crate::document::Document;
5
6use super::{child_doc, TextSplitter};
7
8/// Splits text by character count, falling back through a list of
9/// separators until each chunk is at most `chunk_size` characters.
10///
11/// Adjacent chunks share a `chunk_overlap` window so context isn't lost
12/// at boundaries. Operates on Rust `&str` and counts characters, not tokens.
13pub struct RecursiveCharSplitter {
14    chunk_size: usize,
15    chunk_overlap: usize,
16    separators: Vec<String>,
17}
18
19impl Default for RecursiveCharSplitter {
20    fn default() -> Self {
21        Self {
22            chunk_size: 1000,
23            chunk_overlap: 200,
24            separators: vec![
25                "\n\n".to_string(),
26                "\n".to_string(),
27                ". ".to_string(),
28                " ".to_string(),
29                "".to_string(),
30            ],
31        }
32    }
33}
34
35impl RecursiveCharSplitter {
36    /// Construct with default settings.
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    /// Set the target maximum chunk size (in chars).
42    pub fn with_chunk_size(mut self, n: usize) -> Self {
43        self.chunk_size = n;
44        self
45    }
46
47    /// Set the overlap window between adjacent chunks.
48    pub fn with_overlap(mut self, n: usize) -> Self {
49        self.chunk_overlap = n;
50        self
51    }
52
53    /// Override the separator list (tried in order, coarsest first).
54    pub fn with_separators<I, S>(mut self, seps: I) -> Self
55    where
56        I: IntoIterator<Item = S>,
57        S: Into<String>,
58    {
59        self.separators = seps.into_iter().map(Into::into).collect();
60        self
61    }
62
63    fn split_text(&self, text: &str) -> Vec<String> {
64        let pieces = self.recurse(text, 0);
65        // Merge adjacent pieces up to chunk_size with overlap.
66        merge_with_overlap(pieces, self.chunk_size, self.chunk_overlap)
67    }
68
69    fn recurse(&self, text: &str, sep_idx: usize) -> Vec<String> {
70        if text.chars().count() <= self.chunk_size {
71            return if text.is_empty() {
72                vec![]
73            } else {
74                vec![text.to_string()]
75            };
76        }
77        let separator = match self.separators.get(sep_idx) {
78            Some(s) if !s.is_empty() => s.clone(),
79            _ => return hard_split(text, self.chunk_size),
80        };
81
82        let mut pieces = Vec::new();
83        for piece in text.split(&separator) {
84            if piece.chars().count() <= self.chunk_size {
85                if !piece.is_empty() {
86                    pieces.push(piece.to_string());
87                }
88            } else {
89                pieces.extend(self.recurse(piece, sep_idx + 1));
90            }
91        }
92        pieces
93    }
94}
95
96fn merge_with_overlap(pieces: Vec<String>, chunk_size: usize, overlap: usize) -> Vec<String> {
97    let mut out: Vec<String> = Vec::new();
98    let mut buf = String::new();
99    for piece in pieces {
100        if buf.chars().count() + piece.chars().count() < chunk_size {
101            if !buf.is_empty() {
102                buf.push(' ');
103            }
104            buf.push_str(&piece);
105        } else {
106            if !buf.is_empty() {
107                out.push(buf.clone());
108            }
109            buf = if overlap > 0 {
110                let tail: String = out
111                    .last()
112                    .map(|s| {
113                        let n = s.chars().count();
114                        let start = n.saturating_sub(overlap);
115                        s.chars().skip(start).collect()
116                    })
117                    .unwrap_or_default();
118                if tail.is_empty() {
119                    piece
120                } else {
121                    format!("{tail} {piece}")
122                }
123            } else {
124                piece
125            };
126        }
127    }
128    if !buf.is_empty() {
129        out.push(buf);
130    }
131    out
132}
133
134fn hard_split(text: &str, chunk_size: usize) -> Vec<String> {
135    let chars: Vec<char> = text.chars().collect();
136    chars
137        .chunks(chunk_size.max(1))
138        .map(|c| c.iter().collect::<String>())
139        .collect()
140}
141
142impl TextSplitter for RecursiveCharSplitter {
143    fn split(&self, doc: &Document) -> Vec<Document> {
144        self.split_text(&doc.content)
145            .into_iter()
146            .enumerate()
147            .map(|(i, c)| child_doc(doc, c, i))
148            .collect()
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn splits_paragraphs_first() {
158        let text = "p1.\n\np2.\n\np3.";
159        let s = RecursiveCharSplitter::new()
160            .with_chunk_size(4)
161            .with_overlap(0);
162        let chunks = s.split_text(text);
163        assert!(chunks.iter().all(|c| c.chars().count() <= 4));
164        assert!(chunks.iter().any(|c| c.contains("p1")));
165        assert!(chunks.iter().any(|c| c.contains("p3")));
166    }
167
168    #[test]
169    fn falls_through_to_chars_for_long_run() {
170        let s = RecursiveCharSplitter::new()
171            .with_chunk_size(3)
172            .with_overlap(0);
173        let chunks = s.split_text("abcdefghij");
174        assert!(chunks.iter().all(|c| c.chars().count() <= 3));
175        assert_eq!(chunks.concat().replace(' ', ""), "abcdefghij");
176    }
177
178    #[test]
179    fn small_text_returns_one_chunk() {
180        let s = RecursiveCharSplitter::new().with_chunk_size(100);
181        let chunks = s.split_text("hi");
182        assert_eq!(chunks, vec!["hi".to_string()]);
183    }
184
185    #[test]
186    fn split_doc_propagates_metadata() {
187        let doc = Document::new("a b c d e f g h i j k").with_metadata("source", "f.txt");
188        let s = RecursiveCharSplitter::new()
189            .with_chunk_size(5)
190            .with_overlap(0);
191        let chunks = s.split(&doc);
192        assert!(!chunks.is_empty());
193        for (i, c) in chunks.iter().enumerate() {
194            assert_eq!(c.metadata["source"], "f.txt");
195            assert_eq!(c.metadata["chunk_index"], serde_json::json!(i));
196        }
197    }
198
199    #[test]
200    fn empty_input_yields_no_chunks() {
201        let s = RecursiveCharSplitter::new().with_chunk_size(10);
202        assert!(s.split_text("").is_empty());
203    }
204}