1use serde::Serialize;
8use std::time::Duration;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::UnixStream;
11use tonic::transport::Channel;
12use tracing::{debug, error, trace};
13
14use crate::errors::AgentProtocolError;
15use crate::grpc::{self, agent_processor_client::AgentProcessorClient};
16use crate::protocol::{
17 AgentRequest, AgentResponse, AuditMetadata, BodyMutation, Decision, EventType, HeaderOp,
18 RequestBodyChunkEvent, RequestCompleteEvent, RequestHeadersEvent, RequestMetadata,
19 ResponseBodyChunkEvent, ResponseHeadersEvent, WebSocketDecision, WebSocketFrameEvent,
20 MAX_MESSAGE_SIZE, PROTOCOL_VERSION,
21};
22
23pub struct AgentClient {
25 id: String,
27 connection: AgentConnection,
29 timeout: Duration,
31 #[allow(dead_code)]
33 max_retries: u32,
34}
35
36enum AgentConnection {
38 UnixSocket(UnixStream),
39 Grpc(AgentProcessorClient<Channel>),
40}
41
42impl AgentClient {
43 pub async fn unix_socket(
45 id: impl Into<String>,
46 path: impl AsRef<std::path::Path>,
47 timeout: Duration,
48 ) -> Result<Self, AgentProtocolError> {
49 let id = id.into();
50 let path = path.as_ref();
51
52 trace!(
53 agent_id = %id,
54 socket_path = %path.display(),
55 timeout_ms = timeout.as_millis() as u64,
56 "Connecting to agent via Unix socket"
57 );
58
59 let stream = UnixStream::connect(path).await.map_err(|e| {
60 error!(
61 agent_id = %id,
62 socket_path = %path.display(),
63 error = %e,
64 "Failed to connect to agent via Unix socket"
65 );
66 AgentProtocolError::ConnectionFailed(e.to_string())
67 })?;
68
69 debug!(
70 agent_id = %id,
71 socket_path = %path.display(),
72 "Connected to agent via Unix socket"
73 );
74
75 Ok(Self {
76 id,
77 connection: AgentConnection::UnixSocket(stream),
78 timeout,
79 max_retries: 3,
80 })
81 }
82
83 pub async fn grpc(
90 id: impl Into<String>,
91 address: impl Into<String>,
92 timeout: Duration,
93 ) -> Result<Self, AgentProtocolError> {
94 let id = id.into();
95 let address = address.into();
96
97 trace!(
98 agent_id = %id,
99 address = %address,
100 timeout_ms = timeout.as_millis() as u64,
101 "Connecting to agent via gRPC"
102 );
103
104 let channel = Channel::from_shared(address.clone())
105 .map_err(|e| {
106 error!(
107 agent_id = %id,
108 address = %address,
109 error = %e,
110 "Invalid gRPC URI"
111 );
112 AgentProtocolError::ConnectionFailed(format!("Invalid URI: {}", e))
113 })?
114 .timeout(timeout)
115 .connect()
116 .await
117 .map_err(|e| {
118 error!(
119 agent_id = %id,
120 address = %address,
121 error = %e,
122 "Failed to connect to agent via gRPC"
123 );
124 AgentProtocolError::ConnectionFailed(format!("gRPC connect failed: {}", e))
125 })?;
126
127 let client = AgentProcessorClient::new(channel);
128
129 debug!(
130 agent_id = %id,
131 address = %address,
132 "Connected to agent via gRPC"
133 );
134
135 Ok(Self {
136 id,
137 connection: AgentConnection::Grpc(client),
138 timeout,
139 max_retries: 3,
140 })
141 }
142
143 #[allow(dead_code)]
145 pub fn id(&self) -> &str {
146 &self.id
147 }
148
149 pub async fn send_event(
151 &mut self,
152 event_type: EventType,
153 payload: impl Serialize,
154 ) -> Result<AgentResponse, AgentProtocolError> {
155 match &mut self.connection {
156 AgentConnection::UnixSocket(_) => {
157 self.send_event_unix_socket(event_type, payload).await
158 }
159 AgentConnection::Grpc(_) => self.send_event_grpc(event_type, payload).await,
160 }
161 }
162
163 async fn send_event_unix_socket(
165 &mut self,
166 event_type: EventType,
167 payload: impl Serialize,
168 ) -> Result<AgentResponse, AgentProtocolError> {
169 let request = AgentRequest {
170 version: PROTOCOL_VERSION,
171 event_type,
172 payload: serde_json::to_value(payload)
173 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?,
174 };
175
176 let request_bytes = serde_json::to_vec(&request)
178 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
179
180 if request_bytes.len() > MAX_MESSAGE_SIZE {
182 return Err(AgentProtocolError::MessageTooLarge {
183 size: request_bytes.len(),
184 max: MAX_MESSAGE_SIZE,
185 });
186 }
187
188 let response = tokio::time::timeout(self.timeout, async {
190 self.send_raw_unix(&request_bytes).await?;
191 self.receive_raw_unix().await
192 })
193 .await
194 .map_err(|_| AgentProtocolError::Timeout(self.timeout))??;
195
196 let agent_response: AgentResponse = serde_json::from_slice(&response)
198 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))?;
199
200 if agent_response.version != PROTOCOL_VERSION {
202 return Err(AgentProtocolError::VersionMismatch {
203 expected: PROTOCOL_VERSION,
204 actual: agent_response.version,
205 });
206 }
207
208 Ok(agent_response)
209 }
210
211 async fn send_event_grpc(
213 &mut self,
214 event_type: EventType,
215 payload: impl Serialize,
216 ) -> Result<AgentResponse, AgentProtocolError> {
217 let grpc_request = Self::build_grpc_request(event_type, payload)?;
219
220 let AgentConnection::Grpc(client) = &mut self.connection else {
221 unreachable!()
222 };
223
224 let response = tokio::time::timeout(self.timeout, client.process_event(grpc_request))
226 .await
227 .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
228 .map_err(|e| {
229 AgentProtocolError::ConnectionFailed(format!("gRPC call failed: {}", e))
230 })?;
231
232 Self::convert_grpc_response(response.into_inner())
234 }
235
236 fn build_grpc_request(
238 event_type: EventType,
239 payload: impl Serialize,
240 ) -> Result<grpc::AgentRequest, AgentProtocolError> {
241 let payload_json = serde_json::to_value(&payload)
242 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
243
244 let grpc_event_type = match event_type {
245 EventType::Configure => {
246 return Err(AgentProtocolError::Serialization(
247 "Configure events are not supported via gRPC".to_string(),
248 ))
249 }
250 EventType::RequestHeaders => grpc::EventType::RequestHeaders,
251 EventType::RequestBodyChunk => grpc::EventType::RequestBodyChunk,
252 EventType::ResponseHeaders => grpc::EventType::ResponseHeaders,
253 EventType::ResponseBodyChunk => grpc::EventType::ResponseBodyChunk,
254 EventType::RequestComplete => grpc::EventType::RequestComplete,
255 EventType::WebSocketFrame => grpc::EventType::WebsocketFrame,
256 };
257
258 let event = match event_type {
259 EventType::Configure => unreachable!("Configure handled above"),
260 EventType::RequestHeaders => {
261 let event: RequestHeadersEvent = serde_json::from_value(payload_json)
262 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
263 grpc::agent_request::Event::RequestHeaders(grpc::RequestHeadersEvent {
264 metadata: Some(Self::convert_metadata_to_grpc(&event.metadata)),
265 method: event.method,
266 uri: event.uri,
267 headers: event
268 .headers
269 .into_iter()
270 .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
271 .collect(),
272 })
273 }
274 EventType::RequestBodyChunk => {
275 let event: RequestBodyChunkEvent = serde_json::from_value(payload_json)
276 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
277 grpc::agent_request::Event::RequestBodyChunk(grpc::RequestBodyChunkEvent {
278 correlation_id: event.correlation_id,
279 data: event.data.into_bytes(),
280 is_last: event.is_last,
281 total_size: event.total_size.map(|s| s as u64),
282 chunk_index: event.chunk_index,
283 bytes_received: event.bytes_received as u64,
284 })
285 }
286 EventType::ResponseHeaders => {
287 let event: ResponseHeadersEvent = serde_json::from_value(payload_json)
288 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
289 grpc::agent_request::Event::ResponseHeaders(grpc::ResponseHeadersEvent {
290 correlation_id: event.correlation_id,
291 status: event.status as u32,
292 headers: event
293 .headers
294 .into_iter()
295 .map(|(k, v)| (k, grpc::HeaderValues { values: v }))
296 .collect(),
297 })
298 }
299 EventType::ResponseBodyChunk => {
300 let event: ResponseBodyChunkEvent = serde_json::from_value(payload_json)
301 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
302 grpc::agent_request::Event::ResponseBodyChunk(grpc::ResponseBodyChunkEvent {
303 correlation_id: event.correlation_id,
304 data: event.data.into_bytes(),
305 is_last: event.is_last,
306 total_size: event.total_size.map(|s| s as u64),
307 chunk_index: event.chunk_index,
308 bytes_sent: event.bytes_sent as u64,
309 })
310 }
311 EventType::RequestComplete => {
312 let event: RequestCompleteEvent = serde_json::from_value(payload_json)
313 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
314 grpc::agent_request::Event::RequestComplete(grpc::RequestCompleteEvent {
315 correlation_id: event.correlation_id,
316 status: event.status as u32,
317 duration_ms: event.duration_ms,
318 request_body_size: event.request_body_size as u64,
319 response_body_size: event.response_body_size as u64,
320 upstream_attempts: event.upstream_attempts,
321 error: event.error,
322 })
323 }
324 EventType::WebSocketFrame => {
325 use base64::{engine::general_purpose::STANDARD, Engine as _};
326 let event: WebSocketFrameEvent = serde_json::from_value(payload_json)
327 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
328 grpc::agent_request::Event::WebsocketFrame(grpc::WebSocketFrameEvent {
329 correlation_id: event.correlation_id,
330 opcode: event.opcode,
331 data: STANDARD.decode(&event.data).unwrap_or_default(),
332 client_to_server: event.client_to_server,
333 frame_index: event.frame_index,
334 fin: event.fin,
335 route_id: event.route_id,
336 client_ip: event.client_ip,
337 })
338 }
339 };
340
341 Ok(grpc::AgentRequest {
342 version: PROTOCOL_VERSION,
343 event_type: grpc_event_type as i32,
344 event: Some(event),
345 })
346 }
347
348 fn convert_metadata_to_grpc(metadata: &RequestMetadata) -> grpc::RequestMetadata {
350 grpc::RequestMetadata {
351 correlation_id: metadata.correlation_id.clone(),
352 request_id: metadata.request_id.clone(),
353 client_ip: metadata.client_ip.clone(),
354 client_port: metadata.client_port as u32,
355 server_name: metadata.server_name.clone(),
356 protocol: metadata.protocol.clone(),
357 tls_version: metadata.tls_version.clone(),
358 tls_cipher: metadata.tls_cipher.clone(),
359 route_id: metadata.route_id.clone(),
360 upstream_id: metadata.upstream_id.clone(),
361 timestamp: metadata.timestamp.clone(),
362 }
363 }
364
365 fn convert_grpc_response(
367 response: grpc::AgentResponse,
368 ) -> Result<AgentResponse, AgentProtocolError> {
369 let decision = match response.decision {
370 Some(grpc::agent_response::Decision::Allow(_)) => Decision::Allow,
371 Some(grpc::agent_response::Decision::Block(b)) => Decision::Block {
372 status: b.status as u16,
373 body: b.body,
374 headers: if b.headers.is_empty() {
375 None
376 } else {
377 Some(b.headers)
378 },
379 },
380 Some(grpc::agent_response::Decision::Redirect(r)) => Decision::Redirect {
381 url: r.url,
382 status: r.status as u16,
383 },
384 Some(grpc::agent_response::Decision::Challenge(c)) => Decision::Challenge {
385 challenge_type: c.challenge_type,
386 params: c.params,
387 },
388 None => Decision::Allow, };
390
391 let request_headers: Vec<HeaderOp> = response
392 .request_headers
393 .into_iter()
394 .filter_map(Self::convert_header_op_from_grpc)
395 .collect();
396
397 let response_headers: Vec<HeaderOp> = response
398 .response_headers
399 .into_iter()
400 .filter_map(Self::convert_header_op_from_grpc)
401 .collect();
402
403 let audit = response.audit.map(|a| AuditMetadata {
404 tags: a.tags,
405 rule_ids: a.rule_ids,
406 confidence: a.confidence,
407 reason_codes: a.reason_codes,
408 custom: a
409 .custom
410 .into_iter()
411 .map(|(k, v)| (k, serde_json::Value::String(v)))
412 .collect(),
413 });
414
415 let request_body_mutation = response.request_body_mutation.map(|m| BodyMutation {
417 data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
418 chunk_index: m.chunk_index,
419 });
420
421 let response_body_mutation = response.response_body_mutation.map(|m| BodyMutation {
422 data: m.data.map(|d| String::from_utf8_lossy(&d).to_string()),
423 chunk_index: m.chunk_index,
424 });
425
426 let websocket_decision = response
428 .websocket_decision
429 .map(|ws_decision| match ws_decision {
430 grpc::agent_response::WebsocketDecision::WebsocketAllow(_) => {
431 WebSocketDecision::Allow
432 }
433 grpc::agent_response::WebsocketDecision::WebsocketDrop(_) => {
434 WebSocketDecision::Drop
435 }
436 grpc::agent_response::WebsocketDecision::WebsocketClose(c) => {
437 WebSocketDecision::Close {
438 code: c.code as u16,
439 reason: c.reason,
440 }
441 }
442 });
443
444 Ok(AgentResponse {
445 version: response.version,
446 decision,
447 request_headers,
448 response_headers,
449 routing_metadata: response.routing_metadata,
450 audit: audit.unwrap_or_default(),
451 needs_more: response.needs_more,
452 request_body_mutation,
453 response_body_mutation,
454 websocket_decision,
455 })
456 }
457
458 fn convert_header_op_from_grpc(op: grpc::HeaderOp) -> Option<HeaderOp> {
460 match op.operation? {
461 grpc::header_op::Operation::Set(s) => Some(HeaderOp::Set {
462 name: s.name,
463 value: s.value,
464 }),
465 grpc::header_op::Operation::Add(a) => Some(HeaderOp::Add {
466 name: a.name,
467 value: a.value,
468 }),
469 grpc::header_op::Operation::Remove(r) => Some(HeaderOp::Remove { name: r.name }),
470 }
471 }
472
473 async fn send_raw_unix(&mut self, data: &[u8]) -> Result<(), AgentProtocolError> {
475 let AgentConnection::UnixSocket(stream) = &mut self.connection else {
476 unreachable!()
477 };
478 let len_bytes = (data.len() as u32).to_be_bytes();
480 stream.write_all(&len_bytes).await?;
481 stream.write_all(data).await?;
483 stream.flush().await?;
484 Ok(())
485 }
486
487 async fn receive_raw_unix(&mut self) -> Result<Vec<u8>, AgentProtocolError> {
489 let AgentConnection::UnixSocket(stream) = &mut self.connection else {
490 unreachable!()
491 };
492 let mut len_bytes = [0u8; 4];
494 stream.read_exact(&mut len_bytes).await?;
495 let message_len = u32::from_be_bytes(len_bytes) as usize;
496
497 if message_len > MAX_MESSAGE_SIZE {
499 return Err(AgentProtocolError::MessageTooLarge {
500 size: message_len,
501 max: MAX_MESSAGE_SIZE,
502 });
503 }
504
505 let mut buffer = vec![0u8; message_len];
507 stream.read_exact(&mut buffer).await?;
508 Ok(buffer)
509 }
510
511 pub async fn close(self) -> Result<(), AgentProtocolError> {
513 match self.connection {
514 AgentConnection::UnixSocket(mut stream) => {
515 stream.shutdown().await?;
516 Ok(())
517 }
518 AgentConnection::Grpc(_) => Ok(()), }
520 }
521}