Skip to main content

grapsus_agent_protocol/v2/
uds_server.rs

1//! Unix Domain Socket server for Agent Protocol v2.
2//!
3//! Provides a UDS-based v2 server that speaks the same binary wire format as
4//! [`AgentClientV2Uds`](super::uds::AgentClientV2Uds). Agents implement
5//! [`AgentHandlerV2`] and pass it to this server.
6
7use std::path::PathBuf;
8use std::sync::Arc;
9use std::time::Instant;
10
11use tokio::io::{BufReader, BufWriter};
12use tokio::net::{UnixListener, UnixStream};
13use tracing::{debug, error, info, trace, warn};
14
15use crate::v2::server::AgentHandlerV2;
16use crate::v2::uds::{
17    read_message, write_message, MessageType, UdsCapabilities, UdsEncoding, UdsHandshakeRequest,
18    UdsHandshakeResponse,
19};
20use crate::v2::HandshakeRequest;
21use crate::{
22    AgentProtocolError, AgentResponse, RequestBodyChunkEvent, RequestCompleteEvent,
23    RequestHeadersEvent, ResponseBodyChunkEvent, ResponseHeadersEvent, WebSocketFrameEvent,
24};
25
26/// v2 agent server over Unix Domain Socket.
27///
28/// Listens on a Unix socket, accepts connections, and dispatches events to an
29/// [`AgentHandlerV2`] implementation using the v2 binary wire format.
30pub struct UdsAgentServerV2 {
31    id: String,
32    socket_path: PathBuf,
33    handler: Arc<dyn AgentHandlerV2>,
34}
35
36impl UdsAgentServerV2 {
37    /// Create a new UDS v2 agent server.
38    pub fn new(
39        id: impl Into<String>,
40        socket_path: impl Into<PathBuf>,
41        handler: Box<dyn AgentHandlerV2>,
42    ) -> Self {
43        let id = id.into();
44        let socket_path = socket_path.into();
45
46        debug!(
47            agent_id = %id,
48            socket_path = %socket_path.display(),
49            "Creating UDS agent server v2"
50        );
51
52        Self {
53            id,
54            socket_path,
55            handler: Arc::from(handler),
56        }
57    }
58
59    /// Start the server.
60    ///
61    /// Removes any stale socket file, binds, and enters an accept loop that
62    /// spawns a task per connection.
63    pub async fn run(&self) -> Result<(), AgentProtocolError> {
64        // Remove existing socket file if it exists
65        if self.socket_path.exists() {
66            trace!(
67                agent_id = %self.id,
68                socket_path = %self.socket_path.display(),
69                "Removing existing socket file"
70            );
71            std::fs::remove_file(&self.socket_path)?;
72        }
73
74        let listener = UnixListener::bind(&self.socket_path)?;
75
76        info!(
77            agent_id = %self.id,
78            socket_path = %self.socket_path.display(),
79            "UDS agent server v2 listening"
80        );
81
82        loop {
83            match listener.accept().await {
84                Ok((stream, _addr)) => {
85                    trace!(agent_id = %self.id, "Accepted new connection");
86                    let handler = Arc::clone(&self.handler);
87                    let agent_id = self.id.clone();
88                    tokio::spawn(async move {
89                        if let Err(e) = handle_connection(handler, stream, agent_id.clone()).await {
90                            if !matches!(e, AgentProtocolError::ConnectionClosed) {
91                                error!(
92                                    agent_id = %agent_id,
93                                    error = %e,
94                                    "Error handling UDS v2 connection"
95                                );
96                            }
97                        }
98                    });
99                }
100                Err(e) => {
101                    error!(
102                        agent_id = %self.id,
103                        error = %e,
104                        "Failed to accept connection"
105                    );
106                }
107            }
108        }
109    }
110}
111
112/// Handle a single connection: handshake then event loop.
113async fn handle_connection(
114    handler: Arc<dyn AgentHandlerV2>,
115    stream: UnixStream,
116    agent_id: String,
117) -> Result<(), AgentProtocolError> {
118    let (read_half, write_half) = stream.into_split();
119    let mut reader = BufReader::new(read_half);
120    let mut writer = BufWriter::new(write_half);
121
122    // ── Handshake (always JSON) ──────────────────────────────────────────
123
124    let (msg_type, payload) = read_message(&mut reader).await?;
125    if msg_type != MessageType::HandshakeRequest {
126        return Err(AgentProtocolError::InvalidMessage(format!(
127            "Expected HandshakeRequest, got {:?}",
128            msg_type
129        )));
130    }
131
132    let uds_req: UdsHandshakeRequest = serde_json::from_slice(&payload)
133        .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
134
135    // Convert to domain-level HandshakeRequest
136    let handshake_req = HandshakeRequest {
137        supported_versions: uds_req.supported_versions,
138        proxy_id: uds_req.proxy_id,
139        proxy_version: uds_req.proxy_version,
140        config: uds_req.config.unwrap_or(serde_json::Value::Null),
141    };
142
143    let handshake_resp = handler.on_handshake(handshake_req).await;
144    let success = handshake_resp.success;
145
146    // Negotiate encoding: pick the first proxy-preferred encoding we support
147    let negotiated_encoding = negotiate_encoding(&uds_req.supported_encodings);
148
149    // Build UDS-level response
150    let uds_resp = UdsHandshakeResponse {
151        protocol_version: handshake_resp.protocol_version,
152        capabilities: UdsCapabilities::from(handshake_resp.capabilities),
153        success,
154        error: handshake_resp.error,
155        encoding: negotiated_encoding,
156    };
157
158    let resp_bytes = serde_json::to_vec(&uds_resp)
159        .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
160    write_message(&mut writer, MessageType::HandshakeResponse, &resp_bytes).await?;
161
162    if !success {
163        debug!(agent_id = %agent_id, "Handshake rejected, closing connection");
164        return Ok(());
165    }
166
167    info!(
168        agent_id = %agent_id,
169        encoding = ?negotiated_encoding,
170        "UDS v2 handshake completed"
171    );
172
173    // ── Event loop (uses negotiated encoding) ────────────────────────────
174
175    loop {
176        let (msg_type, payload) = read_message(&mut reader).await?;
177
178        match msg_type {
179            MessageType::Ping => {
180                trace!(agent_id = %agent_id, "Received ping, sending pong");
181                // Echo the payload back as pong
182                write_message(&mut writer, MessageType::Pong, &payload).await?;
183            }
184            MessageType::Cancel => {
185                // Extract correlation_id for logging
186                let cid = extract_correlation_id(&negotiated_encoding, &payload);
187                debug!(
188                    agent_id = %agent_id,
189                    correlation_id = %cid,
190                    "Request cancelled"
191                );
192            }
193            MessageType::RequestHeaders => {
194                let response =
195                    handle_request_headers(&handler, &negotiated_encoding, &payload).await;
196                write_response(&mut writer, &negotiated_encoding, response).await?;
197            }
198            MessageType::RequestBodyChunk => {
199                let response =
200                    handle_request_body_chunk(&handler, &negotiated_encoding, &payload).await;
201                write_response(&mut writer, &negotiated_encoding, response).await?;
202            }
203            MessageType::ResponseHeaders => {
204                let response =
205                    handle_response_headers(&handler, &negotiated_encoding, &payload).await;
206                write_response(&mut writer, &negotiated_encoding, response).await?;
207            }
208            MessageType::ResponseBodyChunk => {
209                let response =
210                    handle_response_body_chunk(&handler, &negotiated_encoding, &payload).await;
211                write_response(&mut writer, &negotiated_encoding, response).await?;
212            }
213            MessageType::RequestComplete => {
214                let response =
215                    handle_request_complete(&handler, &negotiated_encoding, &payload).await;
216                write_response(&mut writer, &negotiated_encoding, response).await?;
217            }
218            MessageType::WebSocketFrame => {
219                let response =
220                    handle_websocket_frame(&handler, &negotiated_encoding, &payload).await;
221                write_response(&mut writer, &negotiated_encoding, response).await?;
222            }
223            MessageType::Configure => {
224                let response = handle_configure(&handler, &negotiated_encoding, &payload).await;
225                write_response(&mut writer, &negotiated_encoding, response).await?;
226            }
227            _ => {
228                warn!(
229                    agent_id = %agent_id,
230                    msg_type = ?msg_type,
231                    "Received unhandled message type"
232                );
233            }
234        }
235    }
236}
237
238// ─── Encoding negotiation ────────────────────────────────────────────────────
239
240/// Pick the first proxy-preferred encoding that we support. Falls back to JSON.
241fn negotiate_encoding(proxy_encodings: &[UdsEncoding]) -> UdsEncoding {
242    for enc in proxy_encodings {
243        match enc {
244            UdsEncoding::Json => return UdsEncoding::Json,
245            UdsEncoding::MessagePack if cfg!(feature = "binary-uds") => {
246                return UdsEncoding::MessagePack;
247            }
248            _ => continue,
249        }
250    }
251    UdsEncoding::Json
252}
253
254// ─── Event handlers ──────────────────────────────────────────────────────────
255
256async fn handle_request_headers(
257    handler: &Arc<dyn AgentHandlerV2>,
258    encoding: &UdsEncoding,
259    payload: &[u8],
260) -> (String, AgentResponse, u64) {
261    let event: RequestHeadersEvent = match encoding.deserialize(payload) {
262        Ok(e) => e,
263        Err(e) => {
264            warn!(error = %e, "Failed to deserialize RequestHeaders");
265            let cid = extract_correlation_id(encoding, payload);
266            return (cid, AgentResponse::default_allow(), 0);
267        }
268    };
269    let cid = event.metadata.correlation_id.clone();
270    let start = Instant::now();
271    let resp = handler.on_request_headers(event).await;
272    (cid, resp, start.elapsed().as_millis() as u64)
273}
274
275async fn handle_request_body_chunk(
276    handler: &Arc<dyn AgentHandlerV2>,
277    encoding: &UdsEncoding,
278    payload: &[u8],
279) -> (String, AgentResponse, u64) {
280    let event: RequestBodyChunkEvent = match encoding.deserialize(payload) {
281        Ok(e) => e,
282        Err(e) => {
283            warn!(error = %e, "Failed to deserialize RequestBodyChunk");
284            let cid = extract_correlation_id(encoding, payload);
285            return (cid, AgentResponse::default_allow(), 0);
286        }
287    };
288    let cid = event.correlation_id.clone();
289    let start = Instant::now();
290    let resp = handler.on_request_body_chunk(event).await;
291    (cid, resp, start.elapsed().as_millis() as u64)
292}
293
294async fn handle_response_headers(
295    handler: &Arc<dyn AgentHandlerV2>,
296    encoding: &UdsEncoding,
297    payload: &[u8],
298) -> (String, AgentResponse, u64) {
299    let event: ResponseHeadersEvent = match encoding.deserialize(payload) {
300        Ok(e) => e,
301        Err(e) => {
302            warn!(error = %e, "Failed to deserialize ResponseHeaders");
303            let cid = extract_correlation_id(encoding, payload);
304            return (cid, AgentResponse::default_allow(), 0);
305        }
306    };
307    let cid = event.correlation_id.clone();
308    let start = Instant::now();
309    let resp = handler.on_response_headers(event).await;
310    (cid, resp, start.elapsed().as_millis() as u64)
311}
312
313async fn handle_response_body_chunk(
314    handler: &Arc<dyn AgentHandlerV2>,
315    encoding: &UdsEncoding,
316    payload: &[u8],
317) -> (String, AgentResponse, u64) {
318    let event: ResponseBodyChunkEvent = match encoding.deserialize(payload) {
319        Ok(e) => e,
320        Err(e) => {
321            warn!(error = %e, "Failed to deserialize ResponseBodyChunk");
322            let cid = extract_correlation_id(encoding, payload);
323            return (cid, AgentResponse::default_allow(), 0);
324        }
325    };
326    let cid = event.correlation_id.clone();
327    let start = Instant::now();
328    let resp = handler.on_response_body_chunk(event).await;
329    (cid, resp, start.elapsed().as_millis() as u64)
330}
331
332async fn handle_request_complete(
333    handler: &Arc<dyn AgentHandlerV2>,
334    encoding: &UdsEncoding,
335    payload: &[u8],
336) -> (String, AgentResponse, u64) {
337    let event: RequestCompleteEvent = match encoding.deserialize(payload) {
338        Ok(e) => e,
339        Err(e) => {
340            warn!(error = %e, "Failed to deserialize RequestComplete");
341            let cid = extract_correlation_id(encoding, payload);
342            return (cid, AgentResponse::default_allow(), 0);
343        }
344    };
345    let cid = event.correlation_id.clone();
346    let start = Instant::now();
347    let resp = handler.on_request_complete(event).await;
348    (cid, resp, start.elapsed().as_millis() as u64)
349}
350
351async fn handle_websocket_frame(
352    handler: &Arc<dyn AgentHandlerV2>,
353    encoding: &UdsEncoding,
354    payload: &[u8],
355) -> (String, AgentResponse, u64) {
356    let event: WebSocketFrameEvent = match encoding.deserialize(payload) {
357        Ok(e) => e,
358        Err(e) => {
359            warn!(error = %e, "Failed to deserialize WebSocketFrame");
360            let cid = extract_correlation_id(encoding, payload);
361            return (cid, AgentResponse::websocket_allow(), 0);
362        }
363    };
364    let cid = event.correlation_id.clone();
365    let start = Instant::now();
366    let resp = handler.on_websocket_frame(event).await;
367    (cid, resp, start.elapsed().as_millis() as u64)
368}
369
370async fn handle_configure(
371    handler: &Arc<dyn AgentHandlerV2>,
372    encoding: &UdsEncoding,
373    payload: &[u8],
374) -> (String, AgentResponse, u64) {
375    // Configure payloads carry config + optional version
376    #[derive(serde::Deserialize)]
377    struct ConfigurePayload {
378        #[serde(default)]
379        correlation_id: String,
380        #[serde(default)]
381        config: serde_json::Value,
382        #[serde(default)]
383        config_version: Option<String>,
384    }
385
386    let parsed: ConfigurePayload = match encoding.deserialize(payload) {
387        Ok(p) => p,
388        Err(e) => {
389            warn!(error = %e, "Failed to deserialize Configure");
390            let cid = extract_correlation_id(encoding, payload);
391            return (cid, AgentResponse::default_allow(), 0);
392        }
393    };
394
395    let cid = parsed.correlation_id;
396    let start = Instant::now();
397    let accepted = handler
398        .on_configure(parsed.config, parsed.config_version)
399        .await;
400    let resp = if accepted {
401        AgentResponse::default_allow()
402    } else {
403        AgentResponse::block(500, Some("Configuration rejected".to_string()))
404    };
405    (cid, resp, start.elapsed().as_millis() as u64)
406}
407
408// ─── Response serialization ──────────────────────────────────────────────────
409
410/// Serialize and write an agent response, injecting the correlation ID into
411/// `audit.custom` so the multiplexing client can route it.
412async fn write_response<W: tokio::io::AsyncWriteExt + Unpin>(
413    writer: &mut W,
414    encoding: &UdsEncoding,
415    (correlation_id, mut response, _processing_time_ms): (String, AgentResponse, u64),
416) -> Result<(), AgentProtocolError> {
417    // Inject correlation_id so the client can route the response
418    response.audit.custom.insert(
419        "correlation_id".to_string(),
420        serde_json::Value::String(correlation_id),
421    );
422
423    let resp_bytes = encoding.serialize(&response)?;
424    write_message(writer, MessageType::AgentResponse, &resp_bytes).await
425}
426
427// ─── Helpers ─────────────────────────────────────────────────────────────────
428
429/// Best-effort extraction of `correlation_id` from a payload (for error paths).
430fn extract_correlation_id(encoding: &UdsEncoding, payload: &[u8]) -> String {
431    #[derive(serde::Deserialize)]
432    struct CidOnly {
433        #[serde(default)]
434        correlation_id: String,
435        #[serde(default)]
436        metadata: Option<MetaCid>,
437    }
438    #[derive(serde::Deserialize)]
439    struct MetaCid {
440        #[serde(default)]
441        correlation_id: String,
442    }
443
444    if let Ok(parsed) = encoding.deserialize::<CidOnly>(payload) {
445        if !parsed.correlation_id.is_empty() {
446            return parsed.correlation_id;
447        }
448        if let Some(meta) = parsed.metadata {
449            if !meta.correlation_id.is_empty() {
450                return meta.correlation_id;
451            }
452        }
453    }
454    String::new()
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use crate::v2::AgentCapabilities;
461    use crate::RequestMetadata;
462    use async_trait::async_trait;
463
464    struct TestHandler;
465
466    #[async_trait]
467    impl AgentHandlerV2 for TestHandler {
468        fn capabilities(&self) -> AgentCapabilities {
469            AgentCapabilities::new("test-uds-v2", "Test UDS V2 Agent", "1.0.0")
470                .with_event(crate::EventType::RequestHeaders)
471        }
472
473        async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
474            AgentResponse::default_allow().add_request_header(crate::HeaderOp::Set {
475                name: "x-test-agent".to_string(),
476                value: event.metadata.correlation_id.clone(),
477            })
478        }
479    }
480
481    #[test]
482    fn test_negotiate_encoding_json() {
483        let encodings = vec![UdsEncoding::Json];
484        assert_eq!(negotiate_encoding(&encodings), UdsEncoding::Json);
485    }
486
487    #[test]
488    fn test_negotiate_encoding_empty() {
489        let encodings: Vec<UdsEncoding> = vec![];
490        assert_eq!(negotiate_encoding(&encodings), UdsEncoding::Json);
491    }
492
493    #[test]
494    fn test_create_server() {
495        let server = UdsAgentServerV2::new("test", "/tmp/test-uds-v2.sock", Box::new(TestHandler));
496        assert_eq!(server.id, "test");
497    }
498
499    #[tokio::test]
500    async fn test_handshake_and_request_roundtrip() {
501        use crate::v2::uds::AgentClientV2Uds;
502        use std::time::Duration;
503
504        let socket_path = format!("/tmp/test-uds-v2-{}.sock", std::process::id());
505        let socket_path_clone = socket_path.clone();
506
507        // Start server in background
508        let server = UdsAgentServerV2::new("test-roundtrip", &socket_path, Box::new(TestHandler));
509
510        let server_handle = tokio::spawn(async move {
511            let _ = server.run().await;
512        });
513
514        // Give server time to bind
515        tokio::time::sleep(Duration::from_millis(50)).await;
516
517        // Connect client
518        let client =
519            AgentClientV2Uds::new("test-agent", &socket_path_clone, Duration::from_secs(5))
520                .await
521                .unwrap();
522        client.connect().await.unwrap();
523
524        assert!(client.is_connected().await);
525
526        // Send a request headers event
527        let event = RequestHeadersEvent {
528            metadata: RequestMetadata {
529                correlation_id: "test-cid-1".to_string(),
530                request_id: "req-1".to_string(),
531                client_ip: "127.0.0.1".to_string(),
532                client_port: 12345,
533                server_name: None,
534                protocol: "HTTP/1.1".to_string(),
535                tls_version: None,
536                tls_cipher: None,
537                route_id: None,
538                upstream_id: None,
539                timestamp: "0".to_string(),
540                traceparent: None,
541            },
542            method: "GET".to_string(),
543            uri: "/test".to_string(),
544            headers: std::collections::HashMap::new(),
545        };
546
547        let response = client
548            .send_request_headers("test-cid-1", &event)
549            .await
550            .unwrap();
551
552        // Verify handler was called and response returned
553        assert!(matches!(response.decision, crate::Decision::Allow));
554        assert!(response.request_headers.iter().any(|h| matches!(
555            h,
556            crate::HeaderOp::Set { name, value }
557                if name == "x-test-agent" && value == "test-cid-1"
558        )));
559
560        // Cleanup
561        client.close().await.unwrap();
562        server_handle.abort();
563        let _ = std::fs::remove_file(&socket_path_clone);
564    }
565}