1use serde_json::Value;
8
9use super::brotli::BrotliCodec;
10use super::m2m::M2MCodec;
11use super::token_native::TokenNativeCodec;
12use super::{Algorithm, CompressionResult};
13use crate::error::{M2MError, Result};
14use crate::inference::HydraModel;
15use crate::models::Encoding;
16use crate::security::SecurityScanner;
17use crate::tokenizer::count_tokens_with_encoding;
18
19#[derive(Debug, Clone)]
21pub struct ContentAnalysis {
22 pub length: usize,
24 pub is_json: bool,
26 pub is_llm_api: bool,
28 pub repetition_ratio: f32,
30 pub has_tools: bool,
32 pub estimated_tokens: usize,
34}
35
36impl ContentAnalysis {
37 pub fn analyze(content: &str) -> Self {
39 let length = content.len();
40 let parsed: Option<Value> = serde_json::from_str(content).ok();
41 let is_json = parsed.is_some();
42
43 let (is_llm_api, has_tools) = if let Some(ref value) = parsed {
44 let is_api = value.get("messages").is_some()
45 || value.get("model").is_some()
46 || value.get("choices").is_some();
47 let tools = value.get("tools").is_some()
48 || value.get("tool_calls").is_some()
49 || value.get("functions").is_some();
50 (is_api, tools)
51 } else {
52 (false, false)
53 };
54
55 let repetition_ratio = Self::calculate_repetition(content);
57
58 let estimated_tokens = length / 4;
60
61 Self {
62 length,
63 is_json,
64 is_llm_api,
65 repetition_ratio,
66 has_tools,
67 estimated_tokens,
68 }
69 }
70
71 fn calculate_repetition(content: &str) -> f32 {
72 if content.len() < 100 {
73 return 0.0;
74 }
75
76 let mut seen = std::collections::HashSet::new();
78 let chars: Vec<char> = content.chars().collect();
79 let total = chars.len().saturating_sub(3);
80
81 if total == 0 {
82 return 0.0;
83 }
84
85 for window in chars.windows(4) {
86 let gram: String = window.iter().collect();
87 seen.insert(gram);
88 }
89
90 1.0 - (seen.len() as f32 / total as f32)
91 }
92}
93
94#[derive(Clone)]
96pub struct CodecEngine {
97 token_native: TokenNativeCodec,
99 m2m: M2MCodec,
101 brotli: BrotliCodec,
103 hydra: Option<HydraModel>,
105 pub ml_routing: bool,
107 pub brotli_threshold: usize,
109 pub prefer_m2m_for_api: bool,
111}
112
113impl Default for CodecEngine {
114 fn default() -> Self {
115 Self {
116 token_native: TokenNativeCodec::default(),
117 m2m: M2MCodec::new(),
118 brotli: BrotliCodec::new(),
119 hydra: None,
120 ml_routing: false,
121 brotli_threshold: 1024, prefer_m2m_for_api: true,
123 }
124 }
125}
126
127impl CodecEngine {
128 pub fn new() -> Self {
130 Self::default()
131 }
132
133 pub fn with_ml_routing(mut self, enabled: bool) -> Self {
135 self.ml_routing = enabled;
136 self
137 }
138
139 pub fn with_hydra(mut self, model: HydraModel) -> Self {
141 self.hydra = Some(model);
142 self.ml_routing = true;
143 self
144 }
145
146 pub fn with_brotli_threshold(mut self, threshold: usize) -> Self {
148 self.brotli_threshold = threshold;
149 self
150 }
151
152 pub fn with_encoding(mut self, encoding: Encoding) -> Self {
154 self.token_native = TokenNativeCodec::new(encoding);
155 self
156 }
157
158 pub fn compress_with_tokens(
164 &self,
165 content: &str,
166 algorithm: Algorithm,
167 encoding: Encoding,
168 ) -> Result<CompressionResult> {
169 let original_tokens = count_tokens_with_encoding(content, encoding);
170 let mut result = self.compress(content, algorithm)?;
171
172 let compressed_tokens = count_tokens_with_encoding(&result.data, encoding);
173 result.original_tokens = Some(original_tokens);
174 result.compressed_tokens = Some(compressed_tokens);
175
176 Ok(result)
177 }
178
179 pub fn secure_compress(
191 &self,
192 content: &str,
193 scanner: &SecurityScanner,
194 ) -> Result<CompressionResult> {
195 let scan_result = scanner.scan_and_validate(content)?;
197
198 if scan_result.should_block {
199 let threat_desc = scan_result
200 .threats
201 .first()
202 .map(|t| t.name.clone())
203 .unwrap_or_else(|| "unknown".to_string());
204
205 return Err(M2MError::ContentBlocked(format!(
206 "Content blocked: {} (confidence: {:.2})",
207 threat_desc, scan_result.confidence
208 )));
209 }
210
211 let analysis = ContentAnalysis::analyze(content);
213 let algorithm = self.select_algorithm(&analysis);
214 self.compress(content, algorithm)
215 }
216
217 pub fn secure_compress_ml(&self, content: &str) -> Result<(CompressionResult, bool)> {
221 let is_safe = if let Some(ref hydra) = self.hydra {
223 let security = hydra.predict_security(content)?;
224 security.safe
225 } else {
226 let fallback = HydraModel::fallback_only();
228 fallback.predict_security(content)?.safe
229 };
230
231 let algorithm = self.select_algorithm_for_content(content);
233 let result = self.compress(content, algorithm)?;
234
235 Ok((result, is_safe))
236 }
237
238 pub fn compress_auto_with_tokens(
240 &self,
241 content: &str,
242 encoding: Encoding,
243 ) -> Result<(CompressionResult, Algorithm)> {
244 let analysis = ContentAnalysis::analyze(content);
245 let algorithm = self.select_algorithm(&analysis);
246
247 let result = self.compress_with_tokens(content, algorithm, encoding)?;
248 Ok((result, algorithm))
249 }
250
251 pub fn compress(&self, content: &str, algorithm: Algorithm) -> Result<CompressionResult> {
253 match algorithm {
254 Algorithm::None => Ok(CompressionResult::new(
255 content.to_string(),
256 Algorithm::None,
257 content.len(),
258 content.len(),
259 )),
260 Algorithm::M2M => {
261 let wire = self.m2m.encode_string(content)?;
264 Ok(CompressionResult::new(
265 wire.clone(),
266 Algorithm::M2M,
267 content.len(),
268 wire.len(),
269 ))
270 },
271 Algorithm::TokenNative => self.token_native.compress(content),
272 Algorithm::Brotli => self.brotli.compress(content),
273 }
274 }
275
276 pub fn compress_auto(&self, content: &str) -> Result<(CompressionResult, Algorithm)> {
278 let analysis = ContentAnalysis::analyze(content);
279 let algorithm = self.select_algorithm(&analysis);
280
281 let result = self.compress(content, algorithm)?;
282 Ok((result, algorithm))
283 }
284
285 pub fn compress_value(&self, value: &Value) -> Result<(CompressionResult, Algorithm)> {
287 let content = serde_json::to_string(value)?;
288 self.compress_auto(&content)
289 }
290
291 pub fn select_algorithm(&self, analysis: &ContentAnalysis) -> Algorithm {
293 if self.ml_routing {
295 return self.ml_select_algorithm(analysis);
296 }
297
298 self.heuristic_select_algorithm(analysis)
300 }
301
302 fn ml_select_algorithm(&self, analysis: &ContentAnalysis) -> Algorithm {
304 if let Some(ref hydra) = self.hydra {
306 if let Ok(decision) = hydra.predict_compression("") {
311 return decision.algorithm;
312 }
313 }
314
315 self.heuristic_select_algorithm(analysis)
317 }
318
319 pub fn select_algorithm_for_content(&self, content: &str) -> Algorithm {
321 if self.ml_routing {
323 if let Some(ref hydra) = self.hydra {
324 if let Ok(decision) = hydra.predict_compression(content) {
325 return decision.algorithm;
326 }
327 }
328 }
329
330 let analysis = ContentAnalysis::analyze(content);
332 self.heuristic_select_algorithm(&analysis)
333 }
334
335 fn heuristic_select_algorithm(&self, analysis: &ContentAnalysis) -> Algorithm {
342 if analysis.length < 100 {
345 return Algorithm::None;
346 }
347
348 if analysis.length > self.brotli_threshold {
351 return Algorithm::Brotli;
352 }
353
354 if analysis.is_llm_api && self.prefer_m2m_for_api {
357 return Algorithm::M2M;
358 }
359
360 if analysis.repetition_ratio > 0.3 {
362 return Algorithm::Brotli;
363 }
364
365 if analysis.is_json {
367 Algorithm::M2M
368 } else {
369 Algorithm::None
370 }
371 }
372
373 pub fn decompress(&self, wire: &str) -> Result<String> {
375 let algorithm = super::detect_algorithm(wire).unwrap_or(Algorithm::None);
376
377 match algorithm {
378 Algorithm::None => Ok(wire.to_string()),
379 Algorithm::M2M => {
380 self.m2m.decode_string(wire)
382 },
383 Algorithm::TokenNative => self.token_native.decompress(wire),
384 Algorithm::Brotli => self.brotli.decompress(wire),
385 }
386 }
387
388 pub fn decompress_value(&self, wire: &str) -> Result<Value> {
390 let json = self.decompress(wire)?;
391 serde_json::from_str(&json).map_err(|e| M2MError::Decompression(e.to_string()))
392 }
393
394 pub fn compress_best(&self, content: &str) -> Result<CompressionResult> {
396 let mut best: Option<CompressionResult> = None;
397
398 for algo in [Algorithm::M2M, Algorithm::TokenNative, Algorithm::Brotli] {
400 if let Ok(result) = self.compress(content, algo) {
401 let is_better = match &best {
402 None => true,
403 Some(current) => result.compressed_bytes < current.compressed_bytes,
404 };
405
406 if is_better {
407 best = Some(result);
408 }
409 }
410 }
411
412 best.ok_or_else(|| M2MError::Compression("All algorithms failed".to_string()))
413 }
414
415 pub fn analyze(&self, content: &str) -> ContentAnalysis {
417 ContentAnalysis::analyze(content)
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 #[test]
426 fn test_auto_select_small() {
427 let engine = CodecEngine::new();
428 let analysis = ContentAnalysis::analyze("small");
429 assert_eq!(engine.select_algorithm(&analysis), Algorithm::None);
430 }
431
432 #[test]
433 fn test_auto_select_llm_api() {
434 let engine = CodecEngine::new();
435 let content = r#"{"model":"gpt-4o","messages":[{"role":"user","content":"Hello, how are you doing today? This is a longer message to test the compression algorithm selection."}]}"#;
437 let analysis = ContentAnalysis::analyze(content);
438
439 assert!(analysis.is_json);
440 assert!(analysis.is_llm_api);
441 assert!(analysis.length >= 100 && analysis.length <= 1024);
442 assert_eq!(engine.select_algorithm(&analysis), Algorithm::M2M);
443 }
444
445 #[test]
446 fn test_compress_decompress_auto() {
447 let engine = CodecEngine::new();
448 let content = r#"{"model":"gpt-4o","messages":[{"role":"user","content":"Hello, how are you doing today?"}],"temperature":0.7}"#;
449
450 let (result, algo) = engine.compress_auto(content).unwrap();
451 println!("Selected algorithm: {algo:?}");
452 println!(
453 "Original: {} bytes, Compressed: {} bytes",
454 result.original_bytes, result.compressed_bytes
455 );
456
457 let decompressed = engine.decompress(&result.data).unwrap();
458 let original: Value = serde_json::from_str(content).unwrap();
459 let recovered: Value = serde_json::from_str(&decompressed).unwrap();
460
461 assert_eq!(
463 original["messages"][0]["content"],
464 recovered["messages"][0]["content"]
465 );
466 }
467
468 #[test]
469 fn test_compress_best() {
470 let engine = CodecEngine::new();
471 let content = r#"{"model":"gpt-4o","messages":[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"What is the capital of France?"},{"role":"assistant","content":"The capital of France is Paris."}]}"#;
472
473 let result = engine.compress_best(content).unwrap();
474 println!(
475 "Best algorithm: {:?}, ratio: {:.2}",
476 result.algorithm,
477 result.byte_ratio()
478 );
479 }
480
481 #[test]
482 fn test_content_analysis() {
483 let content = r#"{"model":"gpt-4o","messages":[{"role":"user","content":"test"}],"tools":[{"type":"function"}]}"#;
484 let analysis = ContentAnalysis::analyze(content);
485
486 assert!(analysis.is_json);
487 assert!(analysis.is_llm_api);
488 assert!(analysis.has_tools);
489 }
490
491 #[test]
492 fn test_large_content_selects_brotli() {
493 let engine = CodecEngine::new();
494
495 let repeated = "hello world ".repeat(100);
497 let analysis = ContentAnalysis::analyze(&repeated);
498
499 assert!(
500 analysis.length > 1024,
501 "Content should be >1024 bytes for Brotli selection"
502 );
503 assert_eq!(engine.select_algorithm(&analysis), Algorithm::Brotli);
504 }
505
506 #[test]
507 fn test_ml_routing_with_hydra() {
508 let hydra = HydraModel::fallback_only();
509 let engine = CodecEngine::new().with_hydra(hydra);
510
511 let content = r#"{"model":"gpt-4o","messages":[{"role":"user","content":"Test message"}]}"#;
513 let algo = engine.select_algorithm_for_content(content);
514
515 assert_eq!(algo, Algorithm::M2M);
517 }
518
519 #[test]
520 fn test_token_native_roundtrip() {
521 let engine = CodecEngine::new();
522 let content = r#"{"model":"gpt-4o","messages":[{"role":"user","content":"Hello!"}]}"#;
523
524 let result = engine.compress(content, Algorithm::TokenNative).unwrap();
525 assert!(result.data.starts_with("#TK|"));
526
527 let decompressed = engine.decompress(&result.data).unwrap();
528 assert_eq!(content, decompressed);
529 }
530}