Skip to main content

cognis_rag/splitters/
character.rs

1//! Character-level splitter — split on a single user-supplied separator
2//! string with optional chunk overlap.
3//!
4//! Sister to [`super::recursive::RecursiveCharSplitter`]: where the
5//! recursive splitter walks down a list of separators looking for the
6//! coarsest fit, this one splits on exactly one separator and packs the
7//! resulting fragments into `chunk_size`-bounded chunks.
8//!
9//! Customization knobs:
10//! - `with_chunk_size(n)` — target maximum chunk size (chars).
11//! - `with_overlap(n)` — overlap window between adjacent chunks.
12//! - `with_separator(s)` — the separator. Default: `"\n\n"`.
13//! - `with_keep_separator(bool)` — whether the separator is retained at
14//!   the start of subsequent chunks (mirrors V1 behaviour).
15//! - `with_length_fn(fn)` — pluggable length measure. Defaults to chars
16//!   but accepts any `Fn(&str) -> usize` (e.g. a real tokenizer).
17
18use std::sync::Arc;
19
20use crate::document::Document;
21
22use super::{child_doc, TextSplitter};
23
24/// Function used to measure chunk size. Default counts characters; users
25/// can plug in a tokenizer for token-budget splits.
26pub type LengthFn = Arc<dyn Fn(&str) -> usize + Send + Sync>;
27
28/// Single-separator character splitter.
29pub struct CharacterSplitter {
30    chunk_size: usize,
31    chunk_overlap: usize,
32    separator: String,
33    keep_separator: bool,
34    length_fn: LengthFn,
35}
36
37impl Default for CharacterSplitter {
38    fn default() -> Self {
39        Self {
40            chunk_size: 1000,
41            chunk_overlap: 200,
42            separator: "\n\n".to_string(),
43            keep_separator: false,
44            length_fn: Arc::new(|s: &str| s.chars().count()),
45        }
46    }
47}
48
49impl std::fmt::Debug for CharacterSplitter {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        f.debug_struct("CharacterSplitter")
52            .field("chunk_size", &self.chunk_size)
53            .field("chunk_overlap", &self.chunk_overlap)
54            .field("separator", &self.separator)
55            .field("keep_separator", &self.keep_separator)
56            .finish()
57    }
58}
59
60impl CharacterSplitter {
61    /// New splitter with default settings.
62    pub fn new() -> Self {
63        Self::default()
64    }
65
66    /// Set the maximum chunk size measured by `length_fn` (default: chars).
67    pub fn with_chunk_size(mut self, n: usize) -> Self {
68        self.chunk_size = n;
69        self
70    }
71
72    /// Set the overlap between adjacent chunks.
73    pub fn with_overlap(mut self, n: usize) -> Self {
74        self.chunk_overlap = n;
75        self
76    }
77
78    /// Set the separator string used to split text. Default: `"\n\n"`.
79    pub fn with_separator(mut self, s: impl Into<String>) -> Self {
80        self.separator = s.into();
81        self
82    }
83
84    /// If `true`, the separator is prepended to subsequent chunks
85    /// (preserving boundary context). Default: `false`.
86    pub fn with_keep_separator(mut self, keep: bool) -> Self {
87        self.keep_separator = keep;
88        self
89    }
90
91    /// Replace the length measure (default: chars).
92    pub fn with_length_fn<F>(mut self, f: F) -> Self
93    where
94        F: Fn(&str) -> usize + Send + Sync + 'static,
95    {
96        self.length_fn = Arc::new(f);
97        self
98    }
99
100    /// Pack a sequence of fragments into chunk-sized buckets, respecting
101    /// `chunk_size` and applying `chunk_overlap` between adjacent chunks.
102    fn pack_fragments(&self, fragments: Vec<String>) -> Vec<String> {
103        let len = |s: &str| (self.length_fn)(s);
104        let mut chunks: Vec<String> = Vec::new();
105        let mut current = String::new();
106
107        for frag in fragments.into_iter() {
108            if frag.is_empty() {
109                continue;
110            }
111            // If the fragment alone exceeds chunk_size, hard-split by
112            // characters so we never produce a chunk wider than the budget.
113            if len(&frag) > self.chunk_size {
114                if !current.is_empty() {
115                    chunks.push(std::mem::take(&mut current));
116                }
117                let mut buf = String::new();
118                for ch in frag.chars() {
119                    let cand_len = len(&buf) + len(&ch.to_string());
120                    if cand_len > self.chunk_size && !buf.is_empty() {
121                        chunks.push(std::mem::take(&mut buf));
122                    }
123                    buf.push(ch);
124                }
125                if !buf.is_empty() {
126                    chunks.push(buf);
127                }
128                continue;
129            }
130            // Try to append to `current` if there's room.
131            let separator_cost = if current.is_empty() {
132                0
133            } else {
134                len(&self.separator)
135            };
136            if len(&current) + separator_cost + len(&frag) <= self.chunk_size {
137                if !current.is_empty() {
138                    current.push_str(&self.separator);
139                }
140                current.push_str(&frag);
141            } else {
142                chunks.push(std::mem::take(&mut current));
143                current = frag;
144            }
145        }
146        if !current.is_empty() {
147            chunks.push(current);
148        }
149
150        if self.chunk_overlap > 0 && chunks.len() > 1 {
151            for i in 1..chunks.len() {
152                let prev = chunks[i - 1].clone();
153                let prev_chars: Vec<char> = prev.chars().collect();
154                let take = self.chunk_overlap.min(prev_chars.len());
155                let tail: String = prev_chars[prev_chars.len() - take..].iter().collect();
156                let mut new = tail;
157                new.push_str(&chunks[i]);
158                chunks[i] = new;
159            }
160        }
161        chunks
162    }
163}
164
165impl TextSplitter for CharacterSplitter {
166    fn split(&self, doc: &Document) -> Vec<Document> {
167        if doc.content.is_empty() {
168            return Vec::new();
169        }
170        // Split on the configured separator.
171        let raw: Vec<&str> = doc.content.split(self.separator.as_str()).collect();
172        let fragments: Vec<String> = if self.keep_separator {
173            // Re-attach the separator to the start of each non-first piece.
174            raw.iter()
175                .enumerate()
176                .map(|(i, p)| {
177                    if i == 0 {
178                        (*p).to_string()
179                    } else {
180                        format!("{}{}", self.separator, p)
181                    }
182                })
183                .collect()
184        } else {
185            raw.iter().map(|p| (*p).to_string()).collect()
186        };
187        let chunks = self.pack_fragments(fragments);
188        chunks
189            .into_iter()
190            .enumerate()
191            .map(|(i, c)| child_doc(doc, c, i))
192            .collect()
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    fn doc(content: &str) -> Document {
201        Document::new(content)
202    }
203
204    #[test]
205    fn splits_on_default_double_newline() {
206        let s = CharacterSplitter::new().with_chunk_size(50).with_overlap(0);
207        let d = doc("para one\n\npara two\n\npara three");
208        let out = s.split(&d);
209        assert_eq!(out.len(), 1);
210        assert!(out[0].content.contains("para one"));
211        assert!(out[0].content.contains("para three"));
212    }
213
214    #[test]
215    fn packs_fragments_to_chunk_size() {
216        let s = CharacterSplitter::new()
217            .with_chunk_size(15)
218            .with_overlap(0)
219            .with_separator("|");
220        let d = doc("aaaaa|bbbbb|ccccc|ddddd");
221        let out = s.split(&d);
222        // Each pair fits in 15 (5 + 1 + 5 = 11). Three pairs would be
223        // 5+1+5+1+5 = 17 > 15. Expect: ["aaaaa|bbbbb", "ccccc|ddddd"].
224        assert_eq!(out.len(), 2);
225        assert_eq!(out[0].content, "aaaaa|bbbbb");
226        assert_eq!(out[1].content, "ccccc|ddddd");
227    }
228
229    #[test]
230    fn applies_overlap_between_chunks() {
231        let s = CharacterSplitter::new()
232            .with_chunk_size(10)
233            .with_overlap(3)
234            .with_separator("|");
235        let d = doc("aaaa|bbbb|cccc");
236        let out = s.split(&d);
237        assert!(out.len() >= 2);
238        // The second chunk should start with the last 3 chars of the first.
239        let prev = &out[0].content;
240        let prev_tail: String = prev
241            .chars()
242            .rev()
243            .take(3)
244            .collect::<Vec<_>>()
245            .into_iter()
246            .rev()
247            .collect();
248        assert!(out[1].content.starts_with(&prev_tail));
249    }
250
251    #[test]
252    fn keep_separator_preserves_delimiter() {
253        let s = CharacterSplitter::new()
254            .with_chunk_size(20)
255            .with_overlap(0)
256            .with_separator("\n\n")
257            .with_keep_separator(true);
258        let d = doc("one\n\ntwo\n\nthree");
259        let out = s.split(&d);
260        // Joined content should still contain the separator at chunk boundaries.
261        let joined = out
262            .iter()
263            .map(|c| c.content.clone())
264            .collect::<Vec<_>>()
265            .join("|");
266        assert!(joined.contains("\n\ntwo"));
267    }
268
269    #[test]
270    fn hard_splits_oversized_fragment() {
271        let s = CharacterSplitter::new()
272            .with_chunk_size(5)
273            .with_overlap(0)
274            .with_separator("|");
275        let d = doc("abcdefghij");
276        let out = s.split(&d);
277        // No separator in the input, so single fragment of length 10
278        // gets hard-split into 5+5.
279        assert_eq!(out.len(), 2);
280        assert_eq!(out[0].content.chars().count(), 5);
281        assert_eq!(out[1].content.chars().count(), 5);
282    }
283
284    #[test]
285    fn custom_length_fn_used() {
286        // Pretend each "word" costs 5 units; this lets a 1-word fragment
287        // that's 1 char long still be considered "huge".
288        let s = CharacterSplitter::new()
289            .with_chunk_size(7) // budget 7 units
290            .with_overlap(0)
291            .with_separator(" ")
292            .with_length_fn(|s: &str| s.split_whitespace().count() * 5);
293        let d = doc("a b c");
294        // Each fragment is 1 word → 5 units. With separator of 1 word
295        // (counted: 0 since separator " " has 0 words), two frags fit
296        // (10 > 7? yes — so only one per chunk).
297        let out = s.split(&d);
298        assert!(out.len() >= 2);
299    }
300
301    #[test]
302    fn empty_doc_returns_no_chunks() {
303        let s = CharacterSplitter::new();
304        let d = doc("");
305        let out = s.split(&d);
306        assert!(out.is_empty());
307    }
308
309    #[test]
310    fn metadata_propagates_to_children() {
311        let mut d = doc("aaa|bbb");
312        d.metadata.insert(
313            "source".into(),
314            serde_json::Value::String("file.txt".into()),
315        );
316        let s = CharacterSplitter::new()
317            .with_chunk_size(3)
318            .with_overlap(0)
319            .with_separator("|");
320        let out = s.split(&d);
321        assert!(out.iter().all(
322            |c| c.metadata.get("source") == Some(&serde_json::Value::String("file.txt".into()))
323        ));
324        // chunk_index is set on every child.
325        assert!(out.iter().enumerate().all(|(i, c)| {
326            c.metadata.get("chunk_index").and_then(|v| v.as_u64()) == Some(i as u64)
327        }));
328    }
329}