1use std::collections::HashMap;
14use std::io::Write;
15use std::process::{Command, Stdio};
16use std::sync::{Arc, Mutex};
17
18use sha2::{Digest, Sha256};
19use thiserror::Error;
20
21#[derive(Debug, Error)]
22pub enum CompressError {
23 #[error("compression failed: {0}")]
24 Internal(String),
25}
26
27#[derive(Debug, Clone)]
28pub struct CompressedContext {
29 pub text: String,
30 pub original_tokens: usize,
31 pub compressed_tokens: usize,
32 pub cache_hit: bool,
33 pub used_sqz: bool,
35}
36
37#[derive(Debug, Clone, Default)]
38pub struct AgentTokenStats {
39 pub tokens_processed: u64,
40 pub tokens_saved: u64,
41}
42
43#[derive(Debug, Clone)]
44pub struct CompressionStats {
45 pub total_tokens_processed: u64,
46 pub total_tokens_saved: u64,
47 pub cache_entries: usize,
48 pub per_agent: HashMap<u32, AgentTokenStats>,
49 pub sqz_stats: Option<String>,
51}
52
53pub trait CompressionEngine {
54 fn compress(&self, context: &str, threshold: usize, agent_pid: u32) -> Result<CompressedContext, CompressError>;
55 fn get_stats(&self) -> CompressionStats;
56 fn invalidate_cache(&self);
57}
58
59pub fn sqz_available() -> bool {
63 Command::new("sqz").arg("--version").output().is_ok()
64}
65
66fn sqz_compress(input: &str, no_cache: bool) -> Option<String> {
69 let mut cmd = Command::new("sqz");
70 cmd.arg("compress");
71 if no_cache {
72 cmd.arg("--no-cache");
73 }
74 cmd.stdin(Stdio::piped()).stdout(Stdio::piped()).stderr(Stdio::null());
75
76 let mut child = cmd.spawn().ok()?;
77 child.stdin.as_mut()?.write_all(input.as_bytes()).ok()?;
78 let output = child.wait_with_output().ok()?;
79 if output.status.success() {
80 String::from_utf8(output.stdout).ok()
81 } else {
82 None
83 }
84}
85
86pub fn sqz_stats_raw() -> Option<String> {
88 let output = Command::new("sqz").arg("stats").output().ok()?;
89 if output.status.success() {
90 String::from_utf8(output.stdout).ok()
91 } else {
92 None
93 }
94}
95
96pub fn sqz_gain_raw() -> Option<String> {
98 let output = Command::new("sqz").arg("gain").output().ok()?;
99 if output.status.success() {
100 String::from_utf8(output.stdout).ok()
101 } else {
102 None
103 }
104}
105
106fn count_tokens(text: &str) -> usize {
109 text.split_whitespace().count()
110}
111
112fn sha256_hex(input: &str) -> String {
113 let mut hasher = Sha256::new();
114 hasher.update(input.as_bytes());
115 hex::encode(hasher.finalize())
116}
117
118struct Inner {
121 cache: HashMap<String, CompressedContext>,
122 per_agent: HashMap<u32, AgentTokenStats>,
123 total_processed: u64,
124 total_saved: u64,
125}
126
127pub struct SqzEngine {
136 inner: Arc<Mutex<Inner>>,
137}
138
139impl SqzEngine {
140 pub fn new() -> Self {
141 Self {
142 inner: Arc::new(Mutex::new(Inner {
143 cache: HashMap::new(),
144 per_agent: HashMap::new(),
145 total_processed: 0,
146 total_saved: 0,
147 })),
148 }
149 }
150
151 fn fallback_compress(context: &str, threshold: usize) -> String {
154 let mut seen = std::collections::HashSet::new();
156 let deduped: Vec<&str> = context.split(". ")
157 .filter(|s| !s.trim().is_empty() && seen.insert(s.trim()))
158 .collect();
159 let deduped_text = deduped.join(". ");
160
161 let tokens: Vec<&str> = deduped_text.split_whitespace().collect();
163 if tokens.len() <= threshold {
164 return deduped_text;
165 }
166 let start = tokens.len() - threshold;
167 tokens[start..].join(" ")
168 }
169}
170
171impl Default for SqzEngine {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177impl CompressionEngine for SqzEngine {
178 fn compress(&self, context: &str, threshold: usize, agent_pid: u32) -> Result<CompressedContext, CompressError> {
179 let original_tokens = count_tokens(context);
180
181 if original_tokens <= threshold {
183 return Ok(CompressedContext {
184 text: context.to_string(),
185 original_tokens,
186 compressed_tokens: original_tokens,
187 cache_hit: false,
188 used_sqz: false,
189 });
190 }
191
192 let hash = sha256_hex(context);
193 let mut inner = self.inner.lock().unwrap();
194
195 if let Some(cached) = inner.cache.get(&hash) {
197 let mut result = cached.clone();
198 result.cache_hit = true;
199 let saved = (original_tokens.saturating_sub(result.compressed_tokens)) as u64;
200 inner.total_processed += original_tokens as u64;
201 inner.total_saved += saved;
202 let entry = inner.per_agent.entry(agent_pid).or_default();
203 entry.tokens_processed += original_tokens as u64;
204 entry.tokens_saved += saved;
205 return Ok(result);
206 }
207
208 let (compressed_text, used_sqz) = if let Some(out) = sqz_compress(context, false) {
210 (out.trim_end().to_string(), true)
211 } else {
212 (Self::fallback_compress(context, threshold), false)
214 };
215
216 let compressed_tokens = count_tokens(&compressed_text);
217 let saved = (original_tokens.saturating_sub(compressed_tokens)) as u64;
218
219 inner.total_processed += original_tokens as u64;
220 inner.total_saved += saved;
221 let entry = inner.per_agent.entry(agent_pid).or_default();
222 entry.tokens_processed += original_tokens as u64;
223 entry.tokens_saved += saved;
224
225 let result = CompressedContext {
226 text: compressed_text,
227 original_tokens,
228 compressed_tokens,
229 cache_hit: false,
230 used_sqz,
231 };
232
233 inner.cache.insert(hash, result.clone());
234 Ok(result)
235 }
236
237 fn get_stats(&self) -> CompressionStats {
238 let inner = self.inner.lock().unwrap();
239 CompressionStats {
240 total_tokens_processed: inner.total_processed,
241 total_tokens_saved: inner.total_saved,
242 cache_entries: inner.cache.len(),
243 per_agent: inner.per_agent.clone(),
244 sqz_stats: sqz_stats_raw(),
246 }
247 }
248
249 fn invalidate_cache(&self) {
250 self.inner.lock().unwrap().cache.clear();
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 fn large_context(words: usize) -> String {
259 (0..words).map(|i| format!("word{i}")).collect::<Vec<_>>().join(" ")
260 }
261
262 fn large_context_with_duplicates(unique_sentences: usize, repeat: usize) -> String {
263 let sentences: Vec<String> = (0..unique_sentences)
264 .map(|i| (0..10).map(|w| format!("s{i}w{w}")).collect::<Vec<_>>().join(" "))
265 .collect();
266 let mut all = Vec::new();
267 for _ in 0..repeat {
268 all.extend(sentences.iter().cloned());
269 }
270 all.join(". ")
271 }
272
273 #[test]
274 fn compression_reduces_tokens_by_at_least_20_percent() {
275 let engine = SqzEngine::new();
276 let ctx = large_context(200);
277 let result = engine.compress(&ctx, 100, 1).unwrap();
278 assert!(!result.cache_hit);
279 assert_eq!(result.original_tokens, 200);
280 if !result.used_sqz {
285 let reduction = 1.0 - (result.compressed_tokens as f64 / result.original_tokens as f64);
286 assert!(reduction >= 0.20, "expected >= 20% reduction, got {:.1}%", reduction * 100.0);
287 }
288 }
289
290 #[test]
291 fn below_threshold_returns_as_is() {
292 let engine = SqzEngine::new();
293 let ctx = "hello world foo bar";
294 let result = engine.compress(ctx, 100, 1).unwrap();
295 assert_eq!(result.text, ctx);
296 assert_eq!(result.original_tokens, result.compressed_tokens);
297 assert!(!result.cache_hit);
298 }
299
300 #[test]
301 fn deduplication_cache_returns_hit_for_identical_input() {
302 let engine = SqzEngine::new();
303 let ctx = large_context(200);
304 let first = engine.compress(&ctx, 100, 1).unwrap();
305 assert!(!first.cache_hit);
306 let second = engine.compress(&ctx, 100, 1).unwrap();
307 assert!(second.cache_hit);
308 assert_eq!(first.compressed_tokens, second.compressed_tokens);
309 }
310
311 #[test]
312 fn cache_invalidation_clears_all_entries() {
313 let engine = SqzEngine::new();
314 let ctx = large_context(200);
315 engine.compress(&ctx, 100, 1).unwrap();
316 assert_eq!(engine.get_stats().cache_entries, 1);
317 engine.invalidate_cache();
318 assert_eq!(engine.get_stats().cache_entries, 0);
319 let result = engine.compress(&ctx, 100, 1).unwrap();
320 assert!(!result.cache_hit);
321 }
322
323 #[test]
324 fn stats_reflect_compression_operations() {
325 let engine = SqzEngine::new();
326 let ctx = large_context(200);
327 let result = engine.compress(&ctx, 100, 42).unwrap();
328 let stats = engine.get_stats();
329 assert_eq!(stats.total_tokens_processed, 200);
330 assert_eq!(stats.total_tokens_saved, (200 - result.compressed_tokens) as u64);
331 let agent_stats = stats.per_agent.get(&42).expect("agent 42 should have stats");
333 assert_eq!(agent_stats.tokens_processed, 200);
334 }
335
336 #[test]
337 fn stats_accumulate_across_multiple_calls() {
338 let engine = SqzEngine::new();
339 engine.compress(&large_context(100), 50, 1).unwrap();
340 engine.compress(&large_context(120), 50, 1).unwrap();
341 let stats = engine.get_stats();
342 assert_eq!(stats.total_tokens_processed, 220);
343 }
346
347 #[test]
348 fn deduplication_reduces_repeated_sentences() {
349 let engine = SqzEngine::new();
350 let ctx = large_context_with_duplicates(5, 10);
351 let original_tokens = count_tokens(&ctx);
352 let result = engine.compress(&ctx, original_tokens - 1, 1).unwrap();
353 assert!(result.compressed_tokens < original_tokens);
354 }
355
356 #[test]
357 fn different_contexts_produce_different_cache_entries() {
358 let engine = SqzEngine::new();
359 engine.compress(&large_context(100), 50, 1).unwrap();
360 engine.compress(&large_context(110), 50, 1).unwrap();
361 assert_eq!(engine.get_stats().cache_entries, 2);
362 }
363
364 #[test]
365 fn per_agent_stats_are_tracked_separately() {
366 let engine = SqzEngine::new();
367 let ctx = large_context(100);
368 engine.compress(&ctx, 50, 1).unwrap();
369 engine.compress(&ctx, 50, 2).unwrap();
370 let stats = engine.get_stats();
371 assert!(stats.per_agent.contains_key(&1));
372 assert!(stats.per_agent.contains_key(&2));
373 }
374
375 #[test]
376 fn sqz_available_check_does_not_panic() {
377 let _ = sqz_available();
379 }
380
381 #[test]
382 fn sqz_stats_raw_returns_option() {
383 let _ = sqz_stats_raw();
385 }
386}