Skip to main content

cognis_rag/splitters/
code.rs

1//! Language-aware code splitter — uses per-language separator preferences.
2//!
3//! This is a thin specialization of the recursive splitter that picks
4//! sensible separators per language (function/class boundaries, blank
5//! lines, then chars). It's not a parser — for AST-aware splitting,
6//! pull in `tree-sitter` in a downstream crate.
7
8use crate::document::Document;
9
10use super::{recursive::RecursiveCharSplitter, TextSplitter};
11
12/// Common programming languages we ship default separator orderings for.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum CodeLanguage {
15    /// Rust.
16    Rust,
17    /// Python.
18    Python,
19    /// JavaScript / TypeScript.
20    JavaScript,
21    /// Go.
22    Go,
23    /// Java.
24    Java,
25    /// C / C++ / similar curly-brace languages.
26    Cpp,
27    /// Generic / fallback (paragraph → line → space → char).
28    Generic,
29}
30
31impl CodeLanguage {
32    /// Coarsest-first separator list for this language.
33    pub fn separators(&self) -> Vec<&'static str> {
34        match self {
35            Self::Rust => vec![
36                "\nimpl ",
37                "\nfn ",
38                "\nstruct ",
39                "\nenum ",
40                "\ntrait ",
41                "\nmod ",
42                "\n\n",
43                "\n",
44                " ",
45                "",
46            ],
47            Self::Python => vec!["\nclass ", "\ndef ", "\nasync def ", "\n\n", "\n", " ", ""],
48            Self::JavaScript => vec![
49                "\nfunction ",
50                "\nclass ",
51                "\nconst ",
52                "\nlet ",
53                "\nvar ",
54                "\n\n",
55                "\n",
56                " ",
57                "",
58            ],
59            Self::Go => vec!["\nfunc ", "\ntype ", "\n\n", "\n", " ", ""],
60            Self::Java => vec![
61                "\npublic class ",
62                "\nclass ",
63                "\npublic ",
64                "\nprivate ",
65                "\nprotected ",
66                "\n\n",
67                "\n",
68                " ",
69                "",
70            ],
71            Self::Cpp => vec!["\nclass ", "\nstruct ", "\nvoid ", "\n\n", "\n", " ", ""],
72            Self::Generic => vec!["\n\n", "\n", " ", ""],
73        }
74    }
75}
76
77/// Code-aware splitter. Wraps [`RecursiveCharSplitter`] with language-tuned
78/// separators.
79pub struct CodeSplitter {
80    inner: RecursiveCharSplitter,
81}
82
83impl CodeSplitter {
84    /// Build a splitter for the given language.
85    pub fn new(language: CodeLanguage) -> Self {
86        Self {
87            inner: RecursiveCharSplitter::new().with_separators(language.separators()),
88        }
89    }
90
91    /// Cap chunk size.
92    pub fn with_chunk_size(mut self, n: usize) -> Self {
93        self.inner = self.inner.with_chunk_size(n);
94        self
95    }
96
97    /// Set chunk overlap.
98    pub fn with_overlap(mut self, n: usize) -> Self {
99        self.inner = self.inner.with_overlap(n);
100        self
101    }
102}
103
104impl TextSplitter for CodeSplitter {
105    fn split(&self, doc: &Document) -> Vec<Document> {
106        self.inner.split(doc)
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[test]
115    fn rust_splits_at_fn_boundary() {
116        let code = "fn a() { 1 }\n\nfn b() { 2 }\n\nfn c() { 3 }\n";
117        let s = CodeSplitter::new(CodeLanguage::Rust)
118            .with_chunk_size(15)
119            .with_overlap(0);
120        let chunks = s.split(&Document::new(code));
121        assert!(chunks.len() >= 2);
122        assert!(chunks.iter().any(|c| c.content.contains("fn a")));
123    }
124
125    #[test]
126    fn python_splits_at_def_boundary() {
127        let code = "def a():\n    return 1\n\ndef b():\n    return 2\n";
128        let s = CodeSplitter::new(CodeLanguage::Python).with_chunk_size(20);
129        let chunks = s.split(&Document::new(code));
130        assert!(!chunks.is_empty());
131    }
132}