1use bytes::{Buf, BufMut, Bytes, BytesMut};
25use std::collections::HashMap;
26use std::io;
27use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
28
29use crate::{AgentProtocolError, Decision, HeaderOp};
30
31pub const MAX_BINARY_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
33
34#[repr(u8)]
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum MessageType {
38 HandshakeRequest = 0x01,
40 HandshakeResponse = 0x02,
42 RequestHeaders = 0x10,
44 RequestBodyChunk = 0x11,
46 ResponseHeaders = 0x12,
48 ResponseBodyChunk = 0x13,
50 RequestComplete = 0x14,
52 WebSocketFrame = 0x15,
54 AgentResponse = 0x20,
56 Ping = 0x30,
58 Pong = 0x31,
60 Cancel = 0x40,
62 Error = 0xFF,
64}
65
66impl TryFrom<u8> for MessageType {
67 type Error = AgentProtocolError;
68
69 fn try_from(value: u8) -> Result<Self, AgentProtocolError> {
70 match value {
71 0x01 => Ok(MessageType::HandshakeRequest),
72 0x02 => Ok(MessageType::HandshakeResponse),
73 0x10 => Ok(MessageType::RequestHeaders),
74 0x11 => Ok(MessageType::RequestBodyChunk),
75 0x12 => Ok(MessageType::ResponseHeaders),
76 0x13 => Ok(MessageType::ResponseBodyChunk),
77 0x14 => Ok(MessageType::RequestComplete),
78 0x15 => Ok(MessageType::WebSocketFrame),
79 0x20 => Ok(MessageType::AgentResponse),
80 0x30 => Ok(MessageType::Ping),
81 0x31 => Ok(MessageType::Pong),
82 0x40 => Ok(MessageType::Cancel),
83 0xFF => Ok(MessageType::Error),
84 _ => Err(AgentProtocolError::InvalidMessage(format!(
85 "Unknown message type: 0x{:02x}",
86 value
87 ))),
88 }
89 }
90}
91
92#[derive(Debug, Clone)]
94pub struct BinaryFrame {
95 pub msg_type: MessageType,
96 pub payload: Bytes,
97}
98
99impl BinaryFrame {
100 pub fn new(msg_type: MessageType, payload: impl Into<Bytes>) -> Self {
102 Self {
103 msg_type,
104 payload: payload.into(),
105 }
106 }
107
108 pub fn encode(&self) -> Bytes {
110 let payload_len = self.payload.len();
111 let total_len = 1 + payload_len; let mut buf = BytesMut::with_capacity(4 + total_len);
114 buf.put_u32(total_len as u32);
115 buf.put_u8(self.msg_type as u8);
116 buf.put_slice(&self.payload);
117
118 buf.freeze()
119 }
120
121 pub async fn decode<R: AsyncRead + Unpin>(reader: &mut R) -> Result<Self, AgentProtocolError> {
123 let mut len_buf = [0u8; 4];
125 reader.read_exact(&mut len_buf).await.map_err(|e| {
126 if e.kind() == io::ErrorKind::UnexpectedEof {
127 AgentProtocolError::ConnectionFailed("Connection closed".to_string())
128 } else {
129 AgentProtocolError::Io(e)
130 }
131 })?;
132 let total_len = u32::from_be_bytes(len_buf) as usize;
133
134 if total_len == 0 {
136 return Err(AgentProtocolError::InvalidMessage(
137 "Empty message".to_string(),
138 ));
139 }
140 if total_len > MAX_BINARY_MESSAGE_SIZE {
141 return Err(AgentProtocolError::MessageTooLarge {
142 size: total_len,
143 max: MAX_BINARY_MESSAGE_SIZE,
144 });
145 }
146
147 let mut type_buf = [0u8; 1];
149 reader.read_exact(&mut type_buf).await?;
150 let msg_type = MessageType::try_from(type_buf[0])?;
151
152 let payload_len = total_len - 1;
154 let mut payload = BytesMut::with_capacity(payload_len);
155 payload.resize(payload_len, 0);
156 reader.read_exact(&mut payload).await?;
157
158 Ok(Self {
159 msg_type,
160 payload: payload.freeze(),
161 })
162 }
163
164 pub async fn write<W: AsyncWrite + Unpin>(
166 &self,
167 writer: &mut W,
168 ) -> Result<(), AgentProtocolError> {
169 let encoded = self.encode();
170 writer.write_all(&encoded).await?;
171 writer.flush().await?;
172 Ok(())
173 }
174}
175
176#[derive(Debug, Clone)]
186pub struct BinaryRequestHeaders {
187 pub correlation_id: String,
188 pub method: String,
189 pub uri: String,
190 pub headers: HashMap<String, Vec<String>>,
191 pub client_ip: String,
192 pub client_port: u16,
193}
194
195impl BinaryRequestHeaders {
196 pub fn encode(&self) -> Bytes {
198 let mut buf = BytesMut::with_capacity(256);
199
200 put_string(&mut buf, &self.correlation_id);
202 put_string(&mut buf, &self.method);
204 put_string(&mut buf, &self.uri);
206
207 let header_count: usize = self.headers.values().map(|v| v.len()).sum();
209 buf.put_u16(header_count as u16);
210
211 for (name, values) in &self.headers {
213 for value in values {
214 put_string(&mut buf, name);
215 put_string(&mut buf, value);
216 }
217 }
218
219 put_string(&mut buf, &self.client_ip);
221 buf.put_u16(self.client_port);
223
224 buf.freeze()
225 }
226
227 pub fn decode(mut data: Bytes) -> Result<Self, AgentProtocolError> {
229 let correlation_id = get_string(&mut data)?;
230 let method = get_string(&mut data)?;
231 let uri = get_string(&mut data)?;
232
233 if data.remaining() < 2 {
235 return Err(AgentProtocolError::InvalidMessage(
236 "Missing header count".to_string(),
237 ));
238 }
239 let header_count = data.get_u16() as usize;
240
241 let mut headers: HashMap<String, Vec<String>> = HashMap::new();
242 for _ in 0..header_count {
243 let name = get_string(&mut data)?;
244 let value = get_string(&mut data)?;
245 headers.entry(name).or_default().push(value);
246 }
247
248 let client_ip = get_string(&mut data)?;
249
250 if data.remaining() < 2 {
251 return Err(AgentProtocolError::InvalidMessage(
252 "Missing client port".to_string(),
253 ));
254 }
255 let client_port = data.get_u16();
256
257 Ok(Self {
258 correlation_id,
259 method,
260 uri,
261 headers,
262 client_ip,
263 client_port,
264 })
265 }
266}
267
268#[derive(Debug, Clone)]
277pub struct BinaryBodyChunk {
278 pub correlation_id: String,
279 pub chunk_index: u32,
280 pub is_last: bool,
281 pub data: Bytes,
282}
283
284impl BinaryBodyChunk {
285 pub fn encode(&self) -> Bytes {
287 let mut buf = BytesMut::with_capacity(32 + self.data.len());
288
289 put_string(&mut buf, &self.correlation_id);
290 buf.put_u32(self.chunk_index);
291 buf.put_u8(if self.is_last { 1 } else { 0 });
292 buf.put_u32(self.data.len() as u32);
293 buf.put_slice(&self.data);
294
295 buf.freeze()
296 }
297
298 pub fn decode(mut data: Bytes) -> Result<Self, AgentProtocolError> {
300 let correlation_id = get_string(&mut data)?;
301
302 if data.remaining() < 9 {
303 return Err(AgentProtocolError::InvalidMessage(
304 "Missing body chunk fields".to_string(),
305 ));
306 }
307
308 let chunk_index = data.get_u32();
309 let is_last = data.get_u8() != 0;
310 let data_len = data.get_u32() as usize;
311
312 if data.remaining() < data_len {
313 return Err(AgentProtocolError::InvalidMessage(
314 "Body data truncated".to_string(),
315 ));
316 }
317
318 let body_data = data.copy_to_bytes(data_len);
319
320 Ok(Self {
321 correlation_id,
322 chunk_index,
323 is_last,
324 data: body_data,
325 })
326 }
327}
328
329#[derive(Debug, Clone)]
339pub struct BinaryAgentResponse {
340 pub correlation_id: String,
341 pub decision: Decision,
342 pub request_headers: Vec<HeaderOp>,
343 pub response_headers: Vec<HeaderOp>,
344 pub needs_more: bool,
345}
346
347impl BinaryAgentResponse {
348 pub fn encode(&self) -> Bytes {
350 let mut buf = BytesMut::with_capacity(128);
351
352 put_string(&mut buf, &self.correlation_id);
353
354 match &self.decision {
356 Decision::Allow => {
357 buf.put_u8(0);
358 }
359 Decision::Block {
360 status,
361 body,
362 headers,
363 } => {
364 buf.put_u8(1);
365 buf.put_u16(*status);
366 put_optional_string(&mut buf, body.as_deref());
367 let h_count = headers.as_ref().map(|h| h.len()).unwrap_or(0);
369 buf.put_u16(h_count as u16);
370 if let Some(headers) = headers {
371 for (k, v) in headers {
372 put_string(&mut buf, k);
373 put_string(&mut buf, v);
374 }
375 }
376 }
377 Decision::Redirect { url, status } => {
378 buf.put_u8(2);
379 put_string(&mut buf, url);
380 buf.put_u16(*status);
381 }
382 Decision::Challenge {
383 challenge_type,
384 params,
385 } => {
386 buf.put_u8(3);
387 put_string(&mut buf, challenge_type);
388 buf.put_u16(params.len() as u16);
389 for (k, v) in params {
390 put_string(&mut buf, k);
391 put_string(&mut buf, v);
392 }
393 }
394 }
395
396 buf.put_u16(self.request_headers.len() as u16);
398 for op in &self.request_headers {
399 encode_header_op(&mut buf, op);
400 }
401
402 buf.put_u16(self.response_headers.len() as u16);
404 for op in &self.response_headers {
405 encode_header_op(&mut buf, op);
406 }
407
408 buf.put_u8(if self.needs_more { 1 } else { 0 });
410
411 buf.freeze()
412 }
413
414 pub fn decode(mut data: Bytes) -> Result<Self, AgentProtocolError> {
416 let correlation_id = get_string(&mut data)?;
417
418 if data.remaining() < 1 {
419 return Err(AgentProtocolError::InvalidMessage(
420 "Missing decision type".to_string(),
421 ));
422 }
423
424 let decision_type = data.get_u8();
425 let decision = match decision_type {
426 0 => Decision::Allow,
427 1 => {
428 if data.remaining() < 2 {
429 return Err(AgentProtocolError::InvalidMessage(
430 "Missing block status".to_string(),
431 ));
432 }
433 let status = data.get_u16();
434 let body = get_optional_string(&mut data)?;
435 if data.remaining() < 2 {
436 return Err(AgentProtocolError::InvalidMessage(
437 "Missing block headers count".to_string(),
438 ));
439 }
440 let h_count = data.get_u16() as usize;
441 let headers = if h_count > 0 {
442 let mut h = HashMap::new();
443 for _ in 0..h_count {
444 let k = get_string(&mut data)?;
445 let v = get_string(&mut data)?;
446 h.insert(k, v);
447 }
448 Some(h)
449 } else {
450 None
451 };
452 Decision::Block {
453 status,
454 body,
455 headers,
456 }
457 }
458 2 => {
459 let url = get_string(&mut data)?;
460 if data.remaining() < 2 {
461 return Err(AgentProtocolError::InvalidMessage(
462 "Missing redirect status".to_string(),
463 ));
464 }
465 let status = data.get_u16();
466 Decision::Redirect { url, status }
467 }
468 3 => {
469 let challenge_type = get_string(&mut data)?;
470 if data.remaining() < 2 {
471 return Err(AgentProtocolError::InvalidMessage(
472 "Missing challenge params count".to_string(),
473 ));
474 }
475 let p_count = data.get_u16() as usize;
476 let mut params = HashMap::new();
477 for _ in 0..p_count {
478 let k = get_string(&mut data)?;
479 let v = get_string(&mut data)?;
480 params.insert(k, v);
481 }
482 Decision::Challenge {
483 challenge_type,
484 params,
485 }
486 }
487 _ => {
488 return Err(AgentProtocolError::InvalidMessage(format!(
489 "Unknown decision type: {}",
490 decision_type
491 )));
492 }
493 };
494
495 if data.remaining() < 2 {
497 return Err(AgentProtocolError::InvalidMessage(
498 "Missing request headers count".to_string(),
499 ));
500 }
501 let req_h_count = data.get_u16() as usize;
502 let mut request_headers = Vec::with_capacity(req_h_count);
503 for _ in 0..req_h_count {
504 request_headers.push(decode_header_op(&mut data)?);
505 }
506
507 if data.remaining() < 2 {
509 return Err(AgentProtocolError::InvalidMessage(
510 "Missing response headers count".to_string(),
511 ));
512 }
513 let resp_h_count = data.get_u16() as usize;
514 let mut response_headers = Vec::with_capacity(resp_h_count);
515 for _ in 0..resp_h_count {
516 response_headers.push(decode_header_op(&mut data)?);
517 }
518
519 if data.remaining() < 1 {
521 return Err(AgentProtocolError::InvalidMessage(
522 "Missing needs_more".to_string(),
523 ));
524 }
525 let needs_more = data.get_u8() != 0;
526
527 Ok(Self {
528 correlation_id,
529 decision,
530 request_headers,
531 response_headers,
532 needs_more,
533 })
534 }
535}
536
537fn put_string(buf: &mut BytesMut, s: &str) {
542 let bytes = s.as_bytes();
543 buf.put_u16(bytes.len() as u16);
544 buf.put_slice(bytes);
545}
546
547fn get_string(data: &mut Bytes) -> Result<String, AgentProtocolError> {
548 if data.remaining() < 2 {
549 return Err(AgentProtocolError::InvalidMessage(
550 "Missing string length".to_string(),
551 ));
552 }
553 let len = data.get_u16() as usize;
554 if data.remaining() < len {
555 return Err(AgentProtocolError::InvalidMessage(
556 "String data truncated".to_string(),
557 ));
558 }
559 let bytes = data.copy_to_bytes(len);
560 String::from_utf8(bytes.to_vec())
561 .map_err(|e| AgentProtocolError::InvalidMessage(format!("Invalid UTF-8: {}", e)))
562}
563
564fn put_optional_string(buf: &mut BytesMut, s: Option<&str>) {
565 match s {
566 Some(s) => {
567 buf.put_u8(1);
568 put_string(buf, s);
569 }
570 None => {
571 buf.put_u8(0);
572 }
573 }
574}
575
576fn get_optional_string(data: &mut Bytes) -> Result<Option<String>, AgentProtocolError> {
577 if data.remaining() < 1 {
578 return Err(AgentProtocolError::InvalidMessage(
579 "Missing optional string flag".to_string(),
580 ));
581 }
582 let present = data.get_u8() != 0;
583 if present {
584 get_string(data).map(Some)
585 } else {
586 Ok(None)
587 }
588}
589
590fn encode_header_op(buf: &mut BytesMut, op: &HeaderOp) {
591 match op {
592 HeaderOp::Set { name, value } => {
593 buf.put_u8(0);
594 put_string(buf, name);
595 put_string(buf, value);
596 }
597 HeaderOp::Add { name, value } => {
598 buf.put_u8(1);
599 put_string(buf, name);
600 put_string(buf, value);
601 }
602 HeaderOp::Remove { name } => {
603 buf.put_u8(2);
604 put_string(buf, name);
605 }
606 }
607}
608
609fn decode_header_op(data: &mut Bytes) -> Result<HeaderOp, AgentProtocolError> {
610 if data.remaining() < 1 {
611 return Err(AgentProtocolError::InvalidMessage(
612 "Missing header op type".to_string(),
613 ));
614 }
615 let op_type = data.get_u8();
616 match op_type {
617 0 => {
618 let name = get_string(data)?;
619 let value = get_string(data)?;
620 Ok(HeaderOp::Set { name, value })
621 }
622 1 => {
623 let name = get_string(data)?;
624 let value = get_string(data)?;
625 Ok(HeaderOp::Add { name, value })
626 }
627 2 => {
628 let name = get_string(data)?;
629 Ok(HeaderOp::Remove { name })
630 }
631 _ => Err(AgentProtocolError::InvalidMessage(format!(
632 "Unknown header op type: {}",
633 op_type
634 ))),
635 }
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641
642 #[test]
643 fn test_message_type_roundtrip() {
644 for t in [
645 MessageType::HandshakeRequest,
646 MessageType::HandshakeResponse,
647 MessageType::RequestHeaders,
648 MessageType::RequestBodyChunk,
649 MessageType::AgentResponse,
650 MessageType::Ping,
651 MessageType::Pong,
652 MessageType::Cancel,
653 MessageType::Error,
654 ] {
655 let byte = t as u8;
656 let decoded = MessageType::try_from(byte).unwrap();
657 assert_eq!(t, decoded);
658 }
659 }
660
661 #[test]
662 fn test_binary_frame_encode_decode() {
663 let frame = BinaryFrame::new(MessageType::Ping, Bytes::from_static(b"hello"));
664 let encoded = frame.encode();
665
666 assert_eq!(encoded.len(), 4 + 1 + 5); assert_eq!(&encoded[0..4], &[0, 0, 0, 6]); assert_eq!(encoded[4], MessageType::Ping as u8);
670 assert_eq!(&encoded[5..], b"hello");
671 }
672
673 #[test]
674 fn test_binary_request_headers_roundtrip() {
675 let headers = BinaryRequestHeaders {
676 correlation_id: "req-123".to_string(),
677 method: "POST".to_string(),
678 uri: "/api/test".to_string(),
679 headers: {
680 let mut h = HashMap::new();
681 h.insert(
682 "content-type".to_string(),
683 vec!["application/json".to_string()],
684 );
685 h.insert(
686 "x-custom".to_string(),
687 vec!["value1".to_string(), "value2".to_string()],
688 );
689 h
690 },
691 client_ip: "192.168.1.1".to_string(),
692 client_port: 12345,
693 };
694
695 let encoded = headers.encode();
696 let decoded = BinaryRequestHeaders::decode(encoded).unwrap();
697
698 assert_eq!(decoded.correlation_id, "req-123");
699 assert_eq!(decoded.method, "POST");
700 assert_eq!(decoded.uri, "/api/test");
701 assert_eq!(decoded.client_ip, "192.168.1.1");
702 assert_eq!(decoded.client_port, 12345);
703 assert_eq!(
704 decoded.headers.get("content-type").unwrap(),
705 &vec!["application/json".to_string()]
706 );
707 }
708
709 #[test]
710 fn test_binary_body_chunk_roundtrip() {
711 let chunk = BinaryBodyChunk {
712 correlation_id: "req-456".to_string(),
713 chunk_index: 2,
714 is_last: true,
715 data: Bytes::from_static(b"binary data here"),
716 };
717
718 let encoded = chunk.encode();
719 let decoded = BinaryBodyChunk::decode(encoded).unwrap();
720
721 assert_eq!(decoded.correlation_id, "req-456");
722 assert_eq!(decoded.chunk_index, 2);
723 assert!(decoded.is_last);
724 assert_eq!(&decoded.data[..], b"binary data here");
725 }
726
727 #[test]
728 fn test_binary_agent_response_allow() {
729 let response = BinaryAgentResponse {
730 correlation_id: "req-789".to_string(),
731 decision: Decision::Allow,
732 request_headers: vec![HeaderOp::Set {
733 name: "X-Added".to_string(),
734 value: "true".to_string(),
735 }],
736 response_headers: vec![],
737 needs_more: false,
738 };
739
740 let encoded = response.encode();
741 let decoded = BinaryAgentResponse::decode(encoded).unwrap();
742
743 assert_eq!(decoded.correlation_id, "req-789");
744 assert!(matches!(decoded.decision, Decision::Allow));
745 assert_eq!(decoded.request_headers.len(), 1);
746 assert!(!decoded.needs_more);
747 }
748
749 #[test]
750 fn test_binary_agent_response_block() {
751 let response = BinaryAgentResponse {
752 correlation_id: "req-block".to_string(),
753 decision: Decision::Block {
754 status: 403,
755 body: Some("Forbidden".to_string()),
756 headers: None,
757 },
758 request_headers: vec![],
759 response_headers: vec![],
760 needs_more: false,
761 };
762
763 let encoded = response.encode();
764 let decoded = BinaryAgentResponse::decode(encoded).unwrap();
765
766 assert_eq!(decoded.correlation_id, "req-block");
767 match decoded.decision {
768 Decision::Block {
769 status,
770 body,
771 headers,
772 } => {
773 assert_eq!(status, 403);
774 assert_eq!(body, Some("Forbidden".to_string()));
775 assert!(headers.is_none());
776 }
777 _ => panic!("Expected Block decision"),
778 }
779 }
780}