1use serde::Serialize;
8use std::time::Duration;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::UnixStream;
11use tonic::transport::Channel;
12use tracing::{debug, error, trace, warn};
13
14use crate::errors::AgentProtocolError;
15use crate::grpc::{self, agent_processor_client::AgentProcessorClient};
16use crate::protocol::{
17 AgentRequest, AgentResponse, AuditMetadata, Decision, EventType, HeaderOp, RequestBodyChunkEvent,
18 RequestCompleteEvent, RequestHeadersEvent, RequestMetadata, ResponseBodyChunkEvent,
19 ResponseHeadersEvent, MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
20};
21
22pub struct AgentClient {
24 id: String,
26 connection: AgentConnection,
28 timeout: Duration,
30 #[allow(dead_code)]
32 max_retries: u32,
33}
34
35enum AgentConnection {
37 UnixSocket(UnixStream),
38 Grpc(AgentProcessorClient<Channel>),
39}
40
41impl AgentClient {
42 pub async fn unix_socket(
44 id: impl Into<String>,
45 path: impl AsRef<std::path::Path>,
46 timeout: Duration,
47 ) -> Result<Self, AgentProtocolError> {
48 let id = id.into();
49 let path = path.as_ref();
50
51 trace!(
52 agent_id = %id,
53 socket_path = %path.display(),
54 timeout_ms = timeout.as_millis() as u64,
55 "Connecting to agent via Unix socket"
56 );
57
58 let stream = UnixStream::connect(path)
59 .await
60 .map_err(|e| {
61 error!(
62 agent_id = %id,
63 socket_path = %path.display(),
64 error = %e,
65 "Failed to connect to agent via Unix socket"
66 );
67 AgentProtocolError::ConnectionFailed(e.to_string())
68 })?;
69
70 debug!(
71 agent_id = %id,
72 socket_path = %path.display(),
73 "Connected to agent via Unix socket"
74 );
75
76 Ok(Self {
77 id,
78 connection: AgentConnection::UnixSocket(stream),
79 timeout,
80 max_retries: 3,
81 })
82 }
83
84 pub async fn grpc(
91 id: impl Into<String>,
92 address: impl Into<String>,
93 timeout: Duration,
94 ) -> Result<Self, AgentProtocolError> {
95 let id = id.into();
96 let address = address.into();
97
98 trace!(
99 agent_id = %id,
100 address = %address,
101 timeout_ms = timeout.as_millis() as u64,
102 "Connecting to agent via gRPC"
103 );
104
105 let channel = Channel::from_shared(address.clone())
106 .map_err(|e| {
107 error!(
108 agent_id = %id,
109 address = %address,
110 error = %e,
111 "Invalid gRPC URI"
112 );
113 AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e))
114 })?
115 .timeout(timeout)
116 .connect()
117 .await
118 .map_err(|e| {
119 error!(
120 agent_id = %id,
121 address = %address,
122 error = %e,
123 "Failed to connect to agent via gRPC"
124 );
125 AgentProtocolError::ConnectionFailed(format!("gRPC connect failed: {}", e))
126 })?;
127
128 let client = AgentProcessorClient::new(channel);
129
130 debug!(
131 agent_id = %id,
132 address = %address,
133 "Connected to agent via gRPC"
134 );
135
136 Ok(Self {
137 id,
138 connection: AgentConnection::Grpc(client),
139 timeout,
140 max_retries: 3,
141 })
142 }
143
144 #[allow(dead_code)]
146 pub fn id(&self) -> &str {
147 &self.id
148 }
149
150 pub async fn send_event(
152 &mut self,
153 event_type: EventType,
154 payload: impl Serialize,
155 ) -> Result<AgentResponse, AgentProtocolError> {
156 match &mut self.connection {
157 AgentConnection::UnixSocket(_) => {
158 self.send_event_unix_socket(event_type, payload).await
159 }
160 AgentConnection::Grpc(_) => {
161 self.send_event_grpc(event_type, payload).await
162 }
163 }
164 }
165
166 async fn send_event_unix_socket(
168 &mut self,
169 event_type: EventType,
170 payload: impl Serialize,
171 ) -> Result<AgentResponse, AgentProtocolError> {
172 let request = AgentRequest {
173 version: PROTOCOL_VERSION,
174 event_type,
175 payload: serde_json::to_value(payload)
176 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
177 };
178
179 let request_bytes = serde_json::to_vec(&request)
181 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
182
183 if request_bytes.len() > MAX_MESSAGE_SIZE {
185 return Err(AgentProtocolError::MessageTooLarge {
186 size: request_bytes.len(),
187 max: MAX_MESSAGE_SIZE,
188 });
189 }
190
191 let response = tokio::time::timeout(self.timeout, async {
193 self.send_raw_unix(&request_bytes).await?;
194 self.receive_raw_unix().await
195 })
196 .await
197 .map_err(|_| AgentProtocolError::Timeout(self.timeout))??;
198
199 let agent_response: AgentResponse = serde_json::from_slice(&response)
201 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
202
203 if agent_response.version != PROTOCOL_VERSION {
205 return Err(AgentProtocolError::VersionMismatch {
206 expected: PROTOCOL_VERSION,
207 actual: agent_response.version,
208 });
209 }
210
211 Ok(agent_response)
212 }
213
214 async fn send_event_grpc(
216 &mut self,
217 event_type: EventType,
218 payload: impl Serialize,
219 ) -> Result<AgentResponse, AgentProtocolError> {
220 let grpc_request = Self::build_grpc_request(event_type, payload)?;
222
223 let AgentConnection::Grpc(client) = &mut self.connection else {
224 unreachable!()
225 };
226
227 let response = tokio::time::timeout(self.timeout, client.process_event(grpc_request))
229 .await
230 .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
231 .map_err(|e| AgentProtocolError::ConnectionFailed(format!("gRPC call failed: {}", e)))?;
232
233 Self::convert_grpc_response(response.into_inner())
235 }
236
237 fn build_grpc_request(
239 event_type: EventType,
240 payload: impl Serialize,
241 ) -> Result<grpc::AgentRequest, AgentProtocolError> {
242 let payload_json = serde_json::to_value(&payload)
243 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
244
245 let grpc_event_type = match event_type {
246 EventType::RequestHeaders => grpc::EventType::RequestHeaders,
247 EventType::RequestBodyChunk => grpc::EventType::RequestBodyChunk,
248 EventType::ResponseHeaders => grpc::EventType::ResponseHeaders,
249 EventType::ResponseBodyChunk => grpc::EventType::ResponseBodyChunk,
250 EventType::RequestComplete => grpc::EventType::RequestComplete,
251 };
252
253 let event = match event_type {
254 EventType::RequestHeaders => {
255 let event: RequestHeadersEvent = serde_json::from_value(payload_json)
256 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
257 grpc::agent_request::Event::RequestHeaders(grpc::RequestHeadersEvent {
258 metadata: Some(Self::convert_metadata_to_grpc(&event.metadata)),
259 method: event.method,
260 uri: event.uri,
261 headers: event.headers.into_iter().map(|(k, v)| {
262 (k, grpc::HeaderValues { values: v })
263 }).collect(),
264 })
265 }
266 EventType::RequestBodyChunk => {
267 let event: RequestBodyChunkEvent = serde_json::from_value(payload_json)
268 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
269 grpc::agent_request::Event::RequestBodyChunk(grpc::RequestBodyChunkEvent {
270 correlation_id: event.correlation_id,
271 data: event.data.into_bytes(),
272 is_last: event.is_last,
273 total_size: event.total_size.map(|s| s as u64),
274 })
275 }
276 EventType::ResponseHeaders => {
277 let event: ResponseHeadersEvent = serde_json::from_value(payload_json)
278 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
279 grpc::agent_request::Event::ResponseHeaders(grpc::ResponseHeadersEvent {
280 correlation_id: event.correlation_id,
281 status: event.status as u32,
282 headers: event.headers.into_iter().map(|(k, v)| {
283 (k, grpc::HeaderValues { values: v })
284 }).collect(),
285 })
286 }
287 EventType::ResponseBodyChunk => {
288 let event: ResponseBodyChunkEvent = serde_json::from_value(payload_json)
289 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
290 grpc::agent_request::Event::ResponseBodyChunk(grpc::ResponseBodyChunkEvent {
291 correlation_id: event.correlation_id,
292 data: event.data.into_bytes(),
293 is_last: event.is_last,
294 total_size: event.total_size.map(|s| s as u64),
295 })
296 }
297 EventType::RequestComplete => {
298 let event: RequestCompleteEvent = serde_json::from_value(payload_json)
299 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
300 grpc::agent_request::Event::RequestComplete(grpc::RequestCompleteEvent {
301 correlation_id: event.correlation_id,
302 status: event.status as u32,
303 duration_ms: event.duration_ms,
304 request_body_size: event.request_body_size as u64,
305 response_body_size: event.response_body_size as u64,
306 upstream_attempts: event.upstream_attempts,
307 error: event.error,
308 })
309 }
310 };
311
312 Ok(grpc::AgentRequest {
313 version: PROTOCOL_VERSION,
314 event_type: grpc_event_type as i32,
315 event: Some(event),
316 })
317 }
318
319 fn convert_metadata_to_grpc(metadata: &RequestMetadata) -> grpc::RequestMetadata {
321 grpc::RequestMetadata {
322 correlation_id: metadata.correlation_id.clone(),
323 request_id: metadata.request_id.clone(),
324 client_ip: metadata.client_ip.clone(),
325 client_port: metadata.client_port as u32,
326 server_name: metadata.server_name.clone(),
327 protocol: metadata.protocol.clone(),
328 tls_version: metadata.tls_version.clone(),
329 tls_cipher: metadata.tls_cipher.clone(),
330 route_id: metadata.route_id.clone(),
331 upstream_id: metadata.upstream_id.clone(),
332 timestamp: metadata.timestamp.clone(),
333 }
334 }
335
336 fn convert_grpc_response(
338 response: grpc::AgentResponse,
339 ) -> Result<AgentResponse, AgentProtocolError> {
340 let decision = match response.decision {
341 Some(grpc::agent_response::Decision::Allow(_)) => Decision::Allow,
342 Some(grpc::agent_response::Decision::Block(b)) => Decision::Block {
343 status: b.status as u16,
344 body: b.body,
345 headers: if b.headers.is_empty() { None } else { Some(b.headers) },
346 },
347 Some(grpc::agent_response::Decision::Redirect(r)) => Decision::Redirect {
348 url: r.url,
349 status: r.status as u16,
350 },
351 Some(grpc::agent_response::Decision::Challenge(c)) => Decision::Challenge {
352 challenge_type: c.challenge_type,
353 params: c.params,
354 },
355 None => Decision::Allow, };
357
358 let request_headers: Vec<HeaderOp> = response.request_headers
359 .into_iter()
360 .filter_map(Self::convert_header_op_from_grpc)
361 .collect();
362
363 let response_headers: Vec<HeaderOp> = response.response_headers
364 .into_iter()
365 .filter_map(Self::convert_header_op_from_grpc)
366 .collect();
367
368 let audit = response.audit.map(|a| AuditMetadata {
369 tags: a.tags,
370 rule_ids: a.rule_ids,
371 confidence: a.confidence,
372 reason_codes: a.reason_codes,
373 custom: a.custom.into_iter().map(|(k, v)| {
374 (k, serde_json::Value::String(v))
375 }).collect(),
376 });
377
378 Ok(AgentResponse {
379 version: response.version,
380 decision,
381 request_headers,
382 response_headers,
383 routing_metadata: response.routing_metadata,
384 audit: audit.unwrap_or_default(),
385 })
386 }
387
388 fn convert_header_op_from_grpc(op: grpc::HeaderOp) -> Option<HeaderOp> {
390 match op.operation? {
391 grpc::header_op::Operation::Set(s) => Some(HeaderOp::Set {
392 name: s.name,
393 value: s.value,
394 }),
395 grpc::header_op::Operation::Add(a) => Some(HeaderOp::Add {
396 name: a.name,
397 value: a.value,
398 }),
399 grpc::header_op::Operation::Remove(r) => Some(HeaderOp::Remove {
400 name: r.name,
401 }),
402 }
403 }
404
405 async fn send_raw_unix(&mut self, data: &[u8]) -> Result<(), AgentProtocolError> {
407 let AgentConnection::UnixSocket(stream) = &mut self.connection else {
408 unreachable!()
409 };
410 let len_bytes = (data.len() as u32).to_be_bytes();
412 stream.write_all(&len_bytes).await?;
413 stream.write_all(data).await?;
415 stream.flush().await?;
416 Ok(())
417 }
418
419 async fn receive_raw_unix(&mut self) -> Result<Vec<u8>, AgentProtocolError> {
421 let AgentConnection::UnixSocket(stream) = &mut self.connection else {
422 unreachable!()
423 };
424 let mut len_bytes = [0u8; 4];
426 stream.read_exact(&mut len_bytes).await?;
427 let message_len = u32::from_be_bytes(len_bytes) as usize;
428
429 if message_len > MAX_MESSAGE_SIZE {
431 return Err(AgentProtocolError::MessageTooLarge {
432 size: message_len,
433 max: MAX_MESSAGE_SIZE,
434 });
435 }
436
437 let mut buffer = vec![0u8; message_len];
439 stream.read_exact(&mut buffer).await?;
440 Ok(buffer)
441 }
442
443 pub async fn close(self) -> Result<(), AgentProtocolError> {
445 match self.connection {
446 AgentConnection::UnixSocket(mut stream) => {
447 stream.shutdown().await?;
448 Ok(())
449 }
450 AgentConnection::Grpc(_) => Ok(()), }
452 }
453}