Skip to main content

cersei_compression/
dispatch.rs

1//! Route a raw tool output through the right compression stage based on the
2//! tool name and its input JSON.
3//!
4//! Every invocation emits a single `tracing::info!` event on the
5//! `cersei_compression` target with before/after bytes, lines, savings
6//! percent, strategy, and the matched rule / language. Subscribers can filter
7//! with `RUST_LOG=cersei_compression=info`.
8
9use crate::{ansi, code, level::CompressionLevel, toml_rules, truncate};
10use serde_json::Value;
11
12const MAX_LINES_SAFETY: usize = 600;
13
14/// Infallible entry point. On any internal panic-free error we fall back to
15/// the raw content unchanged so the agent loop never breaks.
16pub fn compress_tool_output(
17    tool_name: &str,
18    tool_input: &Value,
19    content: &str,
20    level: CompressionLevel,
21) -> String {
22    if level.is_off() || content.is_empty() {
23        return content.to_string();
24    }
25
26    let lowered = tool_name.to_ascii_lowercase();
27
28    // strategy = short tag describing the branch we took
29    // detail   = finer identifier (rule name, detected language, empty string, …)
30    let (out, strategy, detail): (String, &'static str, String) = match lowered.as_str() {
31        // ─── Shell-like tools → TOML rules DSL ───────────────────────────
32        "bash" | "exec" | "execshell" | "shell" | "run" | "runshell" => {
33            let command = tool_input
34                .get("command")
35                .and_then(Value::as_str)
36                .unwrap_or("");
37            let (out, rule) = compress_command(command, content, level);
38            (out, "shell", rule)
39        }
40
41        // ─── File read → code filter ────────────────────────────────────
42        "read" | "readfile" | "read_file" | "view" => {
43            let path = tool_input
44                .get("file_path")
45                .or_else(|| tool_input.get("path"))
46                .and_then(Value::as_str)
47                .unwrap_or("");
48            let lang = code::Language::from_path(path);
49            let filtered = code::filter(content, lang, level);
50            let capped = safety_cap(&filtered, level);
51            (capped, "code", format!("{lang:?}"))
52        }
53
54        // ─── Structured retrieval tools pass straight through ────────────
55        "grep" | "glob" | "list" | "ls" | "find" | "tree" => {
56            (content.to_string(), "passthrough", String::new())
57        }
58
59        // ─── Web fetch → strip ANSI + generic TOML catch-all ─────────────
60        "webfetch" | "web_fetch" | "fetch" | "http" => {
61            let stripped = ansi::strip_ansi(content);
62            let (out, rule) = compress_command("webfetch", &stripped, level);
63            (out, "web", rule)
64        }
65
66        // ─── Anything else → minimal safety cap at Aggressive, else noop ─
67        _ => {
68            if matches!(level, CompressionLevel::Aggressive) {
69                (safety_cap(content, level), "unknown-capped", String::new())
70            } else {
71                (content.to_string(), "unknown", String::new())
72            }
73        }
74    };
75
76    log_compression(tool_name, level, strategy, &detail, content, &out);
77    out
78}
79
80/// Returns (filtered_output, matched_rule_name_or_empty).
81fn compress_command(command: &str, content: &str, level: CompressionLevel) -> (String, String) {
82    let stripped = ansi::strip_ansi(content);
83    let (out, rule) = if let Some(filter) = toml_rules::find_matching(command.trim()) {
84        (toml_rules::apply(filter, &stripped), filter.name.clone())
85    } else {
86        (stripped, String::new())
87    };
88    (safety_cap(&out, level), rule)
89}
90
91fn safety_cap(content: &str, level: CompressionLevel) -> String {
92    let cap = match level {
93        CompressionLevel::Off => return content.to_string(),
94        CompressionLevel::Minimal => MAX_LINES_SAFETY,
95        CompressionLevel::Aggressive => MAX_LINES_SAFETY / 2,
96    };
97    if content.lines().count() <= cap {
98        content.to_string()
99    } else {
100        truncate::smart_truncate(content, cap)
101    }
102}
103
104fn log_compression(
105    tool: &str,
106    level: CompressionLevel,
107    strategy: &str,
108    detail: &str,
109    before: &str,
110    after: &str,
111) {
112    let before_bytes = before.len();
113    let after_bytes = after.len();
114    let before_lines = before.lines().count();
115    let after_lines = after.lines().count();
116    let savings_pct = if before_bytes > 0 {
117        100.0 * (before_bytes as f64 - after_bytes as f64) / before_bytes as f64
118    } else {
119        0.0
120    };
121
122    tracing::info!(
123        target: "cersei_compression",
124        tool,
125        level = %level,
126        strategy,
127        detail,
128        before_bytes,
129        after_bytes,
130        before_lines,
131        after_lines,
132        savings_pct = format!("{savings_pct:.1}"),
133        "tool-output compressed"
134    );
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use serde_json::json;
141
142    #[test]
143    fn off_is_noop() {
144        let raw = "\x1b[31mhello\x1b[0m";
145        let out = compress_tool_output("Bash", &json!({}), raw, CompressionLevel::Off);
146        assert_eq!(out, raw);
147    }
148
149    #[test]
150    fn bash_strips_ansi_at_minimal() {
151        let raw = "\x1b[31mfatal: not a git repo\x1b[0m";
152        let out = compress_tool_output(
153            "Bash",
154            &json!({"command": "git status"}),
155            raw,
156            CompressionLevel::Minimal,
157        );
158        assert!(!out.contains("\x1b["));
159        assert!(out.contains("fatal"));
160    }
161
162    #[test]
163    fn read_preserves_json_when_data_file() {
164        let raw = r#"{"a": 1, "packages": ["x/*"]}"#;
165        let out = compress_tool_output(
166            "Read",
167            &json!({"file_path": "/x/package.json"}),
168            raw,
169            CompressionLevel::Aggressive,
170        );
171        assert!(out.contains("packages"));
172        assert!(out.contains("x/*"));
173    }
174
175    #[test]
176    fn read_strips_rust_comments_in_aggressive() {
177        let raw = "\
178// normal comment
179/// doc comment
180fn main() {
181    let x = 1;
182    println!(\"{}\", x);
183}
184";
185        let out = compress_tool_output(
186            "Read",
187            &json!({"file_path": "src/main.rs"}),
188            raw,
189            CompressionLevel::Aggressive,
190        );
191        assert!(!out.contains("// normal comment"));
192        assert!(out.contains("fn main"));
193    }
194
195    #[test]
196    fn grep_passthrough() {
197        let raw = "file.rs:1:hit\nfile.rs:2:hit2";
198        let out = compress_tool_output(
199            "Grep",
200            &json!({"pattern": "hit"}),
201            raw,
202            CompressionLevel::Aggressive,
203        );
204        assert_eq!(out, raw);
205    }
206}