1use serde_json::{Map, Value};
28
29use super::tables::{
30 is_default_value, KEY_ABBREV, KEY_EXPAND, MODEL_ABBREV, MODEL_EXPAND, PATTERN_ABBREV,
31 PATTERN_EXPAND, ROLE_ABBREV, ROLE_EXPAND,
32};
33use crate::error::Result;
34
35pub const TOKEN_PREFIX: &str = "#T1|";
37
38#[derive(Clone)]
40pub struct TokenCodec {
41 pub apply_patterns: bool,
43 pub abbreviate_keys: bool,
45 pub abbreviate_roles: bool,
47 pub abbreviate_models: bool,
49 pub remove_defaults: bool,
51 pub remove_nulls: bool,
53}
54
55impl Default for TokenCodec {
56 fn default() -> Self {
57 Self {
58 apply_patterns: false,
61 abbreviate_keys: true,
62 abbreviate_roles: true,
63 abbreviate_models: true,
64 remove_defaults: true,
65 remove_nulls: true,
66 }
67 }
68}
69
70impl TokenCodec {
71 pub fn new() -> Self {
73 Self::default()
74 }
75
76 #[deprecated(note = "Use M2M codec instead")]
80 pub fn compress(&self, value: &Value) -> Result<(String, usize, usize)> {
81 let original = serde_json::to_string(value)?;
82
83 let compressed_value = self.compress_value(value, None);
85 let mut compressed_json = serde_json::to_string(&compressed_value)?;
86
87 if self.apply_patterns {
90 compressed_json = self.apply_pattern_compression(&compressed_json);
91 }
92
93 let wire = format!("{TOKEN_PREFIX}{compressed_json}");
95 let wire_len = wire.len();
96
97 Ok((wire, original.len(), wire_len))
98 }
99
100 pub fn compress_raw(&self, value: &Value) -> String {
102 let compressed_value = self.compress_value(value, None);
103 let mut compressed_json = serde_json::to_string(&compressed_value).unwrap_or_default();
104
105 if self.apply_patterns {
106 compressed_json = self.apply_pattern_compression(&compressed_json);
107 }
108
109 compressed_json
110 }
111
112 pub fn decompress(&self, wire: &str) -> Result<Value> {
114 let json_str = wire.strip_prefix(TOKEN_PREFIX).unwrap_or(wire);
115
116 let expanded_json = self.apply_pattern_expansion(json_str);
118
119 let value: Value = serde_json::from_str(&expanded_json)?;
120 let expanded = self.expand_value(&value, None);
121 Ok(self.restore_defaults(&expanded))
123 }
124
125 pub fn decompress_raw(&self, json_str: &str) -> Result<Value> {
127 let expanded_json = self.apply_pattern_expansion(json_str);
128 let value: Value = serde_json::from_str(&expanded_json)?;
129 let expanded = self.expand_value(&value, None);
130 Ok(self.restore_defaults(&expanded))
131 }
132
133 fn apply_pattern_compression(&self, json: &str) -> String {
138 let mut result = json.to_string();
139
140 for (pattern, replacement) in PATTERN_ABBREV {
141 result = result.replace(pattern, replacement);
142 }
143
144 result
145 }
146
147 fn apply_pattern_expansion(&self, json: &str) -> String {
149 let mut result = json.to_string();
150
151 for (replacement, pattern) in PATTERN_EXPAND {
152 result = result.replace(replacement, pattern);
153 }
154
155 result
156 }
157
158 fn compress_value(&self, value: &Value, parent_key: Option<&str>) -> Value {
159 match value {
160 Value::Object(map) => {
161 let mut result = Map::new();
162
163 for (key, val) in map {
164 if self.remove_nulls && val.is_null() {
166 continue;
167 }
168
169 if self.remove_defaults && is_default_value(key, val) {
171 continue;
172 }
173
174 let new_key = if self.abbreviate_keys {
176 KEY_ABBREV
177 .get(key.as_str())
178 .map(|s| s.to_string())
179 .unwrap_or_else(|| key.clone())
180 } else {
181 key.clone()
182 };
183
184 let new_val = self.compress_value_with_context(val, key);
186 result.insert(new_key, new_val);
187 }
188
189 Value::Object(result)
190 },
191 Value::Array(arr) => Value::Array(
192 arr.iter()
193 .map(|v| self.compress_value(v, parent_key))
194 .collect(),
195 ),
196 _ => self.compress_value_with_context(value, parent_key.unwrap_or("")),
197 }
198 }
199
200 fn compress_value_with_context(&self, value: &Value, key: &str) -> Value {
201 match value {
202 Value::String(s) => {
203 if (key == "role" || key == "r") && self.abbreviate_roles {
205 if let Some(abbrev) = ROLE_ABBREV.get(s.as_str()) {
206 return Value::String(abbrev.to_string());
207 }
208 }
209
210 if (key == "model" || key == "M") && self.abbreviate_models {
212 if let Some(abbrev) = MODEL_ABBREV.get(s.as_str()) {
213 return Value::String(abbrev.to_string());
214 }
215 }
216
217 value.clone()
218 },
219 Value::Object(_) => self.compress_value(value, Some(key)),
220 Value::Array(arr) => Value::Array(
221 arr.iter()
222 .map(|v| self.compress_value(v, Some(key)))
223 .collect(),
224 ),
225 _ => value.clone(),
226 }
227 }
228
229 fn expand_value(&self, value: &Value, parent_key: Option<&str>) -> Value {
230 match value {
231 Value::Object(map) => {
232 let mut result = Map::new();
233
234 for (key, val) in map {
235 let expanded_key = KEY_EXPAND
237 .get(key.as_str())
238 .map(|s| s.to_string())
239 .unwrap_or_else(|| key.clone());
240
241 let expanded_val = self.expand_value(val, Some(&expanded_key));
243 result.insert(expanded_key, expanded_val);
244 }
245
246 Value::Object(result)
247 },
248 Value::Array(arr) => Value::Array(
249 arr.iter()
250 .map(|v| self.expand_value(v, parent_key))
251 .collect(),
252 ),
253 Value::String(s) => {
254 if let Some(key) = parent_key {
255 if key == "role" || key == "r" {
257 if let Some(expanded) = ROLE_EXPAND.get(s.as_str()) {
258 return Value::String(expanded.to_string());
259 }
260 }
261
262 if key == "model" || key == "M" {
264 if let Some(expanded) = MODEL_EXPAND.get(s.as_str()) {
265 return Value::String(expanded.to_string());
266 }
267 }
268 }
269
270 value.clone()
271 },
272 _ => value.clone(),
273 }
274 }
275
276 fn restore_defaults(&self, value: &Value) -> Value {
282 match value {
283 Value::Object(map) => {
284 let is_llm_request = map.contains_key("messages") || map.contains_key("model");
286
287 if is_llm_request {
288 let mut result = map.clone();
289
290 if !result.contains_key("temperature") {
292 result.insert("temperature".to_string(), Value::from(1.0));
293 }
294 if !result.contains_key("top_p") {
295 result.insert("top_p".to_string(), Value::from(1.0));
296 }
297 if !result.contains_key("n") {
298 result.insert("n".to_string(), Value::from(1));
299 }
300 if !result.contains_key("stream") {
301 result.insert("stream".to_string(), Value::Bool(false));
302 }
303 if !result.contains_key("frequency_penalty") {
304 result.insert("frequency_penalty".to_string(), Value::from(0));
305 }
306 if !result.contains_key("presence_penalty") {
307 result.insert("presence_penalty".to_string(), Value::from(0));
308 }
309
310 Value::Object(result)
311 } else {
312 let mut result = Map::new();
314 for (key, val) in map {
315 result.insert(key.clone(), self.restore_defaults(val));
316 }
317 Value::Object(result)
318 }
319 },
320 Value::Array(arr) => {
321 Value::Array(arr.iter().map(|v| self.restore_defaults(v)).collect())
322 },
323 _ => value.clone(),
324 }
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use serde_json::json;
332
333 #[test]
334 #[allow(deprecated)]
335 fn test_compress_basic() {
336 let codec = TokenCodec::new();
337 let input = json!({
338 "model": "gpt-4o",
339 "messages": [
340 {"role": "user", "content": "Hello"}
341 ]
342 });
343
344 let (data, _, _) = codec.compress(&input).unwrap();
345 assert!(data.starts_with("#T1|"));
346 assert!(data.contains("\"M\":\"g4o\"") || data.contains("\"M\": \"g4o\""));
348 assert!(data.contains("\"c\""));
350 }
351
352 #[test]
353 #[allow(deprecated)]
354 fn test_pattern_compression() {
355 let codec = TokenCodec::new();
356 let input = json!({
357 "messages": [
358 {"role": "user", "content": "Hello"},
359 {"role": "assistant", "content": "Hi there!"}
360 ]
361 });
362
363 let (data, _, _) = codec.compress(&input).unwrap();
364
365 println!("Compressed: {}", data);
369
370 let decompressed = codec.decompress(&data).unwrap();
372 assert_eq!(decompressed["messages"][0]["content"], "Hello");
373 assert_eq!(decompressed["messages"][1]["content"], "Hi there!");
374 }
375
376 #[test]
377 #[allow(deprecated)]
378 fn test_roundtrip() {
379 let codec = TokenCodec::new();
380 let input = json!({
381 "model": "gpt-4o",
382 "messages": [
383 {"role": "system", "content": "You are helpful"},
384 {"role": "user", "content": "Hello"}
385 ],
386 "max_tokens": 100
387 });
388
389 let (data, _, _) = codec.compress(&input).unwrap();
390 let decompressed = codec.decompress(&data).unwrap();
391
392 assert_eq!(
394 decompressed["messages"][0]["content"],
395 input["messages"][0]["content"]
396 );
397 assert_eq!(
398 decompressed["messages"][1]["content"],
399 input["messages"][1]["content"]
400 );
401 }
402
403 #[test]
404 #[allow(deprecated)]
405 fn test_removes_defaults() {
406 let codec = TokenCodec::new();
407 let input = json!({
408 "model": "gpt-4o",
409 "messages": [{"role": "user", "content": "Hi"}],
410 "temperature": 1.0,
411 "stream": false,
412 "n": 1
413 });
414
415 let (data, _, _) = codec.compress(&input).unwrap();
416
417 assert!(!data.contains("temperature"));
419 assert!(!data.contains("stream"));
420 }
421
422 #[test]
423 #[allow(deprecated)]
424 fn test_role_abbreviation() {
425 let codec = TokenCodec::new();
426 let input = json!({
427 "messages": [
428 {"role": "system", "content": "Be helpful"},
429 {"role": "assistant", "content": "OK"}
430 ]
431 });
432
433 let (data, _, _) = codec.compress(&input).unwrap();
434
435 let decompressed = codec.decompress(&data).unwrap();
438 assert_eq!(decompressed["messages"][0]["role"], "system");
439 assert_eq!(decompressed["messages"][1]["role"], "assistant");
440 }
441
442 #[test]
443 #[allow(deprecated)]
444 fn test_default_restoration() {
445 let codec = TokenCodec::new();
446
447 let input = json!({
449 "model": "gpt-4o",
450 "messages": [{"role": "user", "content": "Hello"}],
451 "temperature": 1.0,
452 "top_p": 1.0,
453 "n": 1,
454 "stream": false,
455 "frequency_penalty": 0,
456 "presence_penalty": 0
457 });
458
459 let (data, _, _) = codec.compress(&input).unwrap();
460 let decompressed = codec.decompress(&data).unwrap();
461
462 assert_eq!(decompressed["temperature"], 1.0);
464 assert_eq!(decompressed["top_p"], 1.0);
465 assert_eq!(decompressed["n"], 1);
466 assert_eq!(decompressed["stream"], false);
467 assert_eq!(decompressed["frequency_penalty"], 0);
468 assert_eq!(decompressed["presence_penalty"], 0);
469
470 assert_eq!(decompressed["model"], "gpt-4o");
472 assert_eq!(decompressed["messages"][0]["content"], "Hello");
473 }
474
475 #[test]
476 #[allow(deprecated)]
477 fn test_default_restoration_preserves_non_defaults() {
478 let codec = TokenCodec::new();
479
480 let input = json!({
482 "model": "gpt-4o",
483 "messages": [{"role": "user", "content": "Hello"}],
484 "temperature": 0.7,
485 "top_p": 0.9,
486 "n": 2,
487 "stream": true,
488 "frequency_penalty": 0.5,
489 "presence_penalty": 0.5
490 });
491
492 let (data, _, _) = codec.compress(&input).unwrap();
493 let decompressed = codec.decompress(&data).unwrap();
494
495 assert_eq!(decompressed["temperature"], 0.7);
497 assert_eq!(decompressed["top_p"], 0.9);
498 assert_eq!(decompressed["n"], 2);
499 assert_eq!(decompressed["stream"], true);
500 assert_eq!(decompressed["frequency_penalty"], 0.5);
501 assert_eq!(decompressed["presence_penalty"], 0.5);
502 }
503
504 #[test]
505 #[allow(deprecated)]
506 fn test_roundtrip_with_defaults_complete() {
507 let codec = TokenCodec::new();
508
509 let original = json!({
511 "model": "gpt-4o",
512 "messages": [
513 {"role": "system", "content": "You are helpful"},
514 {"role": "user", "content": "Hello"}
515 ],
516 "temperature": 1.0,
517 "stream": false
518 });
519
520 let (data, _, _) = codec.compress(&original).unwrap();
522
523 assert!(!data.contains("temperature"));
525 assert!(!data.contains("stream"));
526
527 let decompressed = codec.decompress(&data).unwrap();
529
530 assert!(decompressed.get("temperature").is_some());
532 assert!(decompressed.get("stream").is_some());
533 assert!(decompressed.get("top_p").is_some());
534 assert!(decompressed.get("n").is_some());
535 assert!(decompressed.get("frequency_penalty").is_some());
536 assert!(decompressed.get("presence_penalty").is_some());
537 }
538
539 #[test]
540 fn test_pattern_expansion_roundtrip() {
541 let codec = TokenCodec::new();
542
543 for (pattern, abbrev) in super::super::tables::PATTERN_ABBREV {
545 let compressed = codec.apply_pattern_compression(pattern);
546 assert_eq!(
547 compressed, *abbrev,
548 "Pattern compression failed for: {}",
549 pattern
550 );
551
552 let expanded = codec.apply_pattern_expansion(&compressed);
553 assert_eq!(
554 expanded, *pattern,
555 "Pattern expansion failed for: {}",
556 pattern
557 );
558 }
559 }
560}