Skip to main content

grapsus_agent_protocol/
binary.rs

1//! Binary protocol for Unix Domain Socket transport.
2//!
3//! This module provides a binary framing format for efficient communication
4//! over UDS, eliminating JSON overhead and base64 encoding for body data.
5//!
6//! # Wire Format
7//!
8//! ```text
9//! +----------------+---------------+-------------------+
10//! | Length (4 BE)  | Type (1 byte) | Payload (N bytes) |
11//! +----------------+---------------+-------------------+
12//! ```
13//!
14//! - **Length**: 4-byte big-endian u32, total length of type + payload
15//! - **Type**: 1-byte message type discriminator
16//! - **Payload**: Variable-length payload (format depends on type)
17//!
18//! # Performance Benefits
19//!
20//! - No JSON parsing overhead (~10x faster for small messages)
21//! - No base64 encoding for body data (saves 33% bandwidth)
22//! - Zero-copy with `bytes::Bytes` where possible
23
24use bytes::{Buf, BufMut, Bytes, BytesMut};
25use std::collections::HashMap;
26use std::io;
27use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
28
29use crate::{AgentProtocolError, Decision, HeaderOp};
30
31/// Maximum binary message size (10 MB)
32pub const MAX_BINARY_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
33
34/// Binary message types
35#[repr(u8)]
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum MessageType {
38    /// Handshake request (proxy -> agent)
39    HandshakeRequest = 0x01,
40    /// Handshake response (agent -> proxy)
41    HandshakeResponse = 0x02,
42    /// Request headers event
43    RequestHeaders = 0x10,
44    /// Request body chunk (raw bytes, no base64)
45    RequestBodyChunk = 0x11,
46    /// Response headers event
47    ResponseHeaders = 0x12,
48    /// Response body chunk (raw bytes, no base64)
49    ResponseBodyChunk = 0x13,
50    /// Request complete event
51    RequestComplete = 0x14,
52    /// WebSocket frame event
53    WebSocketFrame = 0x15,
54    /// Agent response
55    AgentResponse = 0x20,
56    /// Ping
57    Ping = 0x30,
58    /// Pong
59    Pong = 0x31,
60    /// Cancel request
61    Cancel = 0x40,
62    /// Error
63    Error = 0xFF,
64}
65
66impl TryFrom<u8> for MessageType {
67    type Error = AgentProtocolError;
68
69    fn try_from(value: u8) -> Result<Self, AgentProtocolError> {
70        match value {
71            0x01 => Ok(MessageType::HandshakeRequest),
72            0x02 => Ok(MessageType::HandshakeResponse),
73            0x10 => Ok(MessageType::RequestHeaders),
74            0x11 => Ok(MessageType::RequestBodyChunk),
75            0x12 => Ok(MessageType::ResponseHeaders),
76            0x13 => Ok(MessageType::ResponseBodyChunk),
77            0x14 => Ok(MessageType::RequestComplete),
78            0x15 => Ok(MessageType::WebSocketFrame),
79            0x20 => Ok(MessageType::AgentResponse),
80            0x30 => Ok(MessageType::Ping),
81            0x31 => Ok(MessageType::Pong),
82            0x40 => Ok(MessageType::Cancel),
83            0xFF => Ok(MessageType::Error),
84            _ => Err(AgentProtocolError::InvalidMessage(format!(
85                "Unknown message type: 0x{:02x}",
86                value
87            ))),
88        }
89    }
90}
91
92/// Binary frame with header and payload.
93#[derive(Debug, Clone)]
94pub struct BinaryFrame {
95    pub msg_type: MessageType,
96    pub payload: Bytes,
97}
98
99impl BinaryFrame {
100    /// Create a new binary frame.
101    pub fn new(msg_type: MessageType, payload: impl Into<Bytes>) -> Self {
102        Self {
103            msg_type,
104            payload: payload.into(),
105        }
106    }
107
108    /// Encode frame to bytes.
109    pub fn encode(&self) -> Bytes {
110        let payload_len = self.payload.len();
111        let total_len = 1 + payload_len; // type byte + payload
112
113        let mut buf = BytesMut::with_capacity(4 + total_len);
114        buf.put_u32(total_len as u32);
115        buf.put_u8(self.msg_type as u8);
116        buf.put_slice(&self.payload);
117
118        buf.freeze()
119    }
120
121    /// Decode frame from reader.
122    pub async fn decode<R: AsyncRead + Unpin>(reader: &mut R) -> Result<Self, AgentProtocolError> {
123        // Read length (4 bytes)
124        let mut len_buf = [0u8; 4];
125        reader.read_exact(&mut len_buf).await.map_err(|e| {
126            if e.kind() == io::ErrorKind::UnexpectedEof {
127                AgentProtocolError::ConnectionFailed("Connection closed".to_string())
128            } else {
129                AgentProtocolError::Io(e)
130            }
131        })?;
132        let total_len = u32::from_be_bytes(len_buf) as usize;
133
134        // Validate length
135        if total_len == 0 {
136            return Err(AgentProtocolError::InvalidMessage(
137                "Empty message".to_string(),
138            ));
139        }
140        if total_len > MAX_BINARY_MESSAGE_SIZE {
141            return Err(AgentProtocolError::MessageTooLarge {
142                size: total_len,
143                max: MAX_BINARY_MESSAGE_SIZE,
144            });
145        }
146
147        // Read type byte
148        let mut type_buf = [0u8; 1];
149        reader.read_exact(&mut type_buf).await?;
150        let msg_type = MessageType::try_from(type_buf[0])?;
151
152        // Read payload
153        let payload_len = total_len - 1;
154        let mut payload = BytesMut::with_capacity(payload_len);
155        payload.resize(payload_len, 0);
156        reader.read_exact(&mut payload).await?;
157
158        Ok(Self {
159            msg_type,
160            payload: payload.freeze(),
161        })
162    }
163
164    /// Write frame to writer.
165    pub async fn write<W: AsyncWrite + Unpin>(
166        &self,
167        writer: &mut W,
168    ) -> Result<(), AgentProtocolError> {
169        let encoded = self.encode();
170        writer.write_all(&encoded).await?;
171        writer.flush().await?;
172        Ok(())
173    }
174}
175
176/// Binary request headers event.
177///
178/// Wire format:
179/// - correlation_id: length-prefixed string
180/// - method: length-prefixed string
181/// - uri: length-prefixed string
182/// - headers: count (u16) + [(name_len, name, value_len, value), ...]
183/// - client_ip: length-prefixed string
184/// - client_port: u16
185#[derive(Debug, Clone)]
186pub struct BinaryRequestHeaders {
187    pub correlation_id: String,
188    pub method: String,
189    pub uri: String,
190    pub headers: HashMap<String, Vec<String>>,
191    pub client_ip: String,
192    pub client_port: u16,
193}
194
195impl BinaryRequestHeaders {
196    /// Encode to bytes.
197    pub fn encode(&self) -> Bytes {
198        let mut buf = BytesMut::with_capacity(256);
199
200        // Correlation ID
201        put_string(&mut buf, &self.correlation_id);
202        // Method
203        put_string(&mut buf, &self.method);
204        // URI
205        put_string(&mut buf, &self.uri);
206
207        // Headers count
208        let header_count: usize = self.headers.values().map(|v| v.len()).sum();
209        buf.put_u16(header_count as u16);
210
211        // Headers (flattened: each value gets its own entry)
212        for (name, values) in &self.headers {
213            for value in values {
214                put_string(&mut buf, name);
215                put_string(&mut buf, value);
216            }
217        }
218
219        // Client IP
220        put_string(&mut buf, &self.client_ip);
221        // Client port
222        buf.put_u16(self.client_port);
223
224        buf.freeze()
225    }
226
227    /// Decode from bytes.
228    pub fn decode(mut data: Bytes) -> Result<Self, AgentProtocolError> {
229        let correlation_id = get_string(&mut data)?;
230        let method = get_string(&mut data)?;
231        let uri = get_string(&mut data)?;
232
233        // Headers
234        if data.remaining() < 2 {
235            return Err(AgentProtocolError::InvalidMessage(
236                "Missing header count".to_string(),
237            ));
238        }
239        let header_count = data.get_u16() as usize;
240
241        let mut headers: HashMap<String, Vec<String>> = HashMap::new();
242        for _ in 0..header_count {
243            let name = get_string(&mut data)?;
244            let value = get_string(&mut data)?;
245            headers.entry(name).or_default().push(value);
246        }
247
248        let client_ip = get_string(&mut data)?;
249
250        if data.remaining() < 2 {
251            return Err(AgentProtocolError::InvalidMessage(
252                "Missing client port".to_string(),
253            ));
254        }
255        let client_port = data.get_u16();
256
257        Ok(Self {
258            correlation_id,
259            method,
260            uri,
261            headers,
262            client_ip,
263            client_port,
264        })
265    }
266}
267
268/// Binary body chunk event (zero-copy).
269///
270/// Wire format:
271/// - correlation_id: length-prefixed string
272/// - chunk_index: u32
273/// - is_last: u8 (0 or 1)
274/// - data_len: u32
275/// - data: raw bytes (no base64!)
276#[derive(Debug, Clone)]
277pub struct BinaryBodyChunk {
278    pub correlation_id: String,
279    pub chunk_index: u32,
280    pub is_last: bool,
281    pub data: Bytes,
282}
283
284impl BinaryBodyChunk {
285    /// Encode to bytes.
286    pub fn encode(&self) -> Bytes {
287        let mut buf = BytesMut::with_capacity(32 + self.data.len());
288
289        put_string(&mut buf, &self.correlation_id);
290        buf.put_u32(self.chunk_index);
291        buf.put_u8(if self.is_last { 1 } else { 0 });
292        buf.put_u32(self.data.len() as u32);
293        buf.put_slice(&self.data);
294
295        buf.freeze()
296    }
297
298    /// Decode from bytes.
299    pub fn decode(mut data: Bytes) -> Result<Self, AgentProtocolError> {
300        let correlation_id = get_string(&mut data)?;
301
302        if data.remaining() < 9 {
303            return Err(AgentProtocolError::InvalidMessage(
304                "Missing body chunk fields".to_string(),
305            ));
306        }
307
308        let chunk_index = data.get_u32();
309        let is_last = data.get_u8() != 0;
310        let data_len = data.get_u32() as usize;
311
312        if data.remaining() < data_len {
313            return Err(AgentProtocolError::InvalidMessage(
314                "Body data truncated".to_string(),
315            ));
316        }
317
318        let body_data = data.copy_to_bytes(data_len);
319
320        Ok(Self {
321            correlation_id,
322            chunk_index,
323            is_last,
324            data: body_data,
325        })
326    }
327}
328
329/// Binary agent response.
330///
331/// Wire format:
332/// - correlation_id: length-prefixed string
333/// - decision_type: u8 (0=Allow, 1=Block, 2=Redirect, 3=Challenge)
334/// - decision_data: varies by type
335/// - request_headers_ops: count (u16) + ops
336/// - response_headers_ops: count (u16) + ops
337/// - needs_more: u8
338#[derive(Debug, Clone)]
339pub struct BinaryAgentResponse {
340    pub correlation_id: String,
341    pub decision: Decision,
342    pub request_headers: Vec<HeaderOp>,
343    pub response_headers: Vec<HeaderOp>,
344    pub needs_more: bool,
345}
346
347impl BinaryAgentResponse {
348    /// Encode to bytes.
349    pub fn encode(&self) -> Bytes {
350        let mut buf = BytesMut::with_capacity(128);
351
352        put_string(&mut buf, &self.correlation_id);
353
354        // Decision
355        match &self.decision {
356            Decision::Allow => {
357                buf.put_u8(0);
358            }
359            Decision::Block {
360                status,
361                body,
362                headers,
363            } => {
364                buf.put_u8(1);
365                buf.put_u16(*status);
366                put_optional_string(&mut buf, body.as_deref());
367                // Block headers
368                let h_count = headers.as_ref().map(|h| h.len()).unwrap_or(0);
369                buf.put_u16(h_count as u16);
370                if let Some(headers) = headers {
371                    for (k, v) in headers {
372                        put_string(&mut buf, k);
373                        put_string(&mut buf, v);
374                    }
375                }
376            }
377            Decision::Redirect { url, status } => {
378                buf.put_u8(2);
379                put_string(&mut buf, url);
380                buf.put_u16(*status);
381            }
382            Decision::Challenge {
383                challenge_type,
384                params,
385            } => {
386                buf.put_u8(3);
387                put_string(&mut buf, challenge_type);
388                buf.put_u16(params.len() as u16);
389                for (k, v) in params {
390                    put_string(&mut buf, k);
391                    put_string(&mut buf, v);
392                }
393            }
394        }
395
396        // Request header ops
397        buf.put_u16(self.request_headers.len() as u16);
398        for op in &self.request_headers {
399            encode_header_op(&mut buf, op);
400        }
401
402        // Response header ops
403        buf.put_u16(self.response_headers.len() as u16);
404        for op in &self.response_headers {
405            encode_header_op(&mut buf, op);
406        }
407
408        // Needs more
409        buf.put_u8(if self.needs_more { 1 } else { 0 });
410
411        buf.freeze()
412    }
413
414    /// Decode from bytes.
415    pub fn decode(mut data: Bytes) -> Result<Self, AgentProtocolError> {
416        let correlation_id = get_string(&mut data)?;
417
418        if data.remaining() < 1 {
419            return Err(AgentProtocolError::InvalidMessage(
420                "Missing decision type".to_string(),
421            ));
422        }
423
424        let decision_type = data.get_u8();
425        let decision = match decision_type {
426            0 => Decision::Allow,
427            1 => {
428                if data.remaining() < 2 {
429                    return Err(AgentProtocolError::InvalidMessage(
430                        "Missing block status".to_string(),
431                    ));
432                }
433                let status = data.get_u16();
434                let body = get_optional_string(&mut data)?;
435                if data.remaining() < 2 {
436                    return Err(AgentProtocolError::InvalidMessage(
437                        "Missing block headers count".to_string(),
438                    ));
439                }
440                let h_count = data.get_u16() as usize;
441                let headers = if h_count > 0 {
442                    let mut h = HashMap::new();
443                    for _ in 0..h_count {
444                        let k = get_string(&mut data)?;
445                        let v = get_string(&mut data)?;
446                        h.insert(k, v);
447                    }
448                    Some(h)
449                } else {
450                    None
451                };
452                Decision::Block {
453                    status,
454                    body,
455                    headers,
456                }
457            }
458            2 => {
459                let url = get_string(&mut data)?;
460                if data.remaining() < 2 {
461                    return Err(AgentProtocolError::InvalidMessage(
462                        "Missing redirect status".to_string(),
463                    ));
464                }
465                let status = data.get_u16();
466                Decision::Redirect { url, status }
467            }
468            3 => {
469                let challenge_type = get_string(&mut data)?;
470                if data.remaining() < 2 {
471                    return Err(AgentProtocolError::InvalidMessage(
472                        "Missing challenge params count".to_string(),
473                    ));
474                }
475                let p_count = data.get_u16() as usize;
476                let mut params = HashMap::new();
477                for _ in 0..p_count {
478                    let k = get_string(&mut data)?;
479                    let v = get_string(&mut data)?;
480                    params.insert(k, v);
481                }
482                Decision::Challenge {
483                    challenge_type,
484                    params,
485                }
486            }
487            _ => {
488                return Err(AgentProtocolError::InvalidMessage(format!(
489                    "Unknown decision type: {}",
490                    decision_type
491                )));
492            }
493        };
494
495        // Request header ops
496        if data.remaining() < 2 {
497            return Err(AgentProtocolError::InvalidMessage(
498                "Missing request headers count".to_string(),
499            ));
500        }
501        let req_h_count = data.get_u16() as usize;
502        let mut request_headers = Vec::with_capacity(req_h_count);
503        for _ in 0..req_h_count {
504            request_headers.push(decode_header_op(&mut data)?);
505        }
506
507        // Response header ops
508        if data.remaining() < 2 {
509            return Err(AgentProtocolError::InvalidMessage(
510                "Missing response headers count".to_string(),
511            ));
512        }
513        let resp_h_count = data.get_u16() as usize;
514        let mut response_headers = Vec::with_capacity(resp_h_count);
515        for _ in 0..resp_h_count {
516            response_headers.push(decode_header_op(&mut data)?);
517        }
518
519        // Needs more
520        if data.remaining() < 1 {
521            return Err(AgentProtocolError::InvalidMessage(
522                "Missing needs_more".to_string(),
523            ));
524        }
525        let needs_more = data.get_u8() != 0;
526
527        Ok(Self {
528            correlation_id,
529            decision,
530            request_headers,
531            response_headers,
532            needs_more,
533        })
534    }
535}
536
537// =============================================================================
538// Helper Functions
539// =============================================================================
540
541fn put_string(buf: &mut BytesMut, s: &str) {
542    let bytes = s.as_bytes();
543    buf.put_u16(bytes.len() as u16);
544    buf.put_slice(bytes);
545}
546
547fn get_string(data: &mut Bytes) -> Result<String, AgentProtocolError> {
548    if data.remaining() < 2 {
549        return Err(AgentProtocolError::InvalidMessage(
550            "Missing string length".to_string(),
551        ));
552    }
553    let len = data.get_u16() as usize;
554    if data.remaining() < len {
555        return Err(AgentProtocolError::InvalidMessage(
556            "String data truncated".to_string(),
557        ));
558    }
559    let bytes = data.copy_to_bytes(len);
560    String::from_utf8(bytes.to_vec())
561        .map_err(|e| AgentProtocolError::InvalidMessage(format!("Invalid UTF-8: {}", e)))
562}
563
564fn put_optional_string(buf: &mut BytesMut, s: Option<&str>) {
565    match s {
566        Some(s) => {
567            buf.put_u8(1);
568            put_string(buf, s);
569        }
570        None => {
571            buf.put_u8(0);
572        }
573    }
574}
575
576fn get_optional_string(data: &mut Bytes) -> Result<Option<String>, AgentProtocolError> {
577    if data.remaining() < 1 {
578        return Err(AgentProtocolError::InvalidMessage(
579            "Missing optional string flag".to_string(),
580        ));
581    }
582    let present = data.get_u8() != 0;
583    if present {
584        get_string(data).map(Some)
585    } else {
586        Ok(None)
587    }
588}
589
590fn encode_header_op(buf: &mut BytesMut, op: &HeaderOp) {
591    match op {
592        HeaderOp::Set { name, value } => {
593            buf.put_u8(0);
594            put_string(buf, name);
595            put_string(buf, value);
596        }
597        HeaderOp::Add { name, value } => {
598            buf.put_u8(1);
599            put_string(buf, name);
600            put_string(buf, value);
601        }
602        HeaderOp::Remove { name } => {
603            buf.put_u8(2);
604            put_string(buf, name);
605        }
606    }
607}
608
609fn decode_header_op(data: &mut Bytes) -> Result<HeaderOp, AgentProtocolError> {
610    if data.remaining() < 1 {
611        return Err(AgentProtocolError::InvalidMessage(
612            "Missing header op type".to_string(),
613        ));
614    }
615    let op_type = data.get_u8();
616    match op_type {
617        0 => {
618            let name = get_string(data)?;
619            let value = get_string(data)?;
620            Ok(HeaderOp::Set { name, value })
621        }
622        1 => {
623            let name = get_string(data)?;
624            let value = get_string(data)?;
625            Ok(HeaderOp::Add { name, value })
626        }
627        2 => {
628            let name = get_string(data)?;
629            Ok(HeaderOp::Remove { name })
630        }
631        _ => Err(AgentProtocolError::InvalidMessage(format!(
632            "Unknown header op type: {}",
633            op_type
634        ))),
635    }
636}
637
638#[cfg(test)]
639mod tests {
640    use super::*;
641
642    #[test]
643    fn test_message_type_roundtrip() {
644        for t in [
645            MessageType::HandshakeRequest,
646            MessageType::HandshakeResponse,
647            MessageType::RequestHeaders,
648            MessageType::RequestBodyChunk,
649            MessageType::AgentResponse,
650            MessageType::Ping,
651            MessageType::Pong,
652            MessageType::Cancel,
653            MessageType::Error,
654        ] {
655            let byte = t as u8;
656            let decoded = MessageType::try_from(byte).unwrap();
657            assert_eq!(t, decoded);
658        }
659    }
660
661    #[test]
662    fn test_binary_frame_encode_decode() {
663        let frame = BinaryFrame::new(MessageType::Ping, Bytes::from_static(b"hello"));
664        let encoded = frame.encode();
665
666        // Verify structure
667        assert_eq!(encoded.len(), 4 + 1 + 5); // len + type + payload
668        assert_eq!(&encoded[0..4], &[0, 0, 0, 6]); // length = 6 (type + payload)
669        assert_eq!(encoded[4], MessageType::Ping as u8);
670        assert_eq!(&encoded[5..], b"hello");
671    }
672
673    #[test]
674    fn test_binary_request_headers_roundtrip() {
675        let headers = BinaryRequestHeaders {
676            correlation_id: "req-123".to_string(),
677            method: "POST".to_string(),
678            uri: "/api/test".to_string(),
679            headers: {
680                let mut h = HashMap::new();
681                h.insert(
682                    "content-type".to_string(),
683                    vec!["application/json".to_string()],
684                );
685                h.insert(
686                    "x-custom".to_string(),
687                    vec!["value1".to_string(), "value2".to_string()],
688                );
689                h
690            },
691            client_ip: "192.168.1.1".to_string(),
692            client_port: 12345,
693        };
694
695        let encoded = headers.encode();
696        let decoded = BinaryRequestHeaders::decode(encoded).unwrap();
697
698        assert_eq!(decoded.correlation_id, "req-123");
699        assert_eq!(decoded.method, "POST");
700        assert_eq!(decoded.uri, "/api/test");
701        assert_eq!(decoded.client_ip, "192.168.1.1");
702        assert_eq!(decoded.client_port, 12345);
703        assert_eq!(
704            decoded.headers.get("content-type").unwrap(),
705            &vec!["application/json".to_string()]
706        );
707    }
708
709    #[test]
710    fn test_binary_body_chunk_roundtrip() {
711        let chunk = BinaryBodyChunk {
712            correlation_id: "req-456".to_string(),
713            chunk_index: 2,
714            is_last: true,
715            data: Bytes::from_static(b"binary data here"),
716        };
717
718        let encoded = chunk.encode();
719        let decoded = BinaryBodyChunk::decode(encoded).unwrap();
720
721        assert_eq!(decoded.correlation_id, "req-456");
722        assert_eq!(decoded.chunk_index, 2);
723        assert!(decoded.is_last);
724        assert_eq!(&decoded.data[..], b"binary data here");
725    }
726
727    #[test]
728    fn test_binary_agent_response_allow() {
729        let response = BinaryAgentResponse {
730            correlation_id: "req-789".to_string(),
731            decision: Decision::Allow,
732            request_headers: vec![HeaderOp::Set {
733                name: "X-Added".to_string(),
734                value: "true".to_string(),
735            }],
736            response_headers: vec![],
737            needs_more: false,
738        };
739
740        let encoded = response.encode();
741        let decoded = BinaryAgentResponse::decode(encoded).unwrap();
742
743        assert_eq!(decoded.correlation_id, "req-789");
744        assert!(matches!(decoded.decision, Decision::Allow));
745        assert_eq!(decoded.request_headers.len(), 1);
746        assert!(!decoded.needs_more);
747    }
748
749    #[test]
750    fn test_binary_agent_response_block() {
751        let response = BinaryAgentResponse {
752            correlation_id: "req-block".to_string(),
753            decision: Decision::Block {
754                status: 403,
755                body: Some("Forbidden".to_string()),
756                headers: None,
757            },
758            request_headers: vec![],
759            response_headers: vec![],
760            needs_more: false,
761        };
762
763        let encoded = response.encode();
764        let decoded = BinaryAgentResponse::decode(encoded).unwrap();
765
766        assert_eq!(decoded.correlation_id, "req-block");
767        match decoded.decision {
768            Decision::Block {
769                status,
770                body,
771                headers,
772            } => {
773                assert_eq!(status, 403);
774                assert_eq!(body, Some("Forbidden".to_string()));
775                assert!(headers.is_none());
776            }
777            _ => panic!("Expected Block decision"),
778        }
779    }
780}