Skip to main content

grapsus_agent_protocol/v2/
uds.rs

1//! Unix Domain Socket transport for Agent Protocol v2.
2//!
3//! This module provides a binary protocol implementation for v2 over UDS,
4//! supporting bidirectional streaming with connection multiplexing.
5//!
6//! # Wire Format
7//!
8//! All messages use a length-prefixed binary format:
9//! ```text
10//! +--------+--------+------------------+
11//! | Length | Type   | Payload          |
12//! | 4 bytes| 1 byte | variable         |
13//! | BE u32 | u8     | MessagePack/JSON |
14//! +--------+--------+------------------+
15//! ```
16//!
17//! # Message Types
18//!
19//! - 0x01: Handshake Request (proxy -> agent)
20//! - 0x02: Handshake Response (agent -> proxy)
21//! - 0x10: Request Headers Event
22//! - 0x11: Request Body Chunk Event
23//! - 0x12: Response Headers Event
24//! - 0x13: Response Body Chunk Event
25//! - 0x14: Request Complete Event
26//! - 0x15: WebSocket Frame Event
27//! - 0x16: Guardrail Inspect Event
28//! - 0x17: Configure Event
29//! - 0x20: Agent Response
30//! - 0x30: Health Status
31//! - 0x31: Metrics Report
32//! - 0x32: Config Update Request
33//! - 0x33: Flow Control Signal
34//! - 0x40: Cancel Request
35//! - 0x41: Ping
36//! - 0x42: Pong
37
38use std::collections::HashMap;
39use std::sync::atomic::{AtomicU64, Ordering};
40use std::sync::Arc;
41use std::time::Duration;
42
43use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
44use tokio::net::UnixStream;
45use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
46use tracing::{debug, error, info, trace, warn};
47
48use crate::v2::pool::CHANNEL_BUFFER_SIZE;
49use crate::v2::{AgentCapabilities, AgentFeatures, AgentLimits, HealthConfig, PROTOCOL_VERSION_2};
50use crate::{AgentProtocolError, AgentResponse, EventType};
51
52use super::client::{ConfigUpdateCallback, FlowState, MetricsCallback};
53
54/// Maximum message size for UDS transport (16 MB).
55pub const MAX_UDS_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
56
57/// Payload encoding for UDS transport.
58///
59/// Negotiated during handshake. The proxy sends its supported encodings,
60/// and the agent responds with the chosen encoding.
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
62#[serde(rename_all = "lowercase")]
63pub enum UdsEncoding {
64    /// JSON encoding (default, always supported)
65    #[default]
66    Json,
67    /// MessagePack binary encoding (requires `binary-uds` feature)
68    #[serde(rename = "msgpack")]
69    MessagePack,
70}
71
72impl UdsEncoding {
73    /// Serialize a value using this encoding.
74    ///
75    /// Returns the serialized bytes, or an error if serialization fails.
76    #[inline]
77    pub fn serialize<T: serde::Serialize>(&self, value: &T) -> Result<Vec<u8>, AgentProtocolError> {
78        match self {
79            UdsEncoding::Json => serde_json::to_vec(value)
80                .map_err(|e| AgentProtocolError::Serialization(e.to_string())),
81            #[cfg(feature = "binary-uds")]
82            UdsEncoding::MessagePack => rmp_serde::to_vec(value)
83                .map_err(|e| AgentProtocolError::Serialization(e.to_string())),
84            #[cfg(not(feature = "binary-uds"))]
85            UdsEncoding::MessagePack => {
86                // Fall back to JSON if binary-uds feature is not enabled
87                serde_json::to_vec(value)
88                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))
89            }
90        }
91    }
92
93    /// Deserialize a value using this encoding.
94    ///
95    /// Returns the deserialized value, or an error if deserialization fails.
96    #[inline]
97    pub fn deserialize<'a, T: serde::Deserialize<'a>>(
98        &self,
99        bytes: &'a [u8],
100    ) -> Result<T, AgentProtocolError> {
101        match self {
102            UdsEncoding::Json => serde_json::from_slice(bytes)
103                .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string())),
104            #[cfg(feature = "binary-uds")]
105            UdsEncoding::MessagePack => rmp_serde::from_slice(bytes)
106                .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string())),
107            #[cfg(not(feature = "binary-uds"))]
108            UdsEncoding::MessagePack => {
109                // Fall back to JSON if binary-uds feature is not enabled
110                serde_json::from_slice(bytes)
111                    .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))
112            }
113        }
114    }
115}
116
117/// Message type identifiers for the binary protocol.
118#[repr(u8)]
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum MessageType {
121    // Handshake
122    HandshakeRequest = 0x01,
123    HandshakeResponse = 0x02,
124
125    // Events (proxy -> agent)
126    RequestHeaders = 0x10,
127    RequestBodyChunk = 0x11,
128    ResponseHeaders = 0x12,
129    ResponseBodyChunk = 0x13,
130    RequestComplete = 0x14,
131    WebSocketFrame = 0x15,
132    GuardrailInspect = 0x16,
133    Configure = 0x17,
134
135    // Response (agent -> proxy)
136    AgentResponse = 0x20,
137
138    // Control messages (bidirectional)
139    HealthStatus = 0x30,
140    MetricsReport = 0x31,
141    ConfigUpdateRequest = 0x32,
142    FlowControl = 0x33,
143
144    // Management
145    Cancel = 0x40,
146    Ping = 0x41,
147    Pong = 0x42,
148}
149
150impl TryFrom<u8> for MessageType {
151    type Error = AgentProtocolError;
152
153    fn try_from(value: u8) -> Result<Self, Self::Error> {
154        match value {
155            0x01 => Ok(MessageType::HandshakeRequest),
156            0x02 => Ok(MessageType::HandshakeResponse),
157            0x10 => Ok(MessageType::RequestHeaders),
158            0x11 => Ok(MessageType::RequestBodyChunk),
159            0x12 => Ok(MessageType::ResponseHeaders),
160            0x13 => Ok(MessageType::ResponseBodyChunk),
161            0x14 => Ok(MessageType::RequestComplete),
162            0x15 => Ok(MessageType::WebSocketFrame),
163            0x16 => Ok(MessageType::GuardrailInspect),
164            0x17 => Ok(MessageType::Configure),
165            0x20 => Ok(MessageType::AgentResponse),
166            0x30 => Ok(MessageType::HealthStatus),
167            0x31 => Ok(MessageType::MetricsReport),
168            0x32 => Ok(MessageType::ConfigUpdateRequest),
169            0x33 => Ok(MessageType::FlowControl),
170            0x40 => Ok(MessageType::Cancel),
171            0x41 => Ok(MessageType::Ping),
172            0x42 => Ok(MessageType::Pong),
173            _ => Err(AgentProtocolError::InvalidMessage(format!(
174                "Unknown message type: 0x{:02x}",
175                value
176            ))),
177        }
178    }
179}
180
181/// Handshake request sent from proxy to agent over UDS.
182#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
183pub struct UdsHandshakeRequest {
184    pub supported_versions: Vec<u32>,
185    pub proxy_id: String,
186    pub proxy_version: String,
187    pub config: Option<serde_json::Value>,
188    /// Supported payload encodings (in order of preference).
189    /// If empty or missing, only JSON is supported.
190    #[serde(default, skip_serializing_if = "Vec::is_empty")]
191    pub supported_encodings: Vec<UdsEncoding>,
192}
193
194/// Handshake response from agent to proxy over UDS.
195#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
196pub struct UdsHandshakeResponse {
197    pub protocol_version: u32,
198    pub capabilities: UdsCapabilities,
199    pub success: bool,
200    pub error: Option<String>,
201    /// Negotiated encoding for subsequent messages.
202    /// If missing, defaults to JSON for backwards compatibility.
203    #[serde(default)]
204    pub encoding: UdsEncoding,
205}
206
207/// Agent capabilities for UDS protocol.
208#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
209pub struct UdsCapabilities {
210    pub agent_id: String,
211    pub name: String,
212    pub version: String,
213    pub supported_events: Vec<i32>,
214    pub features: UdsFeatures,
215    pub limits: UdsLimits,
216}
217
218/// Agent features.
219#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
220pub struct UdsFeatures {
221    pub streaming_body: bool,
222    pub websocket: bool,
223    pub guardrails: bool,
224    pub config_push: bool,
225    pub metrics_export: bool,
226    pub concurrent_requests: u32,
227    pub cancellation: bool,
228    pub flow_control: bool,
229    pub health_reporting: bool,
230}
231
232/// Agent limits.
233#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
234pub struct UdsLimits {
235    pub max_body_size: u64,
236    pub max_concurrency: u32,
237    pub preferred_chunk_size: u64,
238}
239
240impl From<UdsCapabilities> for AgentCapabilities {
241    fn from(caps: UdsCapabilities) -> Self {
242        AgentCapabilities {
243            protocol_version: PROTOCOL_VERSION_2,
244            agent_id: caps.agent_id,
245            name: caps.name,
246            version: caps.version,
247            supported_events: caps
248                .supported_events
249                .into_iter()
250                .filter_map(event_type_from_i32)
251                .collect(),
252            features: AgentFeatures {
253                streaming_body: caps.features.streaming_body,
254                websocket: caps.features.websocket,
255                guardrails: caps.features.guardrails,
256                config_push: caps.features.config_push,
257                metrics_export: caps.features.metrics_export,
258                concurrent_requests: caps.features.concurrent_requests,
259                cancellation: caps.features.cancellation,
260                flow_control: caps.features.flow_control,
261                health_reporting: caps.features.health_reporting,
262            },
263            limits: AgentLimits {
264                max_body_size: caps.limits.max_body_size as usize,
265                max_concurrency: caps.limits.max_concurrency,
266                preferred_chunk_size: caps.limits.preferred_chunk_size as usize,
267                max_memory: None,
268                max_processing_time_ms: None,
269            },
270            health: HealthConfig::default(),
271        }
272    }
273}
274
275impl From<AgentCapabilities> for UdsCapabilities {
276    fn from(caps: AgentCapabilities) -> Self {
277        use crate::v2::server::event_type_to_i32;
278        UdsCapabilities {
279            agent_id: caps.agent_id,
280            name: caps.name,
281            version: caps.version,
282            supported_events: caps
283                .supported_events
284                .iter()
285                .map(|e| event_type_to_i32(*e))
286                .collect(),
287            features: UdsFeatures {
288                streaming_body: caps.features.streaming_body,
289                websocket: caps.features.websocket,
290                guardrails: caps.features.guardrails,
291                config_push: caps.features.config_push,
292                metrics_export: caps.features.metrics_export,
293                concurrent_requests: caps.features.concurrent_requests,
294                cancellation: caps.features.cancellation,
295                flow_control: caps.features.flow_control,
296                health_reporting: caps.features.health_reporting,
297            },
298            limits: UdsLimits {
299                max_body_size: caps.limits.max_body_size as u64,
300                max_concurrency: caps.limits.max_concurrency,
301                preferred_chunk_size: caps.limits.preferred_chunk_size as u64,
302            },
303        }
304    }
305}
306
307/// Convert i32 to EventType.
308fn event_type_from_i32(value: i32) -> Option<EventType> {
309    match value {
310        0 => Some(EventType::Configure),
311        1 => Some(EventType::RequestHeaders),
312        2 => Some(EventType::RequestBodyChunk),
313        3 => Some(EventType::ResponseHeaders),
314        4 => Some(EventType::ResponseBodyChunk),
315        5 => Some(EventType::RequestComplete),
316        6 => Some(EventType::WebSocketFrame),
317        7 => Some(EventType::GuardrailInspect),
318        _ => None,
319    }
320}
321
322/// v2 agent client over Unix Domain Socket.
323///
324/// This client maintains a single connection and multiplexes multiple requests
325/// over it using correlation IDs, similar to the gRPC client.
326pub struct AgentClientV2Uds {
327    /// Agent identifier
328    agent_id: String,
329    /// Socket path
330    socket_path: String,
331    /// Request timeout
332    timeout: Duration,
333    /// Negotiated capabilities
334    capabilities: RwLock<Option<AgentCapabilities>>,
335    /// Negotiated protocol version
336    protocol_version: AtomicU64,
337    /// Negotiated payload encoding
338    encoding: RwLock<UdsEncoding>,
339    /// Pending requests by correlation ID
340    pending: Arc<Mutex<HashMap<String, oneshot::Sender<AgentResponse>>>>,
341    /// Sender for outbound messages
342    #[allow(clippy::type_complexity)]
343    outbound_tx: Mutex<Option<mpsc::Sender<(MessageType, Vec<u8>)>>>,
344    /// Sequence counter for pings
345    ping_sequence: AtomicU64,
346    /// Connection state
347    connected: RwLock<bool>,
348    /// Flow control state
349    flow_state: RwLock<FlowState>,
350    /// Last known health state
351    health_state: RwLock<i32>,
352    /// In-flight request count
353    in_flight: AtomicU64,
354    /// Callback for metrics reports
355    metrics_callback: Option<MetricsCallback>,
356    /// Callback for config update requests
357    config_update_callback: Option<ConfigUpdateCallback>,
358}
359
360impl AgentClientV2Uds {
361    /// Create a new UDS v2 client.
362    pub async fn new(
363        agent_id: impl Into<String>,
364        socket_path: impl Into<String>,
365        timeout: Duration,
366    ) -> Result<Self, AgentProtocolError> {
367        let agent_id = agent_id.into();
368        let socket_path = socket_path.into();
369
370        debug!(
371            agent_id = %agent_id,
372            socket_path = %socket_path,
373            timeout_ms = timeout.as_millis(),
374            "Creating UDS v2 client"
375        );
376
377        Ok(Self {
378            agent_id,
379            socket_path,
380            timeout,
381            capabilities: RwLock::new(None),
382            protocol_version: AtomicU64::new(0),
383            encoding: RwLock::new(UdsEncoding::Json),
384            pending: Arc::new(Mutex::new(HashMap::new())),
385            outbound_tx: Mutex::new(None),
386            ping_sequence: AtomicU64::new(0),
387            connected: RwLock::new(false),
388            flow_state: RwLock::new(FlowState::Normal),
389            health_state: RwLock::new(1), // HEALTHY
390            in_flight: AtomicU64::new(0),
391            metrics_callback: None,
392            config_update_callback: None,
393        })
394    }
395
396    /// Returns the list of supported encodings for this client.
397    ///
398    /// When compiled with `binary-uds` feature, MessagePack is preferred.
399    fn supported_encodings() -> Vec<UdsEncoding> {
400        #[cfg(feature = "binary-uds")]
401        {
402            vec![UdsEncoding::MessagePack, UdsEncoding::Json]
403        }
404        #[cfg(not(feature = "binary-uds"))]
405        {
406            vec![UdsEncoding::Json]
407        }
408    }
409
410    /// Get the current negotiated encoding.
411    pub async fn encoding(&self) -> UdsEncoding {
412        *self.encoding.read().await
413    }
414
415    /// Set the metrics callback.
416    pub fn set_metrics_callback(&mut self, callback: MetricsCallback) {
417        self.metrics_callback = Some(callback);
418    }
419
420    /// Set the config update callback.
421    pub fn set_config_update_callback(&mut self, callback: ConfigUpdateCallback) {
422        self.config_update_callback = Some(callback);
423    }
424
425    /// Connect and perform handshake.
426    pub async fn connect(&self) -> Result<(), AgentProtocolError> {
427        info!(
428            agent_id = %self.agent_id,
429            socket_path = %self.socket_path,
430            "Connecting to agent via UDS v2"
431        );
432
433        // Connect to Unix socket
434        let stream = UnixStream::connect(&self.socket_path).await.map_err(|e| {
435            error!(
436                agent_id = %self.agent_id,
437                socket_path = %self.socket_path,
438                error = %e,
439                "Failed to connect to agent via UDS"
440            );
441            AgentProtocolError::ConnectionFailed(e.to_string())
442        })?;
443
444        let (read_half, write_half) = stream.into_split();
445        let mut reader = BufReader::new(read_half);
446        let mut writer = BufWriter::new(write_half);
447
448        // Send handshake request with supported encodings
449        let handshake_req = UdsHandshakeRequest {
450            supported_versions: vec![PROTOCOL_VERSION_2],
451            proxy_id: "grapsus-proxy".to_string(),
452            proxy_version: env!("CARGO_PKG_VERSION").to_string(),
453            config: None,
454            supported_encodings: Self::supported_encodings(),
455        };
456
457        // Handshake always uses JSON (before encoding is negotiated)
458        let payload = serde_json::to_vec(&handshake_req)
459            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
460
461        write_message(&mut writer, MessageType::HandshakeRequest, &payload).await?;
462
463        // Read handshake response (always JSON)
464        let (msg_type, response_bytes) = read_message(&mut reader).await?;
465
466        if msg_type != MessageType::HandshakeResponse {
467            return Err(AgentProtocolError::InvalidMessage(format!(
468                "Expected HandshakeResponse, got {:?}",
469                msg_type
470            )));
471        }
472
473        let response: UdsHandshakeResponse = serde_json::from_slice(&response_bytes)
474            .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
475
476        if !response.success {
477            return Err(AgentProtocolError::ConnectionFailed(
478                response
479                    .error
480                    .unwrap_or_else(|| "Unknown handshake error".to_string()),
481            ));
482        }
483
484        // Store capabilities and negotiated encoding
485        let capabilities: AgentCapabilities = response.capabilities.into();
486        *self.capabilities.write().await = Some(capabilities);
487        self.protocol_version
488            .store(response.protocol_version as u64, Ordering::SeqCst);
489
490        // Store the negotiated encoding for subsequent messages
491        let negotiated_encoding = response.encoding;
492        *self.encoding.write().await = negotiated_encoding;
493
494        info!(
495            agent_id = %self.agent_id,
496            protocol_version = response.protocol_version,
497            encoding = ?negotiated_encoding,
498            "UDS v2 handshake successful"
499        );
500
501        // Create message channel
502        let (tx, mut rx) = mpsc::channel::<(MessageType, Vec<u8>)>(CHANNEL_BUFFER_SIZE);
503        *self.outbound_tx.lock().await = Some(tx);
504        *self.connected.write().await = true;
505
506        // Spawn writer task
507        let agent_id_clone = self.agent_id.clone();
508        tokio::spawn(async move {
509            while let Some((msg_type, payload)) = rx.recv().await {
510                if let Err(e) = write_message(&mut writer, msg_type, &payload).await {
511                    error!(
512                        agent_id = %agent_id_clone,
513                        error = %e,
514                        "Failed to write message to UDS"
515                    );
516                    break;
517                }
518            }
519            debug!(agent_id = %agent_id_clone, "UDS writer task ended");
520        });
521
522        // Spawn reader task with the negotiated encoding
523        let pending = Arc::clone(&self.pending);
524        let agent_id = self.agent_id.clone();
525        let flow_state = Arc::new(RwLock::new(FlowState::Normal));
526        let health_state = Arc::new(RwLock::new(1i32));
527        let flow_state_clone = Arc::clone(&flow_state);
528        let health_state_clone = Arc::clone(&health_state);
529        let metrics_callback = self.metrics_callback.clone();
530        let config_update_callback = self.config_update_callback.clone();
531        // Encoding is fixed after handshake, so we can copy it
532        let reader_encoding = negotiated_encoding;
533
534        tokio::spawn(async move {
535            loop {
536                match read_message(&mut reader).await {
537                    Ok((msg_type, payload)) => {
538                        match msg_type {
539                            MessageType::AgentResponse => {
540                                match reader_encoding.deserialize::<AgentResponse>(&payload) {
541                                    Ok(response) => {
542                                        // Extract correlation ID from the response
543                                        // For UDS, we include correlation_id in the response
544                                        if let Some(sender) = pending.lock().await.remove(
545                                            &response
546                                                .audit
547                                                .custom
548                                                .get("correlation_id")
549                                                .and_then(|v| v.as_str())
550                                                .unwrap_or("")
551                                                .to_string(),
552                                        ) {
553                                            let _ = sender.send(response);
554                                        }
555                                    }
556                                    Err(e) => {
557                                        warn!(
558                                            agent_id = %agent_id,
559                                            error = %e,
560                                            encoding = ?reader_encoding,
561                                            "Failed to parse agent response"
562                                        );
563                                    }
564                                }
565                            }
566                            MessageType::HealthStatus => {
567                                // Health status uses a simple struct, try both encodings for robustness
568                                #[derive(serde::Deserialize)]
569                                struct HealthStatusMsg {
570                                    state: Option<i64>,
571                                }
572                                if let Ok(health) =
573                                    reader_encoding.deserialize::<HealthStatusMsg>(&payload)
574                                {
575                                    if let Some(state) = health.state {
576                                        *health_state_clone.write().await = state as i32;
577                                    }
578                                }
579                            }
580                            MessageType::MetricsReport => {
581                                if let Some(ref callback) = metrics_callback {
582                                    if let Ok(report) = reader_encoding.deserialize(&payload) {
583                                        callback(report);
584                                    }
585                                }
586                            }
587                            MessageType::FlowControl => {
588                                #[derive(serde::Deserialize)]
589                                struct FlowControlMsg {
590                                    action: Option<i64>,
591                                }
592                                if let Ok(fc) =
593                                    reader_encoding.deserialize::<FlowControlMsg>(&payload)
594                                {
595                                    let action = fc.action.unwrap_or(0);
596                                    let new_state = match action {
597                                        1 => FlowState::Paused,
598                                        2 => FlowState::Normal,
599                                        _ => FlowState::Normal,
600                                    };
601                                    *flow_state_clone.write().await = new_state;
602                                }
603                            }
604                            MessageType::ConfigUpdateRequest => {
605                                if let Some(ref callback) = config_update_callback {
606                                    if let Ok(request) = reader_encoding.deserialize(&payload) {
607                                        let _response = callback(agent_id.clone(), request);
608                                    }
609                                }
610                            }
611                            MessageType::Pong => {
612                                trace!(agent_id = %agent_id, "Received pong");
613                            }
614                            _ => {
615                                trace!(
616                                    agent_id = %agent_id,
617                                    msg_type = ?msg_type,
618                                    "Received unhandled message type"
619                                );
620                            }
621                        }
622                    }
623                    Err(e) => {
624                        if !matches!(e, AgentProtocolError::ConnectionClosed) {
625                            error!(
626                                agent_id = %agent_id,
627                                error = %e,
628                                "Error reading from UDS"
629                            );
630                        }
631                        break;
632                    }
633                }
634            }
635            debug!(agent_id = %agent_id, "UDS reader task ended");
636        });
637
638        Ok(())
639    }
640
641    /// Get negotiated capabilities.
642    pub async fn capabilities(&self) -> Option<AgentCapabilities> {
643        self.capabilities.read().await.clone()
644    }
645
646    /// Check if connected.
647    pub async fn is_connected(&self) -> bool {
648        *self.connected.read().await
649    }
650
651    /// Send a request headers event.
652    pub async fn send_request_headers(
653        &self,
654        correlation_id: &str,
655        event: &crate::RequestHeadersEvent,
656    ) -> Result<AgentResponse, AgentProtocolError> {
657        self.send_event(MessageType::RequestHeaders, correlation_id, event)
658            .await
659    }
660
661    /// Send a request body chunk event.
662    pub async fn send_request_body_chunk(
663        &self,
664        correlation_id: &str,
665        event: &crate::RequestBodyChunkEvent,
666    ) -> Result<AgentResponse, AgentProtocolError> {
667        self.send_event(MessageType::RequestBodyChunk, correlation_id, event)
668            .await
669    }
670
671    /// Send a response headers event.
672    pub async fn send_response_headers(
673        &self,
674        correlation_id: &str,
675        event: &crate::ResponseHeadersEvent,
676    ) -> Result<AgentResponse, AgentProtocolError> {
677        self.send_event(MessageType::ResponseHeaders, correlation_id, event)
678            .await
679    }
680
681    /// Send a response body chunk event.
682    pub async fn send_response_body_chunk(
683        &self,
684        correlation_id: &str,
685        event: &crate::ResponseBodyChunkEvent,
686    ) -> Result<AgentResponse, AgentProtocolError> {
687        self.send_event(MessageType::ResponseBodyChunk, correlation_id, event)
688            .await
689    }
690
691    /// Send a request complete event.
692    pub async fn send_request_complete(
693        &self,
694        correlation_id: &str,
695        event: &crate::RequestCompleteEvent,
696    ) -> Result<AgentResponse, AgentProtocolError> {
697        self.send_event(MessageType::RequestComplete, correlation_id, event)
698            .await
699    }
700
701    /// Send a WebSocket frame event.
702    pub async fn send_websocket_frame(
703        &self,
704        correlation_id: &str,
705        event: &crate::WebSocketFrameEvent,
706    ) -> Result<AgentResponse, AgentProtocolError> {
707        self.send_event(MessageType::WebSocketFrame, correlation_id, event)
708            .await
709    }
710
711    /// Send a guardrail inspect event.
712    pub async fn send_guardrail_inspect(
713        &self,
714        correlation_id: &str,
715        event: &crate::GuardrailInspectEvent,
716    ) -> Result<AgentResponse, AgentProtocolError> {
717        self.send_event(MessageType::GuardrailInspect, correlation_id, event)
718            .await
719    }
720
721    /// Send a configure event.
722    pub async fn send_configure(
723        &self,
724        correlation_id: &str,
725        event: &serde_json::Value,
726    ) -> Result<AgentResponse, AgentProtocolError> {
727        self.send_event(MessageType::Configure, correlation_id, event)
728            .await
729    }
730
731    /// Send a binary request body chunk event (zero-copy path).
732    ///
733    /// This method avoids base64 encoding when using MessagePack encoding,
734    /// sending raw bytes directly over the wire for better throughput.
735    ///
736    /// # Performance
737    ///
738    /// When MessagePack encoding is negotiated:
739    /// - Bytes are serialized directly (no base64 encode/decode)
740    /// - Reduces CPU usage and latency for large bodies
741    ///
742    /// When JSON encoding is used:
743    /// - Falls back to base64 encoding for JSON compatibility
744    pub async fn send_request_body_chunk_binary(
745        &self,
746        event: &crate::BinaryRequestBodyChunkEvent,
747    ) -> Result<AgentResponse, AgentProtocolError> {
748        let correlation_id = &event.correlation_id;
749        self.send_binary_body_chunk(
750            MessageType::RequestBodyChunk,
751            correlation_id,
752            &event.data,
753            event.is_last,
754            event.total_size,
755            event.chunk_index,
756            Some(event.bytes_received),
757            None,
758        )
759        .await
760    }
761
762    /// Send a binary response body chunk event (zero-copy path).
763    ///
764    /// This method avoids base64 encoding when using MessagePack encoding,
765    /// sending raw bytes directly over the wire for better throughput.
766    pub async fn send_response_body_chunk_binary(
767        &self,
768        event: &crate::BinaryResponseBodyChunkEvent,
769    ) -> Result<AgentResponse, AgentProtocolError> {
770        let correlation_id = &event.correlation_id;
771        self.send_binary_body_chunk(
772            MessageType::ResponseBodyChunk,
773            correlation_id,
774            &event.data,
775            event.is_last,
776            event.total_size,
777            event.chunk_index,
778            None,
779            Some(event.bytes_sent),
780        )
781        .await
782    }
783
784    /// Internal helper to send binary body chunks with encoding-aware serialization.
785    #[allow(clippy::too_many_arguments)]
786    async fn send_binary_body_chunk(
787        &self,
788        msg_type: MessageType,
789        correlation_id: &str,
790        data: &bytes::Bytes,
791        is_last: bool,
792        total_size: Option<usize>,
793        chunk_index: u32,
794        bytes_received: Option<usize>,
795        bytes_sent: Option<usize>,
796    ) -> Result<AgentResponse, AgentProtocolError> {
797        // Create response channel
798        let (tx, rx) = oneshot::channel();
799        self.pending
800            .lock()
801            .await
802            .insert(correlation_id.to_string(), tx);
803
804        // Get the current encoding
805        let encoding = *self.encoding.read().await;
806
807        // Serialize body chunk using encoding-optimized format
808        let payload_bytes = match encoding {
809            UdsEncoding::Json => {
810                // JSON path: must use base64 encoding for binary data
811                use base64::{engine::general_purpose::STANDARD, Engine as _};
812                let json = serde_json::json!({
813                    "correlation_id": correlation_id,
814                    "data": STANDARD.encode(data),
815                    "is_last": is_last,
816                    "total_size": total_size,
817                    "chunk_index": chunk_index,
818                    "bytes_received": bytes_received,
819                    "bytes_sent": bytes_sent,
820                });
821                serde_json::to_vec(&json)
822                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?
823            }
824            UdsEncoding::MessagePack => {
825                // MessagePack path: raw bytes via serde_bytes for zero-copy serialization
826                #[derive(serde::Serialize)]
827                struct BinaryBodyChunk<'a> {
828                    correlation_id: &'a str,
829                    #[serde(with = "serde_bytes")]
830                    data: &'a [u8],
831                    is_last: bool,
832                    #[serde(skip_serializing_if = "Option::is_none")]
833                    total_size: Option<usize>,
834                    chunk_index: u32,
835                    #[serde(skip_serializing_if = "Option::is_none")]
836                    bytes_received: Option<usize>,
837                    #[serde(skip_serializing_if = "Option::is_none")]
838                    bytes_sent: Option<usize>,
839                }
840                let chunk = BinaryBodyChunk {
841                    correlation_id,
842                    data: data.as_ref(),
843                    is_last,
844                    total_size,
845                    chunk_index,
846                    bytes_received,
847                    bytes_sent,
848                };
849                encoding.serialize(&chunk)?
850            }
851        };
852
853        // Send message
854        {
855            let outbound = self.outbound_tx.lock().await;
856            if let Some(tx) = outbound.as_ref() {
857                tx.send((msg_type, payload_bytes))
858                    .await
859                    .map_err(|_| AgentProtocolError::ConnectionClosed)?;
860            } else {
861                return Err(AgentProtocolError::ConnectionClosed);
862            }
863        }
864
865        self.in_flight.fetch_add(1, Ordering::Relaxed);
866
867        // Wait for response with timeout
868        let response = tokio::time::timeout(self.timeout, rx)
869            .await
870            .map_err(|_| {
871                self.pending
872                    .try_lock()
873                    .ok()
874                    .map(|mut p| p.remove(correlation_id));
875                AgentProtocolError::Timeout(self.timeout)
876            })?
877            .map_err(|_| AgentProtocolError::ConnectionClosed)?;
878
879        self.in_flight.fetch_sub(1, Ordering::Relaxed);
880
881        Ok(response)
882    }
883
884    /// Send an event and wait for response.
885    async fn send_event<T: serde::Serialize>(
886        &self,
887        msg_type: MessageType,
888        correlation_id: &str,
889        event: &T,
890    ) -> Result<AgentResponse, AgentProtocolError> {
891        // Create response channel
892        let (tx, rx) = oneshot::channel();
893        self.pending
894            .lock()
895            .await
896            .insert(correlation_id.to_string(), tx);
897
898        // Get the current encoding
899        let encoding = *self.encoding.read().await;
900
901        // Serialize event using negotiated encoding
902        let payload_bytes = match encoding {
903            UdsEncoding::Json => {
904                // JSON path: use Value mutation for backwards compatibility
905                let mut payload = serde_json::to_value(event)
906                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
907                if let Some(obj) = payload.as_object_mut() {
908                    obj.insert(
909                        "correlation_id".to_string(),
910                        serde_json::Value::String(correlation_id.to_string()),
911                    );
912                }
913                serde_json::to_vec(&payload)
914                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?
915            }
916            UdsEncoding::MessagePack => {
917                // MessagePack path: use wrapper struct for efficient serialization
918                #[derive(serde::Serialize)]
919                struct EventWithCorrelation<'a, T: serde::Serialize> {
920                    correlation_id: &'a str,
921                    #[serde(flatten)]
922                    event: &'a T,
923                }
924                let wrapped = EventWithCorrelation {
925                    correlation_id,
926                    event,
927                };
928                encoding.serialize(&wrapped)?
929            }
930        };
931
932        // Send message
933        {
934            let outbound = self.outbound_tx.lock().await;
935            if let Some(tx) = outbound.as_ref() {
936                tx.send((msg_type, payload_bytes))
937                    .await
938                    .map_err(|_| AgentProtocolError::ConnectionClosed)?;
939            } else {
940                return Err(AgentProtocolError::ConnectionClosed);
941            }
942        }
943
944        self.in_flight.fetch_add(1, Ordering::Relaxed);
945
946        // Wait for response with timeout
947        let response = tokio::time::timeout(self.timeout, rx)
948            .await
949            .map_err(|_| {
950                self.pending
951                    .try_lock()
952                    .ok()
953                    .map(|mut p| p.remove(correlation_id));
954                AgentProtocolError::Timeout(self.timeout)
955            })?
956            .map_err(|_| AgentProtocolError::ConnectionClosed)?;
957
958        self.in_flight.fetch_sub(1, Ordering::Relaxed);
959
960        Ok(response)
961    }
962
963    /// Send a cancel request for a specific correlation ID.
964    pub async fn cancel_request(
965        &self,
966        correlation_id: &str,
967        reason: super::client::CancelReason,
968    ) -> Result<(), AgentProtocolError> {
969        let cancel = serde_json::json!({
970            "correlation_id": correlation_id,
971            "reason": reason as i32,
972            "timestamp_ms": now_ms(),
973        });
974
975        let payload = serde_json::to_vec(&cancel)
976            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
977
978        let outbound = self.outbound_tx.lock().await;
979        if let Some(tx) = outbound.as_ref() {
980            tx.send((MessageType::Cancel, payload))
981                .await
982                .map_err(|_| AgentProtocolError::ConnectionClosed)?;
983        }
984
985        // Remove pending request
986        self.pending.lock().await.remove(correlation_id);
987
988        Ok(())
989    }
990
991    /// Cancel all in-flight requests.
992    pub async fn cancel_all(
993        &self,
994        reason: super::client::CancelReason,
995    ) -> Result<usize, AgentProtocolError> {
996        let pending_ids: Vec<String> = self.pending.lock().await.keys().cloned().collect();
997        let count = pending_ids.len();
998
999        for correlation_id in pending_ids {
1000            let _ = self.cancel_request(&correlation_id, reason).await;
1001        }
1002
1003        Ok(count)
1004    }
1005
1006    /// Send a ping.
1007    pub async fn ping(&self) -> Result<(), AgentProtocolError> {
1008        let seq = self.ping_sequence.fetch_add(1, Ordering::Relaxed);
1009        let ping = serde_json::json!({
1010            "sequence": seq,
1011            "timestamp_ms": now_ms(),
1012        });
1013
1014        let payload = serde_json::to_vec(&ping)
1015            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
1016
1017        let outbound = self.outbound_tx.lock().await;
1018        if let Some(tx) = outbound.as_ref() {
1019            tx.send((MessageType::Ping, payload))
1020                .await
1021                .map_err(|_| AgentProtocolError::ConnectionClosed)?;
1022        }
1023
1024        Ok(())
1025    }
1026
1027    /// Close the connection.
1028    pub async fn close(&self) -> Result<(), AgentProtocolError> {
1029        *self.connected.write().await = false;
1030        *self.outbound_tx.lock().await = None;
1031        Ok(())
1032    }
1033
1034    /// Get in-flight request count.
1035    pub fn in_flight(&self) -> u64 {
1036        self.in_flight.load(Ordering::Relaxed)
1037    }
1038
1039    /// Get agent ID.
1040    pub fn agent_id(&self) -> &str {
1041        &self.agent_id
1042    }
1043
1044    /// Check if the agent has requested flow control pause.
1045    ///
1046    /// Returns true if the agent sent a `FlowAction::Pause` signal,
1047    /// indicating it cannot accept more requests.
1048    pub async fn is_paused(&self) -> bool {
1049        matches!(*self.flow_state.read().await, FlowState::Paused)
1050    }
1051
1052    /// Check if the transport can accept new requests.
1053    ///
1054    /// Returns false if the agent has requested a flow control pause.
1055    pub async fn can_accept_requests(&self) -> bool {
1056        !self.is_paused().await
1057    }
1058}
1059
1060/// Write a message to the stream.
1061pub async fn write_message<W: AsyncWriteExt + Unpin>(
1062    writer: &mut W,
1063    msg_type: MessageType,
1064    payload: &[u8],
1065) -> Result<(), AgentProtocolError> {
1066    if payload.len() > MAX_UDS_MESSAGE_SIZE {
1067        return Err(AgentProtocolError::MessageTooLarge {
1068            size: payload.len(),
1069            max: MAX_UDS_MESSAGE_SIZE,
1070        });
1071    }
1072
1073    // Write length (4 bytes, big-endian) - includes type byte
1074    let total_len = (payload.len() + 1) as u32;
1075    writer.write_all(&total_len.to_be_bytes()).await?;
1076
1077    // Write message type (1 byte)
1078    writer.write_all(&[msg_type as u8]).await?;
1079
1080    // Write payload
1081    writer.write_all(payload).await?;
1082    writer.flush().await?;
1083
1084    Ok(())
1085}
1086
1087/// Read a message from the stream.
1088pub async fn read_message<R: AsyncReadExt + Unpin>(
1089    reader: &mut R,
1090) -> Result<(MessageType, Vec<u8>), AgentProtocolError> {
1091    // Read length (4 bytes, big-endian)
1092    let mut len_bytes = [0u8; 4];
1093    match reader.read_exact(&mut len_bytes).await {
1094        Ok(_) => {}
1095        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
1096            return Err(AgentProtocolError::ConnectionClosed);
1097        }
1098        Err(e) => return Err(e.into()),
1099    }
1100
1101    let total_len = u32::from_be_bytes(len_bytes) as usize;
1102
1103    if total_len == 0 {
1104        return Err(AgentProtocolError::InvalidMessage(
1105            "Zero-length message".to_string(),
1106        ));
1107    }
1108
1109    if total_len > MAX_UDS_MESSAGE_SIZE {
1110        return Err(AgentProtocolError::MessageTooLarge {
1111            size: total_len,
1112            max: MAX_UDS_MESSAGE_SIZE,
1113        });
1114    }
1115
1116    // Read message type (1 byte)
1117    let mut type_byte = [0u8; 1];
1118    reader.read_exact(&mut type_byte).await?;
1119    let msg_type = MessageType::try_from(type_byte[0])?;
1120
1121    // Read payload
1122    let payload_len = total_len - 1;
1123    let mut payload = vec![0u8; payload_len];
1124    if payload_len > 0 {
1125        reader.read_exact(&mut payload).await?;
1126    }
1127
1128    Ok((msg_type, payload))
1129}
1130
1131fn now_ms() -> u64 {
1132    std::time::SystemTime::now()
1133        .duration_since(std::time::UNIX_EPOCH)
1134        .map(|d| d.as_millis() as u64)
1135        .unwrap_or(0)
1136}
1137
1138#[cfg(test)]
1139mod tests {
1140    use super::*;
1141
1142    #[test]
1143    fn test_message_type_roundtrip() {
1144        let types = [
1145            MessageType::HandshakeRequest,
1146            MessageType::HandshakeResponse,
1147            MessageType::RequestHeaders,
1148            MessageType::AgentResponse,
1149            MessageType::HealthStatus,
1150            MessageType::Ping,
1151            MessageType::Pong,
1152        ];
1153
1154        for msg_type in types {
1155            let byte = msg_type as u8;
1156            let parsed = MessageType::try_from(byte).unwrap();
1157            assert_eq!(parsed, msg_type);
1158        }
1159    }
1160
1161    #[test]
1162    fn test_invalid_message_type() {
1163        let result = MessageType::try_from(0xFF);
1164        assert!(result.is_err());
1165    }
1166
1167    #[test]
1168    fn test_handshake_serialization() {
1169        let req = UdsHandshakeRequest {
1170            supported_versions: vec![2],
1171            proxy_id: "test-proxy".to_string(),
1172            proxy_version: "1.0.0".to_string(),
1173            config: None,
1174            supported_encodings: vec![],
1175        };
1176
1177        let json = serde_json::to_string(&req).unwrap();
1178        let parsed: UdsHandshakeRequest = serde_json::from_str(&json).unwrap();
1179
1180        assert_eq!(parsed.supported_versions, vec![2]);
1181        assert_eq!(parsed.proxy_id, "test-proxy");
1182    }
1183
1184    #[tokio::test]
1185    async fn test_write_read_message() {
1186        use tokio::io::duplex;
1187
1188        let (mut client, mut server) = duplex(1024);
1189
1190        // Write from client
1191        let payload = b"test payload";
1192        write_message(&mut client, MessageType::Ping, payload)
1193            .await
1194            .unwrap();
1195
1196        // Read from server
1197        let (msg_type, data) = read_message(&mut server).await.unwrap();
1198        assert_eq!(msg_type, MessageType::Ping);
1199        assert_eq!(data, payload);
1200    }
1201
1202    #[test]
1203    fn test_binary_body_chunk_json_serialization() {
1204        use base64::{engine::general_purpose::STANDARD, Engine as _};
1205
1206        let data = bytes::Bytes::from_static(b"test binary data with \x00 null bytes");
1207        let correlation_id = "test-123";
1208
1209        // JSON encoding must use base64
1210        let json = serde_json::json!({
1211            "correlation_id": correlation_id,
1212            "data": STANDARD.encode(&data),
1213            "is_last": true,
1214            "total_size": 100usize,
1215            "chunk_index": 0u32,
1216            "bytes_received": 100usize,
1217        });
1218
1219        let serialized = serde_json::to_vec(&json).unwrap();
1220        let parsed: serde_json::Value = serde_json::from_slice(&serialized).unwrap();
1221
1222        // Verify base64 can be decoded back to original
1223        let data_field = parsed["data"].as_str().unwrap();
1224        let decoded = STANDARD.decode(data_field).unwrap();
1225        assert_eq!(decoded, data.as_ref());
1226    }
1227
1228    #[test]
1229    #[cfg(feature = "binary-uds")]
1230    fn test_binary_body_chunk_msgpack_serialization() {
1231        let data = bytes::Bytes::from_static(b"test binary data with \x00 null bytes");
1232        let correlation_id = "test-123";
1233
1234        // MessagePack uses serde_bytes for efficient serialization
1235        #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
1236        struct BinaryBodyChunk {
1237            correlation_id: String,
1238            #[serde(with = "serde_bytes")]
1239            data: Vec<u8>,
1240            is_last: bool,
1241            chunk_index: u32,
1242        }
1243
1244        let chunk = BinaryBodyChunk {
1245            correlation_id: correlation_id.to_string(),
1246            data: data.to_vec(),
1247            is_last: true,
1248            chunk_index: 0,
1249        };
1250
1251        // Serialize with MessagePack
1252        let serialized = rmp_serde::to_vec(&chunk).unwrap();
1253
1254        // Deserialize and verify
1255        let parsed: BinaryBodyChunk = rmp_serde::from_slice(&serialized).unwrap();
1256        assert_eq!(parsed.correlation_id, correlation_id);
1257        assert_eq!(parsed.data, data.as_ref());
1258        assert!(parsed.is_last);
1259
1260        // Verify MessagePack is more compact than JSON+base64 for binary data
1261        use base64::Engine as _;
1262        let json_size = serde_json::to_vec(&serde_json::json!({
1263            "correlation_id": correlation_id,
1264            "data": base64::engine::general_purpose::STANDARD.encode(&data),
1265            "is_last": true,
1266            "chunk_index": 0u32,
1267        }))
1268        .unwrap()
1269        .len();
1270
1271        // MessagePack should be smaller (raw bytes vs base64 ~33% overhead)
1272        assert!(
1273            serialized.len() < json_size,
1274            "MessagePack ({}) should be smaller than JSON+base64 ({})",
1275            serialized.len(),
1276            json_size
1277        );
1278    }
1279
1280    #[test]
1281    fn test_uds_encoding_default() {
1282        assert_eq!(UdsEncoding::default(), UdsEncoding::Json);
1283    }
1284
1285    #[test]
1286    fn test_uds_encoding_serialize_json() {
1287        let encoding = UdsEncoding::Json;
1288        let value = serde_json::json!({"key": "value"});
1289        let serialized = encoding.serialize(&value).unwrap();
1290        let parsed: serde_json::Value = serde_json::from_slice(&serialized).unwrap();
1291        assert_eq!(parsed, value);
1292    }
1293
1294    #[test]
1295    #[cfg(feature = "binary-uds")]
1296    fn test_uds_encoding_serialize_msgpack() {
1297        let encoding = UdsEncoding::MessagePack;
1298        let value = serde_json::json!({"key": "value"});
1299        let serialized = encoding.serialize(&value).unwrap();
1300        // Verify it's valid MessagePack by deserializing
1301        let parsed: serde_json::Value = rmp_serde::from_slice(&serialized).unwrap();
1302        assert_eq!(parsed, value);
1303    }
1304}