sentinel_agent_protocol/
server.rs

1//! Agent server for implementing external agents.
2//!
3//! Supports two transport mechanisms:
4//! - Unix domain sockets (length-prefixed JSON)
5//! - gRPC (Protocol Buffers over HTTP/2)
6
7use async_trait::async_trait;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11use tokio::net::{UnixListener, UnixStream};
12use tokio_stream::StreamExt;
13use tonic::{Request, Response, Status, Streaming};
14use tracing::{debug, error, info, trace, warn};
15
16use crate::errors::AgentProtocolError;
17use crate::grpc::{self, agent_processor_server::AgentProcessor, agent_processor_server::AgentProcessorServer};
18use crate::protocol::{
19    AgentRequest, AgentResponse, AuditMetadata, Decision, EventType, HeaderOp, RequestBodyChunkEvent,
20    RequestCompleteEvent, RequestHeadersEvent, RequestMetadata, ResponseBodyChunkEvent, ResponseHeadersEvent,
21    MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
22};
23
24/// Agent server for testing and reference implementations
25pub struct AgentServer {
26    /// Agent ID
27    id: String,
28    /// Unix socket path
29    socket_path: std::path::PathBuf,
30    /// Request handler
31    handler: Arc<dyn AgentHandler>,
32}
33
34/// Trait for implementing agent logic
35#[async_trait]
36pub trait AgentHandler: Send + Sync {
37    /// Handle a request headers event
38    async fn on_request_headers(&self, _event: RequestHeadersEvent) -> AgentResponse {
39        AgentResponse::default_allow()
40    }
41
42    /// Handle a request body chunk event
43    async fn on_request_body_chunk(&self, _event: RequestBodyChunkEvent) -> AgentResponse {
44        AgentResponse::default_allow()
45    }
46
47    /// Handle a response headers event
48    async fn on_response_headers(&self, _event: ResponseHeadersEvent) -> AgentResponse {
49        AgentResponse::default_allow()
50    }
51
52    /// Handle a response body chunk event
53    async fn on_response_body_chunk(&self, _event: ResponseBodyChunkEvent) -> AgentResponse {
54        AgentResponse::default_allow()
55    }
56
57    /// Handle a request complete event
58    async fn on_request_complete(&self, _event: RequestCompleteEvent) -> AgentResponse {
59        AgentResponse::default_allow()
60    }
61}
62
63impl AgentServer {
64    /// Create a new agent server
65    pub fn new(
66        id: impl Into<String>,
67        socket_path: impl Into<std::path::PathBuf>,
68        handler: Box<dyn AgentHandler>,
69    ) -> Self {
70        let id = id.into();
71        let socket_path = socket_path.into();
72
73        debug!(
74            agent_id = %id,
75            socket_path = %socket_path.display(),
76            "Creating agent server"
77        );
78
79        Self {
80            id,
81            socket_path,
82            handler: Arc::from(handler),
83        }
84    }
85
86    /// Start the agent server
87    pub async fn run(&self) -> Result<(), AgentProtocolError> {
88        // Remove existing socket file if it exists
89        if self.socket_path.exists() {
90            trace!(
91                agent_id = %self.id,
92                socket_path = %self.socket_path.display(),
93                "Removing existing socket file"
94            );
95            std::fs::remove_file(&self.socket_path)?;
96        }
97
98        // Create Unix socket listener
99        let listener = UnixListener::bind(&self.socket_path)?;
100
101        info!(
102            agent_id = %self.id,
103            socket_path = %self.socket_path.display(),
104            "Agent server listening"
105        );
106
107        loop {
108            match listener.accept().await {
109                Ok((stream, _addr)) => {
110                    trace!(
111                        agent_id = %self.id,
112                        "Accepted new connection"
113                    );
114                    let handler = Arc::clone(&self.handler);
115                    let agent_id = self.id.clone();
116                    tokio::spawn(async move {
117                        if let Err(e) = Self::handle_connection(stream, handler.as_ref()).await {
118                            error!(
119                                agent_id = %agent_id,
120                                error = %e,
121                                "Error handling agent connection"
122                            );
123                        }
124                    });
125                }
126                Err(e) => {
127                    error!(
128                        agent_id = %self.id,
129                        error = %e,
130                        "Failed to accept connection"
131                    );
132                }
133            }
134        }
135    }
136
137    /// Handle a single connection
138    async fn handle_connection(
139        mut stream: UnixStream,
140        handler: &dyn AgentHandler,
141    ) -> Result<(), AgentProtocolError> {
142        trace!("Starting connection handler");
143
144        loop {
145            // Read message length
146            let mut len_bytes = [0u8; 4];
147            match stream.read_exact(&mut len_bytes).await {
148                Ok(_) => {}
149                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
150                    // Client disconnected
151                    trace!("Client disconnected (EOF)");
152                    return Ok(());
153                }
154                Err(e) => {
155                    error!(error = %e, "Error reading message length");
156                    return Err(e.into());
157                }
158            }
159
160            let message_len = u32::from_be_bytes(len_bytes) as usize;
161
162            // Check message size
163            if message_len > MAX_MESSAGE_SIZE {
164                warn!(
165                    message_len = message_len,
166                    max_size = MAX_MESSAGE_SIZE,
167                    "Message too large"
168                );
169                return Err(AgentProtocolError::MessageTooLarge {
170                    size: message_len,
171                    max: MAX_MESSAGE_SIZE,
172                });
173            }
174
175            trace!(message_len = message_len, "Reading message data");
176
177            // Read message data
178            let mut buffer = vec![0u8; message_len];
179            stream.read_exact(&mut buffer).await?;
180
181            // Parse request
182            let request: AgentRequest = serde_json::from_slice(&buffer)
183                .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
184
185            trace!(
186                event_type = ?request.event_type,
187                version = request.version,
188                "Received agent request"
189            );
190
191            // Handle request based on event type
192            let response = match request.event_type {
193                EventType::RequestHeaders => {
194                    let event: RequestHeadersEvent = serde_json::from_value(request.payload)
195                        .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
196                    trace!(
197                        correlation_id = %event.metadata.correlation_id,
198                        method = %event.method,
199                        uri = %event.uri,
200                        "Processing request_headers event"
201                    );
202                    handler.on_request_headers(event).await
203                }
204                EventType::RequestBodyChunk => {
205                    let event: RequestBodyChunkEvent = serde_json::from_value(request.payload)
206                        .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
207                    trace!(
208                        correlation_id = %event.correlation_id,
209                        is_last = event.is_last,
210                        data_len = event.data.len(),
211                        "Processing request_body_chunk event"
212                    );
213                    handler.on_request_body_chunk(event).await
214                }
215                EventType::ResponseHeaders => {
216                    let event: ResponseHeadersEvent = serde_json::from_value(request.payload)
217                        .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
218                    trace!(
219                        correlation_id = %event.correlation_id,
220                        status = event.status,
221                        "Processing response_headers event"
222                    );
223                    handler.on_response_headers(event).await
224                }
225                EventType::ResponseBodyChunk => {
226                    let event: ResponseBodyChunkEvent = serde_json::from_value(request.payload)
227                        .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
228                    trace!(
229                        correlation_id = %event.correlation_id,
230                        is_last = event.is_last,
231                        data_len = event.data.len(),
232                        "Processing response_body_chunk event"
233                    );
234                    handler.on_response_body_chunk(event).await
235                }
236                EventType::RequestComplete => {
237                    let event: RequestCompleteEvent = serde_json::from_value(request.payload)
238                        .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
239                    trace!(
240                        correlation_id = %event.correlation_id,
241                        status = event.status,
242                        duration_ms = event.duration_ms,
243                        "Processing request_complete event"
244                    );
245                    handler.on_request_complete(event).await
246                }
247            };
248
249            trace!(
250                decision = ?response.decision,
251                "Sending agent response"
252            );
253
254            // Send response
255            let response_bytes = serde_json::to_vec(&response)
256                .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
257
258            // Write message length
259            let len_bytes = (response_bytes.len() as u32).to_be_bytes();
260            stream.write_all(&len_bytes).await?;
261            // Write message data
262            stream.write_all(&response_bytes).await?;
263            stream.flush().await?;
264
265            trace!(response_len = response_bytes.len(), "Response sent");
266        }
267    }
268}
269
270/// Reference implementation: Echo agent (for testing)
271pub struct EchoAgent;
272
273#[async_trait]
274impl AgentHandler for EchoAgent {
275    async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
276        debug!(
277            "Echo agent: request headers for {}",
278            event.metadata.correlation_id
279        );
280
281        // Echo back correlation ID as a header
282        AgentResponse::default_allow()
283            .add_request_header(HeaderOp::Set {
284                name: "X-Echo-Agent".to_string(),
285                value: event.metadata.correlation_id.clone(),
286            })
287            .with_audit(AuditMetadata {
288                tags: vec!["echo".to_string()],
289                ..Default::default()
290            })
291    }
292}
293
294/// Reference implementation: Denylist agent
295pub struct DenylistAgent {
296    blocked_paths: Vec<String>,
297    blocked_ips: Vec<String>,
298}
299
300impl DenylistAgent {
301    pub fn new(blocked_paths: Vec<String>, blocked_ips: Vec<String>) -> Self {
302        Self {
303            blocked_paths,
304            blocked_ips,
305        }
306    }
307}
308
309#[async_trait]
310impl AgentHandler for DenylistAgent {
311    async fn on_request_headers(&self, event: RequestHeadersEvent) -> AgentResponse {
312        trace!(
313            correlation_id = %event.metadata.correlation_id,
314            uri = %event.uri,
315            client_ip = %event.metadata.client_ip,
316            "Denylist agent checking request"
317        );
318
319        // Check if path is blocked
320        for blocked_path in &self.blocked_paths {
321            if event.uri.starts_with(blocked_path) {
322                debug!(
323                    correlation_id = %event.metadata.correlation_id,
324                    blocked_path = %blocked_path,
325                    uri = %event.uri,
326                    "Blocking request: path matched denylist"
327                );
328                return AgentResponse::block(403, Some("Forbidden path".to_string())).with_audit(
329                    AuditMetadata {
330                        tags: vec!["denylist".to_string(), "blocked_path".to_string()],
331                        reason_codes: vec!["PATH_BLOCKED".to_string()],
332                        ..Default::default()
333                    },
334                );
335            }
336        }
337
338        // Check if IP is blocked
339        if self.blocked_ips.contains(&event.metadata.client_ip) {
340            debug!(
341                correlation_id = %event.metadata.correlation_id,
342                client_ip = %event.metadata.client_ip,
343                "Blocking request: IP matched denylist"
344            );
345            return AgentResponse::block(403, Some("Forbidden IP".to_string())).with_audit(
346                AuditMetadata {
347                    tags: vec!["denylist".to_string(), "blocked_ip".to_string()],
348                    reason_codes: vec!["IP_BLOCKED".to_string()],
349                    ..Default::default()
350                },
351            );
352        }
353
354        trace!(
355            correlation_id = %event.metadata.correlation_id,
356            "Request allowed by denylist agent"
357        );
358        AgentResponse::default_allow()
359    }
360}
361
362// ============================================================================
363// gRPC Server Implementation
364// ============================================================================
365
366/// gRPC agent server for implementing external agents
367pub struct GrpcAgentServer {
368    /// Agent ID
369    id: String,
370    /// Request handler
371    handler: Arc<dyn AgentHandler>,
372}
373
374impl GrpcAgentServer {
375    /// Create a new gRPC agent server
376    pub fn new(id: impl Into<String>, handler: Box<dyn AgentHandler>) -> Self {
377        let id = id.into();
378        debug!(agent_id = %id, "Creating gRPC agent server");
379        Self {
380            id,
381            handler: Arc::from(handler),
382        }
383    }
384
385    /// Get the tonic service for this agent
386    pub fn into_service(self) -> AgentProcessorServer<GrpcAgentHandler> {
387        trace!(agent_id = %self.id, "Converting to tonic service");
388        AgentProcessorServer::new(GrpcAgentHandler {
389            id: self.id,
390            handler: self.handler,
391        })
392    }
393
394    /// Start the gRPC server on the given address
395    pub async fn run(self, addr: SocketAddr) -> Result<(), AgentProtocolError> {
396        info!(
397            agent_id = %self.id,
398            address = %addr,
399            "gRPC agent server listening"
400        );
401
402        tonic::transport::Server::builder()
403            .add_service(self.into_service())
404            .serve(addr)
405            .await
406            .map_err(|e| {
407                error!(error = %e, "gRPC server error");
408                AgentProtocolError::ConnectionFailed(format!("gRPC server error: {}", e))
409            })
410    }
411}
412
413/// Internal handler that implements the gRPC AgentProcessor trait
414pub struct GrpcAgentHandler {
415    id: String,
416    handler: Arc<dyn AgentHandler>,
417}
418
419#[tonic::async_trait]
420impl AgentProcessor for GrpcAgentHandler {
421    async fn process_event(
422        &self,
423        request: Request<grpc::AgentRequest>,
424    ) -> Result<Response<grpc::AgentResponse>, Status> {
425        let grpc_request = request.into_inner();
426
427        trace!(
428            agent_id = %self.id,
429            event_type = grpc_request.event_type,
430            version = grpc_request.version,
431            "Processing gRPC event"
432        );
433
434        // Convert gRPC event to internal event and dispatch
435        let response = match grpc_request.event {
436            Some(grpc::agent_request::Event::RequestHeaders(e)) => {
437                let event = Self::convert_request_headers_from_grpc(e);
438                trace!(
439                    agent_id = %self.id,
440                    correlation_id = %event.metadata.correlation_id,
441                    "Processing request_headers via gRPC"
442                );
443                self.handler.on_request_headers(event).await
444            }
445            Some(grpc::agent_request::Event::RequestBodyChunk(e)) => {
446                let event = Self::convert_request_body_chunk_from_grpc(e);
447                trace!(
448                    agent_id = %self.id,
449                    correlation_id = %event.correlation_id,
450                    "Processing request_body_chunk via gRPC"
451                );
452                self.handler.on_request_body_chunk(event).await
453            }
454            Some(grpc::agent_request::Event::ResponseHeaders(e)) => {
455                let event = Self::convert_response_headers_from_grpc(e);
456                trace!(
457                    agent_id = %self.id,
458                    correlation_id = %event.correlation_id,
459                    "Processing response_headers via gRPC"
460                );
461                self.handler.on_response_headers(event).await
462            }
463            Some(grpc::agent_request::Event::ResponseBodyChunk(e)) => {
464                let event = Self::convert_response_body_chunk_from_grpc(e);
465                trace!(
466                    agent_id = %self.id,
467                    correlation_id = %event.correlation_id,
468                    "Processing response_body_chunk via gRPC"
469                );
470                self.handler.on_response_body_chunk(event).await
471            }
472            Some(grpc::agent_request::Event::RequestComplete(e)) => {
473                let event = Self::convert_request_complete_from_grpc(e);
474                trace!(
475                    agent_id = %self.id,
476                    correlation_id = %event.correlation_id,
477                    "Processing request_complete via gRPC"
478                );
479                self.handler.on_request_complete(event).await
480            }
481            None => {
482                warn!(agent_id = %self.id, "Missing event in gRPC request");
483                return Err(Status::invalid_argument("Missing event in request"));
484            }
485        };
486
487        trace!(
488            agent_id = %self.id,
489            decision = ?response.decision,
490            "Returning gRPC response"
491        );
492
493        // Convert internal response to gRPC response
494        let grpc_response = Self::convert_response_to_grpc(response);
495        Ok(Response::new(grpc_response))
496    }
497
498    async fn process_event_stream(
499        &self,
500        request: Request<Streaming<grpc::AgentRequest>>,
501    ) -> Result<Response<grpc::AgentResponse>, Status> {
502        let mut stream = request.into_inner();
503
504        trace!(agent_id = %self.id, "Processing gRPC event stream");
505
506        // Process all events in the stream, returning the final response
507        let mut final_response = AgentResponse::default_allow();
508        let mut event_count = 0u32;
509
510        while let Some(result) = stream.next().await {
511            let grpc_request = result.map_err(|e| {
512                error!(agent_id = %self.id, error = %e, "Stream error");
513                Status::internal(format!("Stream error: {}", e))
514            })?;
515
516            event_count += 1;
517            trace!(
518                agent_id = %self.id,
519                event_count = event_count,
520                "Processing stream event"
521            );
522
523            let response = match grpc_request.event {
524                Some(grpc::agent_request::Event::RequestHeaders(e)) => {
525                    let event = Self::convert_request_headers_from_grpc(e);
526                    self.handler.on_request_headers(event).await
527                }
528                Some(grpc::agent_request::Event::RequestBodyChunk(e)) => {
529                    let event = Self::convert_request_body_chunk_from_grpc(e);
530                    self.handler.on_request_body_chunk(event).await
531                }
532                Some(grpc::agent_request::Event::ResponseHeaders(e)) => {
533                    let event = Self::convert_response_headers_from_grpc(e);
534                    self.handler.on_response_headers(event).await
535                }
536                Some(grpc::agent_request::Event::ResponseBodyChunk(e)) => {
537                    let event = Self::convert_response_body_chunk_from_grpc(e);
538                    self.handler.on_response_body_chunk(event).await
539                }
540                Some(grpc::agent_request::Event::RequestComplete(e)) => {
541                    let event = Self::convert_request_complete_from_grpc(e);
542                    self.handler.on_request_complete(event).await
543                }
544                None => continue,
545            };
546
547            // If any event results in a block/redirect, that becomes the final response
548            if !matches!(response.decision, Decision::Allow) {
549                debug!(
550                    agent_id = %self.id,
551                    decision = ?response.decision,
552                    event_count = event_count,
553                    "Non-allow decision in stream, terminating early"
554                );
555                final_response = response;
556                break;
557            }
558            final_response = response;
559        }
560
561        trace!(
562            agent_id = %self.id,
563            event_count = event_count,
564            decision = ?final_response.decision,
565            "Stream processing complete"
566        );
567
568        let grpc_response = Self::convert_response_to_grpc(final_response);
569        Ok(Response::new(grpc_response))
570    }
571}
572
573impl GrpcAgentHandler {
574    /// Convert gRPC RequestHeadersEvent to internal format
575    fn convert_request_headers_from_grpc(e: grpc::RequestHeadersEvent) -> RequestHeadersEvent {
576        RequestHeadersEvent {
577            metadata: Self::convert_metadata_from_grpc(e.metadata),
578            method: e.method,
579            uri: e.uri,
580            headers: e.headers.into_iter().map(|(k, v)| (k, v.values)).collect(),
581        }
582    }
583
584    /// Convert gRPC RequestBodyChunkEvent to internal format
585    fn convert_request_body_chunk_from_grpc(e: grpc::RequestBodyChunkEvent) -> RequestBodyChunkEvent {
586        RequestBodyChunkEvent {
587            correlation_id: e.correlation_id,
588            data: String::from_utf8_lossy(&e.data).to_string(),
589            is_last: e.is_last,
590            total_size: e.total_size.map(|s| s as usize),
591        }
592    }
593
594    /// Convert gRPC ResponseHeadersEvent to internal format
595    fn convert_response_headers_from_grpc(e: grpc::ResponseHeadersEvent) -> ResponseHeadersEvent {
596        ResponseHeadersEvent {
597            correlation_id: e.correlation_id,
598            status: e.status as u16,
599            headers: e.headers.into_iter().map(|(k, v)| (k, v.values)).collect(),
600        }
601    }
602
603    /// Convert gRPC ResponseBodyChunkEvent to internal format
604    fn convert_response_body_chunk_from_grpc(e: grpc::ResponseBodyChunkEvent) -> ResponseBodyChunkEvent {
605        ResponseBodyChunkEvent {
606            correlation_id: e.correlation_id,
607            data: String::from_utf8_lossy(&e.data).to_string(),
608            is_last: e.is_last,
609            total_size: e.total_size.map(|s| s as usize),
610        }
611    }
612
613    /// Convert gRPC RequestCompleteEvent to internal format
614    fn convert_request_complete_from_grpc(e: grpc::RequestCompleteEvent) -> RequestCompleteEvent {
615        RequestCompleteEvent {
616            correlation_id: e.correlation_id,
617            status: e.status as u16,
618            duration_ms: e.duration_ms,
619            request_body_size: e.request_body_size as usize,
620            response_body_size: e.response_body_size as usize,
621            upstream_attempts: e.upstream_attempts,
622            error: e.error,
623        }
624    }
625
626    /// Convert gRPC metadata to internal format
627    fn convert_metadata_from_grpc(metadata: Option<grpc::RequestMetadata>) -> RequestMetadata {
628        match metadata {
629            Some(m) => RequestMetadata {
630                correlation_id: m.correlation_id,
631                request_id: m.request_id,
632                client_ip: m.client_ip,
633                client_port: m.client_port as u16,
634                server_name: m.server_name,
635                protocol: m.protocol,
636                tls_version: m.tls_version,
637                tls_cipher: m.tls_cipher,
638                route_id: m.route_id,
639                upstream_id: m.upstream_id,
640                timestamp: m.timestamp,
641            },
642            None => RequestMetadata {
643                correlation_id: String::new(),
644                request_id: String::new(),
645                client_ip: String::new(),
646                client_port: 0,
647                server_name: None,
648                protocol: String::new(),
649                tls_version: None,
650                tls_cipher: None,
651                route_id: None,
652                upstream_id: None,
653                timestamp: String::new(),
654            },
655        }
656    }
657
658    /// Convert internal response to gRPC format
659    fn convert_response_to_grpc(response: AgentResponse) -> grpc::AgentResponse {
660        let decision = match response.decision {
661            Decision::Allow => Some(grpc::agent_response::Decision::Allow(grpc::AllowDecision {})),
662            Decision::Block { status, body, headers } => {
663                Some(grpc::agent_response::Decision::Block(grpc::BlockDecision {
664                    status: status as u32,
665                    body,
666                    headers: headers.unwrap_or_default(),
667                }))
668            }
669            Decision::Redirect { url, status } => {
670                Some(grpc::agent_response::Decision::Redirect(grpc::RedirectDecision {
671                    url,
672                    status: status as u32,
673                }))
674            }
675            Decision::Challenge { challenge_type, params } => {
676                Some(grpc::agent_response::Decision::Challenge(grpc::ChallengeDecision {
677                    challenge_type,
678                    params,
679                }))
680            }
681        };
682
683        let request_headers: Vec<grpc::HeaderOp> = response.request_headers
684            .into_iter()
685            .map(Self::convert_header_op_to_grpc)
686            .collect();
687
688        let response_headers: Vec<grpc::HeaderOp> = response.response_headers
689            .into_iter()
690            .map(Self::convert_header_op_to_grpc)
691            .collect();
692
693        let audit = Some(grpc::AuditMetadata {
694            tags: response.audit.tags,
695            rule_ids: response.audit.rule_ids,
696            confidence: response.audit.confidence,
697            reason_codes: response.audit.reason_codes,
698            custom: response.audit.custom.into_iter().map(|(k, v)| {
699                (k, v.to_string())
700            }).collect(),
701        });
702
703        grpc::AgentResponse {
704            version: PROTOCOL_VERSION,
705            decision,
706            request_headers,
707            response_headers,
708            routing_metadata: response.routing_metadata,
709            audit,
710        }
711    }
712
713    /// Convert internal header operation to gRPC format
714    fn convert_header_op_to_grpc(op: HeaderOp) -> grpc::HeaderOp {
715        let operation = match op {
716            HeaderOp::Set { name, value } => {
717                Some(grpc::header_op::Operation::Set(grpc::SetHeader { name, value }))
718            }
719            HeaderOp::Add { name, value } => {
720                Some(grpc::header_op::Operation::Add(grpc::AddHeader { name, value }))
721            }
722            HeaderOp::Remove { name } => {
723                Some(grpc::header_op::Operation::Remove(grpc::RemoveHeader { name }))
724            }
725        };
726        grpc::HeaderOp { operation }
727    }
728}