1use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9
10use tracing::{debug, error, info, trace, warn};
11use grapsus_agent_protocol::v2::{
12 AgentCapabilities, AgentPool, AgentPoolConfig as ProtocolPoolConfig, AgentPoolStats,
13 CancelReason, ConfigPusher, ConfigUpdateType, LoadBalanceStrategy as ProtocolLBStrategy,
14 MetricsCollector,
15};
16use grapsus_agent_protocol::{
17 AgentResponse, EventType, GuardrailInspectEvent, RequestBodyChunkEvent, RequestHeadersEvent,
18 ResponseBodyChunkEvent, ResponseHeadersEvent,
19};
20use grapsus_common::{
21 errors::{GrapsusError, GrapsusResult},
22 CircuitBreaker,
23};
24use grapsus_config::{AgentConfig, AgentEvent, FailureMode, LoadBalanceStrategy};
25
26use super::metrics::AgentMetrics;
27
28const NO_TIMESTAMP: u64 = 0;
30
31pub struct AgentV2 {
33 config: AgentConfig,
35 pool: Arc<AgentPool>,
37 circuit_breaker: Arc<CircuitBreaker>,
39 metrics: Arc<AgentMetrics>,
41 base_instant: Instant,
43 last_success_ns: AtomicU64,
45 consecutive_failures: AtomicU32,
47}
48
49impl AgentV2 {
50 pub fn new(config: AgentConfig, circuit_breaker: Arc<CircuitBreaker>) -> Self {
52 trace!(
53 agent_id = %config.id,
54 agent_type = ?config.agent_type,
55 timeout_ms = config.timeout_ms,
56 events = ?config.events,
57 "Creating v2 agent instance"
58 );
59
60 let pool_config = config
62 .pool
63 .as_ref()
64 .map(|p| ProtocolPoolConfig {
65 connections_per_agent: p.connections_per_agent,
66 load_balance_strategy: convert_lb_strategy(p.load_balance_strategy),
67 connect_timeout: Duration::from_millis(p.connect_timeout_ms),
68 request_timeout: Duration::from_millis(config.timeout_ms),
69 reconnect_interval: Duration::from_millis(p.reconnect_interval_ms),
70 max_reconnect_attempts: p.max_reconnect_attempts,
71 drain_timeout: Duration::from_millis(p.drain_timeout_ms),
72 max_concurrent_per_connection: p.max_concurrent_per_connection,
73 health_check_interval: Duration::from_millis(p.health_check_interval_ms),
74 ..Default::default()
75 })
76 .unwrap_or_default();
77
78 let pool = Arc::new(AgentPool::with_config(pool_config));
79
80 Self {
81 config,
82 pool,
83 circuit_breaker,
84 metrics: Arc::new(AgentMetrics::default()),
85 base_instant: Instant::now(),
86 last_success_ns: AtomicU64::new(NO_TIMESTAMP),
87 consecutive_failures: AtomicU32::new(0),
88 }
89 }
90
91 pub fn id(&self) -> &str {
93 &self.config.id
94 }
95
96 pub fn circuit_breaker(&self) -> &CircuitBreaker {
98 &self.circuit_breaker
99 }
100
101 pub fn failure_mode(&self) -> FailureMode {
103 self.config.failure_mode
104 }
105
106 pub fn timeout_ms(&self) -> u64 {
108 self.config.timeout_ms
109 }
110
111 pub fn metrics(&self) -> &AgentMetrics {
113 &self.metrics
114 }
115
116 pub fn handles_event(&self, event_type: EventType) -> bool {
118 self.config.events.iter().any(|e| match (e, event_type) {
119 (AgentEvent::RequestHeaders, EventType::RequestHeaders) => true,
120 (AgentEvent::RequestBody, EventType::RequestBodyChunk) => true,
121 (AgentEvent::ResponseHeaders, EventType::ResponseHeaders) => true,
122 (AgentEvent::ResponseBody, EventType::ResponseBodyChunk) => true,
123 (AgentEvent::Log, EventType::RequestComplete) => true,
124 (AgentEvent::WebSocketFrame, EventType::WebSocketFrame) => true,
125 (AgentEvent::Guardrail, EventType::GuardrailInspect) => true,
126 _ => false,
127 })
128 }
129
130 pub async fn initialize(&self) -> GrapsusResult<()> {
132 let endpoint = self.get_endpoint()?;
133
134 debug!(
135 agent_id = %self.config.id,
136 endpoint = %endpoint,
137 "Initializing v2 agent pool"
138 );
139
140 let start = Instant::now();
141
142 self.pool
144 .add_agent(&self.config.id, &endpoint)
145 .await
146 .map_err(|e| {
147 error!(
148 agent_id = %self.config.id,
149 endpoint = %endpoint,
150 error = %e,
151 "Failed to add agent to v2 pool"
152 );
153 GrapsusError::Agent {
154 agent: self.config.id.clone(),
155 message: format!("Failed to initialize v2 agent: {}", e),
156 event: "initialize".to_string(),
157 source: None,
158 }
159 })?;
160
161 info!(
162 agent_id = %self.config.id,
163 endpoint = %endpoint,
164 connect_time_ms = start.elapsed().as_millis(),
165 "V2 agent pool initialized"
166 );
167
168 if let Some(config_value) = &self.config.config {
170 self.send_configure(config_value.clone()).await?;
171 }
172
173 Ok(())
174 }
175
176 fn get_endpoint(&self) -> GrapsusResult<String> {
178 use grapsus_config::AgentTransport;
179 match &self.config.transport {
180 AgentTransport::Grpc { address, .. } => Ok(address.clone()),
181 AgentTransport::UnixSocket { path } => {
182 Ok(format!("unix:{}", path.display()))
184 }
185 AgentTransport::Http { url, .. } => {
186 Err(GrapsusError::Agent {
188 agent: self.config.id.clone(),
189 message: "HTTP transport not supported for v2 protocol".to_string(),
190 event: "initialize".to_string(),
191 source: None,
192 })
193 }
194 }
195 }
196
197 async fn send_configure(&self, _config: serde_json::Value) -> GrapsusResult<()> {
199 use grapsus_agent_protocol::v2::ConfigUpdateType;
200
201 if let Some(push_id) = self
202 .pool
203 .push_config_to_agent(&self.config.id, ConfigUpdateType::RequestReload)
204 {
205 info!(
206 agent_id = %self.config.id,
207 push_id = %push_id,
208 "Configuration push sent to agent"
209 );
210 Ok(())
211 } else {
212 debug!(
213 agent_id = %self.config.id,
214 "Agent does not support config push, config will be sent on next connection"
215 );
216 Ok(())
217 }
218 }
219
220 pub async fn call_request_headers(
222 &self,
223 event: &RequestHeadersEvent,
224 ) -> GrapsusResult<AgentResponse> {
225 let call_num = self.metrics.calls_total.fetch_add(1, Ordering::Relaxed) + 1;
226
227 let correlation_id = &event.metadata.correlation_id;
229
230 trace!(
231 agent_id = %self.config.id,
232 call_num = call_num,
233 correlation_id = %correlation_id,
234 "Sending request headers to v2 agent"
235 );
236
237 self.pool
238 .send_request_headers(&self.config.id, correlation_id, event)
239 .await
240 .map_err(|e| {
241 error!(
242 agent_id = %self.config.id,
243 correlation_id = %correlation_id,
244 error = %e,
245 "V2 agent request headers call failed"
246 );
247 GrapsusError::Agent {
248 agent: self.config.id.clone(),
249 message: e.to_string(),
250 event: "request_headers".to_string(),
251 source: None,
252 }
253 })
254 }
255
256 pub async fn call_request_body_chunk(
261 &self,
262 event: &RequestBodyChunkEvent,
263 ) -> GrapsusResult<AgentResponse> {
264 let correlation_id = &event.correlation_id;
265
266 trace!(
267 agent_id = %self.config.id,
268 correlation_id = %correlation_id,
269 chunk_index = event.chunk_index,
270 is_last = event.is_last,
271 "Sending request body chunk to v2 agent"
272 );
273
274 self.pool
275 .send_request_body_chunk(&self.config.id, correlation_id, event)
276 .await
277 .map_err(|e| {
278 error!(
279 agent_id = %self.config.id,
280 correlation_id = %correlation_id,
281 error = %e,
282 "V2 agent request body chunk call failed"
283 );
284 GrapsusError::Agent {
285 agent: self.config.id.clone(),
286 message: e.to_string(),
287 event: "request_body_chunk".to_string(),
288 source: None,
289 }
290 })
291 }
292
293 pub async fn call_response_headers(
298 &self,
299 event: &ResponseHeadersEvent,
300 ) -> GrapsusResult<AgentResponse> {
301 let correlation_id = &event.correlation_id;
302
303 trace!(
304 agent_id = %self.config.id,
305 correlation_id = %correlation_id,
306 status = event.status,
307 "Sending response headers to v2 agent"
308 );
309
310 self.pool
311 .send_response_headers(&self.config.id, correlation_id, event)
312 .await
313 .map_err(|e| {
314 error!(
315 agent_id = %self.config.id,
316 correlation_id = %correlation_id,
317 error = %e,
318 "V2 agent response headers call failed"
319 );
320 GrapsusError::Agent {
321 agent: self.config.id.clone(),
322 message: e.to_string(),
323 event: "response_headers".to_string(),
324 source: None,
325 }
326 })
327 }
328
329 pub async fn call_response_body_chunk(
334 &self,
335 event: &ResponseBodyChunkEvent,
336 ) -> GrapsusResult<AgentResponse> {
337 let correlation_id = &event.correlation_id;
338
339 trace!(
340 agent_id = %self.config.id,
341 correlation_id = %correlation_id,
342 chunk_index = event.chunk_index,
343 is_last = event.is_last,
344 "Sending response body chunk to v2 agent"
345 );
346
347 self.pool
348 .send_response_body_chunk(&self.config.id, correlation_id, event)
349 .await
350 .map_err(|e| {
351 error!(
352 agent_id = %self.config.id,
353 correlation_id = %correlation_id,
354 error = %e,
355 "V2 agent response body chunk call failed"
356 );
357 GrapsusError::Agent {
358 agent: self.config.id.clone(),
359 message: e.to_string(),
360 event: "response_body_chunk".to_string(),
361 source: None,
362 }
363 })
364 }
365
366 pub async fn call_guardrail_inspect(
368 &self,
369 event: &GuardrailInspectEvent,
370 ) -> GrapsusResult<AgentResponse> {
371 let call_num = self.metrics.calls_total.fetch_add(1, Ordering::Relaxed) + 1;
372
373 let correlation_id = &event.correlation_id;
374
375 trace!(
376 agent_id = %self.config.id,
377 call_num = call_num,
378 correlation_id = %correlation_id,
379 inspection_type = ?event.inspection_type,
380 "Sending guardrail inspect to v2 agent"
381 );
382
383 self.pool
384 .send_guardrail_inspect(&self.config.id, correlation_id, event)
385 .await
386 .map_err(|e| {
387 error!(
388 agent_id = %self.config.id,
389 correlation_id = %correlation_id,
390 error = %e,
391 "V2 agent guardrail inspect call failed"
392 );
393 GrapsusError::Agent {
394 agent: self.config.id.clone(),
395 message: e.to_string(),
396 event: "guardrail_inspect".to_string(),
397 source: None,
398 }
399 })
400 }
401
402 pub async fn call_event<T: serde::Serialize>(
407 &self,
408 event_type: EventType,
409 event: &T,
410 ) -> GrapsusResult<AgentResponse> {
411 let json = serde_json::to_value(event).map_err(|e| GrapsusError::Agent {
412 agent: self.config.id.clone(),
413 message: format!("Failed to serialize event: {}", e),
414 event: format!("{:?}", event_type),
415 source: None,
416 })?;
417
418 match event_type {
419 EventType::RequestHeaders => {
420 let typed: RequestHeadersEvent =
421 serde_json::from_value(json).map_err(|e| GrapsusError::Agent {
422 agent: self.config.id.clone(),
423 message: format!("Failed to deserialize RequestHeadersEvent: {}", e),
424 event: format!("{:?}", event_type),
425 source: None,
426 })?;
427 self.call_request_headers(&typed).await
428 }
429 EventType::RequestBodyChunk => {
430 let typed: RequestBodyChunkEvent =
431 serde_json::from_value(json).map_err(|e| GrapsusError::Agent {
432 agent: self.config.id.clone(),
433 message: format!("Failed to deserialize RequestBodyChunkEvent: {}", e),
434 event: format!("{:?}", event_type),
435 source: None,
436 })?;
437 self.call_request_body_chunk(&typed).await
438 }
439 EventType::ResponseHeaders => {
440 let typed: ResponseHeadersEvent =
441 serde_json::from_value(json).map_err(|e| GrapsusError::Agent {
442 agent: self.config.id.clone(),
443 message: format!("Failed to deserialize ResponseHeadersEvent: {}", e),
444 event: format!("{:?}", event_type),
445 source: None,
446 })?;
447 self.call_response_headers(&typed).await
448 }
449 EventType::ResponseBodyChunk => {
450 let typed: ResponseBodyChunkEvent =
451 serde_json::from_value(json).map_err(|e| GrapsusError::Agent {
452 agent: self.config.id.clone(),
453 message: format!("Failed to deserialize ResponseBodyChunkEvent: {}", e),
454 event: format!("{:?}", event_type),
455 source: None,
456 })?;
457 self.call_response_body_chunk(&typed).await
458 }
459 EventType::GuardrailInspect => {
460 let typed: GuardrailInspectEvent =
461 serde_json::from_value(json).map_err(|e| GrapsusError::Agent {
462 agent: self.config.id.clone(),
463 message: format!("Failed to deserialize GuardrailInspectEvent: {}", e),
464 event: format!("{:?}", event_type),
465 source: None,
466 })?;
467 self.call_guardrail_inspect(&typed).await
468 }
469 _ => Err(GrapsusError::Agent {
470 agent: self.config.id.clone(),
471 message: format!("Unsupported event type {:?}", event_type),
472 event: format!("{:?}", event_type),
473 source: None,
474 }),
475 }
476 }
477
478 pub async fn cancel_request(
480 &self,
481 correlation_id: &str,
482 reason: CancelReason,
483 ) -> GrapsusResult<()> {
484 trace!(
485 agent_id = %self.config.id,
486 correlation_id = %correlation_id,
487 reason = ?reason,
488 "Cancelling request on v2 agent"
489 );
490
491 self.pool
492 .cancel_request(&self.config.id, correlation_id, reason)
493 .await
494 .map_err(|e| {
495 warn!(
496 agent_id = %self.config.id,
497 correlation_id = %correlation_id,
498 error = %e,
499 "Failed to cancel request on v2 agent"
500 );
501 GrapsusError::Agent {
502 agent: self.config.id.clone(),
503 message: format!("Cancel failed: {}", e),
504 event: "cancel".to_string(),
505 source: None,
506 }
507 })
508 }
509
510 pub async fn capabilities(&self) -> Option<AgentCapabilities> {
512 self.pool.agent_capabilities(&self.config.id).await
513 }
514
515 pub async fn is_healthy(&self) -> bool {
517 self.pool.is_agent_healthy(&self.config.id)
518 }
519
520 pub fn record_success(&self, duration: Duration) {
522 let success_count = self.metrics.calls_success.fetch_add(1, Ordering::Relaxed) + 1;
523 self.metrics
524 .duration_total_us
525 .fetch_add(duration.as_micros() as u64, Ordering::Relaxed);
526 self.consecutive_failures.store(0, Ordering::Relaxed);
527 self.last_success_ns.store(
528 self.base_instant.elapsed().as_nanos() as u64,
529 Ordering::Relaxed,
530 );
531
532 trace!(
533 agent_id = %self.config.id,
534 duration_ms = duration.as_millis(),
535 total_successes = success_count,
536 "Recorded v2 agent call success"
537 );
538
539 self.circuit_breaker.record_success();
540 }
541
542 #[inline]
544 pub fn time_since_last_success(&self) -> Option<Duration> {
545 let last_ns = self.last_success_ns.load(Ordering::Relaxed);
546 if last_ns == NO_TIMESTAMP {
547 return None;
548 }
549 let current_ns = self.base_instant.elapsed().as_nanos() as u64;
550 Some(Duration::from_nanos(current_ns.saturating_sub(last_ns)))
551 }
552
553 pub fn record_failure(&self) {
555 let fail_count = self.metrics.calls_failed.fetch_add(1, Ordering::Relaxed) + 1;
556 let consecutive = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1;
557
558 debug!(
559 agent_id = %self.config.id,
560 total_failures = fail_count,
561 consecutive_failures = consecutive,
562 "Recorded v2 agent call failure"
563 );
564
565 self.circuit_breaker.record_failure();
566 }
567
568 pub fn record_timeout(&self) {
570 let timeout_count = self.metrics.calls_timeout.fetch_add(1, Ordering::Relaxed) + 1;
571 let consecutive = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1;
572
573 debug!(
574 agent_id = %self.config.id,
575 total_timeouts = timeout_count,
576 consecutive_failures = consecutive,
577 timeout_ms = self.config.timeout_ms,
578 "Recorded v2 agent call timeout"
579 );
580
581 self.circuit_breaker.record_failure();
582 }
583
584 pub async fn pool_stats(&self) -> Option<AgentPoolStats> {
586 self.pool.agent_stats(&self.config.id).await
587 }
588
589 pub fn pool_metrics_collector(&self) -> &MetricsCollector {
594 self.pool.metrics_collector()
595 }
596
597 pub fn pool_metrics_collector_arc(&self) -> Arc<MetricsCollector> {
601 self.pool.metrics_collector_arc()
602 }
603
604 pub fn export_prometheus(&self) -> String {
609 self.pool.export_prometheus()
610 }
611
612 pub fn config_pusher(&self) -> &ConfigPusher {
617 self.pool.config_pusher()
618 }
619
620 pub fn push_config(&self, update_type: ConfigUpdateType) -> Option<String> {
624 self.pool.push_config_to_agent(&self.config.id, update_type)
625 }
626
627 pub async fn send_configuration(&self, config: serde_json::Value) -> GrapsusResult<()> {
631 if let Some(push_id) = self.push_config(ConfigUpdateType::RequestReload) {
635 debug!(
636 agent_id = %self.config.id,
637 push_id = %push_id,
638 "Configuration push initiated"
639 );
640 Ok(())
641 } else {
642 warn!(
643 agent_id = %self.config.id,
644 "Agent does not support config push"
645 );
646 Err(GrapsusError::Agent {
647 agent: self.config.id.clone(),
648 message: "Agent does not support config push".to_string(),
649 event: "send_configuration".to_string(),
650 source: None,
651 })
652 }
653 }
654
655 pub async fn shutdown(&self) {
659 debug!(
660 agent_id = %self.config.id,
661 "Shutting down v2 agent"
662 );
663
664 if let Err(e) = self.pool.remove_agent(&self.config.id).await {
666 warn!(
667 agent_id = %self.config.id,
668 error = %e,
669 "Error removing agent from pool during shutdown"
670 );
671 }
672
673 let stats = (
674 self.metrics.calls_total.load(Ordering::Relaxed),
675 self.metrics.calls_success.load(Ordering::Relaxed),
676 self.metrics.calls_failed.load(Ordering::Relaxed),
677 self.metrics.calls_timeout.load(Ordering::Relaxed),
678 );
679
680 info!(
681 agent_id = %self.config.id,
682 total_calls = stats.0,
683 successes = stats.1,
684 failures = stats.2,
685 timeouts = stats.3,
686 "V2 agent shutdown complete"
687 );
688 }
689}
690
691fn convert_lb_strategy(strategy: LoadBalanceStrategy) -> ProtocolLBStrategy {
693 match strategy {
694 LoadBalanceStrategy::RoundRobin => ProtocolLBStrategy::RoundRobin,
695 LoadBalanceStrategy::LeastConnections => ProtocolLBStrategy::LeastConnections,
696 LoadBalanceStrategy::HealthBased => ProtocolLBStrategy::HealthBased,
697 LoadBalanceStrategy::Random => ProtocolLBStrategy::Random,
698 }
699}
700
701#[cfg(test)]
702mod tests {
703 use super::*;
704
705 #[test]
706 fn test_convert_lb_strategy() {
707 assert_eq!(
708 convert_lb_strategy(LoadBalanceStrategy::RoundRobin),
709 ProtocolLBStrategy::RoundRobin
710 );
711 assert_eq!(
712 convert_lb_strategy(LoadBalanceStrategy::LeastConnections),
713 ProtocolLBStrategy::LeastConnections
714 );
715 assert_eq!(
716 convert_lb_strategy(LoadBalanceStrategy::HealthBased),
717 ProtocolLBStrategy::HealthBased
718 );
719 assert_eq!(
720 convert_lb_strategy(LoadBalanceStrategy::Random),
721 ProtocolLBStrategy::Random
722 );
723 }
724}