Skip to main content

reddb_server/runtime/ai/
text_chunker.rs

1//! Text chunker — issue #277.
2//!
3//! Splits text into chunks using strategy: paragraph > sentence > character.
4//! Approximate tokenisation: 1 token ≈ 4 bytes (ASCII-biased heuristic).
5//!
6//! Single mode (default): returns only the first chunk (preserves 1:1 mapping
7//! to embeddings). Multi mode: returns all chunks.
8
9use std::sync::atomic::{AtomicU64, Ordering};
10
11pub const CONFIG_CHUNK_MODE: &str = "runtime.ai.embedding_chunk_mode";
12pub const CONFIG_MAX_TOKENS: &str = "runtime.ai.embedding_max_tokens";
13pub const DEFAULT_MAX_TOKENS: usize = 8192;
14
15/// Approximate bytes per token (ASCII-biased heuristic).
16const BYTES_PER_TOKEN: usize = 4;
17
18/// Global counter: how many texts were chunked (i.e. exceeded max_tokens).
19static CHUNKED_TOTAL: AtomicU64 = AtomicU64::new(0);
20
21pub fn chunked_total() -> u64 {
22    CHUNKED_TOTAL.load(Ordering::Relaxed)
23}
24
25#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
26pub enum ChunkMode {
27    /// Return only the first chunk — preserves 1:1 input→embedding mapping.
28    #[default]
29    Single,
30    /// Return all chunks — downstream decides how to merge embeddings.
31    Multi,
32}
33
34impl ChunkMode {
35    pub fn from_str(s: &str) -> Self {
36        match s.trim().to_lowercase().as_str() {
37            "multi" => Self::Multi,
38            _ => Self::Single,
39        }
40    }
41}
42
43/// Chunk `text` into pieces where each piece is ≤ `max_tokens` tokens.
44/// Strategy (greedy, in priority order):
45///   1. Split on blank lines (paragraphs)
46///   2. Split on sentence boundaries (`. `, `! `, `? `)
47///   3. Hard-split on character boundary
48///
49/// Returns a `Vec<String>`. The caller applies `ChunkMode`:
50/// - `Single`: take `chunks[0]`
51/// - `Multi`: use all chunks
52pub fn chunk(text: &str, max_tokens: usize) -> Vec<String> {
53    let max_bytes = max_tokens * BYTES_PER_TOKEN;
54    if text.len() <= max_bytes {
55        return vec![text.to_string()];
56    }
57    CHUNKED_TOTAL.fetch_add(1, Ordering::Relaxed);
58    split_into_chunks(text, max_bytes)
59}
60
61/// Apply chunk mode to pre-chunked `Vec<String>`.
62/// Single → first element only. Multi → all.
63pub fn apply_mode(chunks: Vec<String>, mode: ChunkMode) -> Vec<String> {
64    match mode {
65        ChunkMode::Single => chunks.into_iter().take(1).collect(),
66        ChunkMode::Multi => chunks,
67    }
68}
69
70fn split_into_chunks(text: &str, max_bytes: usize) -> Vec<String> {
71    let mut chunks = Vec::new();
72    let mut current = String::new();
73
74    for paragraph in split_paragraphs(text) {
75        if paragraph.is_empty() {
76            continue;
77        }
78        if current.len() + paragraph.len() <= max_bytes {
79            if !current.is_empty() {
80                current.push('\n');
81            }
82            current.push_str(&paragraph);
83        } else if paragraph.len() <= max_bytes {
84            if !current.is_empty() {
85                chunks.push(std::mem::take(&mut current));
86            }
87            current = paragraph;
88        } else {
89            // Paragraph itself is too large — split by sentence
90            if !current.is_empty() {
91                chunks.push(std::mem::take(&mut current));
92            }
93            for sentence in split_sentences(&paragraph) {
94                if current.len() + sentence.len() <= max_bytes {
95                    if !current.is_empty() {
96                        current.push(' ');
97                    }
98                    current.push_str(&sentence);
99                } else if sentence.len() <= max_bytes {
100                    if !current.is_empty() {
101                        chunks.push(std::mem::take(&mut current));
102                    }
103                    current = sentence;
104                } else {
105                    // Sentence too large — hard split by bytes
106                    if !current.is_empty() {
107                        chunks.push(std::mem::take(&mut current));
108                    }
109                    chunks.extend(hard_split(&sentence, max_bytes));
110                }
111            }
112        }
113    }
114
115    if !current.is_empty() {
116        chunks.push(current);
117    }
118
119    if chunks.is_empty() {
120        chunks.push(String::new());
121    }
122
123    chunks
124}
125
126fn split_paragraphs(text: &str) -> Vec<String> {
127    text.split("\n\n")
128        .map(|p| p.trim().to_string())
129        .filter(|p| !p.is_empty())
130        .collect()
131}
132
133fn split_sentences(text: &str) -> Vec<String> {
134    let mut result = Vec::new();
135    let mut current = String::new();
136    let chars: Vec<char> = text.chars().collect();
137    let len = chars.len();
138    let mut i = 0;
139    while i < len {
140        current.push(chars[i]);
141        if (chars[i] == '.' || chars[i] == '!' || chars[i] == '?')
142            && i + 1 < len
143            && chars[i + 1] == ' '
144        {
145            result.push(current.trim().to_string());
146            current = String::new();
147            i += 2; // skip the space
148            continue;
149        }
150        i += 1;
151    }
152    if !current.trim().is_empty() {
153        result.push(current.trim().to_string());
154    }
155    if result.is_empty() {
156        result.push(text.to_string());
157    }
158    result
159}
160
161fn hard_split(text: &str, max_bytes: usize) -> Vec<String> {
162    let mut chunks = Vec::new();
163    let bytes = text.as_bytes();
164    let mut start = 0;
165    while start < bytes.len() {
166        let mut end = (start + max_bytes).min(bytes.len());
167        // walk back to a valid UTF-8 boundary
168        while end > start && !text.is_char_boundary(end) {
169            end -= 1;
170        }
171        if end == start {
172            end = start + 1; // safety: advance at least one byte
173        }
174        chunks.push(text[start..end].to_string());
175        start = end;
176    }
177    chunks
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn short_text_not_chunked() {
186        let text = "hello world";
187        let chunks = chunk(text, 8192);
188        assert_eq!(chunks.len(), 1);
189        assert_eq!(chunks[0], text);
190    }
191
192    #[test]
193    fn long_text_chunked_single_mode() {
194        // 8K tokens * 4 bytes = 32768 bytes threshold
195        let long_text = "word ".repeat(10_000); // ~50000 bytes
196        let chunks = chunk(&long_text, 8192);
197        assert!(chunks.len() > 1, "long text should produce multiple chunks");
198        let first = apply_mode(chunks, ChunkMode::Single);
199        assert_eq!(first.len(), 1);
200        assert!(first[0].len() <= 8192 * 4 + 1); // within token limit
201    }
202
203    #[test]
204    fn long_text_multi_mode_returns_all() {
205        let long_text = "word ".repeat(10_000);
206        let chunks_single = chunk(&long_text, 8192);
207        let n = chunks_single.len();
208        let all = apply_mode(chunk(&long_text, 8192), ChunkMode::Multi);
209        assert_eq!(all.len(), n);
210    }
211
212    #[test]
213    fn paragraph_split_preference() {
214        let text = "First paragraph with some content.\n\nSecond paragraph with more text.";
215        let chunks = chunk(text, 8); // 8 tokens = 32 bytes — small to force split
216                                     // Both paragraphs have ~35-40 bytes, so they split
217        assert!(chunks.len() >= 1);
218        // Each chunk fits within 32 bytes
219        for c in &chunks {
220            assert!(c.len() <= 8 * 4 + 10, "chunk too large: {}", c.len());
221        }
222    }
223
224    #[test]
225    fn chunk_mode_from_str() {
226        assert_eq!(ChunkMode::from_str("multi"), ChunkMode::Multi);
227        assert_eq!(ChunkMode::from_str("single"), ChunkMode::Single);
228        assert_eq!(ChunkMode::from_str("unknown"), ChunkMode::Single);
229    }
230
231    #[test]
232    fn hard_split_handles_multibyte_utf8() {
233        // "café" repeated — ensure no split in middle of multi-byte char
234        let text = "café ".repeat(2000); // each "café " is 6 bytes
235        let chunks = hard_split(&text, 32);
236        // All chunks should be valid UTF-8
237        for c in &chunks {
238            assert!(std::str::from_utf8(c.as_bytes()).is_ok());
239        }
240    }
241}