Skip to main content

rlm_rs/chunking/
code.rs

1//! Code-aware chunking strategy.
2//!
3//! Chunks source code at natural boundaries (functions, classes, methods)
4//! using regex-based pattern matching for multiple languages.
5
6use crate::chunking::traits::{ChunkMetadata, Chunker};
7use crate::chunking::{DEFAULT_CHUNK_SIZE, DEFAULT_OVERLAP};
8use crate::core::Chunk;
9use crate::error::Result;
10use regex::Regex;
11use std::ops::Range;
12use std::sync::OnceLock;
13
14/// Code-aware chunker that splits at function/class boundaries.
15///
16/// Supports multiple programming languages and falls back to
17/// line-based chunking for unknown languages.
18///
19/// # Supported Languages
20///
21/// - Rust (.rs)
22/// - Python (.py)
23/// - JavaScript (.js, .jsx)
24/// - TypeScript (.ts, .tsx)
25/// - Go (.go)
26/// - Java (.java)
27/// - C/C++ (.c, .cpp, .h, .hpp)
28/// - Ruby (.rb)
29/// - PHP (.php)
30///
31/// # Examples
32///
33/// ```
34/// use rlm_rs::chunking::{Chunker, CodeChunker, ChunkerMetadata};
35///
36/// let chunker = CodeChunker::new();
37/// let code = r#"
38/// fn main() {
39///     println!("Hello");
40/// }
41///
42/// fn helper() {
43///     println!("Helper");
44/// }
45/// "#;
46///
47/// let meta = ChunkerMetadata::new().content_type("rs");
48/// let chunks = chunker.chunk(1, code, Some(&meta)).unwrap();
49/// assert!(!chunks.is_empty());
50/// ```
51#[derive(Debug, Clone)]
52pub struct CodeChunker {
53    /// Target chunk size in characters.
54    chunk_size: usize,
55    /// Overlap between consecutive chunks.
56    overlap: usize,
57}
58
59impl Default for CodeChunker {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl CodeChunker {
66    /// Creates a new code chunker with default settings.
67    #[must_use]
68    pub const fn new() -> Self {
69        Self {
70            chunk_size: DEFAULT_CHUNK_SIZE,
71            overlap: DEFAULT_OVERLAP,
72        }
73    }
74
75    /// Creates a code chunker with custom chunk size.
76    #[must_use]
77    pub const fn with_size(chunk_size: usize) -> Self {
78        Self {
79            chunk_size,
80            overlap: DEFAULT_OVERLAP,
81        }
82    }
83
84    /// Creates a code chunker with custom size and overlap.
85    #[must_use]
86    pub const fn with_size_and_overlap(chunk_size: usize, overlap: usize) -> Self {
87        Self {
88            chunk_size,
89            overlap,
90        }
91    }
92
93    /// Detects language from file extension or content type.
94    fn detect_language(metadata: Option<&ChunkMetadata>) -> Language {
95        let ext = metadata
96            .and_then(|m| {
97                m.content_type
98                    .as_deref()
99                    .or_else(|| m.source.as_deref().and_then(|s| s.rsplit('.').next()))
100            })
101            .unwrap_or("");
102
103        Language::from_extension(ext)
104    }
105
106    /// Finds code structure boundaries in the text.
107    #[allow(clippy::unused_self)]
108    fn find_boundaries(&self, text: &str, lang: Language) -> Vec<usize> {
109        let patterns = lang.boundary_patterns();
110        let mut boundaries = Vec::new();
111
112        for pattern in patterns {
113            let re = pattern.regex();
114            for m in re.find_iter(text) {
115                // Find the start of the line containing this match
116                let line_start = text[..m.start()].rfind('\n').map_or(0, |pos| pos + 1);
117                if !boundaries.contains(&line_start) {
118                    boundaries.push(line_start);
119                }
120            }
121        }
122
123        boundaries.sort_unstable();
124        boundaries
125    }
126
127    /// Chunks text at code boundaries.
128    fn chunk_at_boundaries(
129        &self,
130        buffer_id: i64,
131        text: &str,
132        boundaries: &[usize],
133        chunk_size: usize,
134        overlap: usize,
135    ) -> Vec<Chunk> {
136        let mut chunks = Vec::new();
137        let mut chunk_start = 0;
138        let mut chunk_index = 0;
139
140        while chunk_start < text.len() {
141            // Find the end of this chunk
142            let ideal_end = (chunk_start + chunk_size).min(text.len());
143
144            // Try to find a boundary near the ideal end
145            let chunk_end = self.find_best_boundary(text, chunk_start, ideal_end, boundaries);
146
147            // Extract content
148            let content = &text[chunk_start..chunk_end];
149
150            if !content.trim().is_empty() {
151                chunks.push(Chunk::new(
152                    buffer_id,
153                    content.to_string(),
154                    Range {
155                        start: chunk_start,
156                        end: chunk_end,
157                    },
158                    chunk_index,
159                ));
160                chunk_index += 1;
161            }
162
163            // Move to next chunk with overlap
164            if chunk_end >= text.len() {
165                break;
166            }
167
168            // Calculate next start with overlap
169            let next_start = if overlap > 0 {
170                self.find_overlap_start(text, chunk_end, overlap, boundaries)
171            } else {
172                chunk_end
173            };
174
175            chunk_start = next_start;
176        }
177
178        chunks
179    }
180
181    /// Finds the best boundary near the ideal end position.
182    fn find_best_boundary(
183        &self,
184        text: &str,
185        start: usize,
186        ideal_end: usize,
187        boundaries: &[usize],
188    ) -> usize {
189        // If we're at the end of text, use that
190        if ideal_end >= text.len() {
191            return text.len();
192        }
193
194        // Look for a code boundary near the ideal end
195        let search_start = start + (ideal_end - start) / 2; // Start from halfway
196        let search_end = (ideal_end + self.chunk_size / 4).min(text.len());
197
198        // Find boundaries in the search range
199        let candidates: Vec<usize> = boundaries
200            .iter()
201            .copied()
202            .filter(|&b| b > search_start && b <= search_end)
203            .collect();
204
205        // Prefer a boundary closer to ideal_end
206        #[allow(clippy::cast_possible_wrap)]
207        if let Some(&boundary) = candidates
208            .iter()
209            .min_by_key(|&&b| (b as i64 - ideal_end as i64).abs())
210        {
211            return boundary;
212        }
213
214        // Fall back to line boundary
215        if let Some(newline) = text[search_start..ideal_end].rfind('\n') {
216            return search_start + newline + 1;
217        }
218
219        ideal_end
220    }
221
222    /// Finds the start position for overlap.
223    #[allow(clippy::unused_self)]
224    fn find_overlap_start(
225        &self,
226        text: &str,
227        current_end: usize,
228        overlap: usize,
229        boundaries: &[usize],
230    ) -> usize {
231        let target = current_end.saturating_sub(overlap);
232
233        // Try to find a boundary at or before the target
234        if let Some(&boundary) = boundaries
235            .iter()
236            .rev()
237            .find(|&&b| b <= target && b < current_end)
238        {
239            return boundary;
240        }
241
242        // Fall back to line boundary
243        if let Some(newline) = text[..target.min(text.len())].rfind('\n') {
244            return newline + 1;
245        }
246
247        target.min(current_end)
248    }
249}
250
251impl Chunker for CodeChunker {
252    fn chunk(
253        &self,
254        buffer_id: i64,
255        text: &str,
256        metadata: Option<&ChunkMetadata>,
257    ) -> Result<Vec<Chunk>> {
258        self.validate(metadata)?;
259
260        if text.is_empty() {
261            return Ok(vec![]);
262        }
263
264        let chunk_size = metadata.map_or(self.chunk_size, |m| {
265            if m.chunk_size > 0 {
266                m.chunk_size
267            } else {
268                self.chunk_size
269            }
270        });
271        let overlap = metadata.map_or(self.overlap, |m| m.overlap);
272
273        // Detect language
274        let lang = Self::detect_language(metadata);
275
276        // Find code structure boundaries
277        let boundaries = self.find_boundaries(text, lang);
278
279        // Chunk at boundaries
280        Ok(self.chunk_at_boundaries(buffer_id, text, &boundaries, chunk_size, overlap))
281    }
282
283    fn name(&self) -> &'static str {
284        "code"
285    }
286
287    fn description(&self) -> &'static str {
288        "Code-aware chunking at function/class boundaries"
289    }
290}
291
292/// Supported programming languages.
293#[derive(Debug, Clone, Copy, PartialEq, Eq)]
294enum Language {
295    Rust,
296    Python,
297    JavaScript,
298    TypeScript,
299    Go,
300    Java,
301    C,
302    Cpp,
303    Ruby,
304    Php,
305    Unknown,
306}
307
308impl Language {
309    /// Detects language from file extension.
310    fn from_extension(ext: &str) -> Self {
311        match ext.to_lowercase().as_str() {
312            "rs" => Self::Rust,
313            "py" | "pyw" | "pyi" => Self::Python,
314            "js" | "mjs" | "cjs" | "jsx" => Self::JavaScript,
315            "ts" | "tsx" | "mts" | "cts" => Self::TypeScript,
316            "go" => Self::Go,
317            "java" => Self::Java,
318            "c" | "h" => Self::C,
319            "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "hh" => Self::Cpp,
320            "rb" | "rake" | "gemspec" => Self::Ruby,
321            "php" | "phtml" => Self::Php,
322            _ => Self::Unknown,
323        }
324    }
325
326    /// Returns regex patterns for detecting code boundaries.
327    fn boundary_patterns(self) -> Vec<BoundaryPattern> {
328        match self {
329            Self::Rust => vec![
330                BoundaryPattern::RustFn,
331                BoundaryPattern::RustImpl,
332                BoundaryPattern::RustStruct,
333                BoundaryPattern::RustEnum,
334                BoundaryPattern::RustTrait,
335                BoundaryPattern::RustMod,
336            ],
337            Self::Python => vec![
338                BoundaryPattern::PythonDef,
339                BoundaryPattern::PythonClass,
340                BoundaryPattern::PythonAsync,
341            ],
342            Self::JavaScript | Self::TypeScript => vec![
343                BoundaryPattern::JsFunction,
344                BoundaryPattern::JsClass,
345                BoundaryPattern::JsArrowNamed,
346                BoundaryPattern::JsMethod,
347            ],
348            Self::Go => vec![BoundaryPattern::GoFunc, BoundaryPattern::GoType],
349            Self::Java => vec![
350                BoundaryPattern::JavaClass,
351                BoundaryPattern::JavaMethod,
352                BoundaryPattern::JavaInterface,
353            ],
354            Self::C | Self::Cpp => vec![
355                BoundaryPattern::CFunction,
356                BoundaryPattern::CppClass,
357                BoundaryPattern::CppNamespace,
358            ],
359            Self::Ruby => vec![
360                BoundaryPattern::RubyDef,
361                BoundaryPattern::RubyClass,
362                BoundaryPattern::RubyModule,
363            ],
364            Self::Php => vec![BoundaryPattern::PhpFunction, BoundaryPattern::PhpClass],
365            Self::Unknown => vec![BoundaryPattern::GenericFunction],
366        }
367    }
368}
369
370/// Patterns for detecting code structure boundaries.
371#[derive(Debug, Clone, Copy)]
372enum BoundaryPattern {
373    // Rust patterns
374    RustFn,
375    RustImpl,
376    RustStruct,
377    RustEnum,
378    RustTrait,
379    RustMod,
380
381    // Python patterns
382    PythonDef,
383    PythonClass,
384    PythonAsync,
385
386    // JavaScript/TypeScript patterns
387    JsFunction,
388    JsClass,
389    JsArrowNamed,
390    JsMethod,
391
392    // Go patterns
393    GoFunc,
394    GoType,
395
396    // Java patterns
397    JavaClass,
398    JavaMethod,
399    JavaInterface,
400
401    // C/C++ patterns
402    CFunction,
403    CppClass,
404    CppNamespace,
405
406    // Ruby patterns
407    RubyDef,
408    RubyClass,
409    RubyModule,
410
411    // PHP patterns
412    PhpFunction,
413    PhpClass,
414
415    // Generic fallback
416    GenericFunction,
417}
418
419impl BoundaryPattern {
420    /// Returns the compiled regex for this pattern.
421    fn regex(self) -> &'static Regex {
422        macro_rules! static_regex {
423            ($name:ident, $pattern:expr) => {{
424                static $name: OnceLock<Regex> = OnceLock::new();
425                $name.get_or_init(|| Regex::new($pattern).expect("valid regex"))
426            }};
427        }
428
429        match self {
430            // Rust
431            Self::RustFn => static_regex!(
432                RUST_FN,
433                r"(?m)^[ \t]*(pub(\s*\([^)]*\))?\s+)?(async\s+)?(unsafe\s+)?(extern\s+\S+\s+)?fn\s+\w+"
434            ),
435            Self::RustImpl => static_regex!(RUST_IMPL, r"(?m)^[ \t]*(unsafe\s+)?impl(<[^>]*>)?\s+"),
436            Self::RustStruct => static_regex!(
437                RUST_STRUCT,
438                r"(?m)^[ \t]*(pub(\s*\([^)]*\))?\s+)?struct\s+\w+"
439            ),
440            Self::RustEnum => {
441                static_regex!(RUST_ENUM, r"(?m)^[ \t]*(pub(\s*\([^)]*\))?\s+)?enum\s+\w+")
442            }
443            Self::RustTrait => static_regex!(
444                RUST_TRAIT,
445                r"(?m)^[ \t]*(pub(\s*\([^)]*\))?\s+)?(unsafe\s+)?trait\s+\w+"
446            ),
447            Self::RustMod => {
448                static_regex!(RUST_MOD, r"(?m)^[ \t]*(pub(\s*\([^)]*\))?\s+)?mod\s+\w+")
449            }
450
451            // Python
452            Self::PythonDef => static_regex!(PYTHON_DEF, r"(?m)^[ \t]*def\s+\w+"),
453            Self::PythonClass => static_regex!(PYTHON_CLASS, r"(?m)^[ \t]*class\s+\w+"),
454            Self::PythonAsync => static_regex!(PYTHON_ASYNC, r"(?m)^[ \t]*async\s+def\s+\w+"),
455
456            // JavaScript/TypeScript
457            Self::JsFunction => static_regex!(
458                JS_FUNCTION,
459                r"(?m)^[ \t]*(export\s+)?(async\s+)?function\s*\*?\s*\w+"
460            ),
461            Self::JsClass => static_regex!(
462                JS_CLASS,
463                r"(?m)^[ \t]*(export\s+)?(abstract\s+)?class\s+\w+"
464            ),
465            Self::JsArrowNamed => static_regex!(
466                JS_ARROW,
467                r"(?m)^[ \t]*(export\s+)?(const|let|var)\s+\w+\s*=\s*(async\s+)?\([^)]*\)\s*=>"
468            ),
469            Self::JsMethod => static_regex!(
470                JS_METHOD,
471                r"(?m)^[ \t]*(static\s+)?(async\s+)?(get\s+|set\s+)?\w+\s*\([^)]*\)\s*\{"
472            ),
473
474            // Go
475            Self::GoFunc => static_regex!(GO_FUNC, r"(?m)^func\s+(\([^)]+\)\s*)?\w+"),
476            Self::GoType => static_regex!(GO_TYPE, r"(?m)^type\s+\w+\s+(struct|interface)"),
477
478            // Java
479            Self::JavaClass => static_regex!(
480                JAVA_CLASS,
481                r"(?m)^[ \t]*(public|private|protected)?\s*(abstract\s+)?(final\s+)?class\s+\w+"
482            ),
483            Self::JavaMethod => static_regex!(
484                JAVA_METHOD,
485                r"(?m)^[ \t]*(public|private|protected)\s+(static\s+)?(\w+\s+)+\w+\s*\([^)]*\)\s*(\{|throws)"
486            ),
487            Self::JavaInterface => {
488                static_regex!(JAVA_INTERFACE, r"(?m)^[ \t]*(public\s+)?interface\s+\w+")
489            }
490
491            // C/C++
492            Self::CFunction => static_regex!(
493                C_FUNCTION,
494                r"(?m)^[ \t]*(\w+\s+)+\**\s*\w+\s*\([^)]*\)\s*\{"
495            ),
496            Self::CppClass => static_regex!(
497                CPP_CLASS,
498                r"(?m)^[ \t]*(template\s*<[^>]*>\s*)?(class|struct)\s+\w+"
499            ),
500            Self::CppNamespace => static_regex!(CPP_NAMESPACE, r"(?m)^[ \t]*namespace\s+\w+"),
501
502            // Ruby
503            Self::RubyDef => static_regex!(RUBY_DEF, r"(?m)^[ \t]*def\s+\w+"),
504            Self::RubyClass => static_regex!(RUBY_CLASS, r"(?m)^[ \t]*class\s+\w+"),
505            Self::RubyModule => static_regex!(RUBY_MODULE, r"(?m)^[ \t]*module\s+\w+"),
506
507            // PHP
508            Self::PhpFunction => static_regex!(
509                PHP_FUNCTION,
510                r"(?m)^[ \t]*(public|private|protected)?\s*(static\s+)?function\s+\w+"
511            ),
512            Self::PhpClass => {
513                static_regex!(PHP_CLASS, r"(?m)^[ \t]*(abstract\s+|final\s+)?class\s+\w+")
514            }
515
516            // Generic
517            Self::GenericFunction => static_regex!(
518                GENERIC_FUNCTION,
519                r"(?m)^[ \t]*(function|def|fn|func|sub|proc)\s+\w+"
520            ),
521        }
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528
529    #[test]
530    fn test_code_chunker_new() {
531        let chunker = CodeChunker::new();
532        assert_eq!(chunker.name(), "code");
533        assert_eq!(chunker.chunk_size, DEFAULT_CHUNK_SIZE);
534    }
535
536    #[test]
537    fn test_code_chunker_with_size() {
538        let chunker = CodeChunker::with_size(1000);
539        assert_eq!(chunker.chunk_size, 1000);
540        assert_eq!(chunker.overlap, DEFAULT_OVERLAP);
541    }
542
543    #[test]
544    fn test_detect_language_rust() {
545        let meta = ChunkMetadata::new().content_type("rs");
546        let lang = CodeChunker::detect_language(Some(&meta));
547        assert_eq!(lang, Language::Rust);
548    }
549
550    #[test]
551    fn test_detect_language_from_source() {
552        let meta = ChunkMetadata::new().source("src/main.py");
553        let lang = CodeChunker::detect_language(Some(&meta));
554        assert_eq!(lang, Language::Python);
555    }
556
557    #[test]
558    fn test_detect_language_unknown() {
559        let meta = ChunkMetadata::new().content_type("xyz");
560        let lang = CodeChunker::detect_language(Some(&meta));
561        assert_eq!(lang, Language::Unknown);
562    }
563
564    #[test]
565    fn test_chunk_rust_code() {
566        let chunker = CodeChunker::with_size(200);
567        let code = r#"
568fn main() {
569    println!("Hello");
570}
571
572fn helper() {
573    println!("Helper");
574}
575
576pub fn public_fn() {
577    println!("Public");
578}
579"#;
580
581        let meta = ChunkMetadata::with_size(200).content_type("rs");
582        let chunks = chunker.chunk(1, code, Some(&meta)).unwrap();
583
584        assert!(!chunks.is_empty());
585        // Each function should ideally be in its own chunk
586        for chunk in &chunks {
587            assert!(!chunk.content.trim().is_empty());
588        }
589    }
590
591    #[test]
592    fn test_chunk_python_code() {
593        let chunker = CodeChunker::with_size(150);
594        let code = r#"
595def main():
596    print("Hello")
597
598class MyClass:
599    def method(self):
600        pass
601
602async def async_func():
603    await something()
604"#;
605
606        let meta = ChunkMetadata::with_size(150).content_type("py");
607        let chunks = chunker.chunk(1, code, Some(&meta)).unwrap();
608
609        assert!(!chunks.is_empty());
610    }
611
612    #[test]
613    fn test_chunk_javascript_code() {
614        let chunker = CodeChunker::with_size(200);
615        let code = r#"
616function greet(name) {
617    console.log("Hello " + name);
618}
619
620class Person {
621    constructor(name) {
622        this.name = name;
623    }
624}
625
626const arrow = (x) => x * 2;
627
628export async function fetchData() {
629    return await fetch("/api");
630}
631"#;
632
633        let meta = ChunkMetadata::with_size(200).content_type("js");
634        let chunks = chunker.chunk(1, code, Some(&meta)).unwrap();
635
636        assert!(!chunks.is_empty());
637    }
638
639    #[test]
640    fn test_chunk_empty_text() {
641        let chunker = CodeChunker::new();
642        let chunks = chunker.chunk(1, "", None).unwrap();
643        assert!(chunks.is_empty());
644    }
645
646    #[test]
647    fn test_chunk_unknown_language() {
648        let chunker = CodeChunker::with_size(100);
649        let code = "some random text without code structure";
650
651        let chunks = chunker.chunk(1, code, None).unwrap();
652        assert!(!chunks.is_empty());
653    }
654
655    #[test]
656    fn test_boundary_patterns_rust() {
657        let patterns = Language::Rust.boundary_patterns();
658        assert!(!patterns.is_empty());
659
660        let code = "pub fn my_function() {}";
661        let re = BoundaryPattern::RustFn.regex();
662        assert!(re.is_match(code));
663    }
664
665    #[test]
666    fn test_boundary_patterns_python() {
667        let code = "def my_function():";
668        let re = BoundaryPattern::PythonDef.regex();
669        assert!(re.is_match(code));
670
671        let code = "class MyClass:";
672        let re = BoundaryPattern::PythonClass.regex();
673        assert!(re.is_match(code));
674    }
675
676    #[test]
677    fn test_language_extensions() {
678        assert_eq!(Language::from_extension("rs"), Language::Rust);
679        assert_eq!(Language::from_extension("py"), Language::Python);
680        assert_eq!(Language::from_extension("js"), Language::JavaScript);
681        assert_eq!(Language::from_extension("ts"), Language::TypeScript);
682        assert_eq!(Language::from_extension("go"), Language::Go);
683        assert_eq!(Language::from_extension("java"), Language::Java);
684        assert_eq!(Language::from_extension("c"), Language::C);
685        assert_eq!(Language::from_extension("cpp"), Language::Cpp);
686        assert_eq!(Language::from_extension("rb"), Language::Ruby);
687        assert_eq!(Language::from_extension("php"), Language::Php);
688        assert_eq!(Language::from_extension("unknown"), Language::Unknown);
689    }
690
691    #[test]
692    fn test_chunker_description() {
693        let chunker = CodeChunker::new();
694        assert!(!chunker.description().is_empty());
695    }
696}