Skip to main content

grapsus_proxy/agents/
agent_v2.rs

1//! Protocol v2 agent implementation.
2//!
3//! This module provides v2 agent support using the bidirectional streaming
4//! protocol with capabilities, health reporting, and metrics export.
5
6use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use tracing::{debug, error, info, trace, warn};
11use grapsus_agent_protocol::v2::{
12    AgentCapabilities, AgentPool, AgentPoolConfig as ProtocolPoolConfig, AgentPoolStats,
13    CancelReason, ConfigPusher, ConfigUpdateType, LoadBalanceStrategy as ProtocolLBStrategy,
14    MetricsCollector,
15};
16use grapsus_agent_protocol::{
17    AgentResponse, EventType, GuardrailInspectEvent, RequestBodyChunkEvent, RequestHeadersEvent,
18    ResponseBodyChunkEvent, ResponseHeadersEvent,
19};
20use grapsus_common::{
21    errors::{GrapsusError, GrapsusResult},
22    CircuitBreaker,
23};
24use grapsus_config::{AgentConfig, AgentEvent, FailureMode, LoadBalanceStrategy};
25
26use super::metrics::AgentMetrics;
27
28/// Grapsus value indicating no timestamp recorded
29const NO_TIMESTAMP: u64 = 0;
30
31/// Protocol v2 agent with connection pooling and bidirectional streaming.
32pub struct AgentV2 {
33    /// Agent configuration
34    config: AgentConfig,
35    /// V2 connection pool
36    pool: Arc<AgentPool>,
37    /// Circuit breaker
38    circuit_breaker: Arc<CircuitBreaker>,
39    /// Agent-specific metrics
40    metrics: Arc<AgentMetrics>,
41    /// Base instant for timestamp calculations
42    base_instant: Instant,
43    /// Last successful call (nanoseconds since base_instant, 0 = never)
44    last_success_ns: AtomicU64,
45    /// Consecutive failures
46    consecutive_failures: AtomicU32,
47}
48
49impl AgentV2 {
50    /// Create a new v2 agent.
51    pub fn new(config: AgentConfig, circuit_breaker: Arc<CircuitBreaker>) -> Self {
52        trace!(
53            agent_id = %config.id,
54            agent_type = ?config.agent_type,
55            timeout_ms = config.timeout_ms,
56            events = ?config.events,
57            "Creating v2 agent instance"
58        );
59
60        // Convert config pool settings to protocol pool config
61        let pool_config = config
62            .pool
63            .as_ref()
64            .map(|p| ProtocolPoolConfig {
65                connections_per_agent: p.connections_per_agent,
66                load_balance_strategy: convert_lb_strategy(p.load_balance_strategy),
67                connect_timeout: Duration::from_millis(p.connect_timeout_ms),
68                request_timeout: Duration::from_millis(config.timeout_ms),
69                reconnect_interval: Duration::from_millis(p.reconnect_interval_ms),
70                max_reconnect_attempts: p.max_reconnect_attempts,
71                drain_timeout: Duration::from_millis(p.drain_timeout_ms),
72                max_concurrent_per_connection: p.max_concurrent_per_connection,
73                health_check_interval: Duration::from_millis(p.health_check_interval_ms),
74                ..Default::default()
75            })
76            .unwrap_or_default();
77
78        let pool = Arc::new(AgentPool::with_config(pool_config));
79
80        Self {
81            config,
82            pool,
83            circuit_breaker,
84            metrics: Arc::new(AgentMetrics::default()),
85            base_instant: Instant::now(),
86            last_success_ns: AtomicU64::new(NO_TIMESTAMP),
87            consecutive_failures: AtomicU32::new(0),
88        }
89    }
90
91    /// Get the agent ID.
92    pub fn id(&self) -> &str {
93        &self.config.id
94    }
95
96    /// Get the agent's circuit breaker.
97    pub fn circuit_breaker(&self) -> &CircuitBreaker {
98        &self.circuit_breaker
99    }
100
101    /// Get the agent's failure mode.
102    pub fn failure_mode(&self) -> FailureMode {
103        self.config.failure_mode
104    }
105
106    /// Get the agent's timeout in milliseconds.
107    pub fn timeout_ms(&self) -> u64 {
108        self.config.timeout_ms
109    }
110
111    /// Get the agent's metrics.
112    pub fn metrics(&self) -> &AgentMetrics {
113        &self.metrics
114    }
115
116    /// Check if agent handles a specific event type.
117    pub fn handles_event(&self, event_type: EventType) -> bool {
118        self.config.events.iter().any(|e| match (e, event_type) {
119            (AgentEvent::RequestHeaders, EventType::RequestHeaders) => true,
120            (AgentEvent::RequestBody, EventType::RequestBodyChunk) => true,
121            (AgentEvent::ResponseHeaders, EventType::ResponseHeaders) => true,
122            (AgentEvent::ResponseBody, EventType::ResponseBodyChunk) => true,
123            (AgentEvent::Log, EventType::RequestComplete) => true,
124            (AgentEvent::WebSocketFrame, EventType::WebSocketFrame) => true,
125            (AgentEvent::Guardrail, EventType::GuardrailInspect) => true,
126            _ => false,
127        })
128    }
129
130    /// Initialize agent connection(s).
131    pub async fn initialize(&self) -> GrapsusResult<()> {
132        let endpoint = self.get_endpoint()?;
133
134        debug!(
135            agent_id = %self.config.id,
136            endpoint = %endpoint,
137            "Initializing v2 agent pool"
138        );
139
140        let start = Instant::now();
141
142        // Add agent to pool - pool will establish connections
143        self.pool
144            .add_agent(&self.config.id, &endpoint)
145            .await
146            .map_err(|e| {
147                error!(
148                    agent_id = %self.config.id,
149                    endpoint = %endpoint,
150                    error = %e,
151                    "Failed to add agent to v2 pool"
152                );
153                GrapsusError::Agent {
154                    agent: self.config.id.clone(),
155                    message: format!("Failed to initialize v2 agent: {}", e),
156                    event: "initialize".to_string(),
157                    source: None,
158                }
159            })?;
160
161        info!(
162            agent_id = %self.config.id,
163            endpoint = %endpoint,
164            connect_time_ms = start.elapsed().as_millis(),
165            "V2 agent pool initialized"
166        );
167
168        // Send configuration if present
169        if let Some(config_value) = &self.config.config {
170            self.send_configure(config_value.clone()).await?;
171        }
172
173        Ok(())
174    }
175
176    /// Get endpoint from transport config.
177    fn get_endpoint(&self) -> GrapsusResult<String> {
178        use grapsus_config::AgentTransport;
179        match &self.config.transport {
180            AgentTransport::Grpc { address, .. } => Ok(address.clone()),
181            AgentTransport::UnixSocket { path } => {
182                // For UDS, format as unix:path
183                Ok(format!("unix:{}", path.display()))
184            }
185            AgentTransport::Http { url, .. } => {
186                // V2 doesn't support HTTP transport
187                Err(GrapsusError::Agent {
188                    agent: self.config.id.clone(),
189                    message: "HTTP transport not supported for v2 protocol".to_string(),
190                    event: "initialize".to_string(),
191                    source: None,
192                })
193            }
194        }
195    }
196
197    /// Send configuration to the agent via the pool's config push mechanism.
198    async fn send_configure(&self, _config: serde_json::Value) -> GrapsusResult<()> {
199        use grapsus_agent_protocol::v2::ConfigUpdateType;
200
201        if let Some(push_id) = self
202            .pool
203            .push_config_to_agent(&self.config.id, ConfigUpdateType::RequestReload)
204        {
205            info!(
206                agent_id = %self.config.id,
207                push_id = %push_id,
208                "Configuration push sent to agent"
209            );
210            Ok(())
211        } else {
212            debug!(
213                agent_id = %self.config.id,
214                "Agent does not support config push, config will be sent on next connection"
215            );
216            Ok(())
217        }
218    }
219
220    /// Call agent with request headers event.
221    pub async fn call_request_headers(
222        &self,
223        event: &RequestHeadersEvent,
224    ) -> GrapsusResult<AgentResponse> {
225        let call_num = self.metrics.calls_total.fetch_add(1, Ordering::Relaxed) + 1;
226
227        // Get correlation_id from event metadata
228        let correlation_id = &event.metadata.correlation_id;
229
230        trace!(
231            agent_id = %self.config.id,
232            call_num = call_num,
233            correlation_id = %correlation_id,
234            "Sending request headers to v2 agent"
235        );
236
237        self.pool
238            .send_request_headers(&self.config.id, correlation_id, event)
239            .await
240            .map_err(|e| {
241                error!(
242                    agent_id = %self.config.id,
243                    correlation_id = %correlation_id,
244                    error = %e,
245                    "V2 agent request headers call failed"
246                );
247                GrapsusError::Agent {
248                    agent: self.config.id.clone(),
249                    message: e.to_string(),
250                    event: "request_headers".to_string(),
251                    source: None,
252                }
253            })
254    }
255
256    /// Call agent with request body chunk event.
257    ///
258    /// For streaming body inspection, chunks are sent sequentially with
259    /// increasing `chunk_index`. The agent responds after processing each chunk.
260    pub async fn call_request_body_chunk(
261        &self,
262        event: &RequestBodyChunkEvent,
263    ) -> GrapsusResult<AgentResponse> {
264        let correlation_id = &event.correlation_id;
265
266        trace!(
267            agent_id = %self.config.id,
268            correlation_id = %correlation_id,
269            chunk_index = event.chunk_index,
270            is_last = event.is_last,
271            "Sending request body chunk to v2 agent"
272        );
273
274        self.pool
275            .send_request_body_chunk(&self.config.id, correlation_id, event)
276            .await
277            .map_err(|e| {
278                error!(
279                    agent_id = %self.config.id,
280                    correlation_id = %correlation_id,
281                    error = %e,
282                    "V2 agent request body chunk call failed"
283                );
284                GrapsusError::Agent {
285                    agent: self.config.id.clone(),
286                    message: e.to_string(),
287                    event: "request_body_chunk".to_string(),
288                    source: None,
289                }
290            })
291    }
292
293    /// Call agent with response headers event.
294    ///
295    /// Called when upstream response headers are received, allowing the agent
296    /// to inspect/modify response headers before they're sent to the client.
297    pub async fn call_response_headers(
298        &self,
299        event: &ResponseHeadersEvent,
300    ) -> GrapsusResult<AgentResponse> {
301        let correlation_id = &event.correlation_id;
302
303        trace!(
304            agent_id = %self.config.id,
305            correlation_id = %correlation_id,
306            status = event.status,
307            "Sending response headers to v2 agent"
308        );
309
310        self.pool
311            .send_response_headers(&self.config.id, correlation_id, event)
312            .await
313            .map_err(|e| {
314                error!(
315                    agent_id = %self.config.id,
316                    correlation_id = %correlation_id,
317                    error = %e,
318                    "V2 agent response headers call failed"
319                );
320                GrapsusError::Agent {
321                    agent: self.config.id.clone(),
322                    message: e.to_string(),
323                    event: "response_headers".to_string(),
324                    source: None,
325                }
326            })
327    }
328
329    /// Call agent with response body chunk event.
330    ///
331    /// For streaming response body inspection, chunks are sent sequentially.
332    /// The agent can inspect and optionally modify response body data.
333    pub async fn call_response_body_chunk(
334        &self,
335        event: &ResponseBodyChunkEvent,
336    ) -> GrapsusResult<AgentResponse> {
337        let correlation_id = &event.correlation_id;
338
339        trace!(
340            agent_id = %self.config.id,
341            correlation_id = %correlation_id,
342            chunk_index = event.chunk_index,
343            is_last = event.is_last,
344            "Sending response body chunk to v2 agent"
345        );
346
347        self.pool
348            .send_response_body_chunk(&self.config.id, correlation_id, event)
349            .await
350            .map_err(|e| {
351                error!(
352                    agent_id = %self.config.id,
353                    correlation_id = %correlation_id,
354                    error = %e,
355                    "V2 agent response body chunk call failed"
356                );
357                GrapsusError::Agent {
358                    agent: self.config.id.clone(),
359                    message: e.to_string(),
360                    event: "response_body_chunk".to_string(),
361                    source: None,
362                }
363            })
364    }
365
366    /// Call agent with guardrail inspect event.
367    pub async fn call_guardrail_inspect(
368        &self,
369        event: &GuardrailInspectEvent,
370    ) -> GrapsusResult<AgentResponse> {
371        let call_num = self.metrics.calls_total.fetch_add(1, Ordering::Relaxed) + 1;
372
373        let correlation_id = &event.correlation_id;
374
375        trace!(
376            agent_id = %self.config.id,
377            call_num = call_num,
378            correlation_id = %correlation_id,
379            inspection_type = ?event.inspection_type,
380            "Sending guardrail inspect to v2 agent"
381        );
382
383        self.pool
384            .send_guardrail_inspect(&self.config.id, correlation_id, event)
385            .await
386            .map_err(|e| {
387                error!(
388                    agent_id = %self.config.id,
389                    correlation_id = %correlation_id,
390                    error = %e,
391                    "V2 agent guardrail inspect call failed"
392                );
393                GrapsusError::Agent {
394                    agent: self.config.id.clone(),
395                    message: e.to_string(),
396                    event: "guardrail_inspect".to_string(),
397                    source: None,
398                }
399            })
400    }
401
402    /// Call agent with a generic event, dispatching to the appropriate typed method.
403    ///
404    /// The event is serialized and deserialized to convert between the generic
405    /// type and the specific event struct expected by each typed method.
406    pub async fn call_event<T: serde::Serialize>(
407        &self,
408        event_type: EventType,
409        event: &T,
410    ) -> GrapsusResult<AgentResponse> {
411        let json = serde_json::to_value(event).map_err(|e| GrapsusError::Agent {
412            agent: self.config.id.clone(),
413            message: format!("Failed to serialize event: {}", e),
414            event: format!("{:?}", event_type),
415            source: None,
416        })?;
417
418        match event_type {
419            EventType::RequestHeaders => {
420                let typed: RequestHeadersEvent =
421                    serde_json::from_value(json).map_err(|e| GrapsusError::Agent {
422                        agent: self.config.id.clone(),
423                        message: format!("Failed to deserialize RequestHeadersEvent: {}", e),
424                        event: format!("{:?}", event_type),
425                        source: None,
426                    })?;
427                self.call_request_headers(&typed).await
428            }
429            EventType::RequestBodyChunk => {
430                let typed: RequestBodyChunkEvent =
431                    serde_json::from_value(json).map_err(|e| GrapsusError::Agent {
432                        agent: self.config.id.clone(),
433                        message: format!("Failed to deserialize RequestBodyChunkEvent: {}", e),
434                        event: format!("{:?}", event_type),
435                        source: None,
436                    })?;
437                self.call_request_body_chunk(&typed).await
438            }
439            EventType::ResponseHeaders => {
440                let typed: ResponseHeadersEvent =
441                    serde_json::from_value(json).map_err(|e| GrapsusError::Agent {
442                        agent: self.config.id.clone(),
443                        message: format!("Failed to deserialize ResponseHeadersEvent: {}", e),
444                        event: format!("{:?}", event_type),
445                        source: None,
446                    })?;
447                self.call_response_headers(&typed).await
448            }
449            EventType::ResponseBodyChunk => {
450                let typed: ResponseBodyChunkEvent =
451                    serde_json::from_value(json).map_err(|e| GrapsusError::Agent {
452                        agent: self.config.id.clone(),
453                        message: format!("Failed to deserialize ResponseBodyChunkEvent: {}", e),
454                        event: format!("{:?}", event_type),
455                        source: None,
456                    })?;
457                self.call_response_body_chunk(&typed).await
458            }
459            EventType::GuardrailInspect => {
460                let typed: GuardrailInspectEvent =
461                    serde_json::from_value(json).map_err(|e| GrapsusError::Agent {
462                        agent: self.config.id.clone(),
463                        message: format!("Failed to deserialize GuardrailInspectEvent: {}", e),
464                        event: format!("{:?}", event_type),
465                        source: None,
466                    })?;
467                self.call_guardrail_inspect(&typed).await
468            }
469            _ => Err(GrapsusError::Agent {
470                agent: self.config.id.clone(),
471                message: format!("Unsupported event type {:?}", event_type),
472                event: format!("{:?}", event_type),
473                source: None,
474            }),
475        }
476    }
477
478    /// Cancel an in-flight request.
479    pub async fn cancel_request(
480        &self,
481        correlation_id: &str,
482        reason: CancelReason,
483    ) -> GrapsusResult<()> {
484        trace!(
485            agent_id = %self.config.id,
486            correlation_id = %correlation_id,
487            reason = ?reason,
488            "Cancelling request on v2 agent"
489        );
490
491        self.pool
492            .cancel_request(&self.config.id, correlation_id, reason)
493            .await
494            .map_err(|e| {
495                warn!(
496                    agent_id = %self.config.id,
497                    correlation_id = %correlation_id,
498                    error = %e,
499                    "Failed to cancel request on v2 agent"
500                );
501                GrapsusError::Agent {
502                    agent: self.config.id.clone(),
503                    message: format!("Cancel failed: {}", e),
504                    event: "cancel".to_string(),
505                    source: None,
506                }
507            })
508    }
509
510    /// Get agent capabilities.
511    pub async fn capabilities(&self) -> Option<AgentCapabilities> {
512        self.pool.agent_capabilities(&self.config.id).await
513    }
514
515    /// Check if agent is healthy.
516    pub async fn is_healthy(&self) -> bool {
517        self.pool.is_agent_healthy(&self.config.id)
518    }
519
520    /// Record successful call (lock-free).
521    pub fn record_success(&self, duration: Duration) {
522        let success_count = self.metrics.calls_success.fetch_add(1, Ordering::Relaxed) + 1;
523        self.metrics
524            .duration_total_us
525            .fetch_add(duration.as_micros() as u64, Ordering::Relaxed);
526        self.consecutive_failures.store(0, Ordering::Relaxed);
527        self.last_success_ns.store(
528            self.base_instant.elapsed().as_nanos() as u64,
529            Ordering::Relaxed,
530        );
531
532        trace!(
533            agent_id = %self.config.id,
534            duration_ms = duration.as_millis(),
535            total_successes = success_count,
536            "Recorded v2 agent call success"
537        );
538
539        self.circuit_breaker.record_success();
540    }
541
542    /// Get the time since last successful call.
543    #[inline]
544    pub fn time_since_last_success(&self) -> Option<Duration> {
545        let last_ns = self.last_success_ns.load(Ordering::Relaxed);
546        if last_ns == NO_TIMESTAMP {
547            return None;
548        }
549        let current_ns = self.base_instant.elapsed().as_nanos() as u64;
550        Some(Duration::from_nanos(current_ns.saturating_sub(last_ns)))
551    }
552
553    /// Record failed call.
554    pub fn record_failure(&self) {
555        let fail_count = self.metrics.calls_failed.fetch_add(1, Ordering::Relaxed) + 1;
556        let consecutive = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1;
557
558        debug!(
559            agent_id = %self.config.id,
560            total_failures = fail_count,
561            consecutive_failures = consecutive,
562            "Recorded v2 agent call failure"
563        );
564
565        self.circuit_breaker.record_failure();
566    }
567
568    /// Record timeout.
569    pub fn record_timeout(&self) {
570        let timeout_count = self.metrics.calls_timeout.fetch_add(1, Ordering::Relaxed) + 1;
571        let consecutive = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1;
572
573        debug!(
574            agent_id = %self.config.id,
575            total_timeouts = timeout_count,
576            consecutive_failures = consecutive,
577            timeout_ms = self.config.timeout_ms,
578            "Recorded v2 agent call timeout"
579        );
580
581        self.circuit_breaker.record_failure();
582    }
583
584    /// Get pool statistics.
585    pub async fn pool_stats(&self) -> Option<AgentPoolStats> {
586        self.pool.agent_stats(&self.config.id).await
587    }
588
589    /// Get the pool's metrics collector.
590    ///
591    /// Returns a reference to the shared metrics collector that aggregates
592    /// metrics reports from all agents in this pool.
593    pub fn pool_metrics_collector(&self) -> &MetricsCollector {
594        self.pool.metrics_collector()
595    }
596
597    /// Get an Arc to the pool's metrics collector.
598    ///
599    /// This is useful for registering the collector with a MetricsManager.
600    pub fn pool_metrics_collector_arc(&self) -> Arc<MetricsCollector> {
601        self.pool.metrics_collector_arc()
602    }
603
604    /// Export agent metrics in Prometheus format.
605    ///
606    /// Returns a string containing all metrics collected from agents
607    /// in Prometheus exposition format.
608    pub fn export_prometheus(&self) -> String {
609        self.pool.export_prometheus()
610    }
611
612    /// Get the pool's config pusher.
613    ///
614    /// Returns a reference to the shared config pusher that distributes
615    /// configuration updates to agents.
616    pub fn config_pusher(&self) -> &ConfigPusher {
617        self.pool.config_pusher()
618    }
619
620    /// Push a configuration update to this agent.
621    ///
622    /// Returns the push ID if the agent supports config push, None otherwise.
623    pub fn push_config(&self, update_type: ConfigUpdateType) -> Option<String> {
624        self.pool.push_config_to_agent(&self.config.id, update_type)
625    }
626
627    /// Send a configuration update to this agent via the control stream.
628    ///
629    /// This is a direct config push using the `ConfigureEvent` message.
630    pub async fn send_configuration(&self, config: serde_json::Value) -> GrapsusResult<()> {
631        // Get a connection and send the configure event
632        // For now, we rely on the pool's config push mechanism
633        // which tracks acknowledgments and retries
634        if let Some(push_id) = self.push_config(ConfigUpdateType::RequestReload) {
635            debug!(
636                agent_id = %self.config.id,
637                push_id = %push_id,
638                "Configuration push initiated"
639            );
640            Ok(())
641        } else {
642            warn!(
643                agent_id = %self.config.id,
644                "Agent does not support config push"
645            );
646            Err(GrapsusError::Agent {
647                agent: self.config.id.clone(),
648                message: "Agent does not support config push".to_string(),
649                event: "send_configuration".to_string(),
650                source: None,
651            })
652        }
653    }
654
655    /// Shutdown agent.
656    ///
657    /// This removes the agent from the pool and closes all connections.
658    pub async fn shutdown(&self) {
659        debug!(
660            agent_id = %self.config.id,
661            "Shutting down v2 agent"
662        );
663
664        // Remove from pool - this gracefully closes connections
665        if let Err(e) = self.pool.remove_agent(&self.config.id).await {
666            warn!(
667                agent_id = %self.config.id,
668                error = %e,
669                "Error removing agent from pool during shutdown"
670            );
671        }
672
673        let stats = (
674            self.metrics.calls_total.load(Ordering::Relaxed),
675            self.metrics.calls_success.load(Ordering::Relaxed),
676            self.metrics.calls_failed.load(Ordering::Relaxed),
677            self.metrics.calls_timeout.load(Ordering::Relaxed),
678        );
679
680        info!(
681            agent_id = %self.config.id,
682            total_calls = stats.0,
683            successes = stats.1,
684            failures = stats.2,
685            timeouts = stats.3,
686            "V2 agent shutdown complete"
687        );
688    }
689}
690
691/// Convert config load balance strategy to protocol load balance strategy.
692fn convert_lb_strategy(strategy: LoadBalanceStrategy) -> ProtocolLBStrategy {
693    match strategy {
694        LoadBalanceStrategy::RoundRobin => ProtocolLBStrategy::RoundRobin,
695        LoadBalanceStrategy::LeastConnections => ProtocolLBStrategy::LeastConnections,
696        LoadBalanceStrategy::HealthBased => ProtocolLBStrategy::HealthBased,
697        LoadBalanceStrategy::Random => ProtocolLBStrategy::Random,
698    }
699}
700
701#[cfg(test)]
702mod tests {
703    use super::*;
704
705    #[test]
706    fn test_convert_lb_strategy() {
707        assert_eq!(
708            convert_lb_strategy(LoadBalanceStrategy::RoundRobin),
709            ProtocolLBStrategy::RoundRobin
710        );
711        assert_eq!(
712            convert_lb_strategy(LoadBalanceStrategy::LeastConnections),
713            ProtocolLBStrategy::LeastConnections
714        );
715        assert_eq!(
716            convert_lb_strategy(LoadBalanceStrategy::HealthBased),
717            ProtocolLBStrategy::HealthBased
718        );
719        assert_eq!(
720            convert_lb_strategy(LoadBalanceStrategy::Random),
721            ProtocolLBStrategy::Random
722        );
723    }
724}