1use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
10use serde_json::Value;
11use std::collections::HashMap;
12
13use crate::error::Result;
14
15pub const DICTIONARY_PREFIX: &str = "#M2M|";
17
18const PATTERN_START: u8 = 0x80;
20
21lazy_static::lazy_static! {
22 static ref PATTERN_ENCODE: HashMap<&'static str, u8> = {
24 let mut m = HashMap::new();
25 m.insert(r#"{"role":"user","content":"#, 0x80);
27 m.insert(r#"{"role":"assistant","content":"#, 0x81);
28 m.insert(r#"{"role":"system","content":"#, 0x82);
29 m.insert(r#""}"#, 0x83);
30 m.insert(r#"},"#, 0x84);
31 m.insert(r#""}]"#, 0x85);
32 m.insert(r#"{"messages":["#, 0x86);
33 m.insert(r#"{"model":"#, 0x87);
34 m.insert(r#","messages":["#, 0x88);
35 m.insert(r#","max_tokens":"#, 0x89);
36 m.insert(r#","temperature":"#, 0x8A);
37 m.insert(r#","stream":true"#, 0x8B);
38 m.insert(r#","stream":false"#, 0x8C);
39 m.insert(r#""gpt-4"#, 0x90);
41 m.insert(r#""gpt-4o"#, 0x91);
42 m.insert(r#""gpt-4o-mini"#, 0x92);
43 m.insert(r#""gpt-3.5-turbo"#, 0x93);
44 m.insert(r#""claude-3"#, 0x94);
45 m.insert(r#""llama"#, 0x95);
46 m.insert(r#"{"choices":[{"#, 0xA0);
48 m.insert(r#""finish_reason":"stop""#, 0xA1);
49 m.insert(r#""finish_reason":"length""#, 0xA2);
50 m.insert(r#","usage":{"#, 0xA3);
51 m.insert(r#""prompt_tokens":"#, 0xA4);
52 m.insert(r#","completion_tokens":"#, 0xA5);
53 m.insert(r#","total_tokens":"#, 0xA6);
54 m.insert(r#""index":0,"#, 0xA7);
55 m.insert(r#""message":{"#, 0xA8);
56 m.insert(r#""delta":{"#, 0xA9);
57 m.insert(r#""tool_calls":[{"#, 0xB0);
59 m.insert(r#""type":"function","#, 0xB1);
60 m.insert(r#""function":{"#, 0xB2);
61 m.insert(r#""name":"#, 0xB3);
62 m.insert(r#","arguments":"#, 0xB4);
63 m
64 };
65
66 static ref PATTERN_DECODE: HashMap<u8, &'static str> = {
68 PATTERN_ENCODE.iter().map(|(k, v)| (*v, *k)).collect()
69 };
70
71 static ref PATTERNS_SORTED: Vec<(&'static str, u8)> = {
73 let mut patterns: Vec<_> = PATTERN_ENCODE.iter().map(|(k, v)| (*k, *v)).collect();
74 patterns.sort_by(|a, b| b.0.len().cmp(&a.0.len()));
76 patterns
77 };
78}
79
80#[derive(Clone)]
82pub struct DictionaryCodec {
83 pub use_patterns: bool,
85 pub min_length: usize,
87}
88
89impl Default for DictionaryCodec {
90 fn default() -> Self {
91 Self {
92 use_patterns: true,
93 min_length: 50,
94 }
95 }
96}
97
98impl DictionaryCodec {
99 pub fn new() -> Self {
101 Self::default()
102 }
103
104 #[deprecated(note = "Use M2M codec instead")]
108 pub fn compress(&self, content: &str) -> Result<(String, usize, usize)> {
109 if content.len() < self.min_length {
110 let wire = format!("{DICTIONARY_PREFIX}{content}");
112 let wire_len = wire.len();
113 return Ok((wire, content.len(), wire_len));
114 }
115
116 let compressed = if self.use_patterns {
117 self.compress_with_patterns(content)
118 } else {
119 content.as_bytes().to_vec()
120 };
121
122 let encoded = BASE64.encode(&compressed);
124 let wire = format!("{DICTIONARY_PREFIX}{encoded}");
125 let wire_len = wire.len();
126
127 Ok((wire, content.len(), wire_len))
128 }
129
130 pub fn decompress(&self, wire: &str) -> Result<String> {
132 let data = wire.strip_prefix(DICTIONARY_PREFIX).unwrap_or(wire);
133
134 match BASE64.decode(data) {
136 Ok(decoded) => {
137 if self.use_patterns {
138 self.decompress_with_patterns(&decoded)
139 } else {
140 String::from_utf8(decoded)
141 .map_err(|e| crate::error::M2MError::Decompression(e.to_string()))
142 }
143 },
144 Err(_) => {
145 Ok(data.to_string())
147 },
148 }
149 }
150
151 fn compress_with_patterns(&self, content: &str) -> Vec<u8> {
153 let mut result = Vec::with_capacity(content.len());
154 let bytes = content.as_bytes();
155 let mut i = 0;
156
157 while i < bytes.len() {
158 let remaining = &content[i..];
159 let mut matched = false;
160
161 for (pattern, code) in PATTERNS_SORTED.iter() {
163 if remaining.starts_with(pattern) {
164 result.push(*code);
165 i += pattern.len();
166 matched = true;
167 break;
168 }
169 }
170
171 if !matched {
172 result.push(bytes[i]);
173 i += 1;
174 }
175 }
176
177 result
178 }
179
180 fn decompress_with_patterns(&self, data: &[u8]) -> Result<String> {
182 let mut result = String::with_capacity(data.len() * 2);
183
184 for &byte in data {
185 if byte >= PATTERN_START {
186 if let Some(&pattern) = PATTERN_DECODE.get(&byte) {
187 result.push_str(pattern);
188 } else {
189 result.push(byte as char);
191 }
192 } else {
193 result.push(byte as char);
194 }
195 }
196
197 Ok(result)
198 }
199
200 #[deprecated(note = "Use M2M codec instead")]
204 #[allow(deprecated)]
205 pub fn compress_value(&self, value: &Value) -> Result<(String, usize, usize)> {
206 let json = serde_json::to_string(value)?;
207 self.compress(&json)
208 }
209
210 pub fn decompress_value(&self, wire: &str) -> Result<Value> {
212 let json = self.decompress(wire)?;
213 serde_json::from_str(&json)
214 .map_err(|e| crate::error::M2MError::Decompression(e.to_string()))
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 #[allow(deprecated)]
224 fn test_compress_short() {
225 let codec = DictionaryCodec::new();
226 let content = r#"{"model":"gpt-4o"}"#;
227
228 let (data, _, _) = codec.compress(content).unwrap();
229 assert!(data.starts_with("#M2M|"));
230
231 let decompressed = codec.decompress(&data).unwrap();
232 assert_eq!(decompressed, content);
233 }
234
235 #[test]
236 #[allow(deprecated)]
237 fn test_compress_with_patterns() {
238 let content = r#"{"messages":[{"role":"user","content":"Hello"}]}"#;
239
240 let mut codec = DictionaryCodec::new();
242 codec.min_length = 0;
243
244 let (data, _, _) = codec.compress(content).unwrap();
245 assert!(data.starts_with("#M2M|"));
246
247 let decompressed = codec.decompress(&data).unwrap();
248 assert_eq!(decompressed, content);
249 }
250
251 #[test]
252 #[allow(deprecated)]
253 fn test_compress_request() {
254 let codec = DictionaryCodec {
255 min_length: 0,
256 ..Default::default()
257 };
258
259 let content = r#"{"model":"gpt-4o","messages":[{"role":"system","content":"Be helpful"},{"role":"user","content":"Hello"}]}"#;
260
261 let (data, original_len, compressed_len) = codec.compress(content).unwrap();
262
263 let decompressed = codec.decompress(&data).unwrap();
265 assert_eq!(decompressed, content);
266
267 println!(
269 "Original: {} bytes, Wire: {} bytes",
270 original_len, compressed_len
271 );
272 }
273
274 #[test]
275 fn test_pattern_encode_decode() {
276 for (pattern, code) in PATTERN_ENCODE.iter() {
278 assert!(
279 PATTERN_DECODE.contains_key(code),
280 "Pattern '{pattern}' (0x{code:02X}) missing decode entry"
281 );
282 }
283 }
284}