1use serde::{Deserialize, Serialize};
23
24use crate::advanced_chunker::{ChunkStrategy, RecursiveCharSplitter};
25use crate::chunker::Chunk;
26use crate::error::RagError;
27
28#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
34pub enum Language {
35 Rust,
37 Python,
39 Json,
41 #[default]
43 Plain,
44}
45
46impl Language {
47 pub fn from_extension(ext: &str) -> Self {
50 match ext.to_ascii_lowercase().as_str() {
51 "rs" => Self::Rust,
52 "py" => Self::Python,
53 "json" => Self::Json,
54 _ => Self::Plain,
55 }
56 }
57}
58
59pub struct CodeChunker {
65 language: Language,
66 fallback_window: usize,
67 min_chunk_chars: usize,
68}
69
70impl Default for CodeChunker {
71 fn default() -> Self {
72 Self {
73 language: Language::default(),
74 fallback_window: 1024,
75 min_chunk_chars: 16,
76 }
77 }
78}
79
80impl CodeChunker {
81 pub fn new(language: Language) -> Self {
83 Self {
84 language,
85 ..Self::default()
86 }
87 }
88
89 #[must_use]
91 pub fn with_fallback_window(mut self, window: usize) -> Self {
92 self.fallback_window = window.max(64);
93 self
94 }
95
96 #[must_use]
98 pub fn with_min_chunk_chars(mut self, min: usize) -> Self {
99 self.min_chunk_chars = min;
100 self
101 }
102
103 pub fn language(&self) -> Language {
105 self.language
106 }
107
108 pub fn chunk(&self, text: &str, doc_id: usize) -> Result<Vec<Chunk>, RagError> {
110 if text.trim().is_empty() {
111 return Ok(Vec::new());
112 }
113
114 let raw_chunks: Vec<(usize, String)> = match self.language {
115 Language::Rust => split_rust(text),
116 Language::Python => split_python(text),
117 Language::Json => split_json(text).unwrap_or_else(|| self.split_plain(text)),
118 Language::Plain => self.split_plain(text),
119 };
120
121 let mut out = Vec::with_capacity(raw_chunks.len());
122 for (char_offset, body) in raw_chunks {
123 let trimmed = body.trim();
124 if trimmed.chars().count() < self.min_chunk_chars {
125 continue;
126 }
127 let chunk_idx = out.len();
128 out.push(Chunk::new(
129 trimmed.to_string(),
130 doc_id,
131 chunk_idx,
132 char_offset,
133 ));
134 }
135
136 if out.is_empty() {
139 out.push(Chunk::new(text.trim().to_string(), doc_id, 0, 0));
140 }
141
142 Ok(out)
143 }
144
145 fn split_plain(&self, text: &str) -> Vec<(usize, String)> {
146 let splitter = RecursiveCharSplitter::new(self.fallback_window);
147 splitter
148 .chunk(text)
149 .into_iter()
150 .map(|rc| (rc.char_start, rc.text))
151 .collect()
152 }
153}
154
155const RUST_MARKERS: &[&str] = &[
161 "\nfn ",
162 "\npub fn ",
163 "\nimpl ",
164 "\nstruct ",
165 "\nenum ",
166 "\nmod ",
167 "\npub mod ",
168 "\ntrait ",
169 "\npub struct ",
170 "\npub enum ",
171 "\npub trait ",
172];
173
174fn split_rust(text: &str) -> Vec<(usize, String)> {
175 split_by_line_prefixes(text, RUST_MARKERS)
176}
177
178const PYTHON_MARKERS: &[&str] = &["\nclass ", "\ndef ", "\nasync def "];
180
181fn split_python(text: &str) -> Vec<(usize, String)> {
182 split_by_line_prefixes(text, PYTHON_MARKERS)
183}
184
185fn split_by_line_prefixes(text: &str, markers: &[&str]) -> Vec<(usize, String)> {
189 let mut boundaries: Vec<usize> = Vec::new();
191 for marker in markers {
192 let mut start = 0usize;
193 while let Some(idx) = text[start..].find(marker) {
194 let absolute = start + idx + 1; boundaries.push(absolute);
196 start = absolute + marker.len() - 1;
197 }
198 }
199 boundaries.sort_unstable();
200 boundaries.dedup();
201
202 let mut starts = Vec::with_capacity(boundaries.len() + 1);
205 starts.push(0usize);
206 starts.extend(boundaries);
207 starts.dedup();
208
209 let mut out = Vec::with_capacity(starts.len());
210 for i in 0..starts.len() {
211 let begin = starts[i];
212 let end = starts.get(i + 1).copied().unwrap_or(text.len());
213 let body = &text[begin..end];
214 if !body.trim().is_empty() {
215 out.push((begin, body.to_string()));
216 }
217 }
218 out
219}
220
221fn split_json(text: &str) -> Option<Vec<(usize, String)>> {
222 let value: serde_json::Value = serde_json::from_str(text).ok()?;
223 match value {
224 serde_json::Value::Array(items) => {
225 let mut out = Vec::with_capacity(items.len());
226 for (idx, item) in items.into_iter().enumerate() {
227 if let Ok(text) = serde_json::to_string_pretty(&item) {
228 out.push((idx, text));
229 }
230 }
231 Some(out)
232 }
233 serde_json::Value::Object(obj) => {
234 let mut out = Vec::with_capacity(obj.len());
235 for (idx, (key, value)) in obj.into_iter().enumerate() {
236 let body = match serde_json::to_string_pretty(&value) {
237 Ok(s) => s,
238 Err(_) => continue,
239 };
240 out.push((idx, format!("\"{key}\": {body}")));
241 }
242 Some(out)
243 }
244 _ => None,
247 }
248}
249
250#[cfg(test)]
255mod tests {
256 use super::*;
257
258 #[test]
259 fn rust_splits_on_fn_markers() {
260 let source = "\nfn one() {}\nfn two() {}\nfn three() {}\n";
261 let chunker = CodeChunker::new(Language::Rust).with_min_chunk_chars(1);
262 let chunks = chunker.chunk(source, 0).expect("chunk");
263 assert!(chunks.len() >= 3, "got {} chunks", chunks.len());
264 }
265
266 #[test]
267 fn python_splits_on_def_and_class() {
268 let source =
269 "\nclass A:\n pass\n\ndef foo():\n return 1\n\ndef bar():\n return 2\n";
270 let chunker = CodeChunker::new(Language::Python).with_min_chunk_chars(1);
271 let chunks = chunker.chunk(source, 0).expect("chunk");
272 assert!(chunks.len() >= 3, "got {} chunks", chunks.len());
273 }
274
275 #[test]
276 fn json_array_splits_by_element() {
277 let source = "[1, 2, 3, 4]";
278 let chunker = CodeChunker::new(Language::Json).with_min_chunk_chars(1);
279 let chunks = chunker.chunk(source, 0).expect("chunk");
280 assert_eq!(chunks.len(), 4);
281 }
282
283 #[test]
284 fn plain_delegates_to_splitter() {
285 let text = "a".repeat(4096);
286 let chunker = CodeChunker::new(Language::Plain)
287 .with_fallback_window(512)
288 .with_min_chunk_chars(1);
289 let chunks = chunker.chunk(&text, 0).expect("chunk");
290 assert!(chunks.len() > 1);
291 }
292}