1use async_trait::async_trait;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::time::Instant;
10use tokio::sync::mpsc;
11use tokio_stream::{wrappers::ReceiverStream, Stream, StreamExt};
12use tonic::{Request, Response, Status, Streaming};
13use tracing::{debug, error, info, trace, warn};
14
15use crate::grpc_v2::{
16 self, agent_service_v2_server::AgentServiceV2, agent_service_v2_server::AgentServiceV2Server,
17 AgentToProxy, ProxyToAgent,
18};
19use crate::v2::pool::CHANNEL_BUFFER_SIZE;
20use crate::v2::{AgentCapabilities, HandshakeRequest, HandshakeResponse, HealthStatus};
21use crate::{
22 AgentResponse, Decision, EventType, HeaderOp, RequestBodyChunkEvent, RequestCompleteEvent,
23 RequestHeadersEvent, RequestMetadata, ResponseBodyChunkEvent, ResponseHeadersEvent,
24 WebSocketFrameEvent,
25};
26
27#[async_trait]
36pub trait AgentHandlerV2: Send + Sync {
37 fn capabilities(&self) -> AgentCapabilities;
39
40 async fn on_handshake(&self, _request: HandshakeRequest) -> HandshakeResponse {
42 HandshakeResponse::success(self.capabilities())
44 }
45
46 async fn on_request_headers(&self, _event: RequestHeadersEvent) -> AgentResponse {
48 AgentResponse::default_allow()
49 }
50
51 async fn on_request_body_chunk(&self, _event: RequestBodyChunkEvent) -> AgentResponse {
53 AgentResponse::default_allow()
54 }
55
56 async fn on_response_headers(&self, _event: ResponseHeadersEvent) -> AgentResponse {
58 AgentResponse::default_allow()
59 }
60
61 async fn on_response_body_chunk(&self, _event: ResponseBodyChunkEvent) -> AgentResponse {
63 AgentResponse::default_allow()
64 }
65
66 async fn on_request_complete(&self, _event: RequestCompleteEvent) -> AgentResponse {
68 AgentResponse::default_allow()
69 }
70
71 async fn on_websocket_frame(&self, _event: WebSocketFrameEvent) -> AgentResponse {
73 AgentResponse::websocket_allow()
74 }
75
76 fn health_status(&self) -> HealthStatus {
78 HealthStatus::healthy(self.capabilities().agent_id.clone())
79 }
80
81 fn metrics_report(&self) -> Option<crate::v2::MetricsReport> {
83 None
84 }
85
86 async fn on_configure(&self, _config: serde_json::Value, _version: Option<String>) -> bool {
88 true
90 }
91
92 async fn on_shutdown(&self, _reason: ShutdownReason, _grace_period_ms: u64) {
94 }
96
97 async fn on_drain(&self, _duration_ms: u64, _reason: DrainReason) {
99 }
101
102 async fn on_stream_closed(&self) {}
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108pub enum ShutdownReason {
109 Graceful,
110 Immediate,
111 ConfigReload,
112 Upgrade,
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
117pub enum DrainReason {
118 ConfigReload,
119 Maintenance,
120 HealthCheckFailed,
121 Manual,
122}
123
124pub struct GrpcAgentServerV2 {
126 id: String,
127 handler: Arc<dyn AgentHandlerV2>,
128}
129
130impl GrpcAgentServerV2 {
131 pub fn new(id: impl Into<String>, handler: Box<dyn AgentHandlerV2>) -> Self {
133 let id = id.into();
134 debug!(agent_id = %id, "Creating gRPC agent server v2");
135 Self {
136 id,
137 handler: Arc::from(handler),
138 }
139 }
140
141 pub fn into_service(self) -> AgentServiceV2Server<GrpcAgentHandlerV2> {
143 trace!(agent_id = %self.id, "Converting to tonic v2 service");
144 AgentServiceV2Server::new(GrpcAgentHandlerV2 {
145 id: self.id,
146 handler: self.handler,
147 })
148 }
149
150 pub async fn run(self, addr: std::net::SocketAddr) -> Result<(), crate::AgentProtocolError> {
152 info!(
153 agent_id = %self.id,
154 address = %addr,
155 "gRPC agent server v2 listening"
156 );
157
158 tonic::transport::Server::builder()
159 .add_service(self.into_service())
160 .serve(addr)
161 .await
162 .map_err(|e| {
163 error!(error = %e, "gRPC v2 server error");
164 crate::AgentProtocolError::ConnectionFailed(format!("gRPC v2 server error: {}", e))
165 })
166 }
167}
168
169pub struct GrpcAgentHandlerV2 {
171 id: String,
172 handler: Arc<dyn AgentHandlerV2>,
173}
174
175type ProcessResponseStream = Pin<Box<dyn Stream<Item = Result<AgentToProxy, Status>> + Send>>;
176type ControlResponseStream =
177 Pin<Box<dyn Stream<Item = Result<grpc_v2::ProxyControl, Status>> + Send>>;
178
179#[tonic::async_trait]
180impl AgentServiceV2 for GrpcAgentHandlerV2 {
181 type ProcessStreamStream = ProcessResponseStream;
182 type ControlStreamStream = ControlResponseStream;
183
184 async fn process_stream(
186 &self,
187 request: Request<Streaming<ProxyToAgent>>,
188 ) -> Result<Response<Self::ProcessStreamStream>, Status> {
189 let mut inbound = request.into_inner();
190 let (tx, rx) = mpsc::channel(CHANNEL_BUFFER_SIZE);
191 let handler = Arc::clone(&self.handler);
192 let agent_id = self.id.clone();
193
194 debug!(agent_id = %agent_id, "Starting v2 process stream");
195
196 tokio::spawn(async move {
197 let mut handshake_done = false;
198
199 while let Some(result) = inbound.next().await {
200 let msg = match result {
201 Ok(m) => m,
202 Err(e) => {
203 error!(agent_id = %agent_id, error = %e, "Stream error");
204 break;
205 }
206 };
207
208 let response = match msg.message {
209 Some(grpc_v2::proxy_to_agent::Message::Handshake(req)) => {
210 trace!(agent_id = %agent_id, "Processing handshake");
211 let handshake_req = convert_handshake_request(req);
212 let resp = handler.on_handshake(handshake_req).await;
213 handshake_done = resp.success;
214 Some(AgentToProxy {
215 message: Some(grpc_v2::agent_to_proxy::Message::Handshake(
216 convert_handshake_response(resp),
217 )),
218 })
219 }
220 Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(e)) => {
221 if !handshake_done {
222 warn!(agent_id = %agent_id, "Received event before handshake");
223 continue;
224 }
225 let event = convert_request_headers_from_grpc(e);
226 let correlation_id = event.metadata.correlation_id.clone();
227 let start = Instant::now();
228 let resp = handler.on_request_headers(event).await;
229 let processing_time_ms = start.elapsed().as_millis() as u64;
230 Some(create_agent_response(
231 correlation_id,
232 resp,
233 processing_time_ms,
234 ))
235 }
236 Some(grpc_v2::proxy_to_agent::Message::RequestBodyChunk(e)) => {
237 if !handshake_done {
238 continue;
239 }
240 let event = convert_body_chunk_to_request(e);
241 let correlation_id = event.correlation_id.clone();
242 let start = Instant::now();
243 let resp = handler.on_request_body_chunk(event).await;
244 let processing_time_ms = start.elapsed().as_millis() as u64;
245 Some(create_agent_response(
246 correlation_id,
247 resp,
248 processing_time_ms,
249 ))
250 }
251 Some(grpc_v2::proxy_to_agent::Message::ResponseHeaders(e)) => {
252 if !handshake_done {
253 continue;
254 }
255 let event = convert_response_headers_from_grpc(e);
256 let correlation_id = event.correlation_id.clone();
257 let start = Instant::now();
258 let resp = handler.on_response_headers(event).await;
259 let processing_time_ms = start.elapsed().as_millis() as u64;
260 Some(create_agent_response(
261 correlation_id,
262 resp,
263 processing_time_ms,
264 ))
265 }
266 Some(grpc_v2::proxy_to_agent::Message::ResponseBodyChunk(e)) => {
267 if !handshake_done {
268 continue;
269 }
270 let event = convert_body_chunk_to_response(e);
271 let correlation_id = event.correlation_id.clone();
272 let start = Instant::now();
273 let resp = handler.on_response_body_chunk(event).await;
274 let processing_time_ms = start.elapsed().as_millis() as u64;
275 Some(create_agent_response(
276 correlation_id,
277 resp,
278 processing_time_ms,
279 ))
280 }
281 Some(grpc_v2::proxy_to_agent::Message::RequestComplete(e)) => {
282 if !handshake_done {
283 continue;
284 }
285 let event = convert_request_complete_from_grpc(e);
286 let correlation_id = event.correlation_id.clone();
287 let start = Instant::now();
288 let resp = handler.on_request_complete(event).await;
289 let processing_time_ms = start.elapsed().as_millis() as u64;
290 Some(create_agent_response(
291 correlation_id,
292 resp,
293 processing_time_ms,
294 ))
295 }
296 Some(grpc_v2::proxy_to_agent::Message::WebsocketFrame(e)) => {
297 if !handshake_done {
298 continue;
299 }
300 let event = convert_websocket_frame_from_grpc(e);
301 let correlation_id = event.correlation_id.clone();
302 let start = Instant::now();
303 let resp = handler.on_websocket_frame(event).await;
304 let processing_time_ms = start.elapsed().as_millis() as u64;
305 Some(create_agent_response(
306 correlation_id,
307 resp,
308 processing_time_ms,
309 ))
310 }
311 Some(grpc_v2::proxy_to_agent::Message::Ping(ping)) => {
312 trace!(agent_id = %agent_id, sequence = ping.sequence, "Received ping");
313 Some(AgentToProxy {
314 message: Some(grpc_v2::agent_to_proxy::Message::Pong(grpc_v2::Pong {
315 sequence: ping.sequence,
316 ping_timestamp_ms: ping.timestamp_ms,
317 timestamp_ms: now_ms(),
318 })),
319 })
320 }
321 Some(grpc_v2::proxy_to_agent::Message::Cancel(cancel)) => {
322 debug!(
323 agent_id = %agent_id,
324 correlation_id = %cancel.correlation_id,
325 "Request cancelled"
326 );
327 None
328 }
329 Some(grpc_v2::proxy_to_agent::Message::Configure(_)) => {
330 None
332 }
333 Some(grpc_v2::proxy_to_agent::Message::Guardrail(_)) => {
334 None
336 }
337 None => {
338 warn!(agent_id = %agent_id, "Empty message received");
339 None
340 }
341 };
342
343 if let Some(resp) = response {
344 if tx.send(Ok(resp)).await.is_err() {
345 debug!(agent_id = %agent_id, "Stream closed by receiver");
346 break;
347 }
348 }
349 }
350
351 handler.on_stream_closed().await;
352 debug!(agent_id = %agent_id, "Process stream ended");
353 });
354
355 let output_stream = ReceiverStream::new(rx);
356 Ok(Response::new(
357 Box::pin(output_stream) as Self::ProcessStreamStream
358 ))
359 }
360
361 async fn control_stream(
367 &self,
368 request: Request<Streaming<grpc_v2::AgentControl>>,
369 ) -> Result<Response<Self::ControlStreamStream>, Status> {
370 let mut inbound = request.into_inner();
371 let (tx, rx) = mpsc::channel::<Result<grpc_v2::ProxyControl, Status>>(16);
372 let handler = Arc::clone(&self.handler);
373 let agent_id = self.id.clone();
374
375 debug!(agent_id = %agent_id, "Starting v2 control stream");
376
377 let _handler_clone = Arc::clone(&handler);
379 let tx_clone = tx.clone();
380 let agent_id_clone = agent_id.clone();
381 tokio::spawn(async move {
382 while let Some(result) = inbound.next().await {
383 let msg = match result {
384 Ok(m) => m,
385 Err(e) => {
386 error!(agent_id = %agent_id_clone, error = %e, "Control stream error");
387 break;
388 }
389 };
390
391 match msg.message {
395 Some(grpc_v2::agent_control::Message::Health(health)) => {
396 trace!(
397 agent_id = %agent_id_clone,
398 state = health.state,
399 "Received health status from agent"
400 );
401 }
403 Some(grpc_v2::agent_control::Message::Metrics(metrics)) => {
404 trace!(
405 agent_id = %agent_id_clone,
406 counters = metrics.counters.len(),
407 gauges = metrics.gauges.len(),
408 "Received metrics report from agent"
409 );
410 }
412 Some(grpc_v2::agent_control::Message::ConfigUpdate(update)) => {
413 debug!(
414 agent_id = %agent_id_clone,
415 request_id = %update.request_id,
416 "Received config update request from agent"
417 );
418 let response = grpc_v2::ProxyControl {
420 message: Some(grpc_v2::proxy_control::Message::ConfigResponse(
421 grpc_v2::ConfigUpdateResponse {
422 request_id: update.request_id,
423 accepted: true,
424 error: None,
425 timestamp_ms: now_ms(),
426 },
427 )),
428 };
429 if tx_clone.send(Ok(response)).await.is_err() {
430 break;
431 }
432 }
433 Some(grpc_v2::agent_control::Message::Log(log)) => {
434 match log.level {
436 1 => {
437 trace!(agent_id = %agent_id_clone, msg = %log.message, "Agent log")
438 }
439 2 => {
440 debug!(agent_id = %agent_id_clone, msg = %log.message, "Agent log")
441 }
442 3 => warn!(agent_id = %agent_id_clone, msg = %log.message, "Agent log"),
443 4 => {
444 error!(agent_id = %agent_id_clone, msg = %log.message, "Agent log")
445 }
446 _ => info!(agent_id = %agent_id_clone, msg = %log.message, "Agent log"),
447 }
448 }
449 None => {
450 warn!(agent_id = %agent_id_clone, "Empty control message received");
451 }
452 }
453 }
454
455 debug!(agent_id = %agent_id_clone, "Control stream inbound handler ended");
456 });
457
458 let capabilities = handler.capabilities();
460 let health_interval_ms = capabilities.health.report_interval_ms;
461 let metrics_enabled = capabilities.features.metrics_export;
462
463 if health_interval_ms > 0 || metrics_enabled {
464 let handler_for_health = Arc::clone(&handler);
465 let tx_for_health = tx;
466 let agent_id_for_health = agent_id.clone();
467
468 tokio::spawn(async move {
469 let health_interval = std::time::Duration::from_millis(health_interval_ms as u64);
470 let mut interval = tokio::time::interval(health_interval);
471
472 loop {
473 interval.tick().await;
474
475 let health = handler_for_health.health_status();
477 trace!(
478 agent_id = %agent_id_for_health,
479 state = ?health.state,
480 message = ?health.message,
481 "Agent health status collected"
482 );
483
484 let heartbeat = grpc_v2::ProxyControl {
487 message: Some(grpc_v2::proxy_control::Message::Configure(
488 grpc_v2::ConfigureEvent {
489 config_json: "{}".to_string(),
490 config_version: None,
491 is_initial: false,
492 timestamp_ms: now_ms(),
493 },
494 )),
495 };
496
497 if tx_for_health.send(Ok(heartbeat)).await.is_err() {
498 debug!(
499 agent_id = %agent_id_for_health,
500 "Control stream closed, stopping health reporter"
501 );
502 break;
503 }
504 }
505 });
506 }
507
508 let output_stream = ReceiverStream::new(rx);
509 Ok(Response::new(
510 Box::pin(output_stream) as Self::ControlStreamStream
511 ))
512 }
513
514 async fn process_event(
516 &self,
517 request: Request<ProxyToAgent>,
518 ) -> Result<Response<AgentToProxy>, Status> {
519 let msg = request.into_inner();
520
521 trace!(agent_id = %self.id, "Processing single event (v1 compat)");
522
523 let response = match msg.message {
524 Some(grpc_v2::proxy_to_agent::Message::Handshake(req)) => {
525 let handshake_req = convert_handshake_request(req);
526 let resp = self.handler.on_handshake(handshake_req).await;
527 AgentToProxy {
528 message: Some(grpc_v2::agent_to_proxy::Message::Handshake(
529 convert_handshake_response(resp),
530 )),
531 }
532 }
533 Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(e)) => {
534 let event = convert_request_headers_from_grpc(e);
535 let correlation_id = event.metadata.correlation_id.clone();
536 let start = Instant::now();
537 let resp = self.handler.on_request_headers(event).await;
538 let processing_time_ms = start.elapsed().as_millis() as u64;
539 create_agent_response(correlation_id, resp, processing_time_ms)
540 }
541 Some(grpc_v2::proxy_to_agent::Message::Ping(ping)) => AgentToProxy {
542 message: Some(grpc_v2::agent_to_proxy::Message::Pong(grpc_v2::Pong {
543 sequence: ping.sequence,
544 ping_timestamp_ms: ping.timestamp_ms,
545 timestamp_ms: now_ms(),
546 })),
547 },
548 _ => {
549 return Err(Status::invalid_argument("Unsupported event type"));
550 }
551 };
552
553 Ok(Response::new(response))
554 }
555}
556
557fn convert_handshake_request(req: grpc_v2::HandshakeRequest) -> HandshakeRequest {
562 HandshakeRequest {
563 supported_versions: req.supported_versions,
564 proxy_id: req.proxy_id,
565 proxy_version: req.proxy_version,
566 config: serde_json::from_str(&req.config_json).unwrap_or(serde_json::Value::Null),
567 }
568}
569
570fn convert_handshake_response(resp: HandshakeResponse) -> grpc_v2::HandshakeResponse {
571 grpc_v2::HandshakeResponse {
572 protocol_version: resp.protocol_version,
573 capabilities: Some(convert_capabilities_to_grpc(&resp.capabilities)),
574 success: resp.success,
575 error: resp.error,
576 }
577}
578
579fn convert_capabilities_to_grpc(caps: &AgentCapabilities) -> grpc_v2::AgentCapabilities {
580 grpc_v2::AgentCapabilities {
581 protocol_version: caps.protocol_version,
582 agent_id: caps.agent_id.clone(),
583 name: caps.name.clone(),
584 version: caps.version.clone(),
585 supported_events: caps
586 .supported_events
587 .iter()
588 .map(|e| event_type_to_i32(*e))
589 .collect(),
590 features: Some(grpc_v2::AgentFeatures {
591 streaming_body: caps.features.streaming_body,
592 websocket: caps.features.websocket,
593 guardrails: caps.features.guardrails,
594 config_push: caps.features.config_push,
595 metrics_export: caps.features.metrics_export,
596 concurrent_requests: caps.features.concurrent_requests,
597 cancellation: caps.features.cancellation,
598 flow_control: caps.features.flow_control,
599 health_reporting: caps.features.health_reporting,
600 }),
601 limits: Some(grpc_v2::AgentLimits {
602 max_body_size: caps.limits.max_body_size as u64,
603 max_concurrency: caps.limits.max_concurrency,
604 preferred_chunk_size: caps.limits.preferred_chunk_size as u64,
605 max_memory: caps.limits.max_memory.map(|m| m as u64),
606 max_processing_time_ms: caps.limits.max_processing_time_ms,
607 }),
608 health_config: Some(grpc_v2::HealthConfig {
609 report_interval_ms: caps.health.report_interval_ms,
610 include_load_metrics: caps.health.include_load_metrics,
611 include_resource_metrics: caps.health.include_resource_metrics,
612 }),
613 }
614}
615
616pub(crate) fn event_type_to_i32(event_type: EventType) -> i32 {
617 match event_type {
618 EventType::Configure => 8,
619 EventType::RequestHeaders => 1,
620 EventType::RequestBodyChunk => 2,
621 EventType::ResponseHeaders => 3,
622 EventType::ResponseBodyChunk => 4,
623 EventType::RequestComplete => 5,
624 EventType::WebSocketFrame => 6,
625 EventType::GuardrailInspect => 7,
626 }
627}
628
629fn convert_request_headers_from_grpc(e: grpc_v2::RequestHeadersEvent) -> RequestHeadersEvent {
630 let metadata = match e.metadata {
631 Some(m) => RequestMetadata {
632 correlation_id: m.correlation_id,
633 request_id: m.request_id,
634 client_ip: m.client_ip,
635 client_port: m.client_port as u16,
636 server_name: m.server_name,
637 protocol: m.protocol,
638 tls_version: m.tls_version,
639 tls_cipher: None,
640 route_id: m.route_id,
641 upstream_id: m.upstream_id,
642 timestamp: format!("{}", m.timestamp_ms),
643 traceparent: m.traceparent,
644 },
645 None => RequestMetadata {
646 correlation_id: String::new(),
647 request_id: String::new(),
648 client_ip: String::new(),
649 client_port: 0,
650 server_name: None,
651 protocol: String::new(),
652 tls_version: None,
653 tls_cipher: None,
654 route_id: None,
655 upstream_id: None,
656 timestamp: String::new(),
657 traceparent: None,
658 },
659 };
660
661 let headers = e
662 .headers
663 .into_iter()
664 .fold(std::collections::HashMap::new(), |mut map, h| {
665 map.entry(h.name).or_insert_with(Vec::new).push(h.value);
666 map
667 });
668
669 RequestHeadersEvent {
670 metadata,
671 method: e.method,
672 uri: e.uri,
673 headers,
674 }
675}
676
677fn convert_body_chunk_to_request(e: grpc_v2::BodyChunkEvent) -> RequestBodyChunkEvent {
678 use base64::{engine::general_purpose::STANDARD, Engine as _};
679 RequestBodyChunkEvent {
680 correlation_id: e.correlation_id,
681 data: STANDARD.encode(&e.data),
682 is_last: e.is_last,
683 total_size: e.total_size.map(|s| s as usize),
684 chunk_index: e.chunk_index,
685 bytes_received: e.bytes_transferred as usize,
686 }
687}
688
689fn convert_body_chunk_to_response(e: grpc_v2::BodyChunkEvent) -> ResponseBodyChunkEvent {
690 use base64::{engine::general_purpose::STANDARD, Engine as _};
691 ResponseBodyChunkEvent {
692 correlation_id: e.correlation_id,
693 data: STANDARD.encode(&e.data),
694 is_last: e.is_last,
695 total_size: e.total_size.map(|s| s as usize),
696 chunk_index: e.chunk_index,
697 bytes_sent: e.bytes_transferred as usize,
698 }
699}
700
701fn convert_response_headers_from_grpc(e: grpc_v2::ResponseHeadersEvent) -> ResponseHeadersEvent {
702 let headers = e
703 .headers
704 .into_iter()
705 .fold(std::collections::HashMap::new(), |mut map, h| {
706 map.entry(h.name).or_insert_with(Vec::new).push(h.value);
707 map
708 });
709
710 ResponseHeadersEvent {
711 correlation_id: e.correlation_id,
712 status: e.status_code as u16,
713 headers,
714 }
715}
716
717fn convert_request_complete_from_grpc(e: grpc_v2::RequestCompleteEvent) -> RequestCompleteEvent {
718 RequestCompleteEvent {
719 correlation_id: e.correlation_id,
720 status: e.status_code as u16,
721 duration_ms: e.duration_ms,
722 request_body_size: e.bytes_received as usize,
723 response_body_size: e.bytes_sent as usize,
724 upstream_attempts: 1,
725 error: e.error,
726 }
727}
728
729fn convert_websocket_frame_from_grpc(e: grpc_v2::WebSocketFrameEvent) -> WebSocketFrameEvent {
730 use base64::{engine::general_purpose::STANDARD, Engine as _};
731 WebSocketFrameEvent {
732 correlation_id: e.correlation_id,
733 opcode: format!("{}", e.frame_type),
734 data: STANDARD.encode(&e.payload),
735 client_to_server: e.client_to_server,
736 frame_index: 0,
737 fin: true,
738 route_id: None,
739 client_ip: String::new(),
740 }
741}
742
743fn create_agent_response(
744 correlation_id: String,
745 resp: AgentResponse,
746 processing_time_ms: u64,
747) -> AgentToProxy {
748 let decision = match resp.decision {
749 Decision::Allow => Some(grpc_v2::agent_response::Decision::Allow(
750 grpc_v2::AllowDecision {},
751 )),
752 Decision::Block {
753 status,
754 body,
755 headers,
756 } => Some(grpc_v2::agent_response::Decision::Block(
757 grpc_v2::BlockDecision {
758 status: status as u32,
759 body,
760 headers: headers
761 .unwrap_or_default()
762 .into_iter()
763 .map(|(k, v)| grpc_v2::Header { name: k, value: v })
764 .collect(),
765 },
766 )),
767 Decision::Redirect { url, status } => Some(grpc_v2::agent_response::Decision::Redirect(
768 grpc_v2::RedirectDecision {
769 url,
770 status: status as u32,
771 },
772 )),
773 Decision::Challenge {
774 challenge_type,
775 params,
776 } => Some(grpc_v2::agent_response::Decision::Challenge(
777 grpc_v2::ChallengeDecision {
778 challenge_type,
779 params,
780 },
781 )),
782 };
783
784 let request_headers: Vec<grpc_v2::HeaderOp> = resp
785 .request_headers
786 .into_iter()
787 .map(convert_header_op_to_grpc)
788 .collect();
789
790 let response_headers: Vec<grpc_v2::HeaderOp> = resp
791 .response_headers
792 .into_iter()
793 .map(convert_header_op_to_grpc)
794 .collect();
795
796 let audit = Some(grpc_v2::AuditMetadata {
797 tags: resp.audit.tags,
798 rule_ids: resp.audit.rule_ids,
799 confidence: resp.audit.confidence,
800 reason_codes: resp.audit.reason_codes,
801 custom: resp
802 .audit
803 .custom
804 .into_iter()
805 .map(|(k, v)| (k, v.to_string()))
806 .collect(),
807 });
808
809 AgentToProxy {
810 message: Some(grpc_v2::agent_to_proxy::Message::Response(
811 grpc_v2::AgentResponse {
812 correlation_id,
813 decision,
814 request_headers,
815 response_headers,
816 audit,
817 processing_time_ms: Some(processing_time_ms),
818 needs_more: resp.needs_more,
819 },
820 )),
821 }
822}
823
824fn convert_header_op_to_grpc(op: HeaderOp) -> grpc_v2::HeaderOp {
825 let operation = match op {
826 HeaderOp::Set { name, value } => {
827 Some(grpc_v2::header_op::Operation::Set(grpc_v2::Header {
828 name,
829 value,
830 }))
831 }
832 HeaderOp::Add { name, value } => {
833 Some(grpc_v2::header_op::Operation::Add(grpc_v2::Header {
834 name,
835 value,
836 }))
837 }
838 HeaderOp::Remove { name } => Some(grpc_v2::header_op::Operation::Remove(name)),
839 };
840 grpc_v2::HeaderOp { operation }
841}
842
843fn now_ms() -> u64 {
844 std::time::SystemTime::now()
845 .duration_since(std::time::UNIX_EPOCH)
846 .map(|d| d.as_millis() as u64)
847 .unwrap_or(0)
848}
849
850#[cfg(test)]
851mod tests {
852 use super::*;
853
854 struct TestHandlerV2;
855
856 #[async_trait]
857 impl AgentHandlerV2 for TestHandlerV2 {
858 fn capabilities(&self) -> AgentCapabilities {
859 AgentCapabilities::new("test-v2", "Test Agent V2", "1.0.0")
860 }
861 }
862
863 #[test]
864 fn test_create_server() {
865 let server = GrpcAgentServerV2::new("test", Box::new(TestHandlerV2));
866 assert_eq!(server.id, "test");
867 }
868}