1use async_trait::async_trait;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11use tokio::net::{UnixListener, UnixStream};
12use tokio_stream::StreamExt;
13use tonic::{Request, Response, Status, Streaming};
14use tracing::{debug, error, info, trace, warn};
15
16use crate::errors::AgentProtocolError;
17use crate::grpc::{self, agent_processor_server::AgentProcessor, agent_processor_server::AgentProcessorServer};
18use crate::protocol::{
19 AgentRequest, AgentResponse, AuditMetadata, Decision, EventType, HeaderOp, RequestBodyChunkEvent,
20 RequestCompleteEvent, RequestHeadersEvent, RequestMetadata, ResponseBodyChunkEvent, ResponseHeadersEvent,
21 MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
22};
23
24pub struct AgentServer {
26 id: String,
28 socket_path: std::path::PathBuf,
30 handler: Arc<dyn AgentHandler>,
32}
33
34#[async_trait]
36pub trait AgentHandler: Send + Sync {
37 async fn on_request_headers(&self, _event: RequestHeadersEvent) -> AgentResponse {
39 AgentResponse::default_allow()
40 }
41
42 async fn on_request_body_chunk(&self, _event: RequestBodyChunkEvent) -> AgentResponse {
44 AgentResponse::default_allow()
45 }
46
47 async fn on_response_headers(&self, _event: ResponseHeadersEvent) -> AgentResponse {
49 AgentResponse::default_allow()
50 }
51
52 async fn on_response_body_chunk(&self, _event: ResponseBodyChunkEvent) -> AgentResponse {
54 AgentResponse::default_allow()
55 }
56
57 async fn on_request_complete(&self, _event: RequestCompleteEvent) -> AgentResponse {
59 AgentResponse::default_allow()
60 }
61}
62
63impl AgentServer {
64 pub fn new(
66 id: impl Into<String>,
67 socket_path: impl Into<std::path::PathBuf>,
68 handler: Box<dyn AgentHandler>,
69 ) -> Self {
70 let id = id.into();
71 let socket_path = socket_path.into();
72
73 debug!(
74 agent_id = %id,
75 socket_path = %socket_path.display(),
76 "Creating agent server"
77 );
78
79 Self {
80 id,
81 socket_path,
82 handler: Arc::from(handler),
83 }
84 }
85
86 pub async fn run(&self) -> Result<(), AgentProtocolError> {
88 if self.socket_path.exists() {
90 trace!(
91 agent_id = %self.id,
92 socket_path = %self.socket_path.display(),
93 "Removing existing socket file"
94 );
95 std::fs::remove_file(&self.socket_path)?;
96 }
97
98 let listener = UnixListener::bind(&self.socket_path)?;
100
101 info!(
102 agent_id = %self.id,
103 socket_path = %self.socket_path.display(),
104 "Agent server listening"
105 );
106
107 loop {
108 match listener.accept().await {
109 Ok((stream, _addr)) => {
110 trace!(
111 agent_id = %self.id,
112 "Accepted new connection"
113 );
114 let handler = Arc::clone(&self.handler);
115 let agent_id = self.id.clone();
116 tokio::spawn(async move {
117 if let Err(e) = Self::handle_connection(stream, handler.as_ref()).await {
118 error!(
119 agent_id = %agent_id,
120 error = %e,
121 "Error handling agent connection"
122 );
123 }
124 });
125 }
126 Err(e) => {
127 error!(
128 agent_id = %self.id,
129 error = %e,
130 "Failed to accept connection"
131 );
132 }
133 }
134 }
135 }
136
137 async fn handle_connection(
139 mut stream: UnixStream,
140 handler: &dyn AgentHandler,
141 ) -> Result<(), AgentProtocolError> {
142 trace!("Starting connection handler");
143
144 loop {
145 let mut len_bytes = [0u8; 4];
147 match stream.read_exact(&mut len_bytes).await {
148 Ok(_) => {}
149 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
150 trace!("Client disconnected (EOF)");
152 return Ok(());
153 }
154 Err(e) => {
155 error!(error = %e, "Error reading message length");
156 return Err(e.into());
157 }
158 }
159
160 let message_len = u32::from_be_bytes(len_bytes) as usize;
161
162 if message_len > MAX_MESSAGE_SIZE {
164 warn!(
165 message_len = message_len,
166 max_size = MAX_MESSAGE_SIZE,
167 "Message too large"
168 );
169 return Err(AgentProtocolError::MessageTooLarge {
170 size: message_len,
171 max: MAX_MESSAGE_SIZE,
172 });
173 }
174
175 trace!(message_len = message_len, "Reading message data");
176
177 let mut buffer = vec![0u8; message_len];
179 stream.read_exact(&mut buffer).await?;
180
181 let request: AgentRequest = serde_json::from_slice(&buffer)
183 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
184
185 trace!(
186 event_type = ?request.event_type,
187 version = request.version,
188 "Received agent request"
189 );
190
191 let response = match request.event_type {
193 EventType::RequestHeaders => {
194 let event: RequestHeadersEvent = serde_json::from_value(request.payload)
195 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
196 trace!(
197 correlation_id = %event.metadata.correlation_id,
198 method = %event.method,
199 uri = %event.uri,
200 "Processing request_headers event"
201 );
202 handler.on_request_headers(event).await
203 }
204 EventType::RequestBodyChunk => {
205 let event: RequestBodyChunkEvent = serde_json::from_value(request.payload)
206 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
207 trace!(
208 correlation_id = %event.correlation_id,
209 is_last = event.is_last,
210 data_len = event.data.len(),
211 "Processing request_body_chunk event"
212 );
213 handler.on_request_body_chunk(event).await
214 }
215 EventType::ResponseHeaders => {
216 let event: ResponseHeadersEvent = serde_json::from_value(request.payload)
217 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
218 trace!(
219 correlation_id = %event.correlation_id,
220 status = event.status,
221 "Processing response_headers event"
222 );
223 handler.on_response_headers(event).await
224 }
225 EventType::ResponseBodyChunk => {
226 let event: ResponseBodyChunkEvent = serde_json::from_value(request.payload)
227 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
228 trace!(
229 correlation_id = %event.correlation_id,
230 is_last = event.is_last,
231 data_len = event.data.len(),
232 "Processing response_body_chunk event"
233 );
234 handler.on_response_body_chunk(event).await
235 }
236 EventType::RequestComplete => {
237 let event: RequestCompleteEvent = serde_json::from_value(request.payload)
238 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
239 trace!(
240 correlation_id = %event.correlation_id,
241 status = event.status,
242 duration_ms = event.duration_ms,
243 "Processing request_complete event"
244 );
245 handler.on_request_complete(event).await
246 }
247 };
248
249 trace!(
250 decision = ?response.decision,
251 "Sending agent response"
252 );
253
254 let response_bytes = serde_json::to_vec(&response)
256 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
257
258 let len_bytes = (response_bytes.len() as u32).to_be_bytes();
260 stream.write_all(&len_bytes).await?;
261 stream.write_all(&response_bytes).await?;
263 stream.flush().await?;
264
265 trace!(response_len = response_bytes.len(), "Response sent");
266 }
267 }
268}
269
270pub struct EchoAgent;
272
273#[async_trait]
274impl AgentHandler for EchoAgent {
275 async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
276 debug!(
277 "Echo agent: request headers for {}",
278 event.metadata.correlation_id
279 );
280
281 AgentResponse::default_allow()
283 .add_request_header(HeaderOp::Set {
284 name: "X-Echo-Agent".to_string(),
285 value: event.metadata.correlation_id.clone(),
286 })
287 .with_audit(AuditMetadata {
288 tags: vec!["echo".to_string()],
289 ..Default::default()
290 })
291 }
292}
293
294pub struct DenylistAgent {
296 blocked_paths: Vec<String>,
297 blocked_ips: Vec<String>,
298}
299
300impl DenylistAgent {
301 pub fn new(blocked_paths: Vec<String>, blocked_ips: Vec<String>) -> Self {
302 Self {
303 blocked_paths,
304 blocked_ips,
305 }
306 }
307}
308
309#[async_trait]
310impl AgentHandler for DenylistAgent {
311 async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
312 trace!(
313 correlation_id = %event.metadata.correlation_id,
314 uri = %event.uri,
315 client_ip = %event.metadata.client_ip,
316 "Denylist agent checking request"
317 );
318
319 for blocked_path in &self.blocked_paths {
321 if event.uri.starts_with(blocked_path) {
322 debug!(
323 correlation_id = %event.metadata.correlation_id,
324 blocked_path = %blocked_path,
325 uri = %event.uri,
326 "Blocking request: path matched denylist"
327 );
328 return AgentResponse::block(403, Some("Forbidden path".to_string())).with_audit(
329 AuditMetadata {
330 tags: vec!["denylist".to_string(), "blocked_path".to_string()],
331 reason_codes: vec!["PATH_BLOCKED".to_string()],
332 ..Default::default()
333 },
334 );
335 }
336 }
337
338 if self.blocked_ips.contains(&event.metadata.client_ip) {
340 debug!(
341 correlation_id = %event.metadata.correlation_id,
342 client_ip = %event.metadata.client_ip,
343 "Blocking request: IP matched denylist"
344 );
345 return AgentResponse::block(403, Some("Forbidden IP".to_string())).with_audit(
346 AuditMetadata {
347 tags: vec!["denylist".to_string(), "blocked_ip".to_string()],
348 reason_codes: vec!["IP_BLOCKED".to_string()],
349 ..Default::default()
350 },
351 );
352 }
353
354 trace!(
355 correlation_id = %event.metadata.correlation_id,
356 "Request allowed by denylist agent"
357 );
358 AgentResponse::default_allow()
359 }
360}
361
362pub struct GrpcAgentServer {
368 id: String,
370 handler: Arc<dyn AgentHandler>,
372}
373
374impl GrpcAgentServer {
375 pub fn new(id: impl Into<String>, handler: Box<dyn AgentHandler>) -> Self {
377 let id = id.into();
378 debug!(agent_id = %id, "Creating gRPC agent server");
379 Self {
380 id,
381 handler: Arc::from(handler),
382 }
383 }
384
385 pub fn into_service(self) -> AgentProcessorServer<GrpcAgentHandler> {
387 trace!(agent_id = %self.id, "Converting to tonic service");
388 AgentProcessorServer::new(GrpcAgentHandler {
389 id: self.id,
390 handler: self.handler,
391 })
392 }
393
394 pub async fn run(self, addr: SocketAddr) -> Result<(), AgentProtocolError> {
396 info!(
397 agent_id = %self.id,
398 address = %addr,
399 "gRPC agent server listening"
400 );
401
402 tonic::transport::Server::builder()
403 .add_service(self.into_service())
404 .serve(addr)
405 .await
406 .map_err(|e| {
407 error!(error = %e, "gRPC server error");
408 AgentProtocolError::ConnectionFailed(format!("gRPC server error: {}", e))
409 })
410 }
411}
412
413pub struct GrpcAgentHandler {
415 id: String,
416 handler: Arc<dyn AgentHandler>,
417}
418
419#[tonic::async_trait]
420impl AgentProcessor for GrpcAgentHandler {
421 async fn process_event(
422 &self,
423 request: Request<grpc::AgentRequest>,
424 ) -> Result<Response<grpc::AgentResponse>, Status> {
425 let grpc_request = request.into_inner();
426
427 trace!(
428 agent_id = %self.id,
429 event_type = grpc_request.event_type,
430 version = grpc_request.version,
431 "Processing gRPC event"
432 );
433
434 let response = match grpc_request.event {
436 Some(grpc::agent_request::Event::RequestHeaders(e)) => {
437 let event = Self::convert_request_headers_from_grpc(e);
438 trace!(
439 agent_id = %self.id,
440 correlation_id = %event.metadata.correlation_id,
441 "Processing request_headers via gRPC"
442 );
443 self.handler.on_request_headers(event).await
444 }
445 Some(grpc::agent_request::Event::RequestBodyChunk(e)) => {
446 let event = Self::convert_request_body_chunk_from_grpc(e);
447 trace!(
448 agent_id = %self.id,
449 correlation_id = %event.correlation_id,
450 "Processing request_body_chunk via gRPC"
451 );
452 self.handler.on_request_body_chunk(event).await
453 }
454 Some(grpc::agent_request::Event::ResponseHeaders(e)) => {
455 let event = Self::convert_response_headers_from_grpc(e);
456 trace!(
457 agent_id = %self.id,
458 correlation_id = %event.correlation_id,
459 "Processing response_headers via gRPC"
460 );
461 self.handler.on_response_headers(event).await
462 }
463 Some(grpc::agent_request::Event::ResponseBodyChunk(e)) => {
464 let event = Self::convert_response_body_chunk_from_grpc(e);
465 trace!(
466 agent_id = %self.id,
467 correlation_id = %event.correlation_id,
468 "Processing response_body_chunk via gRPC"
469 );
470 self.handler.on_response_body_chunk(event).await
471 }
472 Some(grpc::agent_request::Event::RequestComplete(e)) => {
473 let event = Self::convert_request_complete_from_grpc(e);
474 trace!(
475 agent_id = %self.id,
476 correlation_id = %event.correlation_id,
477 "Processing request_complete via gRPC"
478 );
479 self.handler.on_request_complete(event).await
480 }
481 None => {
482 warn!(agent_id = %self.id, "Missing event in gRPC request");
483 return Err(Status::invalid_argument("Missing event in request"));
484 }
485 };
486
487 trace!(
488 agent_id = %self.id,
489 decision = ?response.decision,
490 "Returning gRPC response"
491 );
492
493 let grpc_response = Self::convert_response_to_grpc(response);
495 Ok(Response::new(grpc_response))
496 }
497
498 async fn process_event_stream(
499 &self,
500 request: Request<Streaming<grpc::AgentRequest>>,
501 ) -> Result<Response<grpc::AgentResponse>, Status> {
502 let mut stream = request.into_inner();
503
504 trace!(agent_id = %self.id, "Processing gRPC event stream");
505
506 let mut final_response = AgentResponse::default_allow();
508 let mut event_count = 0u32;
509
510 while let Some(result) = stream.next().await {
511 let grpc_request = result.map_err(|e| {
512 error!(agent_id = %self.id, error = %e, "Stream error");
513 Status::internal(format!("Stream error: {}", e))
514 })?;
515
516 event_count += 1;
517 trace!(
518 agent_id = %self.id,
519 event_count = event_count,
520 "Processing stream event"
521 );
522
523 let response = match grpc_request.event {
524 Some(grpc::agent_request::Event::RequestHeaders(e)) => {
525 let event = Self::convert_request_headers_from_grpc(e);
526 self.handler.on_request_headers(event).await
527 }
528 Some(grpc::agent_request::Event::RequestBodyChunk(e)) => {
529 let event = Self::convert_request_body_chunk_from_grpc(e);
530 self.handler.on_request_body_chunk(event).await
531 }
532 Some(grpc::agent_request::Event::ResponseHeaders(e)) => {
533 let event = Self::convert_response_headers_from_grpc(e);
534 self.handler.on_response_headers(event).await
535 }
536 Some(grpc::agent_request::Event::ResponseBodyChunk(e)) => {
537 let event = Self::convert_response_body_chunk_from_grpc(e);
538 self.handler.on_response_body_chunk(event).await
539 }
540 Some(grpc::agent_request::Event::RequestComplete(e)) => {
541 let event = Self::convert_request_complete_from_grpc(e);
542 self.handler.on_request_complete(event).await
543 }
544 None => continue,
545 };
546
547 if !matches!(response.decision, Decision::Allow) {
549 debug!(
550 agent_id = %self.id,
551 decision = ?response.decision,
552 event_count = event_count,
553 "Non-allow decision in stream, terminating early"
554 );
555 final_response = response;
556 break;
557 }
558 final_response = response;
559 }
560
561 trace!(
562 agent_id = %self.id,
563 event_count = event_count,
564 decision = ?final_response.decision,
565 "Stream processing complete"
566 );
567
568 let grpc_response = Self::convert_response_to_grpc(final_response);
569 Ok(Response::new(grpc_response))
570 }
571}
572
573impl GrpcAgentHandler {
574 fn convert_request_headers_from_grpc(e: grpc::RequestHeadersEvent) -> RequestHeadersEvent {
576 RequestHeadersEvent {
577 metadata: Self::convert_metadata_from_grpc(e.metadata),
578 method: e.method,
579 uri: e.uri,
580 headers: e.headers.into_iter().map(|(k, v)| (k, v.values)).collect(),
581 }
582 }
583
584 fn convert_request_body_chunk_from_grpc(e: grpc::RequestBodyChunkEvent) -> RequestBodyChunkEvent {
586 RequestBodyChunkEvent {
587 correlation_id: e.correlation_id,
588 data: String::from_utf8_lossy(&e.data).to_string(),
589 is_last: e.is_last,
590 total_size: e.total_size.map(|s| s as usize),
591 }
592 }
593
594 fn convert_response_headers_from_grpc(e: grpc::ResponseHeadersEvent) -> ResponseHeadersEvent {
596 ResponseHeadersEvent {
597 correlation_id: e.correlation_id,
598 status: e.status as u16,
599 headers: e.headers.into_iter().map(|(k, v)| (k, v.values)).collect(),
600 }
601 }
602
603 fn convert_response_body_chunk_from_grpc(e: grpc::ResponseBodyChunkEvent) -> ResponseBodyChunkEvent {
605 ResponseBodyChunkEvent {
606 correlation_id: e.correlation_id,
607 data: String::from_utf8_lossy(&e.data).to_string(),
608 is_last: e.is_last,
609 total_size: e.total_size.map(|s| s as usize),
610 }
611 }
612
613 fn convert_request_complete_from_grpc(e: grpc::RequestCompleteEvent) -> RequestCompleteEvent {
615 RequestCompleteEvent {
616 correlation_id: e.correlation_id,
617 status: e.status as u16,
618 duration_ms: e.duration_ms,
619 request_body_size: e.request_body_size as usize,
620 response_body_size: e.response_body_size as usize,
621 upstream_attempts: e.upstream_attempts,
622 error: e.error,
623 }
624 }
625
626 fn convert_metadata_from_grpc(metadata: Option<grpc::RequestMetadata>) -> RequestMetadata {
628 match metadata {
629 Some(m) => RequestMetadata {
630 correlation_id: m.correlation_id,
631 request_id: m.request_id,
632 client_ip: m.client_ip,
633 client_port: m.client_port as u16,
634 server_name: m.server_name,
635 protocol: m.protocol,
636 tls_version: m.tls_version,
637 tls_cipher: m.tls_cipher,
638 route_id: m.route_id,
639 upstream_id: m.upstream_id,
640 timestamp: m.timestamp,
641 },
642 None => RequestMetadata {
643 correlation_id: String::new(),
644 request_id: String::new(),
645 client_ip: String::new(),
646 client_port: 0,
647 server_name: None,
648 protocol: String::new(),
649 tls_version: None,
650 tls_cipher: None,
651 route_id: None,
652 upstream_id: None,
653 timestamp: String::new(),
654 },
655 }
656 }
657
658 fn convert_response_to_grpc(response: AgentResponse) -> grpc::AgentResponse {
660 let decision = match response.decision {
661 Decision::Allow => Some(grpc::agent_response::Decision::Allow(grpc::AllowDecision {})),
662 Decision::Block { status, body, headers } => {
663 Some(grpc::agent_response::Decision::Block(grpc::BlockDecision {
664 status: status as u32,
665 body,
666 headers: headers.unwrap_or_default(),
667 }))
668 }
669 Decision::Redirect { url, status } => {
670 Some(grpc::agent_response::Decision::Redirect(grpc::RedirectDecision {
671 url,
672 status: status as u32,
673 }))
674 }
675 Decision::Challenge { challenge_type, params } => {
676 Some(grpc::agent_response::Decision::Challenge(grpc::ChallengeDecision {
677 challenge_type,
678 params,
679 }))
680 }
681 };
682
683 let request_headers: Vec<grpc::HeaderOp> = response.request_headers
684 .into_iter()
685 .map(Self::convert_header_op_to_grpc)
686 .collect();
687
688 let response_headers: Vec<grpc::HeaderOp> = response.response_headers
689 .into_iter()
690 .map(Self::convert_header_op_to_grpc)
691 .collect();
692
693 let audit = Some(grpc::AuditMetadata {
694 tags: response.audit.tags,
695 rule_ids: response.audit.rule_ids,
696 confidence: response.audit.confidence,
697 reason_codes: response.audit.reason_codes,
698 custom: response.audit.custom.into_iter().map(|(k, v)| {
699 (k, v.to_string())
700 }).collect(),
701 });
702
703 grpc::AgentResponse {
704 version: PROTOCOL_VERSION,
705 decision,
706 request_headers,
707 response_headers,
708 routing_metadata: response.routing_metadata,
709 audit,
710 }
711 }
712
713 fn convert_header_op_to_grpc(op: HeaderOp) -> grpc::HeaderOp {
715 let operation = match op {
716 HeaderOp::Set { name, value } => {
717 Some(grpc::header_op::Operation::Set(grpc::SetHeader { name, value }))
718 }
719 HeaderOp::Add { name, value } => {
720 Some(grpc::header_op::Operation::Add(grpc::AddHeader { name, value }))
721 }
722 HeaderOp::Remove { name } => {
723 Some(grpc::header_op::Operation::Remove(grpc::RemoveHeader { name }))
724 }
725 };
726 grpc::HeaderOp { operation }
727 }
728}