1use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
11use tonic::transport::Channel;
12use tracing::{debug, info, trace, warn};
13
14use crate::grpc_v2::{self, agent_service_v2_client::AgentServiceV2Client, ProxyToAgent};
15use crate::headers::iter_flat;
16use crate::v2::pool::CHANNEL_BUFFER_SIZE;
17use crate::v2::{AgentCapabilities, PROTOCOL_VERSION_2};
18use crate::{AgentProtocolError, AgentResponse, Decision, EventType, HeaderOp};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum CancelReason {
23 ClientDisconnect,
25 Timeout,
27 BlockedByAgent,
29 UpstreamError,
31 ProxyShutdown,
33 Manual,
35}
36
37impl CancelReason {
38 fn to_grpc(self) -> i32 {
39 match self {
40 CancelReason::ClientDisconnect => 1,
41 CancelReason::Timeout => 2,
42 CancelReason::BlockedByAgent => 3,
43 CancelReason::UpstreamError => 4,
44 CancelReason::ProxyShutdown => 5,
45 CancelReason::Manual => 6,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
52pub enum FlowState {
53 #[default]
55 Normal,
56 Paused,
58 Draining,
60}
61
62pub type MetricsCallback = Arc<dyn Fn(crate::v2::MetricsReport) + Send + Sync>;
64
65pub type ConfigUpdateCallback = Arc<
70 dyn Fn(String, crate::v2::ConfigUpdateRequest) -> crate::v2::ConfigUpdateResponse + Send + Sync,
71>;
72
73pub struct AgentClientV2 {
86 agent_id: String,
88 channel: Channel,
90 timeout: Duration,
92 capabilities: RwLock<Option<AgentCapabilities>>,
94 protocol_version: AtomicU64,
96 pending: Arc<Mutex<HashMap<String, oneshot::Sender<AgentResponse>>>>,
98 outbound_tx: Mutex<Option<mpsc::Sender<ProxyToAgent>>>,
100 ping_sequence: AtomicU64,
102 connected: RwLock<bool>,
104 flow_state: RwLock<FlowState>,
106 health_state: RwLock<i32>,
108 in_flight: AtomicU64,
110 metrics_callback: Option<MetricsCallback>,
112 config_update_callback: Option<ConfigUpdateCallback>,
114}
115
116impl AgentClientV2 {
117 pub async fn new(
119 agent_id: impl Into<String>,
120 endpoint: impl Into<String>,
121 timeout: Duration,
122 ) -> Result<Self, AgentProtocolError> {
123 let agent_id = agent_id.into();
124 let endpoint = endpoint.into();
125
126 debug!(agent_id = %agent_id, endpoint = %endpoint, "Creating v2 client");
127
128 let channel = Channel::from_shared(endpoint.clone())
129 .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Invalid endpoint: {}", e)))?
130 .connect_timeout(timeout)
131 .timeout(timeout)
132 .connect()
133 .await
134 .map_err(|e| {
135 AgentProtocolError::ConnectionFailed(format!("Failed to connect: {}", e))
136 })?;
137
138 Ok(Self {
139 agent_id,
140 channel,
141 timeout,
142 capabilities: RwLock::new(None),
143 protocol_version: AtomicU64::new(1), pending: Arc::new(Mutex::new(HashMap::new())),
145 outbound_tx: Mutex::new(None),
146 ping_sequence: AtomicU64::new(0),
147 connected: RwLock::new(false),
148 flow_state: RwLock::new(FlowState::Normal),
149 health_state: RwLock::new(1), in_flight: AtomicU64::new(0),
151 metrics_callback: None,
152 config_update_callback: None,
153 })
154 }
155
156 pub fn set_metrics_callback(&mut self, callback: MetricsCallback) {
161 self.metrics_callback = Some(callback);
162 }
163
164 pub fn set_config_update_callback(&mut self, callback: ConfigUpdateCallback) {
169 self.config_update_callback = Some(callback);
170 }
171
172 pub async fn connect(&self) -> Result<(), AgentProtocolError> {
174 let mut client = AgentServiceV2Client::new(self.channel.clone());
175
176 let (tx, rx) = mpsc::channel::<ProxyToAgent>(CHANNEL_BUFFER_SIZE);
178 let rx_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
179
180 let response_stream = client
181 .process_stream(rx_stream)
182 .await
183 .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Stream failed: {}", e)))?;
184
185 let mut inbound = response_stream.into_inner();
186
187 let handshake = ProxyToAgent {
189 message: Some(grpc_v2::proxy_to_agent::Message::Handshake(
190 grpc_v2::HandshakeRequest {
191 supported_versions: vec![PROTOCOL_VERSION_2, 1],
192 proxy_id: "grapsus-proxy".to_string(),
193 proxy_version: env!("CARGO_PKG_VERSION").to_string(),
194 config_json: "{}".to_string(),
195 },
196 )),
197 };
198
199 tx.send(handshake).await.map_err(|e| {
200 AgentProtocolError::ConnectionFailed(format!("Failed to send handshake: {}", e))
201 })?;
202
203 let handshake_resp = tokio::time::timeout(self.timeout, inbound.message())
205 .await
206 .map_err(|_| AgentProtocolError::Timeout(self.timeout))?
207 .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Stream error: {}", e)))?
208 .ok_or_else(|| {
209 AgentProtocolError::ConnectionFailed("Empty handshake response".to_string())
210 })?;
211
212 if let Some(grpc_v2::agent_to_proxy::Message::Handshake(resp)) = handshake_resp.message {
214 if !resp.success {
215 return Err(AgentProtocolError::ConnectionFailed(format!(
216 "Handshake failed: {}",
217 resp.error.unwrap_or_default()
218 )));
219 }
220
221 self.protocol_version
222 .store(resp.protocol_version as u64, Ordering::SeqCst);
223
224 if let Some(caps) = resp.capabilities {
225 let capabilities = convert_capabilities_from_grpc(caps);
226 *self.capabilities.write().await = Some(capabilities);
227 }
228
229 info!(
230 agent_id = %self.agent_id,
231 protocol_version = resp.protocol_version,
232 "v2 handshake successful"
233 );
234 } else {
235 return Err(AgentProtocolError::ConnectionFailed(
236 "Invalid handshake response".to_string(),
237 ));
238 }
239
240 *self.outbound_tx.lock().await = Some(tx);
242 *self.connected.write().await = true;
243
244 let pending = Arc::clone(&self.pending);
246 let agent_id = self.agent_id.clone();
247 let flow_state = Arc::new(RwLock::new(FlowState::Normal));
248 let health_state = Arc::new(RwLock::new(1i32));
249 let _in_flight = Arc::new(AtomicU64::new(0));
250
251 let flow_state_clone = Arc::clone(&flow_state);
253 let health_state_clone = Arc::clone(&health_state);
254 let metrics_callback = self.metrics_callback.clone();
255 let config_update_callback = self.config_update_callback.clone();
256
257 tokio::spawn(async move {
258 while let Ok(Some(msg)) = inbound.message().await {
259 match msg.message {
260 Some(grpc_v2::agent_to_proxy::Message::Response(resp)) => {
261 let correlation_id = resp.correlation_id.clone();
262 if let Some(sender) = pending.lock().await.remove(&correlation_id) {
263 let response = convert_response_from_grpc(resp);
264 let _ = sender.send(response);
265 } else {
266 warn!(
267 agent_id = %agent_id,
268 correlation_id = %correlation_id,
269 "Received response for unknown correlation ID"
270 );
271 }
272 }
273 Some(grpc_v2::agent_to_proxy::Message::Health(health)) => {
274 trace!(
275 agent_id = %agent_id,
276 state = health.state,
277 "Received health status"
278 );
279 *health_state_clone.write().await = health.state;
280 }
281 Some(grpc_v2::agent_to_proxy::Message::Metrics(metrics)) => {
282 trace!(
283 agent_id = %agent_id,
284 counters = metrics.counters.len(),
285 gauges = metrics.gauges.len(),
286 histograms = metrics.histograms.len(),
287 "Received metrics report"
288 );
289 if let Some(ref callback) = metrics_callback {
290 let report = convert_metrics_from_grpc(metrics, &agent_id);
291 callback(report);
292 }
293 }
294 Some(grpc_v2::agent_to_proxy::Message::FlowControl(fc)) => {
295 let new_state = match fc.action {
297 1 => FlowState::Paused, 2 => FlowState::Normal, _ => FlowState::Normal,
300 };
301 debug!(
302 agent_id = %agent_id,
303 action = fc.action,
304 correlation_id = ?fc.correlation_id,
305 "Received flow control signal"
306 );
307 *flow_state_clone.write().await = new_state;
308 }
309 Some(grpc_v2::agent_to_proxy::Message::Pong(pong)) => {
310 trace!(
311 agent_id = %agent_id,
312 sequence = pong.sequence,
313 latency_ms = pong.timestamp_ms.saturating_sub(pong.ping_timestamp_ms),
314 "Received pong"
315 );
316 }
317 Some(grpc_v2::agent_to_proxy::Message::ConfigUpdate(update)) => {
318 debug!(
319 agent_id = %agent_id,
320 request_id = %update.request_id,
321 "Received config update request from agent"
322 );
323 if let Some(ref callback) = config_update_callback {
324 let request = convert_config_update_from_grpc(update);
325 let _response = callback(agent_id.clone(), request);
326 }
329 }
330 Some(grpc_v2::agent_to_proxy::Message::Log(log_msg)) => {
331 match log_msg.level {
333 1 => {
334 trace!(agent_id = %agent_id, msg = %log_msg.message, "Agent debug log")
335 }
336 2 => {
337 debug!(agent_id = %agent_id, msg = %log_msg.message, "Agent info log")
338 }
339 3 => {
340 warn!(agent_id = %agent_id, msg = %log_msg.message, "Agent warning")
341 }
342 4 => warn!(agent_id = %agent_id, msg = %log_msg.message, "Agent error"),
343 _ => trace!(agent_id = %agent_id, msg = %log_msg.message, "Agent log"),
344 }
345 }
346 _ => {}
347 }
348 }
349
350 debug!(agent_id = %agent_id, "Response handler ended");
351 });
352
353 Ok(())
354 }
355
356 pub async fn send_request_headers(
358 &self,
359 correlation_id: &str,
360 event: &crate::RequestHeadersEvent,
361 ) -> Result<AgentResponse, AgentProtocolError> {
362 let msg = ProxyToAgent {
363 message: Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(
364 convert_request_headers_to_grpc(event),
365 )),
366 };
367
368 self.send_and_wait(correlation_id, msg).await
369 }
370
371 pub async fn send_request_body_chunk(
376 &self,
377 correlation_id: &str,
378 event: &crate::RequestBodyChunkEvent,
379 ) -> Result<AgentResponse, AgentProtocolError> {
380 let msg = ProxyToAgent {
381 message: Some(grpc_v2::proxy_to_agent::Message::RequestBodyChunk(
382 convert_body_chunk_to_grpc(event),
383 )),
384 };
385
386 self.send_and_wait(correlation_id, msg).await
387 }
388
389 pub async fn send_response_headers(
394 &self,
395 correlation_id: &str,
396 event: &crate::ResponseHeadersEvent,
397 ) -> Result<AgentResponse, AgentProtocolError> {
398 let msg = ProxyToAgent {
399 message: Some(grpc_v2::proxy_to_agent::Message::ResponseHeaders(
400 convert_response_headers_to_grpc(event),
401 )),
402 };
403
404 self.send_and_wait(correlation_id, msg).await
405 }
406
407 pub async fn send_response_body_chunk(
412 &self,
413 correlation_id: &str,
414 event: &crate::ResponseBodyChunkEvent,
415 ) -> Result<AgentResponse, AgentProtocolError> {
416 let msg = ProxyToAgent {
417 message: Some(grpc_v2::proxy_to_agent::Message::ResponseBodyChunk(
418 convert_response_body_chunk_to_grpc(event),
419 )),
420 };
421
422 self.send_and_wait(correlation_id, msg).await
423 }
424
425 pub async fn send_event<T: serde::Serialize>(
427 &self,
428 event_type: EventType,
429 event: &T,
430 ) -> Result<AgentResponse, AgentProtocolError> {
431 let correlation_id = extract_correlation_id(event);
433
434 let msg = match event_type {
435 EventType::RequestHeaders => {
436 if let Ok(e) = serde_json::from_value::<crate::RequestHeadersEvent>(
437 serde_json::to_value(event).unwrap_or_default(),
438 ) {
439 ProxyToAgent {
440 message: Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(
441 convert_request_headers_to_grpc(&e),
442 )),
443 }
444 } else {
445 return Err(AgentProtocolError::InvalidMessage(
446 "Failed to convert event".to_string(),
447 ));
448 }
449 }
450 _ => {
451 return Err(AgentProtocolError::InvalidMessage(format!(
453 "Event type {:?} not yet supported in v2 streaming mode",
454 event_type
455 )));
456 }
457 };
458
459 self.send_and_wait(&correlation_id, msg).await
460 }
461
462 async fn send_and_wait(
464 &self,
465 correlation_id: &str,
466 msg: ProxyToAgent,
467 ) -> Result<AgentResponse, AgentProtocolError> {
468 let (tx, rx) = oneshot::channel();
470
471 self.pending
473 .lock()
474 .await
475 .insert(correlation_id.to_string(), tx);
476
477 {
479 let outbound = self.outbound_tx.lock().await;
480 if let Some(sender) = outbound.as_ref() {
481 sender.send(msg).await.map_err(|e| {
482 AgentProtocolError::ConnectionFailed(format!("Send failed: {}", e))
483 })?;
484 } else {
485 return Err(AgentProtocolError::ConnectionFailed(
486 "Not connected".to_string(),
487 ));
488 }
489 }
490
491 match tokio::time::timeout(self.timeout, rx).await {
493 Ok(Ok(response)) => Ok(response),
494 Ok(Err(_)) => {
495 self.pending.lock().await.remove(correlation_id);
496 Err(AgentProtocolError::ConnectionFailed(
497 "Response channel closed".to_string(),
498 ))
499 }
500 Err(_) => {
501 self.pending.lock().await.remove(correlation_id);
502 Err(AgentProtocolError::Timeout(self.timeout))
503 }
504 }
505 }
506
507 pub async fn ping(&self) -> Result<Duration, AgentProtocolError> {
509 let sequence = self.ping_sequence.fetch_add(1, Ordering::SeqCst);
510 let timestamp_ms = now_ms();
511
512 let msg = ProxyToAgent {
513 message: Some(grpc_v2::proxy_to_agent::Message::Ping(grpc_v2::Ping {
514 sequence,
515 timestamp_ms,
516 })),
517 };
518
519 let outbound = self.outbound_tx.lock().await;
520 if let Some(sender) = outbound.as_ref() {
521 sender
522 .send(msg)
523 .await
524 .map_err(|e| AgentProtocolError::ConnectionFailed(format!("Ping failed: {}", e)))?;
525 }
526
527 Ok(Duration::from_millis(0))
530 }
531
532 pub fn protocol_version(&self) -> u32 {
534 self.protocol_version.load(Ordering::SeqCst) as u32
535 }
536
537 pub async fn capabilities(&self) -> Option<AgentCapabilities> {
539 self.capabilities.read().await.clone()
540 }
541
542 pub async fn is_connected(&self) -> bool {
544 *self.connected.read().await
545 }
546
547 pub async fn close(&self) -> Result<(), AgentProtocolError> {
549 *self.outbound_tx.lock().await = None;
550 *self.connected.write().await = false;
551 Ok(())
552 }
553
554 pub async fn cancel_request(
559 &self,
560 correlation_id: &str,
561 reason: CancelReason,
562 ) -> Result<(), AgentProtocolError> {
563 self.pending.lock().await.remove(correlation_id);
565
566 let msg = ProxyToAgent {
568 message: Some(grpc_v2::proxy_to_agent::Message::Cancel(
569 grpc_v2::CancelRequest {
570 correlation_id: correlation_id.to_string(),
571 reason: reason.to_grpc(),
572 timestamp_ms: now_ms(),
573 blocking_agent_id: None,
574 manual_reason: None,
575 },
576 )),
577 };
578
579 let outbound = self.outbound_tx.lock().await;
580 if let Some(sender) = outbound.as_ref() {
581 sender.send(msg).await.map_err(|e| {
582 AgentProtocolError::ConnectionFailed(format!("Cancel send failed: {}", e))
583 })?;
584 }
585
586 debug!(
587 agent_id = %self.agent_id,
588 correlation_id = %correlation_id,
589 reason = ?reason,
590 "Cancelled request"
591 );
592
593 Ok(())
594 }
595
596 pub async fn cancel_all(&self, reason: CancelReason) -> Result<usize, AgentProtocolError> {
600 let correlation_ids: Vec<String> = {
601 let pending = self.pending.lock().await;
602 pending.keys().cloned().collect()
603 };
604
605 let count = correlation_ids.len();
606 for cid in correlation_ids {
607 let _ = self.cancel_request(&cid, reason).await;
608 }
609
610 debug!(
611 agent_id = %self.agent_id,
612 count = count,
613 reason = ?reason,
614 "Cancelled all requests"
615 );
616
617 Ok(count)
618 }
619
620 pub async fn flow_state(&self) -> FlowState {
622 *self.flow_state.read().await
623 }
624
625 pub async fn can_accept_requests(&self) -> bool {
629 matches!(*self.flow_state.read().await, FlowState::Normal)
630 }
631
632 pub async fn wait_for_flow_control(&self, timeout: Duration) -> Result<(), AgentProtocolError> {
637 let deadline = tokio::time::Instant::now() + timeout;
638
639 loop {
640 if self.can_accept_requests().await {
641 return Ok(());
642 }
643
644 if tokio::time::Instant::now() >= deadline {
645 return Err(AgentProtocolError::Timeout(timeout));
646 }
647
648 tokio::time::sleep(Duration::from_millis(10)).await;
650 }
651 }
652
653 pub async fn health_state(&self) -> i32 {
661 *self.health_state.read().await
662 }
663
664 pub async fn is_healthy(&self) -> bool {
666 *self.health_state.read().await == 1
667 }
668
669 pub fn in_flight_count(&self) -> u64 {
671 self.in_flight.load(Ordering::Relaxed)
672 }
673
674 pub async fn send_configure(
680 &self,
681 config: serde_json::Value,
682 version: Option<String>,
683 ) -> Result<(), AgentProtocolError> {
684 let msg = ProxyToAgent {
685 message: Some(grpc_v2::proxy_to_agent::Message::Configure(
686 grpc_v2::ConfigureEvent {
687 config_json: serde_json::to_string(&config).unwrap_or_default(),
688 config_version: version,
689 is_initial: false,
690 timestamp_ms: now_ms(),
691 },
692 )),
693 };
694
695 let outbound = self.outbound_tx.lock().await;
696 if let Some(sender) = outbound.as_ref() {
697 sender.send(msg).await.map_err(|e| {
698 AgentProtocolError::ConnectionFailed(format!("Configure send failed: {}", e))
699 })?;
700 } else {
701 return Err(AgentProtocolError::ConnectionFailed(
702 "Not connected".to_string(),
703 ));
704 }
705
706 debug!(agent_id = %self.agent_id, "Sent configuration update");
707 Ok(())
708 }
709
710 pub async fn send_shutdown(
712 &self,
713 reason: ShutdownReason,
714 grace_period_ms: u64,
715 ) -> Result<(), AgentProtocolError> {
716 info!(
717 agent_id = %self.agent_id,
718 reason = ?reason,
719 grace_period_ms = grace_period_ms,
720 "Requesting agent shutdown"
721 );
722
723 let _ = self.cancel_all(CancelReason::ProxyShutdown).await;
725
726 self.close().await
728 }
729
730 pub async fn send_drain(
732 &self,
733 duration_ms: u64,
734 reason: DrainReason,
735 ) -> Result<(), AgentProtocolError> {
736 info!(
737 agent_id = %self.agent_id,
738 duration_ms = duration_ms,
739 reason = ?reason,
740 "Requesting agent drain"
741 );
742
743 *self.flow_state.write().await = FlowState::Draining;
745
746 Ok(())
747 }
748
749 pub fn agent_id(&self) -> &str {
751 &self.agent_id
752 }
753}
754
755#[derive(Debug, Clone, Copy, PartialEq, Eq)]
757pub enum ShutdownReason {
758 Graceful,
759 Immediate,
760 ConfigReload,
761 Upgrade,
762}
763
764#[derive(Debug, Clone, Copy, PartialEq, Eq)]
766pub enum DrainReason {
767 ConfigReload,
768 Maintenance,
769 HealthCheckFailed,
770 Manual,
771}
772
773fn convert_capabilities_from_grpc(caps: grpc_v2::AgentCapabilities) -> AgentCapabilities {
778 use crate::v2::{AgentFeatures, AgentLimits, HealthConfig};
779
780 let features = caps
781 .features
782 .map(|f| AgentFeatures {
783 streaming_body: f.streaming_body,
784 websocket: f.websocket,
785 guardrails: f.guardrails,
786 config_push: f.config_push,
787 metrics_export: f.metrics_export,
788 concurrent_requests: f.concurrent_requests,
789 cancellation: f.cancellation,
790 flow_control: f.flow_control,
791 health_reporting: f.health_reporting,
792 })
793 .unwrap_or_default();
794
795 let limits = caps
796 .limits
797 .map(|l| AgentLimits {
798 max_body_size: l.max_body_size as usize,
799 max_concurrency: l.max_concurrency,
800 preferred_chunk_size: l.preferred_chunk_size as usize,
801 max_memory: l.max_memory.map(|m| m as usize),
802 max_processing_time_ms: l.max_processing_time_ms,
803 })
804 .unwrap_or_default();
805
806 let health = caps
807 .health_config
808 .map(|h| HealthConfig {
809 report_interval_ms: h.report_interval_ms,
810 include_load_metrics: h.include_load_metrics,
811 include_resource_metrics: h.include_resource_metrics,
812 })
813 .unwrap_or_default();
814
815 AgentCapabilities {
816 protocol_version: caps.protocol_version,
817 agent_id: caps.agent_id,
818 name: caps.name,
819 version: caps.version,
820 supported_events: caps
821 .supported_events
822 .into_iter()
823 .filter_map(i32_to_event_type)
824 .collect(),
825 features,
826 limits,
827 health,
828 }
829}
830
831fn i32_to_event_type(i: i32) -> Option<EventType> {
832 match i {
833 1 => Some(EventType::RequestHeaders),
834 2 => Some(EventType::RequestBodyChunk),
835 3 => Some(EventType::ResponseHeaders),
836 4 => Some(EventType::ResponseBodyChunk),
837 5 => Some(EventType::RequestComplete),
838 6 => Some(EventType::WebSocketFrame),
839 7 => Some(EventType::GuardrailInspect),
840 8 => Some(EventType::Configure),
841 _ => None,
842 }
843}
844
845fn convert_request_headers_to_grpc(
846 event: &crate::RequestHeadersEvent,
847) -> grpc_v2::RequestHeadersEvent {
848 let metadata = Some(grpc_v2::RequestMetadata {
849 correlation_id: event.metadata.correlation_id.clone(),
850 request_id: event.metadata.request_id.clone(),
851 client_ip: event.metadata.client_ip.clone(),
852 client_port: event.metadata.client_port as u32,
853 server_name: event.metadata.server_name.clone(),
854 protocol: event.metadata.protocol.clone(),
855 tls_version: event.metadata.tls_version.clone(),
856 route_id: event.metadata.route_id.clone(),
857 upstream_id: event.metadata.upstream_id.clone(),
858 timestamp_ms: now_ms(),
859 traceparent: event.metadata.traceparent.clone(),
860 });
861
862 let headers: Vec<grpc_v2::Header> = iter_flat(&event.headers)
864 .map(|(name, value)| grpc_v2::Header {
865 name: name.to_string(),
866 value: value.to_string(),
867 })
868 .collect();
869
870 grpc_v2::RequestHeadersEvent {
871 metadata,
872 method: event.method.clone(),
873 uri: event.uri.clone(),
874 http_version: "HTTP/1.1".to_string(),
875 headers,
876 }
877}
878
879fn convert_body_chunk_to_grpc(event: &crate::RequestBodyChunkEvent) -> grpc_v2::BodyChunkEvent {
880 let binary: crate::BinaryRequestBodyChunkEvent = event.into();
882 convert_binary_body_chunk_to_grpc(&binary)
883}
884
885fn convert_binary_body_chunk_to_grpc(
889 event: &crate::BinaryRequestBodyChunkEvent,
890) -> grpc_v2::BodyChunkEvent {
891 grpc_v2::BodyChunkEvent {
892 correlation_id: event.correlation_id.clone(),
893 chunk_index: event.chunk_index,
894 data: event.data.to_vec(), is_last: event.is_last,
896 total_size: event.total_size.map(|s| s as u64),
897 bytes_transferred: event.bytes_received as u64,
898 proxy_buffer_available: 0, timestamp_ms: now_ms(),
900 }
901}
902
903fn convert_response_headers_to_grpc(
904 event: &crate::ResponseHeadersEvent,
905) -> grpc_v2::ResponseHeadersEvent {
906 let headers: Vec<grpc_v2::Header> = iter_flat(&event.headers)
908 .map(|(name, value)| grpc_v2::Header {
909 name: name.to_string(),
910 value: value.to_string(),
911 })
912 .collect();
913
914 grpc_v2::ResponseHeadersEvent {
915 correlation_id: event.correlation_id.clone(),
916 status_code: event.status as u32,
917 headers,
918 }
919}
920
921fn convert_response_body_chunk_to_grpc(
922 event: &crate::ResponseBodyChunkEvent,
923) -> grpc_v2::BodyChunkEvent {
924 let binary: crate::BinaryResponseBodyChunkEvent = event.into();
926 convert_binary_response_body_chunk_to_grpc(&binary)
927}
928
929fn convert_binary_response_body_chunk_to_grpc(
933 event: &crate::BinaryResponseBodyChunkEvent,
934) -> grpc_v2::BodyChunkEvent {
935 grpc_v2::BodyChunkEvent {
936 correlation_id: event.correlation_id.clone(),
937 chunk_index: event.chunk_index,
938 data: event.data.to_vec(), is_last: event.is_last,
940 total_size: event.total_size.map(|s| s as u64),
941 bytes_transferred: event.bytes_sent as u64,
942 proxy_buffer_available: 0,
943 timestamp_ms: now_ms(),
944 }
945}
946
947fn convert_response_from_grpc(resp: grpc_v2::AgentResponse) -> AgentResponse {
948 let decision = match resp.decision {
949 Some(grpc_v2::agent_response::Decision::Allow(_)) => Decision::Allow,
950 Some(grpc_v2::agent_response::Decision::Block(b)) => Decision::Block {
951 status: b.status as u16,
952 body: b.body,
953 headers: if b.headers.is_empty() {
954 None
955 } else {
956 Some(b.headers.into_iter().map(|h| (h.name, h.value)).collect())
957 },
958 },
959 Some(grpc_v2::agent_response::Decision::Redirect(r)) => Decision::Redirect {
960 url: r.url,
961 status: r.status as u16,
962 },
963 Some(grpc_v2::agent_response::Decision::Challenge(c)) => Decision::Challenge {
964 challenge_type: c.challenge_type,
965 params: c.params,
966 },
967 None => Decision::Allow,
968 };
969
970 let request_headers: Vec<HeaderOp> = resp
971 .request_headers
972 .into_iter()
973 .filter_map(convert_header_op_from_grpc)
974 .collect();
975
976 let response_headers: Vec<HeaderOp> = resp
977 .response_headers
978 .into_iter()
979 .filter_map(convert_header_op_from_grpc)
980 .collect();
981
982 let audit = resp
983 .audit
984 .map(|a| crate::AuditMetadata {
985 tags: a.tags,
986 rule_ids: a.rule_ids,
987 confidence: a.confidence,
988 reason_codes: a.reason_codes,
989 custom: a
990 .custom
991 .into_iter()
992 .map(|(k, v)| (k, serde_json::Value::String(v)))
993 .collect(),
994 })
995 .unwrap_or_default();
996
997 AgentResponse {
998 version: PROTOCOL_VERSION_2,
999 decision,
1000 request_headers,
1001 response_headers,
1002 routing_metadata: HashMap::new(),
1003 audit,
1004 needs_more: resp.needs_more,
1005 request_body_mutation: None,
1006 response_body_mutation: None,
1007 websocket_decision: None,
1008 }
1009}
1010
1011fn convert_header_op_from_grpc(op: grpc_v2::HeaderOp) -> Option<HeaderOp> {
1012 match op.operation {
1013 Some(grpc_v2::header_op::Operation::Set(h)) => Some(HeaderOp::Set {
1014 name: h.name,
1015 value: h.value,
1016 }),
1017 Some(grpc_v2::header_op::Operation::Add(h)) => Some(HeaderOp::Add {
1018 name: h.name,
1019 value: h.value,
1020 }),
1021 Some(grpc_v2::header_op::Operation::Remove(name)) => Some(HeaderOp::Remove { name }),
1022 None => None,
1023 }
1024}
1025
1026fn convert_metrics_from_grpc(
1027 report: grpc_v2::MetricsReport,
1028 agent_id: &str,
1029) -> crate::v2::MetricsReport {
1030 use crate::v2::metrics::{CounterMetric, GaugeMetric, HistogramBucket, HistogramMetric};
1031
1032 let counters = report
1033 .counters
1034 .into_iter()
1035 .map(|c| CounterMetric {
1036 name: c.name,
1037 help: c.help.filter(|s| !s.is_empty()),
1038 labels: c.labels,
1039 value: c.value,
1040 })
1041 .collect();
1042
1043 let gauges = report
1044 .gauges
1045 .into_iter()
1046 .map(|g| GaugeMetric {
1047 name: g.name,
1048 help: g.help.filter(|s| !s.is_empty()),
1049 labels: g.labels,
1050 value: g.value,
1051 })
1052 .collect();
1053
1054 let histograms = report
1055 .histograms
1056 .into_iter()
1057 .map(|h| HistogramMetric {
1058 name: h.name,
1059 help: h.help.filter(|s| !s.is_empty()),
1060 labels: h.labels,
1061 sum: h.sum,
1062 count: h.count,
1063 buckets: h
1064 .buckets
1065 .into_iter()
1066 .map(|b| HistogramBucket {
1067 le: b.le,
1068 count: b.count,
1069 })
1070 .collect(),
1071 })
1072 .collect();
1073
1074 crate::v2::MetricsReport {
1075 agent_id: agent_id.to_string(),
1076 timestamp_ms: report.timestamp_ms,
1077 interval_ms: report.interval_ms,
1078 counters,
1079 gauges,
1080 histograms,
1081 }
1082}
1083
1084fn convert_config_update_from_grpc(
1085 update: grpc_v2::ConfigUpdateRequest,
1086) -> crate::v2::ConfigUpdateRequest {
1087 use crate::v2::control::{ConfigUpdateType, RuleDefinition};
1088
1089 let update_type = match update.update_type {
1090 Some(grpc_v2::config_update_request::UpdateType::RequestReload(_)) => {
1091 ConfigUpdateType::RequestReload
1092 }
1093 Some(grpc_v2::config_update_request::UpdateType::RuleUpdate(ru)) => {
1094 ConfigUpdateType::RuleUpdate {
1095 rule_set: ru.rule_set,
1096 rules: ru
1097 .rules
1098 .into_iter()
1099 .map(|r| RuleDefinition {
1100 id: r.id,
1101 priority: r.priority,
1102 definition: serde_json::from_str(&r.definition_json).unwrap_or_default(),
1103 enabled: r.enabled,
1104 description: r.description,
1105 tags: r.tags,
1106 })
1107 .collect(),
1108 remove_rules: ru.remove_rules,
1109 }
1110 }
1111 Some(grpc_v2::config_update_request::UpdateType::ListUpdate(lu)) => {
1112 ConfigUpdateType::ListUpdate {
1113 list_id: lu.list_id,
1114 add: lu.add,
1115 remove: lu.remove,
1116 }
1117 }
1118 Some(grpc_v2::config_update_request::UpdateType::RestartRequired(rr)) => {
1119 ConfigUpdateType::RestartRequired {
1120 reason: rr.reason,
1121 grace_period_ms: rr.grace_period_ms,
1122 }
1123 }
1124 Some(grpc_v2::config_update_request::UpdateType::ConfigError(ce)) => {
1125 ConfigUpdateType::ConfigError {
1126 error: ce.error,
1127 field: ce.field,
1128 }
1129 }
1130 None => ConfigUpdateType::RequestReload, };
1132
1133 crate::v2::ConfigUpdateRequest {
1134 update_type,
1135 request_id: update.request_id,
1136 timestamp_ms: update.timestamp_ms,
1137 }
1138}
1139
1140fn extract_correlation_id<T: serde::Serialize>(event: &T) -> String {
1141 if let Ok(value) = serde_json::to_value(event) {
1143 if let Some(metadata) = value.get("metadata") {
1144 if let Some(cid) = metadata.get("correlation_id").and_then(|v| v.as_str()) {
1145 return cid.to_string();
1146 }
1147 }
1148 if let Some(cid) = value.get("correlation_id").and_then(|v| v.as_str()) {
1149 return cid.to_string();
1150 }
1151 }
1152 uuid::Uuid::new_v4().to_string()
1153}
1154
1155fn now_ms() -> u64 {
1156 std::time::SystemTime::now()
1157 .duration_since(std::time::UNIX_EPOCH)
1158 .map(|d| d.as_millis() as u64)
1159 .unwrap_or(0)
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164 use super::*;
1165
1166 #[test]
1167 fn test_event_type_conversion() {
1168 assert_eq!(i32_to_event_type(1), Some(EventType::RequestHeaders));
1169 assert_eq!(i32_to_event_type(2), Some(EventType::RequestBodyChunk));
1170 assert_eq!(i32_to_event_type(99), None);
1171 }
1172
1173 #[test]
1174 fn test_extract_correlation_id() {
1175 #[derive(serde::Serialize)]
1176 struct TestEvent {
1177 correlation_id: String,
1178 }
1179
1180 let event = TestEvent {
1181 correlation_id: "test-123".to_string(),
1182 };
1183
1184 assert_eq!(extract_correlation_id(&event), "test-123");
1185 }
1186}