Skip to main content

oxibonsai_rag/
code_chunker.rs

1//! Language-aware chunking for source files.
2//!
3//! [`CodeChunker`] splits code by language-specific structural markers so
4//! that retrieval returns syntactically meaningful units rather than
5//! arbitrary character windows.  Supported languages are enumerated by
6//! [`Language`]; anything outside that set falls back to
7//! [`RecursiveCharSplitter`].
8//!
9//! # Language splitters
10//!
11//! - [`Language::Rust`] — splits on top-level `fn `, `impl `, `struct `,
12//!   `enum `, `mod `, and `pub fn ` forms that occur at the start of a
13//!   line (preceded by `\n`).
14//! - [`Language::Python`] — splits on `\nclass ` and `\ndef ` at the start
15//!   of a line.
16//! - [`Language::Json`] — if the document parses as a JSON array/object,
17//!   each depth-1 child becomes a chunk.  Malformed JSON falls back to
18//!   [`RecursiveCharSplitter`].
19//! - [`Language::Plain`] — delegates to [`RecursiveCharSplitter`] with a
20//!   configurable window size.
21
22use serde::{Deserialize, Serialize};
23
24use crate::advanced_chunker::{ChunkStrategy, RecursiveCharSplitter};
25use crate::chunker::Chunk;
26use crate::error::RagError;
27
28// ─────────────────────────────────────────────────────────────────────────────
29// Language enum
30// ─────────────────────────────────────────────────────────────────────────────
31
32/// Source-code languages recognised by [`CodeChunker`].
33#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
34pub enum Language {
35    /// Rust source.
36    Rust,
37    /// Python source.
38    Python,
39    /// JSON document (depth-1 child splitter).
40    Json,
41    /// Anything else — delegates to the character-recursive fallback.
42    #[default]
43    Plain,
44}
45
46impl Language {
47    /// Classify a file based on its lower-cased extension.  Unknown
48    /// extensions map to [`Language::Plain`].
49    pub fn from_extension(ext: &str) -> Self {
50        match ext.to_ascii_lowercase().as_str() {
51            "rs" => Self::Rust,
52            "py" => Self::Python,
53            "json" => Self::Json,
54            _ => Self::Plain,
55        }
56    }
57}
58
59// ─────────────────────────────────────────────────────────────────────────────
60// CodeChunker
61// ─────────────────────────────────────────────────────────────────────────────
62
63/// Language-aware code chunker.
64pub struct CodeChunker {
65    language: Language,
66    fallback_window: usize,
67    min_chunk_chars: usize,
68}
69
70impl Default for CodeChunker {
71    fn default() -> Self {
72        Self {
73            language: Language::default(),
74            fallback_window: 1024,
75            min_chunk_chars: 16,
76        }
77    }
78}
79
80impl CodeChunker {
81    /// Create a chunker for `language` with default parameters.
82    pub fn new(language: Language) -> Self {
83        Self {
84            language,
85            ..Self::default()
86        }
87    }
88
89    /// Configure the fallback [`RecursiveCharSplitter`] window size.
90    #[must_use]
91    pub fn with_fallback_window(mut self, window: usize) -> Self {
92        self.fallback_window = window.max(64);
93        self
94    }
95
96    /// Discard chunks shorter than `min` characters.
97    #[must_use]
98    pub fn with_min_chunk_chars(mut self, min: usize) -> Self {
99        self.min_chunk_chars = min;
100        self
101    }
102
103    /// The language this chunker was constructed with.
104    pub fn language(&self) -> Language {
105        self.language
106    }
107
108    /// Split `text` into chunks appropriate for its language.
109    pub fn chunk(&self, text: &str, doc_id: usize) -> Result<Vec<Chunk>, RagError> {
110        if text.trim().is_empty() {
111            return Ok(Vec::new());
112        }
113
114        let raw_chunks: Vec<(usize, String)> = match self.language {
115            Language::Rust => split_rust(text),
116            Language::Python => split_python(text),
117            Language::Json => split_json(text).unwrap_or_else(|| self.split_plain(text)),
118            Language::Plain => self.split_plain(text),
119        };
120
121        let mut out = Vec::with_capacity(raw_chunks.len());
122        for (char_offset, body) in raw_chunks {
123            let trimmed = body.trim();
124            if trimmed.chars().count() < self.min_chunk_chars {
125                continue;
126            }
127            let chunk_idx = out.len();
128            out.push(Chunk::new(
129                trimmed.to_string(),
130                doc_id,
131                chunk_idx,
132                char_offset,
133            ));
134        }
135
136        // Fallback: if we produced nothing (e.g. a short file with no
137        // structural markers) treat the whole document as one chunk.
138        if out.is_empty() {
139            out.push(Chunk::new(text.trim().to_string(), doc_id, 0, 0));
140        }
141
142        Ok(out)
143    }
144
145    fn split_plain(&self, text: &str) -> Vec<(usize, String)> {
146        let splitter = RecursiveCharSplitter::new(self.fallback_window);
147        splitter
148            .chunk(text)
149            .into_iter()
150            .map(|rc| (rc.char_start, rc.text))
151            .collect()
152    }
153}
154
155// ─────────────────────────────────────────────────────────────────────────────
156// Language-specific splitters
157// ─────────────────────────────────────────────────────────────────────────────
158
159/// Markers that signal the start of a new top-level Rust item.
160const RUST_MARKERS: &[&str] = &[
161    "\nfn ",
162    "\npub fn ",
163    "\nimpl ",
164    "\nstruct ",
165    "\nenum ",
166    "\nmod ",
167    "\npub mod ",
168    "\ntrait ",
169    "\npub struct ",
170    "\npub enum ",
171    "\npub trait ",
172];
173
174fn split_rust(text: &str) -> Vec<(usize, String)> {
175    split_by_line_prefixes(text, RUST_MARKERS)
176}
177
178/// Markers that signal the start of a new top-level Python definition.
179const PYTHON_MARKERS: &[&str] = &["\nclass ", "\ndef ", "\nasync def "];
180
181fn split_python(text: &str) -> Vec<(usize, String)> {
182    split_by_line_prefixes(text, PYTHON_MARKERS)
183}
184
185/// Split on the supplied line-prefix markers.  Each chunk starts at the
186/// `\n` that precedes a marker (so the marker itself is included) and runs
187/// up to the next marker.
188fn split_by_line_prefixes(text: &str, markers: &[&str]) -> Vec<(usize, String)> {
189    // Collect byte offsets of every marker occurrence
190    let mut boundaries: Vec<usize> = Vec::new();
191    for marker in markers {
192        let mut start = 0usize;
193        while let Some(idx) = text[start..].find(marker) {
194            let absolute = start + idx + 1; // skip the leading '\n'
195            boundaries.push(absolute);
196            start = absolute + marker.len() - 1;
197        }
198    }
199    boundaries.sort_unstable();
200    boundaries.dedup();
201
202    // Prepend 0 so the prologue (everything before the first marker) is
203    // treated as its own chunk.
204    let mut starts = Vec::with_capacity(boundaries.len() + 1);
205    starts.push(0usize);
206    starts.extend(boundaries);
207    starts.dedup();
208
209    let mut out = Vec::with_capacity(starts.len());
210    for i in 0..starts.len() {
211        let begin = starts[i];
212        let end = starts.get(i + 1).copied().unwrap_or(text.len());
213        let body = &text[begin..end];
214        if !body.trim().is_empty() {
215            out.push((begin, body.to_string()));
216        }
217    }
218    out
219}
220
221fn split_json(text: &str) -> Option<Vec<(usize, String)>> {
222    let value: serde_json::Value = serde_json::from_str(text).ok()?;
223    match value {
224        serde_json::Value::Array(items) => {
225            let mut out = Vec::with_capacity(items.len());
226            for (idx, item) in items.into_iter().enumerate() {
227                if let Ok(text) = serde_json::to_string_pretty(&item) {
228                    out.push((idx, text));
229                }
230            }
231            Some(out)
232        }
233        serde_json::Value::Object(obj) => {
234            let mut out = Vec::with_capacity(obj.len());
235            for (idx, (key, value)) in obj.into_iter().enumerate() {
236                let body = match serde_json::to_string_pretty(&value) {
237                    Ok(s) => s,
238                    Err(_) => continue,
239                };
240                out.push((idx, format!("\"{key}\": {body}")));
241            }
242            Some(out)
243        }
244        // Scalar JSON values aren't splittable — signal failure so we fall
245        // back to the character splitter.
246        _ => None,
247    }
248}
249
250// ─────────────────────────────────────────────────────────────────────────────
251// Inline tests
252// ─────────────────────────────────────────────────────────────────────────────
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257
258    #[test]
259    fn rust_splits_on_fn_markers() {
260        let source = "\nfn one() {}\nfn two() {}\nfn three() {}\n";
261        let chunker = CodeChunker::new(Language::Rust).with_min_chunk_chars(1);
262        let chunks = chunker.chunk(source, 0).expect("chunk");
263        assert!(chunks.len() >= 3, "got {} chunks", chunks.len());
264    }
265
266    #[test]
267    fn python_splits_on_def_and_class() {
268        let source =
269            "\nclass A:\n    pass\n\ndef foo():\n    return 1\n\ndef bar():\n    return 2\n";
270        let chunker = CodeChunker::new(Language::Python).with_min_chunk_chars(1);
271        let chunks = chunker.chunk(source, 0).expect("chunk");
272        assert!(chunks.len() >= 3, "got {} chunks", chunks.len());
273    }
274
275    #[test]
276    fn json_array_splits_by_element() {
277        let source = "[1, 2, 3, 4]";
278        let chunker = CodeChunker::new(Language::Json).with_min_chunk_chars(1);
279        let chunks = chunker.chunk(source, 0).expect("chunk");
280        assert_eq!(chunks.len(), 4);
281    }
282
283    #[test]
284    fn plain_delegates_to_splitter() {
285        let text = "a".repeat(4096);
286        let chunker = CodeChunker::new(Language::Plain)
287            .with_fallback_window(512)
288            .with_min_chunk_chars(1);
289        let chunks = chunker.chunk(&text, 0).expect("chunk");
290        assert!(chunks.len() > 1);
291    }
292}