Skip to main content

cognis_rag/splitters/
recursive.rs

1//! Recursive character splitter — tries successively coarser separators
2//! (paragraph → line → sentence → word) until each chunk fits the target size.
3
4use crate::document::Document;
5
6use super::{child_doc, TextSplitter};
7
8/// Splits text by character count, falling back through a list of
9/// separators until each chunk is at most `chunk_size` characters.
10///
11/// Adjacent chunks share a `chunk_overlap` window so context isn't lost
12/// at boundaries. Mirrors LangChain's `RecursiveCharacterTextSplitter` in
13/// behaviour but operates on Rust `&str` and counts characters, not tokens.
14pub struct RecursiveCharSplitter {
15    chunk_size: usize,
16    chunk_overlap: usize,
17    separators: Vec<String>,
18}
19
20impl Default for RecursiveCharSplitter {
21    fn default() -> Self {
22        Self {
23            chunk_size: 1000,
24            chunk_overlap: 200,
25            separators: vec![
26                "\n\n".to_string(),
27                "\n".to_string(),
28                ". ".to_string(),
29                " ".to_string(),
30                "".to_string(),
31            ],
32        }
33    }
34}
35
36impl RecursiveCharSplitter {
37    /// Construct with default settings.
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Set the target maximum chunk size (in chars).
43    pub fn with_chunk_size(mut self, n: usize) -> Self {
44        self.chunk_size = n;
45        self
46    }
47
48    /// Set the overlap window between adjacent chunks.
49    pub fn with_overlap(mut self, n: usize) -> Self {
50        self.chunk_overlap = n;
51        self
52    }
53
54    /// Override the separator list (tried in order, coarsest first).
55    pub fn with_separators<I, S>(mut self, seps: I) -> Self
56    where
57        I: IntoIterator<Item = S>,
58        S: Into<String>,
59    {
60        self.separators = seps.into_iter().map(Into::into).collect();
61        self
62    }
63
64    fn split_text(&self, text: &str) -> Vec<String> {
65        let pieces = self.recurse(text, 0);
66        // Merge adjacent pieces up to chunk_size with overlap.
67        merge_with_overlap(pieces, self.chunk_size, self.chunk_overlap)
68    }
69
70    fn recurse(&self, text: &str, sep_idx: usize) -> Vec<String> {
71        if text.chars().count() <= self.chunk_size {
72            return if text.is_empty() {
73                vec![]
74            } else {
75                vec![text.to_string()]
76            };
77        }
78        let separator = match self.separators.get(sep_idx) {
79            Some(s) if !s.is_empty() => s.clone(),
80            _ => return hard_split(text, self.chunk_size),
81        };
82
83        let mut pieces = Vec::new();
84        for piece in text.split(&separator) {
85            if piece.chars().count() <= self.chunk_size {
86                if !piece.is_empty() {
87                    pieces.push(piece.to_string());
88                }
89            } else {
90                pieces.extend(self.recurse(piece, sep_idx + 1));
91            }
92        }
93        pieces
94    }
95}
96
97fn merge_with_overlap(pieces: Vec<String>, chunk_size: usize, overlap: usize) -> Vec<String> {
98    let mut out: Vec<String> = Vec::new();
99    let mut buf = String::new();
100    for piece in pieces {
101        if buf.chars().count() + piece.chars().count() < chunk_size {
102            if !buf.is_empty() {
103                buf.push(' ');
104            }
105            buf.push_str(&piece);
106        } else {
107            if !buf.is_empty() {
108                out.push(buf.clone());
109            }
110            buf = if overlap > 0 {
111                let tail: String = out
112                    .last()
113                    .map(|s| {
114                        let n = s.chars().count();
115                        let start = n.saturating_sub(overlap);
116                        s.chars().skip(start).collect()
117                    })
118                    .unwrap_or_default();
119                if tail.is_empty() {
120                    piece
121                } else {
122                    format!("{tail} {piece}")
123                }
124            } else {
125                piece
126            };
127        }
128    }
129    if !buf.is_empty() {
130        out.push(buf);
131    }
132    out
133}
134
135fn hard_split(text: &str, chunk_size: usize) -> Vec<String> {
136    let chars: Vec<char> = text.chars().collect();
137    chars
138        .chunks(chunk_size.max(1))
139        .map(|c| c.iter().collect::<String>())
140        .collect()
141}
142
143impl TextSplitter for RecursiveCharSplitter {
144    fn split(&self, doc: &Document) -> Vec<Document> {
145        self.split_text(&doc.content)
146            .into_iter()
147            .enumerate()
148            .map(|(i, c)| child_doc(doc, c, i))
149            .collect()
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn splits_paragraphs_first() {
159        let text = "p1.\n\np2.\n\np3.";
160        let s = RecursiveCharSplitter::new()
161            .with_chunk_size(4)
162            .with_overlap(0);
163        let chunks = s.split_text(text);
164        assert!(chunks.iter().all(|c| c.chars().count() <= 4));
165        assert!(chunks.iter().any(|c| c.contains("p1")));
166        assert!(chunks.iter().any(|c| c.contains("p3")));
167    }
168
169    #[test]
170    fn falls_through_to_chars_for_long_run() {
171        let s = RecursiveCharSplitter::new()
172            .with_chunk_size(3)
173            .with_overlap(0);
174        let chunks = s.split_text("abcdefghij");
175        assert!(chunks.iter().all(|c| c.chars().count() <= 3));
176        assert_eq!(chunks.concat().replace(' ', ""), "abcdefghij");
177    }
178
179    #[test]
180    fn small_text_returns_one_chunk() {
181        let s = RecursiveCharSplitter::new().with_chunk_size(100);
182        let chunks = s.split_text("hi");
183        assert_eq!(chunks, vec!["hi".to_string()]);
184    }
185
186    #[test]
187    fn split_doc_propagates_metadata() {
188        let doc = Document::new("a b c d e f g h i j k").with_metadata("source", "f.txt");
189        let s = RecursiveCharSplitter::new()
190            .with_chunk_size(5)
191            .with_overlap(0);
192        let chunks = s.split(&doc);
193        assert!(!chunks.is_empty());
194        for (i, c) in chunks.iter().enumerate() {
195            assert_eq!(c.metadata["source"], "f.txt");
196            assert_eq!(c.metadata["chunk_index"], serde_json::json!(i));
197        }
198    }
199
200    #[test]
201    fn empty_input_yields_no_chunks() {
202        let s = RecursiveCharSplitter::new().with_chunk_size(10);
203        assert!(s.split_text("").is_empty());
204    }
205}