sentinel_agent_protocol/
client.rs

1//! Agent client for communicating with external agents.
2//!
3//! Supports two transport mechanisms:
4//! - Unix domain sockets (length-prefixed JSON)
5//! - gRPC (Protocol Buffers over HTTP/2)
6
7use serde::Serialize;
8use std::time::Duration;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::UnixStream;
11use tonic::transport::Channel;
12use tracing::{debug, error, trace, warn};
13
14use crate::errors::AgentProtocolError;
15use crate::grpc::{self, agent_processor_client::AgentProcessorClient};
16use crate::protocol::{
17    AgentRequest, AgentResponse, AuditMetadata, Decision, EventType, HeaderOp, RequestBodyChunkEvent,
18    RequestCompleteEvent, RequestHeadersEvent, RequestMetadata, ResponseBodyChunkEvent,
19    ResponseHeadersEvent, MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
20};
21
22/// Agent client for communicating with external agents
23pub struct AgentClient {
24    /// Agent ID
25    id: String,
26    /// Connection to agent
27    connection: AgentConnection,
28    /// Timeout for agent calls
29    timeout: Duration,
30    /// Maximum retries
31    #[allow(dead_code)]
32    max_retries: u32,
33}
34
35/// Agent connection type
36enum AgentConnection {
37    UnixSocket(UnixStream),
38    Grpc(AgentProcessorClient<Channel>),
39}
40
41impl AgentClient {
42    /// Create a new Unix socket agent client
43    pub async fn unix_socket(
44        id: impl Into<String>,
45        path: impl AsRef<std::path::Path>,
46        timeout: Duration,
47    ) -> Result<Self, AgentProtocolError> {
48        let id = id.into();
49        let path = path.as_ref();
50
51        trace!(
52            agent_id = %id,
53            socket_path = %path.display(),
54            timeout_ms = timeout.as_millis() as u64,
55            "Connecting to agent via Unix socket"
56        );
57
58        let stream = UnixStream::connect(path)
59            .await
60            .map_err(|e| {
61                error!(
62                    agent_id = %id,
63                    socket_path = %path.display(),
64                    error = %e,
65                    "Failed to connect to agent via Unix socket"
66                );
67                AgentProtocolError::ConnectionFailed(e.to_string())
68            })?;
69
70        debug!(
71            agent_id = %id,
72            socket_path = %path.display(),
73            "Connected to agent via Unix socket"
74        );
75
76        Ok(Self {
77            id,
78            connection: AgentConnection::UnixSocket(stream),
79            timeout,
80            max_retries: 3,
81        })
82    }
83
84    /// Create a new gRPC agent client
85    ///
86    /// # Arguments
87    /// * `id` - Agent identifier
88    /// * `address` - gRPC server address (e.g., "http://localhost:50051")
89    /// * `timeout` - Timeout for agent calls
90    pub async fn grpc(
91        id: impl Into<String>,
92        address: impl Into<String>,
93        timeout: Duration,
94    ) -> Result<Self, AgentProtocolError> {
95        let id = id.into();
96        let address = address.into();
97
98        trace!(
99            agent_id = %id,
100            address = %address,
101            timeout_ms = timeout.as_millis() as u64,
102            "Connecting to agent via gRPC"
103        );
104
105        let channel = Channel::from_shared(address.clone())
106            .map_err(|e| {
107                error!(
108                    agent_id = %id,
109                    address = %address,
110                    error = %e,
111                    "Invalid gRPC URI"
112                );
113                AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e))
114            })?
115            .timeout(timeout)
116            .connect()
117            .await
118            .map_err(|e| {
119                error!(
120                    agent_id = %id,
121                    address = %address,
122                    error = %e,
123                    "Failed to connect to agent via gRPC"
124                );
125                AgentProtocolError::ConnectionFailed(format!("gRPC connect failed: {}", e))
126            })?;
127
128        let client = AgentProcessorClient::new(channel);
129
130        debug!(
131            agent_id = %id,
132            address = %address,
133            "Connected to agent via gRPC"
134        );
135
136        Ok(Self {
137            id,
138            connection: AgentConnection::Grpc(client),
139            timeout,
140            max_retries: 3,
141        })
142    }
143
144    /// Get the agent ID
145    #[allow(dead_code)]
146    pub fn id(&self) -> &str {
147        &self.id
148    }
149
150    /// Send an event to the agent and get a response
151    pub async fn send_event(
152        &mut self,
153        event_type: EventType,
154        payload: impl Serialize,
155    ) -> Result<AgentResponse, AgentProtocolError> {
156        match &mut self.connection {
157            AgentConnection::UnixSocket(_) => {
158                self.send_event_unix_socket(event_type, payload).await
159            }
160            AgentConnection::Grpc(_) => {
161                self.send_event_grpc(event_type, payload).await
162            }
163        }
164    }
165
166    /// Send event via Unix socket (length-prefixed JSON)
167    async fn send_event_unix_socket(
168        &mut self,
169        event_type: EventType,
170        payload: impl Serialize,
171    ) -> Result<AgentResponse, AgentProtocolError> {
172        let request = AgentRequest {
173            version: PROTOCOL_VERSION,
174            event_type,
175            payload: serde_json::to_value(payload)
176                .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
177        };
178
179        // Serialize request
180        let request_bytes = serde_json::to_vec(&request)
181            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
182
183        // Check message size
184        if request_bytes.len() > MAX_MESSAGE_SIZE {
185            return Err(AgentProtocolError::MessageTooLarge {
186                size: request_bytes.len(),
187                max: MAX_MESSAGE_SIZE,
188            });
189        }
190
191        // Send with timeout
192        let response = tokio::time::timeout(self.timeout, async {
193            self.send_raw_unix(&request_bytes).await?;
194            self.receive_raw_unix().await
195        })
196        .await
197        .map_err(|_| AgentProtocolError::Timeout(self.timeout))??;
198
199        // Parse response
200        let agent_response: AgentResponse = serde_json::from_slice(&response)
201            .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
202
203        // Verify protocol version
204        if agent_response.version != PROTOCOL_VERSION {
205            return Err(AgentProtocolError::VersionMismatch {
206                expected: PROTOCOL_VERSION,
207                actual: agent_response.version,
208            });
209        }
210
211        Ok(agent_response)
212    }
213
214    /// Send event via gRPC
215    async fn send_event_grpc(
216        &mut self,
217        event_type: EventType,
218        payload: impl Serialize,
219    ) -> Result<AgentResponse, AgentProtocolError> {
220        // Build request first (doesn't need mutable borrow)
221        let grpc_request = Self::build_grpc_request(event_type, payload)?;
222
223        let AgentConnection::Grpc(client) = &mut self.connection else {
224            unreachable!()
225        };
226
227        // Send with timeout
228        let response = tokio::time::timeout(self.timeout, client.process_event(grpc_request))
229            .await
230            .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
231            .map_err(|e| AgentProtocolError::ConnectionFailed(format!("gRPC call failed: {}", e)))?;
232
233        // Convert gRPC response to internal format
234        Self::convert_grpc_response(response.into_inner())
235    }
236
237    /// Build a gRPC request from internal types
238    fn build_grpc_request(
239        event_type: EventType,
240        payload: impl Serialize,
241    ) -> Result<grpc::AgentRequest, AgentProtocolError> {
242        let payload_json = serde_json::to_value(&payload)
243            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
244
245        let grpc_event_type = match event_type {
246            EventType::RequestHeaders => grpc::EventType::RequestHeaders,
247            EventType::RequestBodyChunk => grpc::EventType::RequestBodyChunk,
248            EventType::ResponseHeaders => grpc::EventType::ResponseHeaders,
249            EventType::ResponseBodyChunk => grpc::EventType::ResponseBodyChunk,
250            EventType::RequestComplete => grpc::EventType::RequestComplete,
251        };
252
253        let event = match event_type {
254            EventType::RequestHeaders => {
255                let event: RequestHeadersEvent = serde_json::from_value(payload_json)
256                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
257                grpc::agent_request::Event::RequestHeaders(grpc::RequestHeadersEvent {
258                    metadata: Some(Self::convert_metadata_to_grpc(&event.metadata)),
259                    method: event.method,
260                    uri: event.uri,
261                    headers: event.headers.into_iter().map(|(k, v)| {
262                        (k, grpc::HeaderValues { values: v })
263                    }).collect(),
264                })
265            }
266            EventType::RequestBodyChunk => {
267                let event: RequestBodyChunkEvent = serde_json::from_value(payload_json)
268                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
269                grpc::agent_request::Event::RequestBodyChunk(grpc::RequestBodyChunkEvent {
270                    correlation_id: event.correlation_id,
271                    data: event.data.into_bytes(),
272                    is_last: event.is_last,
273                    total_size: event.total_size.map(|s| s as u64),
274                })
275            }
276            EventType::ResponseHeaders => {
277                let event: ResponseHeadersEvent = serde_json::from_value(payload_json)
278                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
279                grpc::agent_request::Event::ResponseHeaders(grpc::ResponseHeadersEvent {
280                    correlation_id: event.correlation_id,
281                    status: event.status as u32,
282                    headers: event.headers.into_iter().map(|(k, v)| {
283                        (k, grpc::HeaderValues { values: v })
284                    }).collect(),
285                })
286            }
287            EventType::ResponseBodyChunk => {
288                let event: ResponseBodyChunkEvent = serde_json::from_value(payload_json)
289                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
290                grpc::agent_request::Event::ResponseBodyChunk(grpc::ResponseBodyChunkEvent {
291                    correlation_id: event.correlation_id,
292                    data: event.data.into_bytes(),
293                    is_last: event.is_last,
294                    total_size: event.total_size.map(|s| s as u64),
295                })
296            }
297            EventType::RequestComplete => {
298                let event: RequestCompleteEvent = serde_json::from_value(payload_json)
299                    .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
300                grpc::agent_request::Event::RequestComplete(grpc::RequestCompleteEvent {
301                    correlation_id: event.correlation_id,
302                    status: event.status as u32,
303                    duration_ms: event.duration_ms,
304                    request_body_size: event.request_body_size as u64,
305                    response_body_size: event.response_body_size as u64,
306                    upstream_attempts: event.upstream_attempts,
307                    error: event.error,
308                })
309            }
310        };
311
312        Ok(grpc::AgentRequest {
313            version: PROTOCOL_VERSION,
314            event_type: grpc_event_type as i32,
315            event: Some(event),
316        })
317    }
318
319    /// Convert internal metadata to gRPC format
320    fn convert_metadata_to_grpc(metadata: &RequestMetadata) -> grpc::RequestMetadata {
321        grpc::RequestMetadata {
322            correlation_id: metadata.correlation_id.clone(),
323            request_id: metadata.request_id.clone(),
324            client_ip: metadata.client_ip.clone(),
325            client_port: metadata.client_port as u32,
326            server_name: metadata.server_name.clone(),
327            protocol: metadata.protocol.clone(),
328            tls_version: metadata.tls_version.clone(),
329            tls_cipher: metadata.tls_cipher.clone(),
330            route_id: metadata.route_id.clone(),
331            upstream_id: metadata.upstream_id.clone(),
332            timestamp: metadata.timestamp.clone(),
333        }
334    }
335
336    /// Convert gRPC response to internal format
337    fn convert_grpc_response(
338        response: grpc::AgentResponse,
339    ) -> Result<AgentResponse, AgentProtocolError> {
340        let decision = match response.decision {
341            Some(grpc::agent_response::Decision::Allow(_)) => Decision::Allow,
342            Some(grpc::agent_response::Decision::Block(b)) => Decision::Block {
343                status: b.status as u16,
344                body: b.body,
345                headers: if b.headers.is_empty() { None } else { Some(b.headers) },
346            },
347            Some(grpc::agent_response::Decision::Redirect(r)) => Decision::Redirect {
348                url: r.url,
349                status: r.status as u16,
350            },
351            Some(grpc::agent_response::Decision::Challenge(c)) => Decision::Challenge {
352                challenge_type: c.challenge_type,
353                params: c.params,
354            },
355            None => Decision::Allow, // Default to allow if no decision
356        };
357
358        let request_headers: Vec<HeaderOp> = response.request_headers
359            .into_iter()
360            .filter_map(Self::convert_header_op_from_grpc)
361            .collect();
362
363        let response_headers: Vec<HeaderOp> = response.response_headers
364            .into_iter()
365            .filter_map(Self::convert_header_op_from_grpc)
366            .collect();
367
368        let audit = response.audit.map(|a| AuditMetadata {
369            tags: a.tags,
370            rule_ids: a.rule_ids,
371            confidence: a.confidence,
372            reason_codes: a.reason_codes,
373            custom: a.custom.into_iter().map(|(k, v)| {
374                (k, serde_json::Value::String(v))
375            }).collect(),
376        });
377
378        Ok(AgentResponse {
379            version: response.version,
380            decision,
381            request_headers,
382            response_headers,
383            routing_metadata: response.routing_metadata,
384            audit: audit.unwrap_or_default(),
385        })
386    }
387
388    /// Convert gRPC header operation to internal format
389    fn convert_header_op_from_grpc(op: grpc::HeaderOp) -> Option<HeaderOp> {
390        match op.operation? {
391            grpc::header_op::Operation::Set(s) => Some(HeaderOp::Set {
392                name: s.name,
393                value: s.value,
394            }),
395            grpc::header_op::Operation::Add(a) => Some(HeaderOp::Add {
396                name: a.name,
397                value: a.value,
398            }),
399            grpc::header_op::Operation::Remove(r) => Some(HeaderOp::Remove {
400                name: r.name,
401            }),
402        }
403    }
404
405    /// Send raw bytes to agent (Unix socket only)
406    async fn send_raw_unix(&mut self, data: &[u8]) -> Result<(), AgentProtocolError> {
407        let AgentConnection::UnixSocket(stream) = &mut self.connection else {
408            unreachable!()
409        };
410        // Write message length (4 bytes, big-endian)
411        let len_bytes = (data.len() as u32).to_be_bytes();
412        stream.write_all(&len_bytes).await?;
413        // Write message data
414        stream.write_all(data).await?;
415        stream.flush().await?;
416        Ok(())
417    }
418
419    /// Receive raw bytes from agent (Unix socket only)
420    async fn receive_raw_unix(&mut self) -> Result<Vec<u8>, AgentProtocolError> {
421        let AgentConnection::UnixSocket(stream) = &mut self.connection else {
422            unreachable!()
423        };
424        // Read message length (4 bytes, big-endian)
425        let mut len_bytes = [0u8; 4];
426        stream.read_exact(&mut len_bytes).await?;
427        let message_len = u32::from_be_bytes(len_bytes) as usize;
428
429        // Check message size
430        if message_len > MAX_MESSAGE_SIZE {
431            return Err(AgentProtocolError::MessageTooLarge {
432                size: message_len,
433                max: MAX_MESSAGE_SIZE,
434            });
435        }
436
437        // Read message data
438        let mut buffer = vec![0u8; message_len];
439        stream.read_exact(&mut buffer).await?;
440        Ok(buffer)
441    }
442
443    /// Close the agent connection
444    pub async fn close(self) -> Result<(), AgentProtocolError> {
445        match self.connection {
446            AgentConnection::UnixSocket(mut stream) => {
447                stream.shutdown().await?;
448                Ok(())
449            }
450            AgentConnection::Grpc(_) => Ok(()), // gRPC channels close automatically
451        }
452    }
453}