1use std::collections::HashMap;
39use std::sync::atomic::{AtomicU64, Ordering};
40use std::sync::Arc;
41use std::time::Duration;
42
43use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
44use tokio::net::UnixStream;
45use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
46use tracing::{debug, error, info, trace, warn};
47
48use crate::v2::pool::CHANNEL_BUFFER_SIZE;
49use crate::v2::{AgentCapabilities, AgentFeatures, AgentLimits, HealthConfig, PROTOCOL_VERSION_2};
50use crate::{AgentProtocolError, AgentResponse, EventType};
51
52use super::client::{ConfigUpdateCallback, FlowState, MetricsCallback};
53
54pub const MAX_UDS_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
62#[serde(rename_all = "lowercase")]
63pub enum UdsEncoding {
64 #[default]
66 Json,
67 #[serde(rename = "msgpack")]
69 MessagePack,
70}
71
72impl UdsEncoding {
73 #[inline]
77 pub fn serialize<T: serde::Serialize>(&self, value: &T) -> Result<Vec<u8>, AgentProtocolError> {
78 match self {
79 UdsEncoding::Json => serde_json::to_vec(value)
80 .map_err(|e| AgentProtocolError::Serialization(e.to_string())),
81 #[cfg(feature = "binary-uds")]
82 UdsEncoding::MessagePack => rmp_serde::to_vec(value)
83 .map_err(|e| AgentProtocolError::Serialization(e.to_string())),
84 #[cfg(not(feature = "binary-uds"))]
85 UdsEncoding::MessagePack => {
86 serde_json::to_vec(value)
88 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))
89 }
90 }
91 }
92
93 #[inline]
97 pub fn deserialize<'a, T: serde::Deserialize<'a>>(
98 &self,
99 bytes: &'a [u8],
100 ) -> Result<T, AgentProtocolError> {
101 match self {
102 UdsEncoding::Json => serde_json::from_slice(bytes)
103 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string())),
104 #[cfg(feature = "binary-uds")]
105 UdsEncoding::MessagePack => rmp_serde::from_slice(bytes)
106 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string())),
107 #[cfg(not(feature = "binary-uds"))]
108 UdsEncoding::MessagePack => {
109 serde_json::from_slice(bytes)
111 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))
112 }
113 }
114 }
115}
116
117#[repr(u8)]
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum MessageType {
121 HandshakeRequest = 0x01,
123 HandshakeResponse = 0x02,
124
125 RequestHeaders = 0x10,
127 RequestBodyChunk = 0x11,
128 ResponseHeaders = 0x12,
129 ResponseBodyChunk = 0x13,
130 RequestComplete = 0x14,
131 WebSocketFrame = 0x15,
132 GuardrailInspect = 0x16,
133 Configure = 0x17,
134
135 AgentResponse = 0x20,
137
138 HealthStatus = 0x30,
140 MetricsReport = 0x31,
141 ConfigUpdateRequest = 0x32,
142 FlowControl = 0x33,
143
144 Cancel = 0x40,
146 Ping = 0x41,
147 Pong = 0x42,
148}
149
150impl TryFrom<u8> for MessageType {
151 type Error = AgentProtocolError;
152
153 fn try_from(value: u8) -> Result<Self, Self::Error> {
154 match value {
155 0x01 => Ok(MessageType::HandshakeRequest),
156 0x02 => Ok(MessageType::HandshakeResponse),
157 0x10 => Ok(MessageType::RequestHeaders),
158 0x11 => Ok(MessageType::RequestBodyChunk),
159 0x12 => Ok(MessageType::ResponseHeaders),
160 0x13 => Ok(MessageType::ResponseBodyChunk),
161 0x14 => Ok(MessageType::RequestComplete),
162 0x15 => Ok(MessageType::WebSocketFrame),
163 0x16 => Ok(MessageType::GuardrailInspect),
164 0x17 => Ok(MessageType::Configure),
165 0x20 => Ok(MessageType::AgentResponse),
166 0x30 => Ok(MessageType::HealthStatus),
167 0x31 => Ok(MessageType::MetricsReport),
168 0x32 => Ok(MessageType::ConfigUpdateRequest),
169 0x33 => Ok(MessageType::FlowControl),
170 0x40 => Ok(MessageType::Cancel),
171 0x41 => Ok(MessageType::Ping),
172 0x42 => Ok(MessageType::Pong),
173 _ => Err(AgentProtocolError::InvalidMessage(format!(
174 "Unknown message type: 0x{:02x}",
175 value
176 ))),
177 }
178 }
179}
180
181#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
183pub struct UdsHandshakeRequest {
184 pub supported_versions: Vec<u32>,
185 pub proxy_id: String,
186 pub proxy_version: String,
187 pub config: Option<serde_json::Value>,
188 #[serde(default, skip_serializing_if = "Vec::is_empty")]
191 pub supported_encodings: Vec<UdsEncoding>,
192}
193
194#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
196pub struct UdsHandshakeResponse {
197 pub protocol_version: u32,
198 pub capabilities: UdsCapabilities,
199 pub success: bool,
200 pub error: Option<String>,
201 #[serde(default)]
204 pub encoding: UdsEncoding,
205}
206
207#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
209pub struct UdsCapabilities {
210 pub agent_id: String,
211 pub name: String,
212 pub version: String,
213 pub supported_events: Vec<i32>,
214 pub features: UdsFeatures,
215 pub limits: UdsLimits,
216}
217
218#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
220pub struct UdsFeatures {
221 pub streaming_body: bool,
222 pub websocket: bool,
223 pub guardrails: bool,
224 pub config_push: bool,
225 pub metrics_export: bool,
226 pub concurrent_requests: u32,
227 pub cancellation: bool,
228 pub flow_control: bool,
229 pub health_reporting: bool,
230}
231
232#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
234pub struct UdsLimits {
235 pub max_body_size: u64,
236 pub max_concurrency: u32,
237 pub preferred_chunk_size: u64,
238}
239
240impl From<UdsCapabilities> for AgentCapabilities {
241 fn from(caps: UdsCapabilities) -> Self {
242 AgentCapabilities {
243 protocol_version: PROTOCOL_VERSION_2,
244 agent_id: caps.agent_id,
245 name: caps.name,
246 version: caps.version,
247 supported_events: caps
248 .supported_events
249 .into_iter()
250 .filter_map(event_type_from_i32)
251 .collect(),
252 features: AgentFeatures {
253 streaming_body: caps.features.streaming_body,
254 websocket: caps.features.websocket,
255 guardrails: caps.features.guardrails,
256 config_push: caps.features.config_push,
257 metrics_export: caps.features.metrics_export,
258 concurrent_requests: caps.features.concurrent_requests,
259 cancellation: caps.features.cancellation,
260 flow_control: caps.features.flow_control,
261 health_reporting: caps.features.health_reporting,
262 },
263 limits: AgentLimits {
264 max_body_size: caps.limits.max_body_size as usize,
265 max_concurrency: caps.limits.max_concurrency,
266 preferred_chunk_size: caps.limits.preferred_chunk_size as usize,
267 max_memory: None,
268 max_processing_time_ms: None,
269 },
270 health: HealthConfig::default(),
271 }
272 }
273}
274
275impl From<AgentCapabilities> for UdsCapabilities {
276 fn from(caps: AgentCapabilities) -> Self {
277 use crate::v2::server::event_type_to_i32;
278 UdsCapabilities {
279 agent_id: caps.agent_id,
280 name: caps.name,
281 version: caps.version,
282 supported_events: caps
283 .supported_events
284 .iter()
285 .map(|e| event_type_to_i32(*e))
286 .collect(),
287 features: UdsFeatures {
288 streaming_body: caps.features.streaming_body,
289 websocket: caps.features.websocket,
290 guardrails: caps.features.guardrails,
291 config_push: caps.features.config_push,
292 metrics_export: caps.features.metrics_export,
293 concurrent_requests: caps.features.concurrent_requests,
294 cancellation: caps.features.cancellation,
295 flow_control: caps.features.flow_control,
296 health_reporting: caps.features.health_reporting,
297 },
298 limits: UdsLimits {
299 max_body_size: caps.limits.max_body_size as u64,
300 max_concurrency: caps.limits.max_concurrency,
301 preferred_chunk_size: caps.limits.preferred_chunk_size as u64,
302 },
303 }
304 }
305}
306
307fn event_type_from_i32(value: i32) -> Option<EventType> {
309 match value {
310 0 => Some(EventType::Configure),
311 1 => Some(EventType::RequestHeaders),
312 2 => Some(EventType::RequestBodyChunk),
313 3 => Some(EventType::ResponseHeaders),
314 4 => Some(EventType::ResponseBodyChunk),
315 5 => Some(EventType::RequestComplete),
316 6 => Some(EventType::WebSocketFrame),
317 7 => Some(EventType::GuardrailInspect),
318 _ => None,
319 }
320}
321
322pub struct AgentClientV2Uds {
327 agent_id: String,
329 socket_path: String,
331 timeout: Duration,
333 capabilities: RwLock<Option<AgentCapabilities>>,
335 protocol_version: AtomicU64,
337 encoding: RwLock<UdsEncoding>,
339 pending: Arc<Mutex<HashMap<String, oneshot::Sender<AgentResponse>>>>,
341 #[allow(clippy::type_complexity)]
343 outbound_tx: Mutex<Option<mpsc::Sender<(MessageType, Vec<u8>)>>>,
344 ping_sequence: AtomicU64,
346 connected: RwLock<bool>,
348 flow_state: RwLock<FlowState>,
350 health_state: RwLock<i32>,
352 in_flight: AtomicU64,
354 metrics_callback: Option<MetricsCallback>,
356 config_update_callback: Option<ConfigUpdateCallback>,
358}
359
360impl AgentClientV2Uds {
361 pub async fn new(
363 agent_id: impl Into<String>,
364 socket_path: impl Into<String>,
365 timeout: Duration,
366 ) -> Result<Self, AgentProtocolError> {
367 let agent_id = agent_id.into();
368 let socket_path = socket_path.into();
369
370 debug!(
371 agent_id = %agent_id,
372 socket_path = %socket_path,
373 timeout_ms = timeout.as_millis(),
374 "Creating UDS v2 client"
375 );
376
377 Ok(Self {
378 agent_id,
379 socket_path,
380 timeout,
381 capabilities: RwLock::new(None),
382 protocol_version: AtomicU64::new(0),
383 encoding: RwLock::new(UdsEncoding::Json),
384 pending: Arc::new(Mutex::new(HashMap::new())),
385 outbound_tx: Mutex::new(None),
386 ping_sequence: AtomicU64::new(0),
387 connected: RwLock::new(false),
388 flow_state: RwLock::new(FlowState::Normal),
389 health_state: RwLock::new(1), in_flight: AtomicU64::new(0),
391 metrics_callback: None,
392 config_update_callback: None,
393 })
394 }
395
396 fn supported_encodings() -> Vec<UdsEncoding> {
400 #[cfg(feature = "binary-uds")]
401 {
402 vec![UdsEncoding::MessagePack, UdsEncoding::Json]
403 }
404 #[cfg(not(feature = "binary-uds"))]
405 {
406 vec![UdsEncoding::Json]
407 }
408 }
409
410 pub async fn encoding(&self) -> UdsEncoding {
412 *self.encoding.read().await
413 }
414
415 pub fn set_metrics_callback(&mut self, callback: MetricsCallback) {
417 self.metrics_callback = Some(callback);
418 }
419
420 pub fn set_config_update_callback(&mut self, callback: ConfigUpdateCallback) {
422 self.config_update_callback = Some(callback);
423 }
424
425 pub async fn connect(&self) -> Result<(), AgentProtocolError> {
427 info!(
428 agent_id = %self.agent_id,
429 socket_path = %self.socket_path,
430 "Connecting to agent via UDS v2"
431 );
432
433 let stream = UnixStream::connect(&self.socket_path).await.map_err(|e| {
435 error!(
436 agent_id = %self.agent_id,
437 socket_path = %self.socket_path,
438 error = %e,
439 "Failed to connect to agent via UDS"
440 );
441 AgentProtocolError::ConnectionFailed(e.to_string())
442 })?;
443
444 let (read_half, write_half) = stream.into_split();
445 let mut reader = BufReader::new(read_half);
446 let mut writer = BufWriter::new(write_half);
447
448 let handshake_req = UdsHandshakeRequest {
450 supported_versions: vec![PROTOCOL_VERSION_2],
451 proxy_id: "grapsus-proxy".to_string(),
452 proxy_version: env!("CARGO_PKG_VERSION").to_string(),
453 config: None,
454 supported_encodings: Self::supported_encodings(),
455 };
456
457 let payload = serde_json::to_vec(&handshake_req)
459 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
460
461 write_message(&mut writer, MessageType::HandshakeRequest, &payload).await?;
462
463 let (msg_type, response_bytes) = read_message(&mut reader).await?;
465
466 if msg_type != MessageType::HandshakeResponse {
467 return Err(AgentProtocolError::InvalidMessage(format!(
468 "Expected HandshakeResponse, got {:?}",
469 msg_type
470 )));
471 }
472
473 let response: UdsHandshakeResponse = serde_json::from_slice(&response_bytes)
474 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
475
476 if !response.success {
477 return Err(AgentProtocolError::ConnectionFailed(
478 response
479 .error
480 .unwrap_or_else(|| "Unknown handshake error".to_string()),
481 ));
482 }
483
484 let capabilities: AgentCapabilities = response.capabilities.into();
486 *self.capabilities.write().await = Some(capabilities);
487 self.protocol_version
488 .store(response.protocol_version as u64, Ordering::SeqCst);
489
490 let negotiated_encoding = response.encoding;
492 *self.encoding.write().await = negotiated_encoding;
493
494 info!(
495 agent_id = %self.agent_id,
496 protocol_version = response.protocol_version,
497 encoding = ?negotiated_encoding,
498 "UDS v2 handshake successful"
499 );
500
501 let (tx, mut rx) = mpsc::channel::<(MessageType, Vec<u8>)>(CHANNEL_BUFFER_SIZE);
503 *self.outbound_tx.lock().await = Some(tx);
504 *self.connected.write().await = true;
505
506 let agent_id_clone = self.agent_id.clone();
508 tokio::spawn(async move {
509 while let Some((msg_type, payload)) = rx.recv().await {
510 if let Err(e) = write_message(&mut writer, msg_type, &payload).await {
511 error!(
512 agent_id = %agent_id_clone,
513 error = %e,
514 "Failed to write message to UDS"
515 );
516 break;
517 }
518 }
519 debug!(agent_id = %agent_id_clone, "UDS writer task ended");
520 });
521
522 let pending = Arc::clone(&self.pending);
524 let agent_id = self.agent_id.clone();
525 let flow_state = Arc::new(RwLock::new(FlowState::Normal));
526 let health_state = Arc::new(RwLock::new(1i32));
527 let flow_state_clone = Arc::clone(&flow_state);
528 let health_state_clone = Arc::clone(&health_state);
529 let metrics_callback = self.metrics_callback.clone();
530 let config_update_callback = self.config_update_callback.clone();
531 let reader_encoding = negotiated_encoding;
533
534 tokio::spawn(async move {
535 loop {
536 match read_message(&mut reader).await {
537 Ok((msg_type, payload)) => {
538 match msg_type {
539 MessageType::AgentResponse => {
540 match reader_encoding.deserialize::<AgentResponse>(&payload) {
541 Ok(response) => {
542 if let Some(sender) = pending.lock().await.remove(
545 &response
546 .audit
547 .custom
548 .get("correlation_id")
549 .and_then(|v| v.as_str())
550 .unwrap_or("")
551 .to_string(),
552 ) {
553 let _ = sender.send(response);
554 }
555 }
556 Err(e) => {
557 warn!(
558 agent_id = %agent_id,
559 error = %e,
560 encoding = ?reader_encoding,
561 "Failed to parse agent response"
562 );
563 }
564 }
565 }
566 MessageType::HealthStatus => {
567 #[derive(serde::Deserialize)]
569 struct HealthStatusMsg {
570 state: Option<i64>,
571 }
572 if let Ok(health) =
573 reader_encoding.deserialize::<HealthStatusMsg>(&payload)
574 {
575 if let Some(state) = health.state {
576 *health_state_clone.write().await = state as i32;
577 }
578 }
579 }
580 MessageType::MetricsReport => {
581 if let Some(ref callback) = metrics_callback {
582 if let Ok(report) = reader_encoding.deserialize(&payload) {
583 callback(report);
584 }
585 }
586 }
587 MessageType::FlowControl => {
588 #[derive(serde::Deserialize)]
589 struct FlowControlMsg {
590 action: Option<i64>,
591 }
592 if let Ok(fc) =
593 reader_encoding.deserialize::<FlowControlMsg>(&payload)
594 {
595 let action = fc.action.unwrap_or(0);
596 let new_state = match action {
597 1 => FlowState::Paused,
598 2 => FlowState::Normal,
599 _ => FlowState::Normal,
600 };
601 *flow_state_clone.write().await = new_state;
602 }
603 }
604 MessageType::ConfigUpdateRequest => {
605 if let Some(ref callback) = config_update_callback {
606 if let Ok(request) = reader_encoding.deserialize(&payload) {
607 let _response = callback(agent_id.clone(), request);
608 }
609 }
610 }
611 MessageType::Pong => {
612 trace!(agent_id = %agent_id, "Received pong");
613 }
614 _ => {
615 trace!(
616 agent_id = %agent_id,
617 msg_type = ?msg_type,
618 "Received unhandled message type"
619 );
620 }
621 }
622 }
623 Err(e) => {
624 if !matches!(e, AgentProtocolError::ConnectionClosed) {
625 error!(
626 agent_id = %agent_id,
627 error = %e,
628 "Error reading from UDS"
629 );
630 }
631 break;
632 }
633 }
634 }
635 debug!(agent_id = %agent_id, "UDS reader task ended");
636 });
637
638 Ok(())
639 }
640
641 pub async fn capabilities(&self) -> Option<AgentCapabilities> {
643 self.capabilities.read().await.clone()
644 }
645
646 pub async fn is_connected(&self) -> bool {
648 *self.connected.read().await
649 }
650
651 pub async fn send_request_headers(
653 &self,
654 correlation_id: &str,
655 event: &crate::RequestHeadersEvent,
656 ) -> Result<AgentResponse, AgentProtocolError> {
657 self.send_event(MessageType::RequestHeaders, correlation_id, event)
658 .await
659 }
660
661 pub async fn send_request_body_chunk(
663 &self,
664 correlation_id: &str,
665 event: &crate::RequestBodyChunkEvent,
666 ) -> Result<AgentResponse, AgentProtocolError> {
667 self.send_event(MessageType::RequestBodyChunk, correlation_id, event)
668 .await
669 }
670
671 pub async fn send_response_headers(
673 &self,
674 correlation_id: &str,
675 event: &crate::ResponseHeadersEvent,
676 ) -> Result<AgentResponse, AgentProtocolError> {
677 self.send_event(MessageType::ResponseHeaders, correlation_id, event)
678 .await
679 }
680
681 pub async fn send_response_body_chunk(
683 &self,
684 correlation_id: &str,
685 event: &crate::ResponseBodyChunkEvent,
686 ) -> Result<AgentResponse, AgentProtocolError> {
687 self.send_event(MessageType::ResponseBodyChunk, correlation_id, event)
688 .await
689 }
690
691 pub async fn send_request_complete(
693 &self,
694 correlation_id: &str,
695 event: &crate::RequestCompleteEvent,
696 ) -> Result<AgentResponse, AgentProtocolError> {
697 self.send_event(MessageType::RequestComplete, correlation_id, event)
698 .await
699 }
700
701 pub async fn send_websocket_frame(
703 &self,
704 correlation_id: &str,
705 event: &crate::WebSocketFrameEvent,
706 ) -> Result<AgentResponse, AgentProtocolError> {
707 self.send_event(MessageType::WebSocketFrame, correlation_id, event)
708 .await
709 }
710
711 pub async fn send_guardrail_inspect(
713 &self,
714 correlation_id: &str,
715 event: &crate::GuardrailInspectEvent,
716 ) -> Result<AgentResponse, AgentProtocolError> {
717 self.send_event(MessageType::GuardrailInspect, correlation_id, event)
718 .await
719 }
720
721 pub async fn send_configure(
723 &self,
724 correlation_id: &str,
725 event: &serde_json::Value,
726 ) -> Result<AgentResponse, AgentProtocolError> {
727 self.send_event(MessageType::Configure, correlation_id, event)
728 .await
729 }
730
731 pub async fn send_request_body_chunk_binary(
745 &self,
746 event: &crate::BinaryRequestBodyChunkEvent,
747 ) -> Result<AgentResponse, AgentProtocolError> {
748 let correlation_id = &event.correlation_id;
749 self.send_binary_body_chunk(
750 MessageType::RequestBodyChunk,
751 correlation_id,
752 &event.data,
753 event.is_last,
754 event.total_size,
755 event.chunk_index,
756 Some(event.bytes_received),
757 None,
758 )
759 .await
760 }
761
762 pub async fn send_response_body_chunk_binary(
767 &self,
768 event: &crate::BinaryResponseBodyChunkEvent,
769 ) -> Result<AgentResponse, AgentProtocolError> {
770 let correlation_id = &event.correlation_id;
771 self.send_binary_body_chunk(
772 MessageType::ResponseBodyChunk,
773 correlation_id,
774 &event.data,
775 event.is_last,
776 event.total_size,
777 event.chunk_index,
778 None,
779 Some(event.bytes_sent),
780 )
781 .await
782 }
783
784 #[allow(clippy::too_many_arguments)]
786 async fn send_binary_body_chunk(
787 &self,
788 msg_type: MessageType,
789 correlation_id: &str,
790 data: &bytes::Bytes,
791 is_last: bool,
792 total_size: Option<usize>,
793 chunk_index: u32,
794 bytes_received: Option<usize>,
795 bytes_sent: Option<usize>,
796 ) -> Result<AgentResponse, AgentProtocolError> {
797 let (tx, rx) = oneshot::channel();
799 self.pending
800 .lock()
801 .await
802 .insert(correlation_id.to_string(), tx);
803
804 let encoding = *self.encoding.read().await;
806
807 let payload_bytes = match encoding {
809 UdsEncoding::Json => {
810 use base64::{engine::general_purpose::STANDARD, Engine as _};
812 let json = serde_json::json!({
813 "correlation_id": correlation_id,
814 "data": STANDARD.encode(data),
815 "is_last": is_last,
816 "total_size": total_size,
817 "chunk_index": chunk_index,
818 "bytes_received": bytes_received,
819 "bytes_sent": bytes_sent,
820 });
821 serde_json::to_vec(&json)
822 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?
823 }
824 UdsEncoding::MessagePack => {
825 #[derive(serde::Serialize)]
827 struct BinaryBodyChunk<'a> {
828 correlation_id: &'a str,
829 #[serde(with = "serde_bytes")]
830 data: &'a [u8],
831 is_last: bool,
832 #[serde(skip_serializing_if = "Option::is_none")]
833 total_size: Option<usize>,
834 chunk_index: u32,
835 #[serde(skip_serializing_if = "Option::is_none")]
836 bytes_received: Option<usize>,
837 #[serde(skip_serializing_if = "Option::is_none")]
838 bytes_sent: Option<usize>,
839 }
840 let chunk = BinaryBodyChunk {
841 correlation_id,
842 data: data.as_ref(),
843 is_last,
844 total_size,
845 chunk_index,
846 bytes_received,
847 bytes_sent,
848 };
849 encoding.serialize(&chunk)?
850 }
851 };
852
853 {
855 let outbound = self.outbound_tx.lock().await;
856 if let Some(tx) = outbound.as_ref() {
857 tx.send((msg_type, payload_bytes))
858 .await
859 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
860 } else {
861 return Err(AgentProtocolError::ConnectionClosed);
862 }
863 }
864
865 self.in_flight.fetch_add(1, Ordering::Relaxed);
866
867 let response = tokio::time::timeout(self.timeout, rx)
869 .await
870 .map_err(|_| {
871 self.pending
872 .try_lock()
873 .ok()
874 .map(|mut p| p.remove(correlation_id));
875 AgentProtocolError::Timeout(self.timeout)
876 })?
877 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
878
879 self.in_flight.fetch_sub(1, Ordering::Relaxed);
880
881 Ok(response)
882 }
883
884 async fn send_event<T: serde::Serialize>(
886 &self,
887 msg_type: MessageType,
888 correlation_id: &str,
889 event: &T,
890 ) -> Result<AgentResponse, AgentProtocolError> {
891 let (tx, rx) = oneshot::channel();
893 self.pending
894 .lock()
895 .await
896 .insert(correlation_id.to_string(), tx);
897
898 let encoding = *self.encoding.read().await;
900
901 let payload_bytes = match encoding {
903 UdsEncoding::Json => {
904 let mut payload = serde_json::to_value(event)
906 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
907 if let Some(obj) = payload.as_object_mut() {
908 obj.insert(
909 "correlation_id".to_string(),
910 serde_json::Value::String(correlation_id.to_string()),
911 );
912 }
913 serde_json::to_vec(&payload)
914 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?
915 }
916 UdsEncoding::MessagePack => {
917 #[derive(serde::Serialize)]
919 struct EventWithCorrelation<'a, T: serde::Serialize> {
920 correlation_id: &'a str,
921 #[serde(flatten)]
922 event: &'a T,
923 }
924 let wrapped = EventWithCorrelation {
925 correlation_id,
926 event,
927 };
928 encoding.serialize(&wrapped)?
929 }
930 };
931
932 {
934 let outbound = self.outbound_tx.lock().await;
935 if let Some(tx) = outbound.as_ref() {
936 tx.send((msg_type, payload_bytes))
937 .await
938 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
939 } else {
940 return Err(AgentProtocolError::ConnectionClosed);
941 }
942 }
943
944 self.in_flight.fetch_add(1, Ordering::Relaxed);
945
946 let response = tokio::time::timeout(self.timeout, rx)
948 .await
949 .map_err(|_| {
950 self.pending
951 .try_lock()
952 .ok()
953 .map(|mut p| p.remove(correlation_id));
954 AgentProtocolError::Timeout(self.timeout)
955 })?
956 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
957
958 self.in_flight.fetch_sub(1, Ordering::Relaxed);
959
960 Ok(response)
961 }
962
963 pub async fn cancel_request(
965 &self,
966 correlation_id: &str,
967 reason: super::client::CancelReason,
968 ) -> Result<(), AgentProtocolError> {
969 let cancel = serde_json::json!({
970 "correlation_id": correlation_id,
971 "reason": reason as i32,
972 "timestamp_ms": now_ms(),
973 });
974
975 let payload = serde_json::to_vec(&cancel)
976 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
977
978 let outbound = self.outbound_tx.lock().await;
979 if let Some(tx) = outbound.as_ref() {
980 tx.send((MessageType::Cancel, payload))
981 .await
982 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
983 }
984
985 self.pending.lock().await.remove(correlation_id);
987
988 Ok(())
989 }
990
991 pub async fn cancel_all(
993 &self,
994 reason: super::client::CancelReason,
995 ) -> Result<usize, AgentProtocolError> {
996 let pending_ids: Vec<String> = self.pending.lock().await.keys().cloned().collect();
997 let count = pending_ids.len();
998
999 for correlation_id in pending_ids {
1000 let _ = self.cancel_request(&correlation_id, reason).await;
1001 }
1002
1003 Ok(count)
1004 }
1005
1006 pub async fn ping(&self) -> Result<(), AgentProtocolError> {
1008 let seq = self.ping_sequence.fetch_add(1, Ordering::Relaxed);
1009 let ping = serde_json::json!({
1010 "sequence": seq,
1011 "timestamp_ms": now_ms(),
1012 });
1013
1014 let payload = serde_json::to_vec(&ping)
1015 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
1016
1017 let outbound = self.outbound_tx.lock().await;
1018 if let Some(tx) = outbound.as_ref() {
1019 tx.send((MessageType::Ping, payload))
1020 .await
1021 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
1022 }
1023
1024 Ok(())
1025 }
1026
1027 pub async fn close(&self) -> Result<(), AgentProtocolError> {
1029 *self.connected.write().await = false;
1030 *self.outbound_tx.lock().await = None;
1031 Ok(())
1032 }
1033
1034 pub fn in_flight(&self) -> u64 {
1036 self.in_flight.load(Ordering::Relaxed)
1037 }
1038
1039 pub fn agent_id(&self) -> &str {
1041 &self.agent_id
1042 }
1043
1044 pub async fn is_paused(&self) -> bool {
1049 matches!(*self.flow_state.read().await, FlowState::Paused)
1050 }
1051
1052 pub async fn can_accept_requests(&self) -> bool {
1056 !self.is_paused().await
1057 }
1058}
1059
1060pub async fn write_message<W: AsyncWriteExt + Unpin>(
1062 writer: &mut W,
1063 msg_type: MessageType,
1064 payload: &[u8],
1065) -> Result<(), AgentProtocolError> {
1066 if payload.len() > MAX_UDS_MESSAGE_SIZE {
1067 return Err(AgentProtocolError::MessageTooLarge {
1068 size: payload.len(),
1069 max: MAX_UDS_MESSAGE_SIZE,
1070 });
1071 }
1072
1073 let total_len = (payload.len() + 1) as u32;
1075 writer.write_all(&total_len.to_be_bytes()).await?;
1076
1077 writer.write_all(&[msg_type as u8]).await?;
1079
1080 writer.write_all(payload).await?;
1082 writer.flush().await?;
1083
1084 Ok(())
1085}
1086
1087pub async fn read_message<R: AsyncReadExt + Unpin>(
1089 reader: &mut R,
1090) -> Result<(MessageType, Vec<u8>), AgentProtocolError> {
1091 let mut len_bytes = [0u8; 4];
1093 match reader.read_exact(&mut len_bytes).await {
1094 Ok(_) => {}
1095 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
1096 return Err(AgentProtocolError::ConnectionClosed);
1097 }
1098 Err(e) => return Err(e.into()),
1099 }
1100
1101 let total_len = u32::from_be_bytes(len_bytes) as usize;
1102
1103 if total_len == 0 {
1104 return Err(AgentProtocolError::InvalidMessage(
1105 "Zero-length message".to_string(),
1106 ));
1107 }
1108
1109 if total_len > MAX_UDS_MESSAGE_SIZE {
1110 return Err(AgentProtocolError::MessageTooLarge {
1111 size: total_len,
1112 max: MAX_UDS_MESSAGE_SIZE,
1113 });
1114 }
1115
1116 let mut type_byte = [0u8; 1];
1118 reader.read_exact(&mut type_byte).await?;
1119 let msg_type = MessageType::try_from(type_byte[0])?;
1120
1121 let payload_len = total_len - 1;
1123 let mut payload = vec![0u8; payload_len];
1124 if payload_len > 0 {
1125 reader.read_exact(&mut payload).await?;
1126 }
1127
1128 Ok((msg_type, payload))
1129}
1130
1131fn now_ms() -> u64 {
1132 std::time::SystemTime::now()
1133 .duration_since(std::time::UNIX_EPOCH)
1134 .map(|d| d.as_millis() as u64)
1135 .unwrap_or(0)
1136}
1137
1138#[cfg(test)]
1139mod tests {
1140 use super::*;
1141
1142 #[test]
1143 fn test_message_type_roundtrip() {
1144 let types = [
1145 MessageType::HandshakeRequest,
1146 MessageType::HandshakeResponse,
1147 MessageType::RequestHeaders,
1148 MessageType::AgentResponse,
1149 MessageType::HealthStatus,
1150 MessageType::Ping,
1151 MessageType::Pong,
1152 ];
1153
1154 for msg_type in types {
1155 let byte = msg_type as u8;
1156 let parsed = MessageType::try_from(byte).unwrap();
1157 assert_eq!(parsed, msg_type);
1158 }
1159 }
1160
1161 #[test]
1162 fn test_invalid_message_type() {
1163 let result = MessageType::try_from(0xFF);
1164 assert!(result.is_err());
1165 }
1166
1167 #[test]
1168 fn test_handshake_serialization() {
1169 let req = UdsHandshakeRequest {
1170 supported_versions: vec![2],
1171 proxy_id: "test-proxy".to_string(),
1172 proxy_version: "1.0.0".to_string(),
1173 config: None,
1174 supported_encodings: vec![],
1175 };
1176
1177 let json = serde_json::to_string(&req).unwrap();
1178 let parsed: UdsHandshakeRequest = serde_json::from_str(&json).unwrap();
1179
1180 assert_eq!(parsed.supported_versions, vec![2]);
1181 assert_eq!(parsed.proxy_id, "test-proxy");
1182 }
1183
1184 #[tokio::test]
1185 async fn test_write_read_message() {
1186 use tokio::io::duplex;
1187
1188 let (mut client, mut server) = duplex(1024);
1189
1190 let payload = b"test payload";
1192 write_message(&mut client, MessageType::Ping, payload)
1193 .await
1194 .unwrap();
1195
1196 let (msg_type, data) = read_message(&mut server).await.unwrap();
1198 assert_eq!(msg_type, MessageType::Ping);
1199 assert_eq!(data, payload);
1200 }
1201
1202 #[test]
1203 fn test_binary_body_chunk_json_serialization() {
1204 use base64::{engine::general_purpose::STANDARD, Engine as _};
1205
1206 let data = bytes::Bytes::from_static(b"test binary data with \x00 null bytes");
1207 let correlation_id = "test-123";
1208
1209 let json = serde_json::json!({
1211 "correlation_id": correlation_id,
1212 "data": STANDARD.encode(&data),
1213 "is_last": true,
1214 "total_size": 100usize,
1215 "chunk_index": 0u32,
1216 "bytes_received": 100usize,
1217 });
1218
1219 let serialized = serde_json::to_vec(&json).unwrap();
1220 let parsed: serde_json::Value = serde_json::from_slice(&serialized).unwrap();
1221
1222 let data_field = parsed["data"].as_str().unwrap();
1224 let decoded = STANDARD.decode(data_field).unwrap();
1225 assert_eq!(decoded, data.as_ref());
1226 }
1227
1228 #[test]
1229 #[cfg(feature = "binary-uds")]
1230 fn test_binary_body_chunk_msgpack_serialization() {
1231 let data = bytes::Bytes::from_static(b"test binary data with \x00 null bytes");
1232 let correlation_id = "test-123";
1233
1234 #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
1236 struct BinaryBodyChunk {
1237 correlation_id: String,
1238 #[serde(with = "serde_bytes")]
1239 data: Vec<u8>,
1240 is_last: bool,
1241 chunk_index: u32,
1242 }
1243
1244 let chunk = BinaryBodyChunk {
1245 correlation_id: correlation_id.to_string(),
1246 data: data.to_vec(),
1247 is_last: true,
1248 chunk_index: 0,
1249 };
1250
1251 let serialized = rmp_serde::to_vec(&chunk).unwrap();
1253
1254 let parsed: BinaryBodyChunk = rmp_serde::from_slice(&serialized).unwrap();
1256 assert_eq!(parsed.correlation_id, correlation_id);
1257 assert_eq!(parsed.data, data.as_ref());
1258 assert!(parsed.is_last);
1259
1260 use base64::Engine as _;
1262 let json_size = serde_json::to_vec(&serde_json::json!({
1263 "correlation_id": correlation_id,
1264 "data": base64::engine::general_purpose::STANDARD.encode(&data),
1265 "is_last": true,
1266 "chunk_index": 0u32,
1267 }))
1268 .unwrap()
1269 .len();
1270
1271 assert!(
1273 serialized.len() < json_size,
1274 "MessagePack ({}) should be smaller than JSON+base64 ({})",
1275 serialized.len(),
1276 json_size
1277 );
1278 }
1279
1280 #[test]
1281 fn test_uds_encoding_default() {
1282 assert_eq!(UdsEncoding::default(), UdsEncoding::Json);
1283 }
1284
1285 #[test]
1286 fn test_uds_encoding_serialize_json() {
1287 let encoding = UdsEncoding::Json;
1288 let value = serde_json::json!({"key": "value"});
1289 let serialized = encoding.serialize(&value).unwrap();
1290 let parsed: serde_json::Value = serde_json::from_slice(&serialized).unwrap();
1291 assert_eq!(parsed, value);
1292 }
1293
1294 #[test]
1295 #[cfg(feature = "binary-uds")]
1296 fn test_uds_encoding_serialize_msgpack() {
1297 let encoding = UdsEncoding::MessagePack;
1298 let value = serde_json::json!({"key": "value"});
1299 let serialized = encoding.serialize(&value).unwrap();
1300 let parsed: serde_json::Value = rmp_serde::from_slice(&serialized).unwrap();
1302 assert_eq!(parsed, value);
1303 }
1304}