1use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use brotli::{CompressorWriter, Decompressor};
8use std::io::{Read, Write};
9
10use super::{Algorithm, CompressionResult};
11use crate::error::{M2MError, Result};
12
13const DEFAULT_QUALITY: u32 = 11;
15
16const DEFAULT_WINDOW_SIZE: u32 = 22;
18
19#[derive(Clone)]
21pub struct BrotliCodec {
22 pub quality: u32,
24 pub window_size: u32,
26}
27
28impl Default for BrotliCodec {
29 fn default() -> Self {
30 Self {
31 quality: DEFAULT_QUALITY,
32 window_size: DEFAULT_WINDOW_SIZE,
33 }
34 }
35}
36
37impl BrotliCodec {
38 pub fn new() -> Self {
40 Self::default()
41 }
42
43 pub fn with_quality(quality: u32) -> Self {
45 Self {
46 quality: quality.min(11),
47 ..Default::default()
48 }
49 }
50
51 pub fn compress_bytes(&self, data: &[u8]) -> Result<Vec<u8>> {
53 let mut compressed = Vec::new();
54 {
55 let mut writer =
56 CompressorWriter::new(&mut compressed, 4096, self.quality, self.window_size);
57 writer
58 .write_all(data)
59 .map_err(|e| M2MError::Compression(e.to_string()))?;
60 }
61 Ok(compressed)
62 }
63
64 pub fn decompress_bytes(&self, data: &[u8]) -> Result<Vec<u8>> {
66 let mut decompressor = Decompressor::new(data, 4096);
67 let mut decompressed = Vec::new();
68 decompressor
69 .read_to_end(&mut decompressed)
70 .map_err(|e| M2MError::Decompression(e.to_string()))?;
71 Ok(decompressed)
72 }
73
74 pub fn compress(&self, content: &str) -> Result<CompressionResult> {
76 let compressed = self.compress_bytes(content.as_bytes())?;
77 let encoded = BASE64.encode(&compressed);
78 let wire = format!("#M2M[v3.0]|DATA:{encoded}");
79 let wire_len = wire.len();
80
81 Ok(CompressionResult::new(
82 wire,
83 Algorithm::Brotli,
84 content.len(),
85 wire_len,
86 ))
87 }
88
89 pub fn decompress(&self, wire: &str) -> Result<String> {
91 let data = wire
92 .strip_prefix("#M2M[v3.0]|DATA:")
93 .ok_or_else(|| M2MError::InvalidMessage("Invalid Brotli wire format".to_string()))?;
94
95 let compressed = BASE64.decode(data)?;
96 let decompressed = self.decompress_bytes(&compressed)?;
97
98 String::from_utf8(decompressed)
99 .map_err(|e| M2MError::Decompression(format!("Invalid UTF-8: {e}")))
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn test_compress_decompress() {
109 let codec = BrotliCodec::new();
110 let original =
111 r#"{"model":"gpt-4o","messages":[{"role":"user","content":"Hello, world!"}]}"#;
112
113 let result = codec.compress(original).unwrap();
114 assert!(result.data.starts_with("#M2M[v3.0]|DATA:"));
115
116 let decompressed = codec.decompress(&result.data).unwrap();
117 assert_eq!(decompressed, original);
118 }
119
120 #[test]
121 fn test_compression_ratio() {
122 let codec = BrotliCodec::new();
123
124 let original = r#"{"messages":[{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there!"},{"role":"user","content":"How are you?"},{"role":"assistant","content":"I'm doing great, thank you for asking!"}]}"#;
126
127 let result = codec.compress(original).unwrap();
128
129 println!(
131 "Original: {} bytes, Compressed: {} bytes, Ratio: {:.2}",
132 result.original_bytes,
133 result.compressed_bytes,
134 result.byte_ratio()
135 );
136 }
137
138 #[test]
139 fn test_bytes_roundtrip() {
140 let codec = BrotliCodec::new();
141 let original = b"Hello, Brotli! This is a test of byte compression.";
142
143 let compressed = codec.compress_bytes(original).unwrap();
144 let decompressed = codec.decompress_bytes(&compressed).unwrap();
145
146 assert_eq!(decompressed, original);
147 }
148}