cognis_rag/splitters/
code.rs1use crate::document::Document;
9
10use super::{recursive::RecursiveCharSplitter, TextSplitter};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum CodeLanguage {
15 Rust,
17 Python,
19 JavaScript,
21 Go,
23 Java,
25 Cpp,
27 Generic,
29}
30
31impl CodeLanguage {
32 pub fn separators(&self) -> Vec<&'static str> {
34 match self {
35 Self::Rust => vec![
36 "\nimpl ",
37 "\nfn ",
38 "\nstruct ",
39 "\nenum ",
40 "\ntrait ",
41 "\nmod ",
42 "\n\n",
43 "\n",
44 " ",
45 "",
46 ],
47 Self::Python => vec!["\nclass ", "\ndef ", "\nasync def ", "\n\n", "\n", " ", ""],
48 Self::JavaScript => vec![
49 "\nfunction ",
50 "\nclass ",
51 "\nconst ",
52 "\nlet ",
53 "\nvar ",
54 "\n\n",
55 "\n",
56 " ",
57 "",
58 ],
59 Self::Go => vec!["\nfunc ", "\ntype ", "\n\n", "\n", " ", ""],
60 Self::Java => vec![
61 "\npublic class ",
62 "\nclass ",
63 "\npublic ",
64 "\nprivate ",
65 "\nprotected ",
66 "\n\n",
67 "\n",
68 " ",
69 "",
70 ],
71 Self::Cpp => vec!["\nclass ", "\nstruct ", "\nvoid ", "\n\n", "\n", " ", ""],
72 Self::Generic => vec!["\n\n", "\n", " ", ""],
73 }
74 }
75}
76
77pub struct CodeSplitter {
80 inner: RecursiveCharSplitter,
81}
82
83impl CodeSplitter {
84 pub fn new(language: CodeLanguage) -> Self {
86 Self {
87 inner: RecursiveCharSplitter::new().with_separators(language.separators()),
88 }
89 }
90
91 pub fn with_chunk_size(mut self, n: usize) -> Self {
93 self.inner = self.inner.with_chunk_size(n);
94 self
95 }
96
97 pub fn with_overlap(mut self, n: usize) -> Self {
99 self.inner = self.inner.with_overlap(n);
100 self
101 }
102}
103
104impl TextSplitter for CodeSplitter {
105 fn split(&self, doc: &Document) -> Vec<Document> {
106 self.inner.split(doc)
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[test]
115 fn rust_splits_at_fn_boundary() {
116 let code = "fn a() { 1 }\n\nfn b() { 2 }\n\nfn c() { 3 }\n";
117 let s = CodeSplitter::new(CodeLanguage::Rust)
118 .with_chunk_size(15)
119 .with_overlap(0);
120 let chunks = s.split(&Document::new(code));
121 assert!(chunks.len() >= 2);
122 assert!(chunks.iter().any(|c| c.content.contains("fn a")));
123 }
124
125 #[test]
126 fn python_splits_at_def_boundary() {
127 let code = "def a():\n return 1\n\ndef b():\n return 2\n";
128 let s = CodeSplitter::new(CodeLanguage::Python).with_chunk_size(20);
129 let chunks = s.split(&Document::new(code));
130 assert!(!chunks.is_empty());
131 }
132}