1use std::path::PathBuf;
8use std::sync::Arc;
9use std::time::Instant;
10
11use tokio::io::{BufReader, BufWriter};
12use tokio::net::{UnixListener, UnixStream};
13use tracing::{debug, error, info, trace, warn};
14
15use crate::v2::server::AgentHandlerV2;
16use crate::v2::uds::{
17 read_message, write_message, MessageType, UdsCapabilities, UdsEncoding, UdsHandshakeRequest,
18 UdsHandshakeResponse,
19};
20use crate::v2::HandshakeRequest;
21use crate::{
22 AgentProtocolError, AgentResponse, RequestBodyChunkEvent, RequestCompleteEvent,
23 RequestHeadersEvent, ResponseBodyChunkEvent, ResponseHeadersEvent, WebSocketFrameEvent,
24};
25
26pub struct UdsAgentServerV2 {
31 id: String,
32 socket_path: PathBuf,
33 handler: Arc<dyn AgentHandlerV2>,
34}
35
36impl UdsAgentServerV2 {
37 pub fn new(
39 id: impl Into<String>,
40 socket_path: impl Into<PathBuf>,
41 handler: Box<dyn AgentHandlerV2>,
42 ) -> Self {
43 let id = id.into();
44 let socket_path = socket_path.into();
45
46 debug!(
47 agent_id = %id,
48 socket_path = %socket_path.display(),
49 "Creating UDS agent server v2"
50 );
51
52 Self {
53 id,
54 socket_path,
55 handler: Arc::from(handler),
56 }
57 }
58
59 pub async fn run(&self) -> Result<(), AgentProtocolError> {
64 if self.socket_path.exists() {
66 trace!(
67 agent_id = %self.id,
68 socket_path = %self.socket_path.display(),
69 "Removing existing socket file"
70 );
71 std::fs::remove_file(&self.socket_path)?;
72 }
73
74 let listener = UnixListener::bind(&self.socket_path)?;
75
76 info!(
77 agent_id = %self.id,
78 socket_path = %self.socket_path.display(),
79 "UDS agent server v2 listening"
80 );
81
82 loop {
83 match listener.accept().await {
84 Ok((stream, _addr)) => {
85 trace!(agent_id = %self.id, "Accepted new connection");
86 let handler = Arc::clone(&self.handler);
87 let agent_id = self.id.clone();
88 tokio::spawn(async move {
89 if let Err(e) = handle_connection(handler, stream, agent_id.clone()).await {
90 if !matches!(e, AgentProtocolError::ConnectionClosed) {
91 error!(
92 agent_id = %agent_id,
93 error = %e,
94 "Error handling UDS v2 connection"
95 );
96 }
97 }
98 });
99 }
100 Err(e) => {
101 error!(
102 agent_id = %self.id,
103 error = %e,
104 "Failed to accept connection"
105 );
106 }
107 }
108 }
109 }
110}
111
112async fn handle_connection(
114 handler: Arc<dyn AgentHandlerV2>,
115 stream: UnixStream,
116 agent_id: String,
117) -> Result<(), AgentProtocolError> {
118 let (read_half, write_half) = stream.into_split();
119 let mut reader = BufReader::new(read_half);
120 let mut writer = BufWriter::new(write_half);
121
122 let (msg_type, payload) = read_message(&mut reader).await?;
125 if msg_type != MessageType::HandshakeRequest {
126 return Err(AgentProtocolError::InvalidMessage(format!(
127 "Expected HandshakeRequest, got {:?}",
128 msg_type
129 )));
130 }
131
132 let uds_req: UdsHandshakeRequest = serde_json::from_slice(&payload)
133 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
134
135 let handshake_req = HandshakeRequest {
137 supported_versions: uds_req.supported_versions,
138 proxy_id: uds_req.proxy_id,
139 proxy_version: uds_req.proxy_version,
140 config: uds_req.config.unwrap_or(serde_json::Value::Null),
141 };
142
143 let handshake_resp = handler.on_handshake(handshake_req).await;
144 let success = handshake_resp.success;
145
146 let negotiated_encoding = negotiate_encoding(&uds_req.supported_encodings);
148
149 let uds_resp = UdsHandshakeResponse {
151 protocol_version: handshake_resp.protocol_version,
152 capabilities: UdsCapabilities::from(handshake_resp.capabilities),
153 success,
154 error: handshake_resp.error,
155 encoding: negotiated_encoding,
156 };
157
158 let resp_bytes = serde_json::to_vec(&uds_resp)
159 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
160 write_message(&mut writer, MessageType::HandshakeResponse, &resp_bytes).await?;
161
162 if !success {
163 debug!(agent_id = %agent_id, "Handshake rejected, closing connection");
164 return Ok(());
165 }
166
167 info!(
168 agent_id = %agent_id,
169 encoding = ?negotiated_encoding,
170 "UDS v2 handshake completed"
171 );
172
173 loop {
176 let (msg_type, payload) = read_message(&mut reader).await?;
177
178 match msg_type {
179 MessageType::Ping => {
180 trace!(agent_id = %agent_id, "Received ping, sending pong");
181 write_message(&mut writer, MessageType::Pong, &payload).await?;
183 }
184 MessageType::Cancel => {
185 let cid = extract_correlation_id(&negotiated_encoding, &payload);
187 debug!(
188 agent_id = %agent_id,
189 correlation_id = %cid,
190 "Request cancelled"
191 );
192 }
193 MessageType::RequestHeaders => {
194 let response =
195 handle_request_headers(&handler, &negotiated_encoding, &payload).await;
196 write_response(&mut writer, &negotiated_encoding, response).await?;
197 }
198 MessageType::RequestBodyChunk => {
199 let response =
200 handle_request_body_chunk(&handler, &negotiated_encoding, &payload).await;
201 write_response(&mut writer, &negotiated_encoding, response).await?;
202 }
203 MessageType::ResponseHeaders => {
204 let response =
205 handle_response_headers(&handler, &negotiated_encoding, &payload).await;
206 write_response(&mut writer, &negotiated_encoding, response).await?;
207 }
208 MessageType::ResponseBodyChunk => {
209 let response =
210 handle_response_body_chunk(&handler, &negotiated_encoding, &payload).await;
211 write_response(&mut writer, &negotiated_encoding, response).await?;
212 }
213 MessageType::RequestComplete => {
214 let response =
215 handle_request_complete(&handler, &negotiated_encoding, &payload).await;
216 write_response(&mut writer, &negotiated_encoding, response).await?;
217 }
218 MessageType::WebSocketFrame => {
219 let response =
220 handle_websocket_frame(&handler, &negotiated_encoding, &payload).await;
221 write_response(&mut writer, &negotiated_encoding, response).await?;
222 }
223 MessageType::Configure => {
224 let response = handle_configure(&handler, &negotiated_encoding, &payload).await;
225 write_response(&mut writer, &negotiated_encoding, response).await?;
226 }
227 _ => {
228 warn!(
229 agent_id = %agent_id,
230 msg_type = ?msg_type,
231 "Received unhandled message type"
232 );
233 }
234 }
235 }
236}
237
238fn negotiate_encoding(proxy_encodings: &[UdsEncoding]) -> UdsEncoding {
242 for enc in proxy_encodings {
243 match enc {
244 UdsEncoding::Json => return UdsEncoding::Json,
245 UdsEncoding::MessagePack if cfg!(feature = "binary-uds") => {
246 return UdsEncoding::MessagePack;
247 }
248 _ => continue,
249 }
250 }
251 UdsEncoding::Json
252}
253
254async fn handle_request_headers(
257 handler: &Arc<dyn AgentHandlerV2>,
258 encoding: &UdsEncoding,
259 payload: &[u8],
260) -> (String, AgentResponse, u64) {
261 let event: RequestHeadersEvent = match encoding.deserialize(payload) {
262 Ok(e) => e,
263 Err(e) => {
264 warn!(error = %e, "Failed to deserialize RequestHeaders");
265 let cid = extract_correlation_id(encoding, payload);
266 return (cid, AgentResponse::default_allow(), 0);
267 }
268 };
269 let cid = event.metadata.correlation_id.clone();
270 let start = Instant::now();
271 let resp = handler.on_request_headers(event).await;
272 (cid, resp, start.elapsed().as_millis() as u64)
273}
274
275async fn handle_request_body_chunk(
276 handler: &Arc<dyn AgentHandlerV2>,
277 encoding: &UdsEncoding,
278 payload: &[u8],
279) -> (String, AgentResponse, u64) {
280 let event: RequestBodyChunkEvent = match encoding.deserialize(payload) {
281 Ok(e) => e,
282 Err(e) => {
283 warn!(error = %e, "Failed to deserialize RequestBodyChunk");
284 let cid = extract_correlation_id(encoding, payload);
285 return (cid, AgentResponse::default_allow(), 0);
286 }
287 };
288 let cid = event.correlation_id.clone();
289 let start = Instant::now();
290 let resp = handler.on_request_body_chunk(event).await;
291 (cid, resp, start.elapsed().as_millis() as u64)
292}
293
294async fn handle_response_headers(
295 handler: &Arc<dyn AgentHandlerV2>,
296 encoding: &UdsEncoding,
297 payload: &[u8],
298) -> (String, AgentResponse, u64) {
299 let event: ResponseHeadersEvent = match encoding.deserialize(payload) {
300 Ok(e) => e,
301 Err(e) => {
302 warn!(error = %e, "Failed to deserialize ResponseHeaders");
303 let cid = extract_correlation_id(encoding, payload);
304 return (cid, AgentResponse::default_allow(), 0);
305 }
306 };
307 let cid = event.correlation_id.clone();
308 let start = Instant::now();
309 let resp = handler.on_response_headers(event).await;
310 (cid, resp, start.elapsed().as_millis() as u64)
311}
312
313async fn handle_response_body_chunk(
314 handler: &Arc<dyn AgentHandlerV2>,
315 encoding: &UdsEncoding,
316 payload: &[u8],
317) -> (String, AgentResponse, u64) {
318 let event: ResponseBodyChunkEvent = match encoding.deserialize(payload) {
319 Ok(e) => e,
320 Err(e) => {
321 warn!(error = %e, "Failed to deserialize ResponseBodyChunk");
322 let cid = extract_correlation_id(encoding, payload);
323 return (cid, AgentResponse::default_allow(), 0);
324 }
325 };
326 let cid = event.correlation_id.clone();
327 let start = Instant::now();
328 let resp = handler.on_response_body_chunk(event).await;
329 (cid, resp, start.elapsed().as_millis() as u64)
330}
331
332async fn handle_request_complete(
333 handler: &Arc<dyn AgentHandlerV2>,
334 encoding: &UdsEncoding,
335 payload: &[u8],
336) -> (String, AgentResponse, u64) {
337 let event: RequestCompleteEvent = match encoding.deserialize(payload) {
338 Ok(e) => e,
339 Err(e) => {
340 warn!(error = %e, "Failed to deserialize RequestComplete");
341 let cid = extract_correlation_id(encoding, payload);
342 return (cid, AgentResponse::default_allow(), 0);
343 }
344 };
345 let cid = event.correlation_id.clone();
346 let start = Instant::now();
347 let resp = handler.on_request_complete(event).await;
348 (cid, resp, start.elapsed().as_millis() as u64)
349}
350
351async fn handle_websocket_frame(
352 handler: &Arc<dyn AgentHandlerV2>,
353 encoding: &UdsEncoding,
354 payload: &[u8],
355) -> (String, AgentResponse, u64) {
356 let event: WebSocketFrameEvent = match encoding.deserialize(payload) {
357 Ok(e) => e,
358 Err(e) => {
359 warn!(error = %e, "Failed to deserialize WebSocketFrame");
360 let cid = extract_correlation_id(encoding, payload);
361 return (cid, AgentResponse::websocket_allow(), 0);
362 }
363 };
364 let cid = event.correlation_id.clone();
365 let start = Instant::now();
366 let resp = handler.on_websocket_frame(event).await;
367 (cid, resp, start.elapsed().as_millis() as u64)
368}
369
370async fn handle_configure(
371 handler: &Arc<dyn AgentHandlerV2>,
372 encoding: &UdsEncoding,
373 payload: &[u8],
374) -> (String, AgentResponse, u64) {
375 #[derive(serde::Deserialize)]
377 struct ConfigurePayload {
378 #[serde(default)]
379 correlation_id: String,
380 #[serde(default)]
381 config: serde_json::Value,
382 #[serde(default)]
383 config_version: Option<String>,
384 }
385
386 let parsed: ConfigurePayload = match encoding.deserialize(payload) {
387 Ok(p) => p,
388 Err(e) => {
389 warn!(error = %e, "Failed to deserialize Configure");
390 let cid = extract_correlation_id(encoding, payload);
391 return (cid, AgentResponse::default_allow(), 0);
392 }
393 };
394
395 let cid = parsed.correlation_id;
396 let start = Instant::now();
397 let accepted = handler
398 .on_configure(parsed.config, parsed.config_version)
399 .await;
400 let resp = if accepted {
401 AgentResponse::default_allow()
402 } else {
403 AgentResponse::block(500, Some("Configuration rejected".to_string()))
404 };
405 (cid, resp, start.elapsed().as_millis() as u64)
406}
407
408async fn write_response<W: tokio::io::AsyncWriteExt + Unpin>(
413 writer: &mut W,
414 encoding: &UdsEncoding,
415 (correlation_id, mut response, _processing_time_ms): (String, AgentResponse, u64),
416) -> Result<(), AgentProtocolError> {
417 response.audit.custom.insert(
419 "correlation_id".to_string(),
420 serde_json::Value::String(correlation_id),
421 );
422
423 let resp_bytes = encoding.serialize(&response)?;
424 write_message(writer, MessageType::AgentResponse, &resp_bytes).await
425}
426
427fn extract_correlation_id(encoding: &UdsEncoding, payload: &[u8]) -> String {
431 #[derive(serde::Deserialize)]
432 struct CidOnly {
433 #[serde(default)]
434 correlation_id: String,
435 #[serde(default)]
436 metadata: Option<MetaCid>,
437 }
438 #[derive(serde::Deserialize)]
439 struct MetaCid {
440 #[serde(default)]
441 correlation_id: String,
442 }
443
444 if let Ok(parsed) = encoding.deserialize::<CidOnly>(payload) {
445 if !parsed.correlation_id.is_empty() {
446 return parsed.correlation_id;
447 }
448 if let Some(meta) = parsed.metadata {
449 if !meta.correlation_id.is_empty() {
450 return meta.correlation_id;
451 }
452 }
453 }
454 String::new()
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460 use crate::v2::AgentCapabilities;
461 use crate::RequestMetadata;
462 use async_trait::async_trait;
463
464 struct TestHandler;
465
466 #[async_trait]
467 impl AgentHandlerV2 for TestHandler {
468 fn capabilities(&self) -> AgentCapabilities {
469 AgentCapabilities::new("test-uds-v2", "Test UDS V2 Agent", "1.0.0")
470 .with_event(crate::EventType::RequestHeaders)
471 }
472
473 async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
474 AgentResponse::default_allow().add_request_header(crate::HeaderOp::Set {
475 name: "x-test-agent".to_string(),
476 value: event.metadata.correlation_id.clone(),
477 })
478 }
479 }
480
481 #[test]
482 fn test_negotiate_encoding_json() {
483 let encodings = vec![UdsEncoding::Json];
484 assert_eq!(negotiate_encoding(&encodings), UdsEncoding::Json);
485 }
486
487 #[test]
488 fn test_negotiate_encoding_empty() {
489 let encodings: Vec<UdsEncoding> = vec![];
490 assert_eq!(negotiate_encoding(&encodings), UdsEncoding::Json);
491 }
492
493 #[test]
494 fn test_create_server() {
495 let server = UdsAgentServerV2::new("test", "/tmp/test-uds-v2.sock", Box::new(TestHandler));
496 assert_eq!(server.id, "test");
497 }
498
499 #[tokio::test]
500 async fn test_handshake_and_request_roundtrip() {
501 use crate::v2::uds::AgentClientV2Uds;
502 use std::time::Duration;
503
504 let socket_path = format!("/tmp/test-uds-v2-{}.sock", std::process::id());
505 let socket_path_clone = socket_path.clone();
506
507 let server = UdsAgentServerV2::new("test-roundtrip", &socket_path, Box::new(TestHandler));
509
510 let server_handle = tokio::spawn(async move {
511 let _ = server.run().await;
512 });
513
514 tokio::time::sleep(Duration::from_millis(50)).await;
516
517 let client =
519 AgentClientV2Uds::new("test-agent", &socket_path_clone, Duration::from_secs(5))
520 .await
521 .unwrap();
522 client.connect().await.unwrap();
523
524 assert!(client.is_connected().await);
525
526 let event = RequestHeadersEvent {
528 metadata: RequestMetadata {
529 correlation_id: "test-cid-1".to_string(),
530 request_id: "req-1".to_string(),
531 client_ip: "127.0.0.1".to_string(),
532 client_port: 12345,
533 server_name: None,
534 protocol: "HTTP/1.1".to_string(),
535 tls_version: None,
536 tls_cipher: None,
537 route_id: None,
538 upstream_id: None,
539 timestamp: "0".to_string(),
540 traceparent: None,
541 },
542 method: "GET".to_string(),
543 uri: "/test".to_string(),
544 headers: std::collections::HashMap::new(),
545 };
546
547 let response = client
548 .send_request_headers("test-cid-1", &event)
549 .await
550 .unwrap();
551
552 assert!(matches!(response.decision, crate::Decision::Allow));
554 assert!(response.request_headers.iter().any(|h| matches!(
555 h,
556 crate::HeaderOp::Set { name, value }
557 if name == "x-test-agent" && value == "test-cid-1"
558 )));
559
560 client.close().await.unwrap();
562 server_handle.abort();
563 let _ = std::fs::remove_file(&socket_path_clone);
564 }
565}