Skip to main content

cersei_compression/
code.rs

1//! Language-aware code filtering.
2//!
3//! Credits: adapted from rtk (Rust Token Killer) — `rtk/src/core/filter.rs`.
4//! MIT © Patrick Szymkowiak. See LICENSE.
5
6use crate::level::CompressionLevel;
7use once_cell::sync::Lazy;
8use regex::Regex;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum Language {
12    Rust,
13    Python,
14    JavaScript,
15    TypeScript,
16    Go,
17    C,
18    Cpp,
19    Java,
20    Ruby,
21    Shell,
22    /// JSON / YAML / TOML / XML / CSV — never code-stripped.
23    Data,
24    Unknown,
25}
26
27impl Language {
28    pub fn from_extension(ext: &str) -> Self {
29        match ext.to_ascii_lowercase().as_str() {
30            "rs" => Language::Rust,
31            "py" | "pyw" => Language::Python,
32            "js" | "mjs" | "cjs" => Language::JavaScript,
33            "ts" | "tsx" => Language::TypeScript,
34            "go" => Language::Go,
35            "c" | "h" => Language::C,
36            "cpp" | "cc" | "cxx" | "hpp" | "hh" => Language::Cpp,
37            "java" => Language::Java,
38            "rb" => Language::Ruby,
39            "sh" | "bash" | "zsh" => Language::Shell,
40            "json" | "jsonc" | "json5" | "yaml" | "yml" | "toml" | "xml" | "csv" | "tsv"
41            | "graphql" | "gql" | "sql" | "md" | "markdown" | "txt" | "env" | "lock" => {
42                Language::Data
43            }
44            _ => Language::Unknown,
45        }
46    }
47
48    pub fn from_path(path: &str) -> Self {
49        let ext = path.rsplit('.').next().unwrap_or("");
50        if ext == path || ext.is_empty() {
51            Language::Unknown
52        } else {
53            Self::from_extension(ext)
54        }
55    }
56
57    fn comment_patterns(&self) -> CommentPatterns {
58        match self {
59            Language::Rust => CommentPatterns {
60                line: Some("//"),
61                block_start: Some("/*"),
62                block_end: Some("*/"),
63                doc_line: Some("///"),
64                doc_block_start: Some("/**"),
65            },
66            Language::Python => CommentPatterns {
67                line: Some("#"),
68                block_start: Some("\"\"\""),
69                block_end: Some("\"\"\""),
70                doc_line: None,
71                doc_block_start: Some("\"\"\""),
72            },
73            Language::JavaScript
74            | Language::TypeScript
75            | Language::Go
76            | Language::C
77            | Language::Cpp
78            | Language::Java => CommentPatterns {
79                line: Some("//"),
80                block_start: Some("/*"),
81                block_end: Some("*/"),
82                doc_line: None,
83                doc_block_start: Some("/**"),
84            },
85            Language::Ruby => CommentPatterns {
86                line: Some("#"),
87                block_start: Some("=begin"),
88                block_end: Some("=end"),
89                doc_line: None,
90                doc_block_start: None,
91            },
92            Language::Shell => CommentPatterns {
93                line: Some("#"),
94                block_start: None,
95                block_end: None,
96                doc_line: None,
97                doc_block_start: None,
98            },
99            Language::Data | Language::Unknown => CommentPatterns::default(),
100        }
101    }
102}
103
104#[derive(Debug, Default, Clone)]
105struct CommentPatterns {
106    line: Option<&'static str>,
107    block_start: Option<&'static str>,
108    block_end: Option<&'static str>,
109    doc_line: Option<&'static str>,
110    doc_block_start: Option<&'static str>,
111}
112
113static MULTIPLE_BLANK_LINES: Lazy<Regex> = Lazy::new(|| Regex::new(r"\n{3,}").unwrap());
114static IMPORT_PATTERN: Lazy<Regex> =
115    Lazy::new(|| Regex::new(r"^(use |import |from |require\(|#include)").unwrap());
116static FUNC_SIGNATURE: Lazy<Regex> = Lazy::new(|| {
117    Regex::new(
118        r"^(pub\s+)?(async\s+)?(fn|def|function|func|class|struct|enum|trait|interface|type)\s+\w+",
119    )
120    .unwrap()
121});
122
123/// Apply code-aware filtering. For `Data` or `Unknown` languages, returns input
124/// unchanged to avoid corrupting JSON/YAML/TOML (rtk issue #464).
125pub fn filter(content: &str, lang: Language, level: CompressionLevel) -> String {
126    match level {
127        CompressionLevel::Off => content.to_string(),
128        CompressionLevel::Minimal => minimal(content, lang),
129        CompressionLevel::Aggressive => aggressive(content, lang),
130    }
131}
132
133fn minimal(content: &str, lang: Language) -> String {
134    if matches!(lang, Language::Data | Language::Unknown) {
135        return content.to_string();
136    }
137    let patterns = lang.comment_patterns();
138    let mut out = String::with_capacity(content.len());
139    let mut in_block_comment = false;
140    let mut in_docstring = false;
141
142    for line in content.lines() {
143        let trimmed = line.trim();
144
145        // Handle block comments
146        if let (Some(start), Some(end)) = (patterns.block_start, patterns.block_end) {
147            if !in_docstring
148                && trimmed.contains(start)
149                && !trimmed.starts_with(patterns.doc_block_start.unwrap_or("\0"))
150            {
151                in_block_comment = true;
152            }
153            if in_block_comment {
154                if trimmed.contains(end) {
155                    in_block_comment = false;
156                }
157                continue;
158            }
159        }
160
161        // Python docstrings: keep in minimal mode
162        if lang == Language::Python && trimmed.starts_with("\"\"\"") {
163            in_docstring = !in_docstring;
164            out.push_str(line);
165            out.push('\n');
166            continue;
167        }
168        if in_docstring {
169            out.push_str(line);
170            out.push('\n');
171            continue;
172        }
173
174        // Single-line comments (keep doc comments if language has them)
175        if let Some(line_comment) = patterns.line {
176            if trimmed.starts_with(line_comment) {
177                if let Some(doc) = patterns.doc_line {
178                    if trimmed.starts_with(doc) {
179                        out.push_str(line);
180                        out.push('\n');
181                    }
182                }
183                continue;
184            }
185        }
186
187        if trimmed.is_empty() {
188            out.push('\n');
189            continue;
190        }
191
192        out.push_str(line);
193        out.push('\n');
194    }
195
196    MULTIPLE_BLANK_LINES
197        .replace_all(&out, "\n\n")
198        .trim()
199        .to_string()
200}
201
202fn aggressive(content: &str, lang: Language) -> String {
203    if matches!(lang, Language::Data | Language::Unknown) {
204        return minimal(content, lang);
205    }
206    let minimal_out = minimal(content, lang);
207    let mut out = String::with_capacity(minimal_out.len() / 2);
208    let mut brace_depth: i32 = 0;
209    let mut in_impl_body = false;
210
211    for line in minimal_out.lines() {
212        let trimmed = line.trim();
213
214        if IMPORT_PATTERN.is_match(trimmed) {
215            out.push_str(line);
216            out.push('\n');
217            continue;
218        }
219
220        if FUNC_SIGNATURE.is_match(trimmed) {
221            out.push_str(line);
222            out.push('\n');
223            in_impl_body = true;
224            brace_depth = 0;
225            continue;
226        }
227
228        let open = trimmed.matches('{').count() as i32;
229        let close = trimmed.matches('}').count() as i32;
230
231        if in_impl_body {
232            brace_depth += open;
233            brace_depth -= close;
234
235            if brace_depth <= 1 && (trimmed == "{" || trimmed == "}" || trimmed.ends_with('{')) {
236                out.push_str(line);
237                out.push('\n');
238            }
239
240            if brace_depth <= 0 {
241                in_impl_body = false;
242                if !trimmed.is_empty() && trimmed != "}" {
243                    out.push_str("    // ... implementation\n");
244                }
245            }
246            continue;
247        }
248
249        if trimmed.starts_with("const ")
250            || trimmed.starts_with("static ")
251            || trimmed.starts_with("let ")
252            || trimmed.starts_with("pub const ")
253            || trimmed.starts_with("pub static ")
254        {
255            out.push_str(line);
256            out.push('\n');
257        }
258    }
259
260    out.trim().to_string()
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn language_detection() {
269        assert_eq!(Language::from_extension("rs"), Language::Rust);
270        assert_eq!(Language::from_extension("py"), Language::Python);
271        assert_eq!(Language::from_extension("json"), Language::Data);
272        assert_eq!(Language::from_extension("lock"), Language::Data);
273        assert_eq!(Language::from_extension("xyz"), Language::Unknown);
274        assert_eq!(Language::from_path("/a/b/c.rs"), Language::Rust);
275        assert_eq!(Language::from_path("Dockerfile"), Language::Unknown);
276    }
277
278    #[test]
279    fn minimal_strips_rust_line_comments_keeps_doc() {
280        let src = "\
281// normal comment
282/// doc comment
283fn main() {
284    println!(\"hi\");
285}
286";
287        let out = minimal(src, Language::Rust);
288        assert!(!out.contains("// normal comment"));
289        assert!(out.contains("/// doc comment"));
290        assert!(out.contains("fn main()"));
291    }
292
293    #[test]
294    fn minimal_preserves_json() {
295        let json =
296            r#"{"pkgs": ["packages/*"], "scripts": {"build": "bun run --workspaces build"}}"#;
297        assert_eq!(minimal(json, Language::Data), json);
298    }
299
300    #[test]
301    fn aggressive_preserves_signatures_and_imports() {
302        let src = "\
303use std::io;
304fn do_thing() {
305    let x = 1;
306    println!(\"{}\", x);
307}
308";
309        let out = aggressive(src, Language::Rust);
310        assert!(out.contains("use std::io"));
311        assert!(out.contains("fn do_thing"));
312        assert!(out.contains("... implementation") || !out.contains("println"));
313    }
314}