1use std::io::{Cursor, Read};
45
46use crate::error::{M2MError, Result};
47
48pub const M3_PREFIX: &str = "#M3|";
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53#[repr(u8)]
54#[allow(clippy::enum_variant_names)] pub enum Schema {
56 ChatCompletionRequest = 0x01,
58 ChatCompletionResponse = 0x02,
60 ChatMessage = 0x03,
62}
63
64impl Schema {
65 fn from_byte(b: u8) -> Option<Self> {
66 match b {
67 0x01 => Some(Schema::ChatCompletionRequest),
68 0x02 => Some(Schema::ChatCompletionResponse),
69 0x03 => Some(Schema::ChatMessage),
70 _ => None,
71 }
72 }
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77#[repr(u8)]
78pub enum Role {
79 System = 0,
80 User = 1,
81 Assistant = 2,
82 Tool = 3,
83}
84
85impl Role {
86 fn from_byte(b: u8) -> Option<Self> {
87 match b {
88 0 => Some(Role::System),
89 1 => Some(Role::User),
90 2 => Some(Role::Assistant),
91 3 => Some(Role::Tool),
92 _ => None,
93 }
94 }
95
96 fn from_str(s: &str) -> Option<Self> {
97 match s {
98 "system" => Some(Role::System),
99 "user" => Some(Role::User),
100 "assistant" => Some(Role::Assistant),
101 "tool" => Some(Role::Tool),
102 _ => None,
103 }
104 }
105
106 fn as_str(&self) -> &'static str {
107 match self {
108 Role::System => "system",
109 Role::User => "user",
110 Role::Assistant => "assistant",
111 Role::Tool => "tool",
112 }
113 }
114}
115
116#[derive(Debug, Clone, Copy, Default)]
118pub struct ParamFlags(u8);
119
120impl ParamFlags {
121 pub const HAS_TEMPERATURE: u8 = 0x01;
122 pub const HAS_MAX_TOKENS: u8 = 0x02;
123 pub const HAS_TOP_P: u8 = 0x04;
124 pub const STREAM: u8 = 0x08;
125 pub const HAS_STOP: u8 = 0x10;
126
127 pub fn new() -> Self {
128 Self(0)
129 }
130
131 pub fn set(&mut self, flag: u8) {
132 self.0 |= flag;
133 }
134
135 pub fn has(&self, flag: u8) -> bool {
136 self.0 & flag != 0
137 }
138
139 pub fn as_byte(&self) -> u8 {
140 self.0
141 }
142
143 pub fn from_byte(b: u8) -> Self {
144 Self(b)
145 }
146}
147
148#[derive(Debug, Clone)]
150pub struct M3Message {
151 pub role: Role,
153 pub content: String,
155 pub name: Option<String>,
157}
158
159#[derive(Debug, Clone, Default)]
161pub struct M3ChatRequest {
162 pub model: String,
164 pub messages: Vec<M3Message>,
166 pub temperature: Option<f32>,
168 pub max_tokens: Option<u32>,
170 pub top_p: Option<f32>,
172 pub stream: bool,
174 pub stop: Option<Vec<String>>,
176}
177
178#[derive(Debug, Clone, Default)]
180pub struct M3Codec;
181
182impl M3Codec {
183 pub fn new() -> Self {
185 Self
186 }
187
188 pub fn encode_request(&self, req: &M3ChatRequest) -> Result<Vec<u8>> {
190 let mut buf = Vec::with_capacity(256);
191
192 buf.extend_from_slice(M3_PREFIX.as_bytes());
194
195 buf.push(Schema::ChatCompletionRequest as u8);
197
198 write_varint(&mut buf, req.model.len() as u64);
200 buf.extend_from_slice(req.model.as_bytes());
201
202 let mut flags = ParamFlags::new();
204 if req.temperature.is_some() {
205 flags.set(ParamFlags::HAS_TEMPERATURE);
206 }
207 if req.max_tokens.is_some() {
208 flags.set(ParamFlags::HAS_MAX_TOKENS);
209 }
210 if req.top_p.is_some() {
211 flags.set(ParamFlags::HAS_TOP_P);
212 }
213 if req.stream {
214 flags.set(ParamFlags::STREAM);
215 }
216 if req.stop.is_some() {
217 flags.set(ParamFlags::HAS_STOP);
218 }
219 buf.push(flags.as_byte());
220
221 write_varint(&mut buf, req.messages.len() as u64);
223
224 for msg in &req.messages {
226 buf.push(msg.role as u8);
227 write_varint(&mut buf, msg.content.len() as u64);
228 buf.extend_from_slice(msg.content.as_bytes());
229 }
230
231 if let Some(temp) = req.temperature {
233 let quantized = (temp * 100.0).round() as u8;
235 buf.push(quantized);
236 }
237 if let Some(max_tok) = req.max_tokens {
238 write_varint(&mut buf, max_tok as u64);
239 }
240 if let Some(top_p) = req.top_p {
241 let quantized = (top_p * 100.0).round() as u8;
242 buf.push(quantized);
243 }
244 if let Some(ref stops) = req.stop {
246 write_varint(&mut buf, stops.len() as u64);
247 for stop in stops {
248 write_varint(&mut buf, stop.len() as u64);
249 buf.extend_from_slice(stop.as_bytes());
250 }
251 }
252
253 Ok(buf)
254 }
255
256 pub fn decode_request(&self, data: &[u8]) -> Result<M3ChatRequest> {
258 if !data.starts_with(M3_PREFIX.as_bytes()) {
260 return Err(M2MError::Decompression("Invalid M3 prefix".to_string()));
261 }
262
263 let mut cursor = Cursor::new(&data[M3_PREFIX.len()..]);
264
265 let mut schema_byte = [0u8; 1];
267 cursor
268 .read_exact(&mut schema_byte)
269 .map_err(|e| M2MError::Decompression(e.to_string()))?;
270
271 if Schema::from_byte(schema_byte[0]) != Some(Schema::ChatCompletionRequest) {
272 return Err(M2MError::Decompression(format!(
273 "Expected ChatCompletionRequest schema, got {:02x}",
274 schema_byte[0]
275 )));
276 }
277
278 let model_len = read_varint(&mut cursor)? as usize;
280 let mut model_bytes = vec![0u8; model_len];
281 cursor
282 .read_exact(&mut model_bytes)
283 .map_err(|e| M2MError::Decompression(e.to_string()))?;
284 let model =
285 String::from_utf8(model_bytes).map_err(|e| M2MError::Decompression(e.to_string()))?;
286
287 let mut flags_byte = [0u8; 1];
289 cursor
290 .read_exact(&mut flags_byte)
291 .map_err(|e| M2MError::Decompression(e.to_string()))?;
292 let flags = ParamFlags::from_byte(flags_byte[0]);
293
294 let num_messages = read_varint(&mut cursor)? as usize;
296
297 let mut messages = Vec::with_capacity(num_messages);
299 for _ in 0..num_messages {
300 let mut role_byte = [0u8; 1];
301 cursor
302 .read_exact(&mut role_byte)
303 .map_err(|e| M2MError::Decompression(e.to_string()))?;
304 let role = Role::from_byte(role_byte[0])
305 .ok_or_else(|| M2MError::Decompression("Invalid role byte".to_string()))?;
306
307 let content_len = read_varint(&mut cursor)? as usize;
308 let mut content_bytes = vec![0u8; content_len];
309 cursor
310 .read_exact(&mut content_bytes)
311 .map_err(|e| M2MError::Decompression(e.to_string()))?;
312 let content = String::from_utf8(content_bytes)
313 .map_err(|e| M2MError::Decompression(e.to_string()))?;
314
315 messages.push(M3Message {
316 role,
317 content,
318 name: None,
319 });
320 }
321
322 let temperature = if flags.has(ParamFlags::HAS_TEMPERATURE) {
324 let mut temp_byte = [0u8; 1];
325 cursor
326 .read_exact(&mut temp_byte)
327 .map_err(|e| M2MError::Decompression(e.to_string()))?;
328 Some(temp_byte[0] as f32 / 100.0)
329 } else {
330 None
331 };
332
333 let max_tokens = if flags.has(ParamFlags::HAS_MAX_TOKENS) {
334 Some(read_varint(&mut cursor)? as u32)
335 } else {
336 None
337 };
338
339 let top_p = if flags.has(ParamFlags::HAS_TOP_P) {
340 let mut top_p_byte = [0u8; 1];
341 cursor
342 .read_exact(&mut top_p_byte)
343 .map_err(|e| M2MError::Decompression(e.to_string()))?;
344 Some(top_p_byte[0] as f32 / 100.0)
345 } else {
346 None
347 };
348
349 let stop = if flags.has(ParamFlags::HAS_STOP) {
350 let num_stops = read_varint(&mut cursor)? as usize;
351 let mut stops = Vec::with_capacity(num_stops);
352 for _ in 0..num_stops {
353 let stop_len = read_varint(&mut cursor)? as usize;
354 let mut stop_bytes = vec![0u8; stop_len];
355 cursor
356 .read_exact(&mut stop_bytes)
357 .map_err(|e| M2MError::Decompression(e.to_string()))?;
358 let stop_str = String::from_utf8(stop_bytes)
359 .map_err(|e| M2MError::Decompression(e.to_string()))?;
360 stops.push(stop_str);
361 }
362 Some(stops)
363 } else {
364 None
365 };
366
367 Ok(M3ChatRequest {
368 model,
369 messages,
370 temperature,
371 max_tokens,
372 top_p,
373 stream: flags.has(ParamFlags::STREAM),
374 stop,
375 })
376 }
377
378 pub fn from_json(&self, json: &str) -> Result<M3ChatRequest> {
380 let value: serde_json::Value = serde_json::from_str(json)
381 .map_err(|e| M2MError::Decompression(format!("Invalid JSON: {}", e)))?;
382
383 let model = value
384 .get("model")
385 .and_then(|v| v.as_str())
386 .unwrap_or("")
387 .to_string();
388
389 let messages = value
390 .get("messages")
391 .and_then(|v| v.as_array())
392 .map(|arr| {
393 arr.iter()
394 .filter_map(|m| {
395 let role_str = m.get("role").and_then(|r| r.as_str())?;
396 let role = Role::from_str(role_str)?;
397 let content = m
398 .get("content")
399 .and_then(|c| c.as_str())
400 .unwrap_or("")
401 .to_string();
402 let name = m
403 .get("name")
404 .and_then(|n| n.as_str())
405 .map(|s| s.to_string());
406 Some(M3Message {
407 role,
408 content,
409 name,
410 })
411 })
412 .collect()
413 })
414 .unwrap_or_default();
415
416 let temperature = value
417 .get("temperature")
418 .and_then(|v| v.as_f64())
419 .map(|f| f as f32);
420 let max_tokens = value
421 .get("max_tokens")
422 .and_then(|v| v.as_u64())
423 .map(|n| n as u32);
424 let top_p = value
425 .get("top_p")
426 .and_then(|v| v.as_f64())
427 .map(|f| f as f32);
428 let stream = value
429 .get("stream")
430 .and_then(|v| v.as_bool())
431 .unwrap_or(false);
432 let stop = value.get("stop").and_then(|v| {
433 if let Some(arr) = v.as_array() {
434 Some(
435 arr.iter()
436 .filter_map(|s| s.as_str().map(|s| s.to_string()))
437 .collect(),
438 )
439 } else if let Some(s) = v.as_str() {
440 Some(vec![s.to_string()])
441 } else {
442 None
443 }
444 });
445
446 Ok(M3ChatRequest {
447 model,
448 messages,
449 temperature,
450 max_tokens,
451 top_p,
452 stream,
453 stop,
454 })
455 }
456
457 pub fn to_json(&self, req: &M3ChatRequest) -> String {
459 let mut obj = serde_json::Map::new();
460
461 obj.insert("model".to_string(), serde_json::json!(req.model));
462
463 let messages: Vec<serde_json::Value> = req
464 .messages
465 .iter()
466 .map(|m| {
467 let mut msg = serde_json::Map::new();
468 msg.insert("role".to_string(), serde_json::json!(m.role.as_str()));
469 msg.insert("content".to_string(), serde_json::json!(m.content));
470 if let Some(ref name) = m.name {
471 msg.insert("name".to_string(), serde_json::json!(name));
472 }
473 serde_json::Value::Object(msg)
474 })
475 .collect();
476 obj.insert("messages".to_string(), serde_json::Value::Array(messages));
477
478 if let Some(temp) = req.temperature {
479 obj.insert("temperature".to_string(), serde_json::json!(temp));
480 }
481 if let Some(max_tok) = req.max_tokens {
482 obj.insert("max_tokens".to_string(), serde_json::json!(max_tok));
483 }
484 if let Some(top_p) = req.top_p {
485 obj.insert("top_p".to_string(), serde_json::json!(top_p));
486 }
487 if req.stream {
488 obj.insert("stream".to_string(), serde_json::json!(true));
489 }
490 if let Some(ref stop) = req.stop {
491 obj.insert("stop".to_string(), serde_json::json!(stop));
492 }
493
494 serde_json::to_string(&serde_json::Value::Object(obj)).unwrap_or_default()
495 }
496
497 #[deprecated(note = "Use M2M codec instead")]
501 pub fn compress(&self, json: &str) -> Result<(String, usize, usize)> {
502 let req = self.from_json(json)?;
503 let encoded = self.encode_request(&req)?;
504
505 let wire = format!("{}", String::from_utf8_lossy(&encoded));
507
508 Ok((wire, json.len(), encoded.len()))
509 }
510
511 pub fn decompress(&self, wire: &str) -> Result<String> {
513 let req = self.decode_request(wire.as_bytes())?;
514 Ok(self.to_json(&req))
515 }
516
517 pub fn is_m3_format(content: &str) -> bool {
519 content.starts_with(M3_PREFIX)
520 }
521}
522
523fn write_varint(buf: &mut Vec<u8>, mut value: u64) {
525 loop {
526 let mut byte = (value & 0x7F) as u8;
527 value >>= 7;
528 if value != 0 {
529 byte |= 0x80;
530 }
531 buf.push(byte);
532 if value == 0 {
533 break;
534 }
535 }
536}
537
538fn read_varint<R: Read>(reader: &mut R) -> Result<u64> {
539 let mut result: u64 = 0;
540 let mut shift = 0;
541
542 loop {
543 let mut byte = [0u8; 1];
544 reader
545 .read_exact(&mut byte)
546 .map_err(|e| M2MError::Decompression(format!("VarInt read error: {}", e)))?;
547
548 result |= ((byte[0] & 0x7F) as u64) << shift;
549
550 if byte[0] & 0x80 == 0 {
551 break;
552 }
553
554 shift += 7;
555 if shift >= 64 {
556 return Err(M2MError::Decompression("VarInt overflow".to_string()));
557 }
558 }
559
560 Ok(result)
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566
567 #[test]
568 fn test_encode_decode_roundtrip() {
569 let codec = M3Codec::new();
570
571 let req = M3ChatRequest {
572 model: "gpt-4o".to_string(),
573 messages: vec![
574 M3Message {
575 role: Role::System,
576 content: "You are a helpful assistant.".to_string(),
577 name: None,
578 },
579 M3Message {
580 role: Role::User,
581 content: "Hello!".to_string(),
582 name: None,
583 },
584 ],
585 temperature: Some(0.7),
586 max_tokens: Some(1000),
587 top_p: None,
588 stream: false,
589 stop: None,
590 };
591
592 let encoded = codec.encode_request(&req).unwrap();
593 let decoded = codec.decode_request(&encoded).unwrap();
594
595 assert_eq!(req.model, decoded.model);
596 assert_eq!(req.messages.len(), decoded.messages.len());
597 assert_eq!(req.messages[0].content, decoded.messages[0].content);
598 assert_eq!(req.messages[1].content, decoded.messages[1].content);
599 assert!((req.temperature.unwrap() - decoded.temperature.unwrap()).abs() < 0.02);
601 assert_eq!(req.max_tokens, decoded.max_tokens);
602 }
603
604 #[test]
605 fn test_json_roundtrip() {
606 let codec = M3Codec::new();
607
608 let json = r#"{"model":"gpt-4o","messages":[{"role":"system","content":"You are helpful."},{"role":"user","content":"Hi!"}],"temperature":0.7,"max_tokens":100}"#;
609
610 let req = codec.from_json(json).unwrap();
611 let back_to_json = codec.to_json(&req);
612
613 let original: serde_json::Value = serde_json::from_str(json).unwrap();
615 let recovered: serde_json::Value = serde_json::from_str(&back_to_json).unwrap();
616
617 assert_eq!(original["model"], recovered["model"]);
618 assert_eq!(
619 original["messages"][0]["content"],
620 recovered["messages"][0]["content"]
621 );
622 assert_eq!(
623 original["messages"][1]["content"],
624 recovered["messages"][1]["content"]
625 );
626 }
627
628 #[test]
629 #[allow(deprecated)]
630 fn test_compression_savings() {
631 let codec = M3Codec::new();
632
633 let json = r#"{"model":"gpt-4o","messages":[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"What is 2+2?"}],"temperature":0.7}"#;
634
635 let (_, original_bytes, compressed_bytes) = codec.compress(json).unwrap();
636
637 println!("Original JSON: {} bytes", json.len());
638 println!("M3 encoded: {} bytes", compressed_bytes);
639 println!(
640 "Savings: {:.1}%",
641 (1.0 - compressed_bytes as f64 / original_bytes as f64) * 100.0
642 );
643
644 assert!(
646 compressed_bytes < original_bytes,
647 "M3 should compress the data"
648 );
649 }
650
651 #[test]
652 fn test_varint_encoding() {
653 let mut buf = Vec::new();
654
655 write_varint(&mut buf, 0);
657 assert_eq!(buf, vec![0]);
658
659 buf.clear();
660 write_varint(&mut buf, 127);
661 assert_eq!(buf, vec![127]);
662
663 buf.clear();
664 write_varint(&mut buf, 128);
665 assert_eq!(buf, vec![0x80, 0x01]);
666
667 buf.clear();
668 write_varint(&mut buf, 300);
669 assert_eq!(buf, vec![0xAC, 0x02]);
670
671 buf.clear();
673 write_varint(&mut buf, 12345);
674 let mut cursor = Cursor::new(&buf);
675 let value = read_varint(&mut cursor).unwrap();
676 assert_eq!(value, 12345);
677 }
678}