sentinel_agent_protocol/
client.rs

1//! Agent client for communicating with external agents.
2//!
3//! Supports two transport mechanisms:
4//! - Unix domain sockets (length-prefixed JSON)
5//! - gRPC (Protocol Buffers over HTTP/2)
6
7use serde::Serialize;
8use std::time::Duration;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::UnixStream;
11use tonic::transport::Channel;
12use tracing::{debug, error, trace};
13
14use crate::errors::AgentProtocolError;
15use crate::grpc::{self, agent_processor_client::AgentProcessorClient};
16use crate::protocol::{
17    AgentRequest, AgentResponse, AuditMetadata, BodyMutation, Decision, EventType, HeaderOp,
18    RequestBodyChunkEvent, RequestCompleteEvent, RequestHeadersEvent, RequestMetadata,
19    ResponseBodyChunkEvent, ResponseHeadersEvent, WebSocketDecision, WebSocketFrameEvent,
20    MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
21};
22
23/// Agent client for communicating with external agents
24pub struct AgentClient {
25    /// Agent ID
26    id: String,
27    /// Connection to agent
28    connection: AgentConnection,
29    /// Timeout for agent calls
30    timeout: Duration,
31    /// Maximum retries
32    #[allow(dead_code)]
33    max_retries: u32,
34}
35
36/// Agent connection type
37enum AgentConnection {
38    UnixSocket(UnixStream),
39    Grpc(AgentProcessorClient<Channel>),
40}
41
42impl AgentClient {
43    /// Create a new Unix socket agent client
44    pub async fn unix_socket(
45        id: impl Into<String>,
46        path: impl AsRef<std::path::Path>,
47        timeout: Duration,
48    ) -> Result<Self, AgentProtocolError> {
49        let id = id.into();
50        let path = path.as_ref();
51
52        trace!(
53            agent_id = %id,
54            socket_path = %path.display(),
55            timeout_ms = timeout.as_millis() as u64,
56            "Connecting to agent via Unix socket"
57        );
58
59        let stream = UnixStream::connect(path).await.map_err(|e| {
60            error!(
61                agent_id = %id,
62                socket_path = %path.display(),
63                error = %e,
64                "Failed to connect to agent via Unix socket"
65            );
66            AgentProtocolError::ConnectionFailed(e.to_string())
67        })?;
68
69        debug!(
70            agent_id = %id,
71            socket_path = %path.display(),
72            "Connected to agent via Unix socket"
73        );
74
75        Ok(Self {
76            id,
77            connection: AgentConnection::UnixSocket(stream),
78            timeout,
79            max_retries: 3,
80        })
81    }
82
83    /// Create a new gRPC agent client
84    ///
85    /// # Arguments
86    /// * `id` - Agent identifier
87    /// * `address` - gRPC server address (e.g., "http://localhost:50051")
88    /// * `timeout` - Timeout for agent calls
89    pub async fn grpc(
90        id: impl Into<String>,
91        address: impl Into<String>,
92        timeout: Duration,
93    ) -> Result<Self, AgentProtocolError> {
94        let id = id.into();
95        let address = address.into();
96
97        trace!(
98            agent_id = %id,
99            address = %address,
100            timeout_ms = timeout.as_millis() as u64,
101            "Connecting to agent via gRPC"
102        );
103
104        let channel = Channel::from_shared(address.clone())
105            .map_err(|e| {
106                error!(
107                    agent_id = %id,
108                    address = %address,
109                    error = %e,
110                    "Invalid gRPC URI"
111                );
112                AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e))
113            })?
114            .timeout(timeout)
115            .connect()
116            .await
117            .map_err(|e| {
118                error!(
119                    agent_id = %id,
120                    address = %address,
121                    error = %e,
122                    "Failed to connect to agent via gRPC"
123                );
124                AgentProtocolError::ConnectionFailed(format!("gRPC connect failed: {}", e))
125            })?;
126
127        let client = AgentProcessorClient::new(channel);
128
129        debug!(
130            agent_id = %id,
131            address = %address,
132            "Connected to agent via gRPC"
133        );
134
135        Ok(Self {
136            id,
137            connection: AgentConnection::Grpc(client),
138            timeout,
139            max_retries: 3,
140        })
141    }
142
143    /// Get the agent ID
144    #[allow(dead_code)]
145    pub fn id(&self) -> &str {
146        &self.id
147    }
148
149    /// Send an event to the agent and get a response
150    pub async fn send_event(
151        &mut self,
152        event_type: EventType,
153        payload: impl Serialize,
154    ) -> Result<AgentResponse, AgentProtocolError> {
155        match &mut self.connection {
156            AgentConnection::UnixSocket(_) => {
157                self.send_event_unix_socket(event_type, payload).await
158            }
159            AgentConnection::Grpc(_) => self.send_event_grpc(event_type, payload).await,
160        }
161    }
162
163    /// Send event via Unix socket (length-prefixed JSON)
164    async fn send_event_unix_socket(
165        &mut self,
166        event_type: EventType,
167        payload: impl Serialize,
168    ) -> Result<AgentResponse, AgentProtocolError> {
169        let request = AgentRequest {
170            version: PROTOCOL_VERSION,
171            event_type,
172            payload: serde_json::to_value(payload)
173                .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
174        };
175
176        // Serialize request
177        let request_bytes = serde_json::to_vec(&request)
178            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
179
180        // Check message size
181        if request_bytes.len() > MAX_MESSAGE_SIZE {
182            return Err(AgentProtocolError::MessageTooLarge {
183                size: request_bytes.len(),
184                max: MAX_MESSAGE_SIZE,
185            });
186        }
187
188        // Send with timeout
189        let response = tokio::time::timeout(self.timeout, async {
190            self.send_raw_unix(&request_bytes).await?;
191            self.receive_raw_unix().await
192        })
193        .await
194        .map_err(|_| AgentProtocolError::Timeout(self.timeout))??;
195
196        // Parse response
197        let agent_response: AgentResponse = serde_json::from_slice(&response)
198            .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
199
200        // Verify protocol version
201        if agent_response.version != PROTOCOL_VERSION {
202            return Err(AgentProtocolError::VersionMismatch {
203                expected: PROTOCOL_VERSION,
204                actual: agent_response.version,
205            });
206        }
207
208        Ok(agent_response)
209    }
210
211    /// Send event via gRPC
212    async fn send_event_grpc(
213        &mut self,
214        event_type: EventType,
215        payload: impl Serialize,
216    ) -> Result<AgentResponse, AgentProtocolError> {
217        // Build request first (doesn't need mutable borrow)
218        let grpc_request = Self::build_grpc_request(event_type, payload)?;
219
220        let AgentConnection::Grpc(client) = &mut self.connection else {
221            unreachable!()
222        };
223
224        // Send with timeout
225        let response = tokio::time::timeout(self.timeout, client.process_event(grpc_request))
226            .await
227            .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
228            .map_err(|e| {
229                AgentProtocolError::ConnectionFailed(format!("gRPC call failed: {}", e))
230            })?;
231
232        // Convert gRPC response to internal format
233        Self::convert_grpc_response(response.into_inner())
234    }
235
236    /// Build a gRPC request from internal types
237    fn build_grpc_request(
238        event_type: EventType,
239        payload: impl Serialize,
240    ) -> Result<grpc::AgentRequest, AgentProtocolError> {
241        let payload_json = serde_json::to_value(&payload)
242            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
243
244        let grpc_event_type = match event_type {
245            EventType::Configure => {
246                return Err(AgentProtocolError::Serialization(
247                    "Configure events are not supported via gRPC".to_string(),
248                ))
249            }
250            EventType::RequestHeaders => grpc::EventType::RequestHeaders,
251            EventType::RequestBodyChunk => grpc::EventType::RequestBodyChunk,
252            EventType::ResponseHeaders => grpc::EventType::ResponseHeaders,
253            EventType::ResponseBodyChunk => grpc::EventType::ResponseBodyChunk,
254            EventType::RequestComplete => grpc::EventType::RequestComplete,
255            EventType::WebSocketFrame => grpc::EventType::WebsocketFrame,
256        };
257
258        let event = match event_type {
259            EventType::Configure => unreachable!("Configure handled above"),
260            EventType::RequestHeaders => {
261                let event: RequestHeadersEvent = serde_json::from_value(payload_json)
262                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
263                grpc::agent_request::Event::RequestHeaders(grpc::RequestHeadersEvent {
264                    metadata: Some(Self::convert_metadata_to_grpc(&event.metadata)),
265                    method: event.method,
266                    uri: event.uri,
267                    headers: event
268                        .headers
269                        .into_iter()
270                        .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
271                        .collect(),
272                })
273            }
274            EventType::RequestBodyChunk => {
275                let event: RequestBodyChunkEvent = serde_json::from_value(payload_json)
276                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
277                grpc::agent_request::Event::RequestBodyChunk(grpc::RequestBodyChunkEvent {
278                    correlation_id: event.correlation_id,
279                    data: event.data.into_bytes(),
280                    is_last: event.is_last,
281                    total_size: event.total_size.map(|s| s as u64),
282                    chunk_index: event.chunk_index,
283                    bytes_received: event.bytes_received as u64,
284                })
285            }
286            EventType::ResponseHeaders => {
287                let event: ResponseHeadersEvent = serde_json::from_value(payload_json)
288                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
289                grpc::agent_request::Event::ResponseHeaders(grpc::ResponseHeadersEvent {
290                    correlation_id: event.correlation_id,
291                    status: event.status as u32,
292                    headers: event
293                        .headers
294                        .into_iter()
295                        .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
296                        .collect(),
297                })
298            }
299            EventType::ResponseBodyChunk => {
300                let event: ResponseBodyChunkEvent = serde_json::from_value(payload_json)
301                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
302                grpc::agent_request::Event::ResponseBodyChunk(grpc::ResponseBodyChunkEvent {
303                    correlation_id: event.correlation_id,
304                    data: event.data.into_bytes(),
305                    is_last: event.is_last,
306                    total_size: event.total_size.map(|s| s as u64),
307                    chunk_index: event.chunk_index,
308                    bytes_sent: event.bytes_sent as u64,
309                })
310            }
311            EventType::RequestComplete => {
312                let event: RequestCompleteEvent = serde_json::from_value(payload_json)
313                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
314                grpc::agent_request::Event::RequestComplete(grpc::RequestCompleteEvent {
315                    correlation_id: event.correlation_id,
316                    status: event.status as u32,
317                    duration_ms: event.duration_ms,
318                    request_body_size: event.request_body_size as u64,
319                    response_body_size: event.response_body_size as u64,
320                    upstream_attempts: event.upstream_attempts,
321                    error: event.error,
322                })
323            }
324            EventType::WebSocketFrame => {
325                use base64::{engine::general_purpose::STANDARD, Engine as _};
326                let event: WebSocketFrameEvent = serde_json::from_value(payload_json)
327                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
328                grpc::agent_request::Event::WebsocketFrame(grpc::WebSocketFrameEvent {
329                    correlation_id: event.correlation_id,
330                    opcode: event.opcode,
331                    data: STANDARD.decode(&event.data).unwrap_or_default(),
332                    client_to_server: event.client_to_server,
333                    frame_index: event.frame_index,
334                    fin: event.fin,
335                    route_id: event.route_id,
336                    client_ip: event.client_ip,
337                })
338            }
339        };
340
341        Ok(grpc::AgentRequest {
342            version: PROTOCOL_VERSION,
343            event_type: grpc_event_type as i32,
344            event: Some(event),
345        })
346    }
347
348    /// Convert internal metadata to gRPC format
349    fn convert_metadata_to_grpc(metadata: &RequestMetadata) -> grpc::RequestMetadata {
350        grpc::RequestMetadata {
351            correlation_id: metadata.correlation_id.clone(),
352            request_id: metadata.request_id.clone(),
353            client_ip: metadata.client_ip.clone(),
354            client_port: metadata.client_port as u32,
355            server_name: metadata.server_name.clone(),
356            protocol: metadata.protocol.clone(),
357            tls_version: metadata.tls_version.clone(),
358            tls_cipher: metadata.tls_cipher.clone(),
359            route_id: metadata.route_id.clone(),
360            upstream_id: metadata.upstream_id.clone(),
361            timestamp: metadata.timestamp.clone(),
362        }
363    }
364
365    /// Convert gRPC response to internal format
366    fn convert_grpc_response(
367        response: grpc::AgentResponse,
368    ) -> Result<AgentResponse, AgentProtocolError> {
369        let decision = match response.decision {
370            Some(grpc::agent_response::Decision::Allow(_)) => Decision::Allow,
371            Some(grpc::agent_response::Decision::Block(b)) => Decision::Block {
372                status: b.status as u16,
373                body: b.body,
374                headers: if b.headers.is_empty() {
375                    None
376                } else {
377                    Some(b.headers)
378                },
379            },
380            Some(grpc::agent_response::Decision::Redirect(r)) => Decision::Redirect {
381                url: r.url,
382                status: r.status as u16,
383            },
384            Some(grpc::agent_response::Decision::Challenge(c)) => Decision::Challenge {
385                challenge_type: c.challenge_type,
386                params: c.params,
387            },
388            None => Decision::Allow, // Default to allow if no decision
389        };
390
391        let request_headers: Vec<HeaderOp> = response
392            .request_headers
393            .into_iter()
394            .filter_map(Self::convert_header_op_from_grpc)
395            .collect();
396
397        let response_headers: Vec<HeaderOp> = response
398            .response_headers
399            .into_iter()
400            .filter_map(Self::convert_header_op_from_grpc)
401            .collect();
402
403        let audit = response.audit.map(|a| AuditMetadata {
404            tags: a.tags,
405            rule_ids: a.rule_ids,
406            confidence: a.confidence,
407            reason_codes: a.reason_codes,
408            custom: a
409                .custom
410                .into_iter()
411                .map(|(k, v)| (k, serde_json::Value::String(v)))
412                .collect(),
413        });
414
415        // Convert body mutations
416        let request_body_mutation = response.request_body_mutation.map(|m| BodyMutation {
417            data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
418            chunk_index: m.chunk_index,
419        });
420
421        let response_body_mutation = response.response_body_mutation.map(|m| BodyMutation {
422            data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
423            chunk_index: m.chunk_index,
424        });
425
426        // Convert WebSocket decision
427        let websocket_decision = response
428            .websocket_decision
429            .map(|ws_decision| match ws_decision {
430                grpc::agent_response::WebsocketDecision::WebsocketAllow(_) => {
431                    WebSocketDecision::Allow
432                }
433                grpc::agent_response::WebsocketDecision::WebsocketDrop(_) => {
434                    WebSocketDecision::Drop
435                }
436                grpc::agent_response::WebsocketDecision::WebsocketClose(c) => {
437                    WebSocketDecision::Close {
438                        code: c.code as u16,
439                        reason: c.reason,
440                    }
441                }
442            });
443
444        Ok(AgentResponse {
445            version: response.version,
446            decision,
447            request_headers,
448            response_headers,
449            routing_metadata: response.routing_metadata,
450            audit: audit.unwrap_or_default(),
451            needs_more: response.needs_more,
452            request_body_mutation,
453            response_body_mutation,
454            websocket_decision,
455        })
456    }
457
458    /// Convert gRPC header operation to internal format
459    fn convert_header_op_from_grpc(op: grpc::HeaderOp) -> Option<HeaderOp> {
460        match op.operation? {
461            grpc::header_op::Operation::Set(s) => Some(HeaderOp::Set {
462                name: s.name,
463                value: s.value,
464            }),
465            grpc::header_op::Operation::Add(a) => Some(HeaderOp::Add {
466                name: a.name,
467                value: a.value,
468            }),
469            grpc::header_op::Operation::Remove(r) => Some(HeaderOp::Remove { name: r.name }),
470        }
471    }
472
473    /// Send raw bytes to agent (Unix socket only)
474    async fn send_raw_unix(&mut self, data: &[u8]) -> Result<(), AgentProtocolError> {
475        let AgentConnection::UnixSocket(stream) = &mut self.connection else {
476            unreachable!()
477        };
478        // Write message length (4 bytes, big-endian)
479        let len_bytes = (data.len() as u32).to_be_bytes();
480        stream.write_all(&len_bytes).await?;
481        // Write message data
482        stream.write_all(data).await?;
483        stream.flush().await?;
484        Ok(())
485    }
486
487    /// Receive raw bytes from agent (Unix socket only)
488    async fn receive_raw_unix(&mut self) -> Result<Vec<u8>, AgentProtocolError> {
489        let AgentConnection::UnixSocket(stream) = &mut self.connection else {
490            unreachable!()
491        };
492        // Read message length (4 bytes, big-endian)
493        let mut len_bytes = [0u8; 4];
494        stream.read_exact(&mut len_bytes).await?;
495        let message_len = u32::from_be_bytes(len_bytes) as usize;
496
497        // Check message size
498        if message_len > MAX_MESSAGE_SIZE {
499            return Err(AgentProtocolError::MessageTooLarge {
500                size: message_len,
501                max: MAX_MESSAGE_SIZE,
502            });
503        }
504
505        // Read message data
506        let mut buffer = vec![0u8; message_len];
507        stream.read_exact(&mut buffer).await?;
508        Ok(buffer)
509    }
510
511    /// Close the agent connection
512    pub async fn close(self) -> Result<(), AgentProtocolError> {
513        match self.connection {
514            AgentConnection::UnixSocket(mut stream) => {
515                stream.shutdown().await?;
516                Ok(())
517            }
518            AgentConnection::Grpc(_) => Ok(()), // gRPC channels close automatically
519        }
520    }
521}