1use crate::chunking::traits::{ChunkMetadata, Chunker};
7use crate::chunking::{DEFAULT_CHUNK_SIZE, DEFAULT_OVERLAP};
8use crate::core::Chunk;
9use crate::error::Result;
10use regex::Regex;
11use std::ops::Range;
12use std::sync::OnceLock;
13
14#[derive(Debug, Clone)]
52pub struct CodeChunker {
53 chunk_size: usize,
55 overlap: usize,
57}
58
59impl Default for CodeChunker {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl CodeChunker {
66 #[must_use]
68 pub const fn new() -> Self {
69 Self {
70 chunk_size: DEFAULT_CHUNK_SIZE,
71 overlap: DEFAULT_OVERLAP,
72 }
73 }
74
75 #[must_use]
77 pub const fn with_size(chunk_size: usize) -> Self {
78 Self {
79 chunk_size,
80 overlap: DEFAULT_OVERLAP,
81 }
82 }
83
84 #[must_use]
86 pub const fn with_size_and_overlap(chunk_size: usize, overlap: usize) -> Self {
87 Self {
88 chunk_size,
89 overlap,
90 }
91 }
92
93 fn detect_language(metadata: Option<&ChunkMetadata>) -> Language {
95 let ext = metadata
96 .and_then(|m| {
97 m.content_type
98 .as_deref()
99 .or_else(|| m.source.as_deref().and_then(|s| s.rsplit('.').next()))
100 })
101 .unwrap_or("");
102
103 Language::from_extension(ext)
104 }
105
106 #[allow(clippy::unused_self)]
108 fn find_boundaries(&self, text: &str, lang: Language) -> Vec<usize> {
109 let patterns = lang.boundary_patterns();
110 let mut boundaries = Vec::new();
111
112 for pattern in patterns {
113 let re = pattern.regex();
114 for m in re.find_iter(text) {
115 let line_start = text[..m.start()].rfind('\n').map_or(0, |pos| pos + 1);
117 if !boundaries.contains(&line_start) {
118 boundaries.push(line_start);
119 }
120 }
121 }
122
123 boundaries.sort_unstable();
124 boundaries
125 }
126
127 fn chunk_at_boundaries(
129 &self,
130 buffer_id: i64,
131 text: &str,
132 boundaries: &[usize],
133 chunk_size: usize,
134 overlap: usize,
135 ) -> Vec<Chunk> {
136 let mut chunks = Vec::new();
137 let mut chunk_start = 0;
138 let mut chunk_index = 0;
139
140 while chunk_start < text.len() {
141 let ideal_end = (chunk_start + chunk_size).min(text.len());
143
144 let chunk_end = self.find_best_boundary(text, chunk_start, ideal_end, boundaries);
146
147 let content = &text[chunk_start..chunk_end];
149
150 if !content.trim().is_empty() {
151 chunks.push(Chunk::new(
152 buffer_id,
153 content.to_string(),
154 Range {
155 start: chunk_start,
156 end: chunk_end,
157 },
158 chunk_index,
159 ));
160 chunk_index += 1;
161 }
162
163 if chunk_end >= text.len() {
165 break;
166 }
167
168 let next_start = if overlap > 0 {
170 self.find_overlap_start(text, chunk_end, overlap, boundaries)
171 } else {
172 chunk_end
173 };
174
175 chunk_start = next_start;
176 }
177
178 chunks
179 }
180
181 fn find_best_boundary(
183 &self,
184 text: &str,
185 start: usize,
186 ideal_end: usize,
187 boundaries: &[usize],
188 ) -> usize {
189 if ideal_end >= text.len() {
191 return text.len();
192 }
193
194 let search_start = start + (ideal_end - start) / 2; let search_end = (ideal_end + self.chunk_size / 4).min(text.len());
197
198 let candidates: Vec<usize> = boundaries
200 .iter()
201 .copied()
202 .filter(|&b| b > search_start && b <= search_end)
203 .collect();
204
205 #[allow(clippy::cast_possible_wrap)]
207 if let Some(&boundary) = candidates
208 .iter()
209 .min_by_key(|&&b| (b as i64 - ideal_end as i64).abs())
210 {
211 return boundary;
212 }
213
214 if let Some(newline) = text[search_start..ideal_end].rfind('\n') {
216 return search_start + newline + 1;
217 }
218
219 ideal_end
220 }
221
222 #[allow(clippy::unused_self)]
224 fn find_overlap_start(
225 &self,
226 text: &str,
227 current_end: usize,
228 overlap: usize,
229 boundaries: &[usize],
230 ) -> usize {
231 let target = current_end.saturating_sub(overlap);
232
233 if let Some(&boundary) = boundaries
235 .iter()
236 .rev()
237 .find(|&&b| b <= target && b < current_end)
238 {
239 return boundary;
240 }
241
242 if let Some(newline) = text[..target.min(text.len())].rfind('\n') {
244 return newline + 1;
245 }
246
247 target.min(current_end)
248 }
249}
250
251impl Chunker for CodeChunker {
252 fn chunk(
253 &self,
254 buffer_id: i64,
255 text: &str,
256 metadata: Option<&ChunkMetadata>,
257 ) -> Result<Vec<Chunk>> {
258 self.validate(metadata)?;
259
260 if text.is_empty() {
261 return Ok(vec![]);
262 }
263
264 let chunk_size = metadata.map_or(self.chunk_size, |m| {
265 if m.chunk_size > 0 {
266 m.chunk_size
267 } else {
268 self.chunk_size
269 }
270 });
271 let overlap = metadata.map_or(self.overlap, |m| m.overlap);
272
273 let lang = Self::detect_language(metadata);
275
276 let boundaries = self.find_boundaries(text, lang);
278
279 Ok(self.chunk_at_boundaries(buffer_id, text, &boundaries, chunk_size, overlap))
281 }
282
283 fn name(&self) -> &'static str {
284 "code"
285 }
286
287 fn description(&self) -> &'static str {
288 "Code-aware chunking at function/class boundaries"
289 }
290}
291
292#[derive(Debug, Clone, Copy, PartialEq, Eq)]
294enum Language {
295 Rust,
296 Python,
297 JavaScript,
298 TypeScript,
299 Go,
300 Java,
301 C,
302 Cpp,
303 Ruby,
304 Php,
305 Unknown,
306}
307
308impl Language {
309 fn from_extension(ext: &str) -> Self {
311 match ext.to_lowercase().as_str() {
312 "rs" => Self::Rust,
313 "py" | "pyw" | "pyi" => Self::Python,
314 "js" | "mjs" | "cjs" | "jsx" => Self::JavaScript,
315 "ts" | "tsx" | "mts" | "cts" => Self::TypeScript,
316 "go" => Self::Go,
317 "java" => Self::Java,
318 "c" | "h" => Self::C,
319 "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "hh" => Self::Cpp,
320 "rb" | "rake" | "gemspec" => Self::Ruby,
321 "php" | "phtml" => Self::Php,
322 _ => Self::Unknown,
323 }
324 }
325
326 fn boundary_patterns(self) -> Vec<BoundaryPattern> {
328 match self {
329 Self::Rust => vec![
330 BoundaryPattern::RustFn,
331 BoundaryPattern::RustImpl,
332 BoundaryPattern::RustStruct,
333 BoundaryPattern::RustEnum,
334 BoundaryPattern::RustTrait,
335 BoundaryPattern::RustMod,
336 ],
337 Self::Python => vec![
338 BoundaryPattern::PythonDef,
339 BoundaryPattern::PythonClass,
340 BoundaryPattern::PythonAsync,
341 ],
342 Self::JavaScript | Self::TypeScript => vec![
343 BoundaryPattern::JsFunction,
344 BoundaryPattern::JsClass,
345 BoundaryPattern::JsArrowNamed,
346 BoundaryPattern::JsMethod,
347 ],
348 Self::Go => vec![BoundaryPattern::GoFunc, BoundaryPattern::GoType],
349 Self::Java => vec![
350 BoundaryPattern::JavaClass,
351 BoundaryPattern::JavaMethod,
352 BoundaryPattern::JavaInterface,
353 ],
354 Self::C | Self::Cpp => vec![
355 BoundaryPattern::CFunction,
356 BoundaryPattern::CppClass,
357 BoundaryPattern::CppNamespace,
358 ],
359 Self::Ruby => vec![
360 BoundaryPattern::RubyDef,
361 BoundaryPattern::RubyClass,
362 BoundaryPattern::RubyModule,
363 ],
364 Self::Php => vec![BoundaryPattern::PhpFunction, BoundaryPattern::PhpClass],
365 Self::Unknown => vec![BoundaryPattern::GenericFunction],
366 }
367 }
368}
369
370#[derive(Debug, Clone, Copy)]
372enum BoundaryPattern {
373 RustFn,
375 RustImpl,
376 RustStruct,
377 RustEnum,
378 RustTrait,
379 RustMod,
380
381 PythonDef,
383 PythonClass,
384 PythonAsync,
385
386 JsFunction,
388 JsClass,
389 JsArrowNamed,
390 JsMethod,
391
392 GoFunc,
394 GoType,
395
396 JavaClass,
398 JavaMethod,
399 JavaInterface,
400
401 CFunction,
403 CppClass,
404 CppNamespace,
405
406 RubyDef,
408 RubyClass,
409 RubyModule,
410
411 PhpFunction,
413 PhpClass,
414
415 GenericFunction,
417}
418
419impl BoundaryPattern {
420 fn regex(self) -> &'static Regex {
422 macro_rules! static_regex {
423 ($name:ident, $pattern:expr) => {{
424 static $name: OnceLock<Regex> = OnceLock::new();
425 $name.get_or_init(|| Regex::new($pattern).expect("valid regex"))
426 }};
427 }
428
429 match self {
430 Self::RustFn => static_regex!(
432 RUST_FN,
433 r"(?m)^[ \t]*(pub(\s*\([^)]*\))?\s+)?(async\s+)?(unsafe\s+)?(extern\s+\S+\s+)?fn\s+\w+"
434 ),
435 Self::RustImpl => static_regex!(RUST_IMPL, r"(?m)^[ \t]*(unsafe\s+)?impl(<[^>]*>)?\s+"),
436 Self::RustStruct => static_regex!(
437 RUST_STRUCT,
438 r"(?m)^[ \t]*(pub(\s*\([^)]*\))?\s+)?struct\s+\w+"
439 ),
440 Self::RustEnum => {
441 static_regex!(RUST_ENUM, r"(?m)^[ \t]*(pub(\s*\([^)]*\))?\s+)?enum\s+\w+")
442 }
443 Self::RustTrait => static_regex!(
444 RUST_TRAIT,
445 r"(?m)^[ \t]*(pub(\s*\([^)]*\))?\s+)?(unsafe\s+)?trait\s+\w+"
446 ),
447 Self::RustMod => {
448 static_regex!(RUST_MOD, r"(?m)^[ \t]*(pub(\s*\([^)]*\))?\s+)?mod\s+\w+")
449 }
450
451 Self::PythonDef => static_regex!(PYTHON_DEF, r"(?m)^[ \t]*def\s+\w+"),
453 Self::PythonClass => static_regex!(PYTHON_CLASS, r"(?m)^[ \t]*class\s+\w+"),
454 Self::PythonAsync => static_regex!(PYTHON_ASYNC, r"(?m)^[ \t]*async\s+def\s+\w+"),
455
456 Self::JsFunction => static_regex!(
458 JS_FUNCTION,
459 r"(?m)^[ \t]*(export\s+)?(async\s+)?function\s*\*?\s*\w+"
460 ),
461 Self::JsClass => static_regex!(
462 JS_CLASS,
463 r"(?m)^[ \t]*(export\s+)?(abstract\s+)?class\s+\w+"
464 ),
465 Self::JsArrowNamed => static_regex!(
466 JS_ARROW,
467 r"(?m)^[ \t]*(export\s+)?(const|let|var)\s+\w+\s*=\s*(async\s+)?\([^)]*\)\s*=>"
468 ),
469 Self::JsMethod => static_regex!(
470 JS_METHOD,
471 r"(?m)^[ \t]*(static\s+)?(async\s+)?(get\s+|set\s+)?\w+\s*\([^)]*\)\s*\{"
472 ),
473
474 Self::GoFunc => static_regex!(GO_FUNC, r"(?m)^func\s+(\([^)]+\)\s*)?\w+"),
476 Self::GoType => static_regex!(GO_TYPE, r"(?m)^type\s+\w+\s+(struct|interface)"),
477
478 Self::JavaClass => static_regex!(
480 JAVA_CLASS,
481 r"(?m)^[ \t]*(public|private|protected)?\s*(abstract\s+)?(final\s+)?class\s+\w+"
482 ),
483 Self::JavaMethod => static_regex!(
484 JAVA_METHOD,
485 r"(?m)^[ \t]*(public|private|protected)\s+(static\s+)?(\w+\s+)+\w+\s*\([^)]*\)\s*(\{|throws)"
486 ),
487 Self::JavaInterface => {
488 static_regex!(JAVA_INTERFACE, r"(?m)^[ \t]*(public\s+)?interface\s+\w+")
489 }
490
491 Self::CFunction => static_regex!(
493 C_FUNCTION,
494 r"(?m)^[ \t]*(\w+\s+)+\**\s*\w+\s*\([^)]*\)\s*\{"
495 ),
496 Self::CppClass => static_regex!(
497 CPP_CLASS,
498 r"(?m)^[ \t]*(template\s*<[^>]*>\s*)?(class|struct)\s+\w+"
499 ),
500 Self::CppNamespace => static_regex!(CPP_NAMESPACE, r"(?m)^[ \t]*namespace\s+\w+"),
501
502 Self::RubyDef => static_regex!(RUBY_DEF, r"(?m)^[ \t]*def\s+\w+"),
504 Self::RubyClass => static_regex!(RUBY_CLASS, r"(?m)^[ \t]*class\s+\w+"),
505 Self::RubyModule => static_regex!(RUBY_MODULE, r"(?m)^[ \t]*module\s+\w+"),
506
507 Self::PhpFunction => static_regex!(
509 PHP_FUNCTION,
510 r"(?m)^[ \t]*(public|private|protected)?\s*(static\s+)?function\s+\w+"
511 ),
512 Self::PhpClass => {
513 static_regex!(PHP_CLASS, r"(?m)^[ \t]*(abstract\s+|final\s+)?class\s+\w+")
514 }
515
516 Self::GenericFunction => static_regex!(
518 GENERIC_FUNCTION,
519 r"(?m)^[ \t]*(function|def|fn|func|sub|proc)\s+\w+"
520 ),
521 }
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528
529 #[test]
530 fn test_code_chunker_new() {
531 let chunker = CodeChunker::new();
532 assert_eq!(chunker.name(), "code");
533 assert_eq!(chunker.chunk_size, DEFAULT_CHUNK_SIZE);
534 }
535
536 #[test]
537 fn test_code_chunker_with_size() {
538 let chunker = CodeChunker::with_size(1000);
539 assert_eq!(chunker.chunk_size, 1000);
540 assert_eq!(chunker.overlap, DEFAULT_OVERLAP);
541 }
542
543 #[test]
544 fn test_detect_language_rust() {
545 let meta = ChunkMetadata::new().content_type("rs");
546 let lang = CodeChunker::detect_language(Some(&meta));
547 assert_eq!(lang, Language::Rust);
548 }
549
550 #[test]
551 fn test_detect_language_from_source() {
552 let meta = ChunkMetadata::new().source("src/main.py");
553 let lang = CodeChunker::detect_language(Some(&meta));
554 assert_eq!(lang, Language::Python);
555 }
556
557 #[test]
558 fn test_detect_language_unknown() {
559 let meta = ChunkMetadata::new().content_type("xyz");
560 let lang = CodeChunker::detect_language(Some(&meta));
561 assert_eq!(lang, Language::Unknown);
562 }
563
564 #[test]
565 fn test_chunk_rust_code() {
566 let chunker = CodeChunker::with_size(200);
567 let code = r#"
568fn main() {
569 println!("Hello");
570}
571
572fn helper() {
573 println!("Helper");
574}
575
576pub fn public_fn() {
577 println!("Public");
578}
579"#;
580
581 let meta = ChunkMetadata::with_size(200).content_type("rs");
582 let chunks = chunker.chunk(1, code, Some(&meta)).unwrap();
583
584 assert!(!chunks.is_empty());
585 for chunk in &chunks {
587 assert!(!chunk.content.trim().is_empty());
588 }
589 }
590
591 #[test]
592 fn test_chunk_python_code() {
593 let chunker = CodeChunker::with_size(150);
594 let code = r#"
595def main():
596 print("Hello")
597
598class MyClass:
599 def method(self):
600 pass
601
602async def async_func():
603 await something()
604"#;
605
606 let meta = ChunkMetadata::with_size(150).content_type("py");
607 let chunks = chunker.chunk(1, code, Some(&meta)).unwrap();
608
609 assert!(!chunks.is_empty());
610 }
611
612 #[test]
613 fn test_chunk_javascript_code() {
614 let chunker = CodeChunker::with_size(200);
615 let code = r#"
616function greet(name) {
617 console.log("Hello " + name);
618}
619
620class Person {
621 constructor(name) {
622 this.name = name;
623 }
624}
625
626const arrow = (x) => x * 2;
627
628export async function fetchData() {
629 return await fetch("/api");
630}
631"#;
632
633 let meta = ChunkMetadata::with_size(200).content_type("js");
634 let chunks = chunker.chunk(1, code, Some(&meta)).unwrap();
635
636 assert!(!chunks.is_empty());
637 }
638
639 #[test]
640 fn test_chunk_empty_text() {
641 let chunker = CodeChunker::new();
642 let chunks = chunker.chunk(1, "", None).unwrap();
643 assert!(chunks.is_empty());
644 }
645
646 #[test]
647 fn test_chunk_unknown_language() {
648 let chunker = CodeChunker::with_size(100);
649 let code = "some random text without code structure";
650
651 let chunks = chunker.chunk(1, code, None).unwrap();
652 assert!(!chunks.is_empty());
653 }
654
655 #[test]
656 fn test_boundary_patterns_rust() {
657 let patterns = Language::Rust.boundary_patterns();
658 assert!(!patterns.is_empty());
659
660 let code = "pub fn my_function() {}";
661 let re = BoundaryPattern::RustFn.regex();
662 assert!(re.is_match(code));
663 }
664
665 #[test]
666 fn test_boundary_patterns_python() {
667 let code = "def my_function():";
668 let re = BoundaryPattern::PythonDef.regex();
669 assert!(re.is_match(code));
670
671 let code = "class MyClass:";
672 let re = BoundaryPattern::PythonClass.regex();
673 assert!(re.is_match(code));
674 }
675
676 #[test]
677 fn test_language_extensions() {
678 assert_eq!(Language::from_extension("rs"), Language::Rust);
679 assert_eq!(Language::from_extension("py"), Language::Python);
680 assert_eq!(Language::from_extension("js"), Language::JavaScript);
681 assert_eq!(Language::from_extension("ts"), Language::TypeScript);
682 assert_eq!(Language::from_extension("go"), Language::Go);
683 assert_eq!(Language::from_extension("java"), Language::Java);
684 assert_eq!(Language::from_extension("c"), Language::C);
685 assert_eq!(Language::from_extension("cpp"), Language::Cpp);
686 assert_eq!(Language::from_extension("rb"), Language::Ruby);
687 assert_eq!(Language::from_extension("php"), Language::Php);
688 assert_eq!(Language::from_extension("unknown"), Language::Unknown);
689 }
690
691 #[test]
692 fn test_chunker_description() {
693 let chunker = CodeChunker::new();
694 assert!(!chunker.description().is_empty());
695 }
696}