1use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
46use std::sync::OnceLock;
47use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE};
48
49use super::{Algorithm, CompressionResult};
50use crate::error::{M2MError, Result};
51use crate::models::Encoding;
52
53static CL100K: OnceLock<CoreBPE> = OnceLock::new();
55static O200K: OnceLock<CoreBPE> = OnceLock::new();
56
57fn get_cl100k() -> &'static CoreBPE {
58 CL100K.get_or_init(|| cl100k_base().expect("Failed to load cl100k_base tokenizer"))
59}
60
61fn get_o200k() -> &'static CoreBPE {
62 O200K.get_or_init(|| o200k_base().expect("Failed to load o200k_base tokenizer"))
63}
64
65#[derive(Debug, Clone, Copy)]
69pub struct TokenNativeCodec {
70 encoding: Encoding,
72}
73
74impl TokenNativeCodec {
75 pub fn new(encoding: Encoding) -> Self {
77 Self { encoding }
78 }
79
80 pub fn cl100k() -> Self {
82 Self::new(Encoding::Cl100kBase)
83 }
84
85 pub fn o200k() -> Self {
87 Self::new(Encoding::O200kBase)
88 }
89
90 pub fn encoding(&self) -> Encoding {
92 self.encoding
93 }
94
95 fn tokenizer_id(&self) -> char {
97 match self.encoding {
98 Encoding::Cl100kBase => 'C',
99 Encoding::O200kBase => 'O',
100 Encoding::LlamaBpe => 'L',
101 Encoding::Heuristic => 'C', }
103 }
104
105 fn encoding_from_id(id: char) -> Encoding {
107 match id {
108 'C' => Encoding::Cl100kBase,
109 'O' => Encoding::O200kBase,
110 'L' => Encoding::LlamaBpe,
111 _ => Encoding::Cl100kBase, }
113 }
114
115 fn tokenize(&self, text: &str) -> Vec<u32> {
117 match self.encoding {
118 Encoding::Cl100kBase => get_cl100k().encode_with_special_tokens(text),
119 Encoding::O200kBase => get_o200k().encode_with_special_tokens(text),
120 Encoding::LlamaBpe => {
121 get_cl100k().encode_with_special_tokens(text)
123 },
124 Encoding::Heuristic => {
125 get_cl100k().encode_with_special_tokens(text)
127 },
128 }
129 }
130
131 fn detokenize(&self, tokens: &[u32]) -> Result<String> {
133 let result = match self.encoding {
134 Encoding::Cl100kBase => get_cl100k().decode(tokens.to_vec()),
135 Encoding::O200kBase => get_o200k().decode(tokens.to_vec()),
136 Encoding::LlamaBpe => get_cl100k().decode(tokens.to_vec()),
137 Encoding::Heuristic => get_cl100k().decode(tokens.to_vec()),
138 };
139
140 result.map_err(|e| M2MError::Decompression(format!("Detokenization failed: {}", e)))
141 }
142
143 pub fn compress(&self, text: &str) -> Result<CompressionResult> {
145 let original_bytes = text.len();
146
147 let tokens = self.tokenize(text);
149 let token_count = tokens.len();
150
151 let varint_bytes = varint_encode(&tokens);
153
154 let encoded = BASE64.encode(&varint_bytes);
156
157 let wire = format!("#TK|{}|{}", self.tokenizer_id(), encoded);
159 let compressed_bytes = wire.len();
160
161 Ok(CompressionResult {
162 data: wire,
163 algorithm: Algorithm::TokenNative,
164 original_bytes,
165 compressed_bytes,
166 original_tokens: Some(token_count),
167 compressed_tokens: Some(token_count), })
169 }
170
171 pub fn decompress(&self, wire: &str) -> Result<String> {
173 let content = wire
175 .strip_prefix("#TK|")
176 .ok_or_else(|| M2MError::Decompression("Invalid token-native format".to_string()))?;
177
178 let mut parts = content.splitn(2, '|');
180 let tokenizer_id = parts
181 .next()
182 .and_then(|s| s.chars().next())
183 .ok_or_else(|| M2MError::Decompression("Missing tokenizer ID".to_string()))?;
184
185 let encoded_data = parts
186 .next()
187 .ok_or_else(|| M2MError::Decompression("Missing encoded data".to_string()))?;
188
189 let wire_encoding = Self::encoding_from_id(tokenizer_id);
191
192 let varint_bytes = BASE64
194 .decode(encoded_data)
195 .map_err(|e| M2MError::Decompression(format!("Base64 decode failed: {}", e)))?;
196
197 let tokens = varint_decode(&varint_bytes)?;
199
200 let wire_codec = TokenNativeCodec::new(wire_encoding);
202 wire_codec.detokenize(&tokens)
203 }
204
205 pub fn compress_raw(&self, text: &str) -> Vec<u8> {
207 let tokens = self.tokenize(text);
208 varint_encode(&tokens)
209 }
210
211 pub fn decompress_raw(&self, bytes: &[u8]) -> Result<String> {
213 let tokens = varint_decode(bytes)?;
214 self.detokenize(&tokens)
215 }
216
217 pub fn compress_binary(&self, text: &str) -> Vec<u8> {
226 let tokens = self.tokenize(text);
227 let mut result = Vec::with_capacity(1 + tokens.len() * 2);
228
229 result.push(self.tokenizer_id_byte());
231
232 result.extend(varint_encode(&tokens));
234
235 result
236 }
237
238 pub fn decompress_binary(bytes: &[u8]) -> Result<String> {
240 if bytes.is_empty() {
241 return Err(M2MError::Decompression("Empty binary data".to_string()));
242 }
243
244 let tokenizer_byte = bytes[0];
246 let encoding = Self::encoding_from_byte(tokenizer_byte);
247
248 let tokens = varint_decode(&bytes[1..])?;
250
251 let codec = TokenNativeCodec::new(encoding);
253 codec.detokenize(&tokens)
254 }
255
256 fn tokenizer_id_byte(&self) -> u8 {
258 match self.encoding {
259 Encoding::Cl100kBase => 0,
260 Encoding::O200kBase => 1,
261 Encoding::LlamaBpe => 2,
262 Encoding::Heuristic => 0, }
264 }
265
266 fn encoding_from_byte(byte: u8) -> Encoding {
268 match byte {
269 0 => Encoding::Cl100kBase,
270 1 => Encoding::O200kBase,
271 2 => Encoding::LlamaBpe,
272 _ => Encoding::Cl100kBase, }
274 }
275}
276
277impl Default for TokenNativeCodec {
278 fn default() -> Self {
279 Self::cl100k()
280 }
281}
282
283fn varint_encode(tokens: &[u32]) -> Vec<u8> {
292 let mut result = Vec::with_capacity(tokens.len() * 2);
293
294 for &token in tokens {
295 let mut value = token;
296 loop {
297 let mut byte = (value & 0x7F) as u8;
298 value >>= 7;
299 if value != 0 {
300 byte |= 0x80; }
302 result.push(byte);
303 if value == 0 {
304 break;
305 }
306 }
307 }
308
309 result
310}
311
312fn varint_decode(bytes: &[u8]) -> Result<Vec<u32>> {
314 let mut tokens = Vec::new();
315 let mut i = 0;
316
317 while i < bytes.len() {
318 let mut value: u32 = 0;
319 let mut shift = 0;
320
321 loop {
322 if i >= bytes.len() {
323 return Err(M2MError::Decompression("Truncated VarInt data".to_string()));
324 }
325
326 let byte = bytes[i];
327 i += 1;
328
329 value |= ((byte & 0x7F) as u32) << shift;
330 shift += 7;
331
332 if byte & 0x80 == 0 {
333 break; }
335
336 if shift > 35 {
337 return Err(M2MError::Decompression("VarInt overflow".to_string()));
338 }
339 }
340
341 tokens.push(value);
342 }
343
344 Ok(tokens)
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn test_varint_encode_decode() {
353 let tokens: Vec<u32> = vec![0, 1, 127, 128, 255, 256, 16383, 16384, 100000];
354 let encoded = varint_encode(&tokens);
355 let decoded = varint_decode(&encoded).unwrap();
356 assert_eq!(tokens, decoded);
357 }
358
359 #[test]
360 fn test_varint_efficiency() {
361 let small_tokens: Vec<u32> = (0..1000).collect();
363 let encoded = varint_encode(&small_tokens);
364
365 let avg_bytes = encoded.len() as f64 / small_tokens.len() as f64;
367 assert!(
368 avg_bytes < 2.0,
369 "Average bytes per token: {} (expected < 2.0)",
370 avg_bytes
371 );
372 }
373
374 #[test]
375 fn test_compress_decompress_roundtrip() {
376 let codec = TokenNativeCodec::cl100k();
377
378 let original =
379 r#"{"model":"gpt-4o","messages":[{"role":"user","content":"Hello, world!"}]}"#;
380
381 let compressed = codec.compress(original).unwrap();
382 assert!(compressed.data.starts_with("#TK|C|"));
383
384 let decompressed = codec.decompress(&compressed.data).unwrap();
385 assert_eq!(original, decompressed);
386 }
387
388 #[test]
389 fn test_compression_ratio() {
390 let codec = TokenNativeCodec::cl100k();
391
392 let original = r#"{"model":"gpt-4o","messages":[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"What is the capital of France?"}]}"#;
393
394 let compressed = codec.compress(original).unwrap();
395
396 let ratio = compressed.compressed_bytes as f64 / compressed.original_bytes as f64;
397 println!(
398 "Compression: {} -> {} bytes ({:.1}% of original)",
399 compressed.original_bytes,
400 compressed.compressed_bytes,
401 ratio * 100.0
402 );
403
404 assert!(
408 ratio < 0.85,
409 "Expected compression ratio < 0.85, got {}",
410 ratio
411 );
412 }
413
414 #[test]
415 fn test_different_encodings() {
416 let original = "Hello, how are you today?";
417
418 let codec_cl100k = TokenNativeCodec::cl100k();
420 let compressed = codec_cl100k.compress(original).unwrap();
421 let decompressed = codec_cl100k.decompress(&compressed.data).unwrap();
422 assert_eq!(original, decompressed);
423
424 let codec_o200k = TokenNativeCodec::o200k();
426 let compressed = codec_o200k.compress(original).unwrap();
427 let decompressed = codec_o200k.decompress(&compressed.data).unwrap();
428 assert_eq!(original, decompressed);
429 }
430
431 #[test]
432 fn test_large_content() {
433 let codec = TokenNativeCodec::cl100k();
434
435 let original = format!(
437 r#"{{"model":"gpt-4o","messages":[{{"role":"system","content":"You are helpful."}},{{"role":"user","content":"{}"}}]}}"#,
438 "Hello world! ".repeat(100)
439 );
440
441 let compressed = codec.compress(&original).unwrap();
442 let decompressed = codec.decompress(&compressed.data).unwrap();
443
444 assert_eq!(original, decompressed);
445
446 let ratio = compressed.compressed_bytes as f64 / compressed.original_bytes as f64;
447 println!(
448 "Large content: {} -> {} bytes ({:.1}% of original)",
449 compressed.original_bytes,
450 compressed.compressed_bytes,
451 ratio * 100.0
452 );
453 }
454
455 #[test]
456 fn test_raw_compression() {
457 let codec = TokenNativeCodec::cl100k();
458
459 let original = "Hello, world!";
460 let raw_bytes = codec.compress_raw(original);
461 let decompressed = codec.decompress_raw(&raw_bytes).unwrap();
462
463 assert_eq!(original, decompressed);
464 }
465
466 #[test]
467 fn test_tokenizer_id_roundtrip() {
468 for encoding in [
469 Encoding::Cl100kBase,
470 Encoding::O200kBase,
471 Encoding::LlamaBpe,
472 ] {
473 let codec = TokenNativeCodec::new(encoding);
474 let id = codec.tokenizer_id();
475 let recovered = TokenNativeCodec::encoding_from_id(id);
476 assert_eq!(
477 encoding, recovered,
478 "Tokenizer ID roundtrip failed for {:?}",
479 encoding
480 );
481 }
482 }
483
484 #[test]
485 fn test_binary_format() {
486 let codec = TokenNativeCodec::cl100k();
487
488 let original = r#"{"model":"gpt-4o","messages":[{"role":"user","content":"Hello!"}]}"#;
489
490 let binary = codec.compress_binary(original);
492
493 assert_eq!(binary[0], 0);
495
496 let decompressed = TokenNativeCodec::decompress_binary(&binary).unwrap();
498 assert_eq!(original, decompressed);
499
500 let wire_result = codec.compress(original).unwrap();
502 println!(
503 "Binary: {} bytes, Wire: {} bytes, Original: {} bytes",
504 binary.len(),
505 wire_result.compressed_bytes,
506 original.len()
507 );
508
509 assert!(
511 binary.len() < wire_result.compressed_bytes,
512 "Binary format should be smaller than wire format"
513 );
514 }
515
516 #[test]
517 fn test_binary_format_different_encodings() {
518 let original = "Hello, how are you today?";
519
520 for encoding in [Encoding::Cl100kBase, Encoding::O200kBase] {
521 let codec = TokenNativeCodec::new(encoding);
522 let binary = codec.compress_binary(original);
523 let decompressed = TokenNativeCodec::decompress_binary(&binary).unwrap();
524 assert_eq!(original, decompressed);
525 }
526 }
527}