1use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15
16use dashmap::DashMap;
17use tokio::sync::{RwLock, Semaphore};
18use tracing::{debug, info, trace, warn};
19
20use crate::v2::client::{AgentClientV2, CancelReason, ConfigUpdateCallback, MetricsCallback};
21use crate::v2::control::ConfigUpdateType;
22use crate::v2::observability::{ConfigPusher, ConfigUpdateHandler, MetricsCollector};
23use crate::v2::protocol_metrics::ProtocolMetrics;
24use crate::v2::reverse::ReverseConnectionClient;
25use crate::v2::uds::AgentClientV2Uds;
26use crate::v2::AgentCapabilities;
27use crate::{
28 AgentProtocolError, AgentResponse, GuardrailInspectEvent, RequestBodyChunkEvent,
29 RequestHeadersEvent, ResponseBodyChunkEvent, ResponseHeadersEvent,
30};
31
32pub const CHANNEL_BUFFER_SIZE: usize = 64;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
45pub enum LoadBalanceStrategy {
46 #[default]
48 RoundRobin,
49 LeastConnections,
51 HealthBased,
53 Random,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
62pub enum FlowControlMode {
63 #[default]
69 FailClosed,
70
71 FailOpen,
77
78 WaitAndRetry,
83}
84
85struct StickySession {
91 connection: Arc<PooledConnection>,
93 agent_id: String,
95 created_at: Instant,
97 last_accessed: AtomicU64,
99}
100
101impl std::fmt::Debug for StickySession {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 f.debug_struct("StickySession")
104 .field("agent_id", &self.agent_id)
105 .field("created_at", &self.created_at)
106 .finish_non_exhaustive()
107 }
108}
109
110#[derive(Debug, Clone)]
112pub struct AgentPoolConfig {
113 pub connections_per_agent: usize,
115 pub load_balance_strategy: LoadBalanceStrategy,
117 pub connect_timeout: Duration,
119 pub request_timeout: Duration,
121 pub reconnect_interval: Duration,
123 pub max_reconnect_attempts: usize,
125 pub drain_timeout: Duration,
127 pub max_concurrent_per_connection: usize,
129 pub health_check_interval: Duration,
131 pub channel_buffer_size: usize,
139 pub flow_control_mode: FlowControlMode,
143 pub flow_control_wait_timeout: Duration,
148 pub sticky_session_timeout: Option<Duration>,
158}
159
160impl Default for AgentPoolConfig {
161 fn default() -> Self {
162 Self {
163 connections_per_agent: 4,
164 load_balance_strategy: LoadBalanceStrategy::RoundRobin,
165 connect_timeout: Duration::from_secs(5),
166 request_timeout: Duration::from_secs(30),
167 reconnect_interval: Duration::from_secs(5),
168 max_reconnect_attempts: 3,
169 drain_timeout: Duration::from_secs(30),
170 max_concurrent_per_connection: 100,
171 health_check_interval: Duration::from_secs(10),
172 channel_buffer_size: CHANNEL_BUFFER_SIZE,
173 flow_control_mode: FlowControlMode::FailClosed,
174 flow_control_wait_timeout: Duration::from_millis(100),
175 sticky_session_timeout: Some(Duration::from_secs(5 * 60)), }
177 }
178}
179
180impl StickySession {
181 fn new(agent_id: String, connection: Arc<PooledConnection>) -> Self {
182 Self {
183 connection,
184 agent_id,
185 created_at: Instant::now(),
186 last_accessed: AtomicU64::new(0),
187 }
188 }
189
190 fn touch(&self) {
191 let offset = self.created_at.elapsed().as_millis() as u64;
192 self.last_accessed.store(offset, Ordering::Relaxed);
193 }
194
195 fn last_accessed(&self) -> Instant {
196 let offset_ms = self.last_accessed.load(Ordering::Relaxed);
197 self.created_at + Duration::from_millis(offset_ms)
198 }
199
200 fn is_expired(&self, timeout: Duration) -> bool {
201 self.last_accessed().elapsed() > timeout
202 }
203}
204
205pub enum V2Transport {
209 Grpc(AgentClientV2),
211 Uds(AgentClientV2Uds),
213 Reverse(ReverseConnectionClient),
215}
216
217impl V2Transport {
218 pub async fn is_connected(&self) -> bool {
220 match self {
221 V2Transport::Grpc(client) => client.is_connected().await,
222 V2Transport::Uds(client) => client.is_connected().await,
223 V2Transport::Reverse(client) => client.is_connected().await,
224 }
225 }
226
227 pub async fn can_accept_requests(&self) -> bool {
231 match self {
232 V2Transport::Grpc(client) => client.can_accept_requests().await,
233 V2Transport::Uds(client) => client.can_accept_requests().await,
234 V2Transport::Reverse(client) => client.can_accept_requests().await,
235 }
236 }
237
238 pub async fn capabilities(&self) -> Option<AgentCapabilities> {
240 match self {
241 V2Transport::Grpc(client) => client.capabilities().await,
242 V2Transport::Uds(client) => client.capabilities().await,
243 V2Transport::Reverse(client) => client.capabilities().await,
244 }
245 }
246
247 pub async fn send_request_headers(
249 &self,
250 correlation_id: &str,
251 event: &RequestHeadersEvent,
252 ) -> Result<AgentResponse, AgentProtocolError> {
253 match self {
254 V2Transport::Grpc(client) => client.send_request_headers(correlation_id, event).await,
255 V2Transport::Uds(client) => client.send_request_headers(correlation_id, event).await,
256 V2Transport::Reverse(client) => {
257 client.send_request_headers(correlation_id, event).await
258 }
259 }
260 }
261
262 pub async fn send_request_body_chunk(
264 &self,
265 correlation_id: &str,
266 event: &RequestBodyChunkEvent,
267 ) -> Result<AgentResponse, AgentProtocolError> {
268 match self {
269 V2Transport::Grpc(client) => {
270 client.send_request_body_chunk(correlation_id, event).await
271 }
272 V2Transport::Uds(client) => client.send_request_body_chunk(correlation_id, event).await,
273 V2Transport::Reverse(client) => {
274 client.send_request_body_chunk(correlation_id, event).await
275 }
276 }
277 }
278
279 pub async fn send_response_headers(
281 &self,
282 correlation_id: &str,
283 event: &ResponseHeadersEvent,
284 ) -> Result<AgentResponse, AgentProtocolError> {
285 match self {
286 V2Transport::Grpc(client) => client.send_response_headers(correlation_id, event).await,
287 V2Transport::Uds(client) => client.send_response_headers(correlation_id, event).await,
288 V2Transport::Reverse(client) => {
289 client.send_response_headers(correlation_id, event).await
290 }
291 }
292 }
293
294 pub async fn send_response_body_chunk(
296 &self,
297 correlation_id: &str,
298 event: &ResponseBodyChunkEvent,
299 ) -> Result<AgentResponse, AgentProtocolError> {
300 match self {
301 V2Transport::Grpc(client) => {
302 client.send_response_body_chunk(correlation_id, event).await
303 }
304 V2Transport::Uds(client) => {
305 client.send_response_body_chunk(correlation_id, event).await
306 }
307 V2Transport::Reverse(client) => {
308 client.send_response_body_chunk(correlation_id, event).await
309 }
310 }
311 }
312
313 pub async fn send_guardrail_inspect(
315 &self,
316 correlation_id: &str,
317 event: &GuardrailInspectEvent,
318 ) -> Result<AgentResponse, AgentProtocolError> {
319 match self {
320 V2Transport::Grpc(_client) => Err(AgentProtocolError::InvalidMessage(
321 "GuardrailInspect events are not yet supported via gRPC".to_string(),
322 )),
323 V2Transport::Uds(client) => client.send_guardrail_inspect(correlation_id, event).await,
324 V2Transport::Reverse(_client) => Err(AgentProtocolError::InvalidMessage(
325 "GuardrailInspect events are not yet supported via reverse connections".to_string(),
326 )),
327 }
328 }
329
330 pub async fn cancel_request(
332 &self,
333 correlation_id: &str,
334 reason: CancelReason,
335 ) -> Result<(), AgentProtocolError> {
336 match self {
337 V2Transport::Grpc(client) => client.cancel_request(correlation_id, reason).await,
338 V2Transport::Uds(client) => client.cancel_request(correlation_id, reason).await,
339 V2Transport::Reverse(client) => client.cancel_request(correlation_id, reason).await,
340 }
341 }
342
343 pub async fn cancel_all(&self, reason: CancelReason) -> Result<usize, AgentProtocolError> {
345 match self {
346 V2Transport::Grpc(client) => client.cancel_all(reason).await,
347 V2Transport::Uds(client) => client.cancel_all(reason).await,
348 V2Transport::Reverse(client) => client.cancel_all(reason).await,
349 }
350 }
351
352 pub async fn close(&self) -> Result<(), AgentProtocolError> {
354 match self {
355 V2Transport::Grpc(client) => client.close().await,
356 V2Transport::Uds(client) => client.close().await,
357 V2Transport::Reverse(client) => client.close().await,
358 }
359 }
360
361 pub fn agent_id(&self) -> &str {
363 match self {
364 V2Transport::Grpc(client) => client.agent_id(),
365 V2Transport::Uds(client) => client.agent_id(),
366 V2Transport::Reverse(client) => client.agent_id(),
367 }
368 }
369}
370
371struct PooledConnection {
373 client: V2Transport,
374 created_at: Instant,
375 last_used_offset_ms: AtomicU64,
377 in_flight: AtomicU64,
378 request_count: AtomicU64,
379 error_count: AtomicU64,
380 consecutive_errors: AtomicU64,
381 concurrency_limiter: Semaphore,
382 healthy_cached: AtomicBool,
384}
385
386impl PooledConnection {
387 fn new(client: V2Transport, max_concurrent: usize) -> Self {
388 Self {
389 client,
390 created_at: Instant::now(),
391 last_used_offset_ms: AtomicU64::new(0),
392 in_flight: AtomicU64::new(0),
393 request_count: AtomicU64::new(0),
394 error_count: AtomicU64::new(0),
395 consecutive_errors: AtomicU64::new(0),
396 concurrency_limiter: Semaphore::new(max_concurrent),
397 healthy_cached: AtomicBool::new(true), }
399 }
400
401 fn in_flight(&self) -> u64 {
402 self.in_flight.load(Ordering::Relaxed)
403 }
404
405 fn error_rate(&self) -> f64 {
406 let requests = self.request_count.load(Ordering::Relaxed);
407 let errors = self.error_count.load(Ordering::Relaxed);
408 if requests == 0 {
409 0.0
410 } else {
411 errors as f64 / requests as f64
412 }
413 }
414
415 #[inline]
418 fn is_healthy_cached(&self) -> bool {
419 self.healthy_cached.load(Ordering::Acquire)
420 }
421
422 async fn check_and_update_health(&self) -> bool {
424 let connected = self.client.is_connected().await;
425 let low_errors = self.consecutive_errors.load(Ordering::Relaxed) < 3;
426 let can_accept = self.client.can_accept_requests().await;
427
428 let healthy = connected && low_errors && can_accept;
429 self.healthy_cached.store(healthy, Ordering::Release);
430 healthy
431 }
432
433 #[inline]
435 fn touch(&self) {
436 let offset = self.created_at.elapsed().as_millis() as u64;
437 self.last_used_offset_ms.store(offset, Ordering::Relaxed);
438 }
439
440 fn last_used(&self) -> Instant {
442 let offset_ms = self.last_used_offset_ms.load(Ordering::Relaxed);
443 self.created_at + Duration::from_millis(offset_ms)
444 }
445}
446
447#[derive(Debug, Clone)]
449pub struct AgentPoolStats {
450 pub agent_id: String,
452 pub active_connections: usize,
454 pub healthy_connections: usize,
456 pub total_in_flight: u64,
458 pub total_requests: u64,
460 pub total_errors: u64,
462 pub error_rate: f64,
464 pub is_healthy: bool,
466}
467
468struct AgentEntry {
470 agent_id: String,
471 endpoint: String,
472 connections: RwLock<Vec<Arc<PooledConnection>>>,
475 capabilities: RwLock<Option<AgentCapabilities>>,
476 round_robin_index: AtomicUsize,
477 reconnect_attempts: AtomicUsize,
478 last_reconnect_attempt_ms: AtomicU64,
480 healthy: AtomicBool,
482}
483
484impl AgentEntry {
485 fn new(agent_id: String, endpoint: String) -> Self {
486 Self {
487 agent_id,
488 endpoint,
489 connections: RwLock::new(Vec::new()),
490 capabilities: RwLock::new(None),
491 round_robin_index: AtomicUsize::new(0),
492 reconnect_attempts: AtomicUsize::new(0),
493 last_reconnect_attempt_ms: AtomicU64::new(0),
494 healthy: AtomicBool::new(true),
495 }
496 }
497
498 fn should_reconnect(&self, interval: Duration) -> bool {
500 let last_ms = self.last_reconnect_attempt_ms.load(Ordering::Relaxed);
501 if last_ms == 0 {
502 return true;
503 }
504 let now_ms = std::time::SystemTime::now()
505 .duration_since(std::time::UNIX_EPOCH)
506 .map(|d| d.as_millis() as u64)
507 .unwrap_or(0);
508 now_ms.saturating_sub(last_ms) > interval.as_millis() as u64
509 }
510
511 fn mark_reconnect_attempt(&self) {
513 let now_ms = std::time::SystemTime::now()
514 .duration_since(std::time::UNIX_EPOCH)
515 .map(|d| d.as_millis() as u64)
516 .unwrap_or(0);
517 self.last_reconnect_attempt_ms
518 .store(now_ms, Ordering::Relaxed);
519 }
520}
521
522pub struct AgentPool {
533 config: AgentPoolConfig,
534 agents: DashMap<String, Arc<AgentEntry>>,
537 total_requests: AtomicU64,
538 total_errors: AtomicU64,
539 metrics_collector: Arc<MetricsCollector>,
541 metrics_callback: MetricsCallback,
543 config_pusher: Arc<ConfigPusher>,
545 config_update_handler: Arc<ConfigUpdateHandler>,
547 config_update_callback: ConfigUpdateCallback,
549 protocol_metrics: Arc<ProtocolMetrics>,
551 correlation_affinity: DashMap<String, Arc<PooledConnection>>,
554 sticky_sessions: DashMap<String, StickySession>,
557}
558
559impl AgentPool {
560 pub fn new() -> Self {
562 Self::with_config(AgentPoolConfig::default())
563 }
564
565 pub fn with_config(config: AgentPoolConfig) -> Self {
567 let metrics_collector = Arc::new(MetricsCollector::new());
568 let collector_clone = Arc::clone(&metrics_collector);
569
570 let metrics_callback: MetricsCallback = Arc::new(move |report| {
572 collector_clone.record(&report);
573 });
574
575 let config_pusher = Arc::new(ConfigPusher::new());
577 let config_update_handler = Arc::new(ConfigUpdateHandler::new());
578 let handler_clone = Arc::clone(&config_update_handler);
579
580 let config_update_callback: ConfigUpdateCallback = Arc::new(move |agent_id, request| {
582 debug!(
583 agent_id = %agent_id,
584 request_id = %request.request_id,
585 "Processing config update request from agent"
586 );
587 handler_clone.handle(request)
588 });
589
590 Self {
591 config,
592 agents: DashMap::new(),
593 total_requests: AtomicU64::new(0),
594 total_errors: AtomicU64::new(0),
595 metrics_collector,
596 metrics_callback,
597 config_pusher,
598 config_update_handler,
599 config_update_callback,
600 protocol_metrics: Arc::new(ProtocolMetrics::new()),
601 correlation_affinity: DashMap::new(),
602 sticky_sessions: DashMap::new(),
603 }
604 }
605
606 pub fn protocol_metrics(&self) -> &ProtocolMetrics {
608 &self.protocol_metrics
609 }
610
611 pub fn protocol_metrics_arc(&self) -> Arc<ProtocolMetrics> {
613 Arc::clone(&self.protocol_metrics)
614 }
615
616 pub fn metrics_collector(&self) -> &MetricsCollector {
618 &self.metrics_collector
619 }
620
621 pub fn metrics_collector_arc(&self) -> Arc<MetricsCollector> {
625 Arc::clone(&self.metrics_collector)
626 }
627
628 pub fn export_prometheus(&self) -> String {
630 self.metrics_collector.export_prometheus()
631 }
632
633 pub fn clear_correlation_affinity(&self, correlation_id: &str) {
639 self.correlation_affinity.remove(correlation_id);
640 }
641
642 pub fn correlation_affinity_count(&self) -> usize {
646 self.correlation_affinity.len()
647 }
648
649 pub fn create_sticky_session(
685 &self,
686 session_id: impl Into<String>,
687 agent_id: &str,
688 ) -> Result<(), AgentProtocolError> {
689 let session_id = session_id.into();
690 let conn = self.select_connection(agent_id)?;
691
692 let session = StickySession::new(agent_id.to_string(), conn);
693 session.touch();
694
695 self.sticky_sessions.insert(session_id.clone(), session);
696
697 debug!(
698 session_id = %session_id,
699 agent_id = %agent_id,
700 "Created sticky session"
701 );
702
703 Ok(())
704 }
705
706 fn get_sticky_session_conn(&self, session_id: &str) -> Option<Arc<PooledConnection>> {
711 let entry = self.sticky_sessions.get(session_id)?;
712
713 if let Some(timeout) = self.config.sticky_session_timeout {
715 if entry.is_expired(timeout) {
716 drop(entry); self.sticky_sessions.remove(session_id);
718 debug!(session_id = %session_id, "Sticky session expired");
719 return None;
720 }
721 }
722
723 entry.touch();
724 Some(Arc::clone(&entry.connection))
725 }
726
727 pub fn refresh_sticky_session(&self, session_id: &str) -> bool {
731 self.get_sticky_session_conn(session_id).is_some()
732 }
733
734 pub fn has_sticky_session(&self, session_id: &str) -> bool {
736 self.get_sticky_session_conn(session_id).is_some()
737 }
738
739 pub fn clear_sticky_session(&self, session_id: &str) {
743 if self.sticky_sessions.remove(session_id).is_some() {
744 debug!(session_id = %session_id, "Cleared sticky session");
745 }
746 }
747
748 pub fn sticky_session_count(&self) -> usize {
752 self.sticky_sessions.len()
753 }
754
755 pub fn sticky_session_agent(&self, session_id: &str) -> Option<String> {
757 self.sticky_sessions
758 .get(session_id)
759 .map(|s| s.agent_id.clone())
760 }
761
762 pub async fn send_request_headers_with_sticky_session(
772 &self,
773 session_id: &str,
774 agent_id: &str,
775 correlation_id: &str,
776 event: &RequestHeadersEvent,
777 ) -> Result<(AgentResponse, bool), AgentProtocolError> {
778 let start = Instant::now();
779 self.total_requests.fetch_add(1, Ordering::Relaxed);
780 self.protocol_metrics.inc_requests();
781 self.protocol_metrics.inc_in_flight();
782
783 let (conn, used_sticky) =
785 if let Some(sticky_conn) = self.get_sticky_session_conn(session_id) {
786 (sticky_conn, true)
787 } else {
788 (self.select_connection(agent_id)?, false)
789 };
790
791 match self.check_flow_control(&conn, agent_id).await {
793 Ok(true) => {}
794 Ok(false) => {
795 self.protocol_metrics.dec_in_flight();
796 return Ok((AgentResponse::default_allow(), used_sticky));
797 }
798 Err(e) => {
799 self.protocol_metrics.dec_in_flight();
800 return Err(e);
801 }
802 }
803
804 let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
806 self.protocol_metrics.dec_in_flight();
807 self.protocol_metrics.inc_connection_errors();
808 AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
809 })?;
810
811 conn.in_flight.fetch_add(1, Ordering::Relaxed);
812 conn.touch();
813
814 self.correlation_affinity
816 .insert(correlation_id.to_string(), Arc::clone(&conn));
817
818 let result = conn
819 .client
820 .send_request_headers(correlation_id, event)
821 .await;
822
823 conn.in_flight.fetch_sub(1, Ordering::Relaxed);
824 conn.request_count.fetch_add(1, Ordering::Relaxed);
825 self.protocol_metrics.dec_in_flight();
826 self.protocol_metrics
827 .record_request_duration(start.elapsed());
828
829 match &result {
830 Ok(_) => {
831 conn.consecutive_errors.store(0, Ordering::Relaxed);
832 self.protocol_metrics.inc_responses();
833 }
834 Err(e) => {
835 conn.error_count.fetch_add(1, Ordering::Relaxed);
836 let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
837 self.total_errors.fetch_add(1, Ordering::Relaxed);
838
839 match e {
840 AgentProtocolError::Timeout(_) => self.protocol_metrics.inc_timeouts(),
841 AgentProtocolError::ConnectionFailed(_)
842 | AgentProtocolError::ConnectionClosed => {
843 self.protocol_metrics.inc_connection_errors();
844 }
845 AgentProtocolError::Serialization(_) => {
846 self.protocol_metrics.inc_serialization_errors();
847 }
848 _ => {}
849 }
850
851 if consecutive >= 3 {
852 conn.healthy_cached.store(false, Ordering::Release);
853 }
854 }
855 }
856
857 result.map(|r| (r, used_sticky))
858 }
859
860 pub fn cleanup_expired_sessions(&self) -> usize {
865 let Some(timeout) = self.config.sticky_session_timeout else {
866 return 0;
867 };
868
869 let mut removed = 0;
870 self.sticky_sessions.retain(|session_id, session| {
871 if session.is_expired(timeout) {
872 debug!(session_id = %session_id, "Removing expired sticky session");
873 removed += 1;
874 false
875 } else {
876 true
877 }
878 });
879
880 if removed > 0 {
881 trace!(removed = removed, "Cleaned up expired sticky sessions");
882 }
883
884 removed
885 }
886
887 pub fn config_pusher(&self) -> &ConfigPusher {
889 &self.config_pusher
890 }
891
892 pub fn config_update_handler(&self) -> &ConfigUpdateHandler {
894 &self.config_update_handler
895 }
896
897 pub fn push_config_to_agent(
901 &self,
902 agent_id: &str,
903 update_type: ConfigUpdateType,
904 ) -> Option<String> {
905 self.config_pusher.push_to_agent(agent_id, update_type)
906 }
907
908 pub fn push_config_to_all(&self, update_type: ConfigUpdateType) -> Vec<String> {
912 self.config_pusher.push_to_all(update_type)
913 }
914
915 pub fn acknowledge_config_push(&self, push_id: &str, accepted: bool, error: Option<String>) {
917 self.config_pusher.acknowledge(push_id, accepted, error);
918 }
919
920 pub async fn add_agent(
924 &self,
925 agent_id: impl Into<String>,
926 endpoint: impl Into<String>,
927 ) -> Result<(), AgentProtocolError> {
928 let agent_id = agent_id.into();
929 let endpoint = endpoint.into();
930
931 info!(agent_id = %agent_id, endpoint = %endpoint, "Adding agent to pool");
932
933 let entry = Arc::new(AgentEntry::new(agent_id.clone(), endpoint.clone()));
934
935 let mut connections = Vec::with_capacity(self.config.connections_per_agent);
937 for i in 0..self.config.connections_per_agent {
938 match self.create_connection(&agent_id, &endpoint).await {
939 Ok(conn) => {
940 connections.push(Arc::new(conn));
941 debug!(
942 agent_id = %agent_id,
943 connection = i,
944 "Created connection"
945 );
946 }
947 Err(e) => {
948 warn!(
949 agent_id = %agent_id,
950 connection = i,
951 error = %e,
952 "Failed to create connection"
953 );
954 }
956 }
957 }
958
959 if connections.is_empty() {
960 return Err(AgentProtocolError::ConnectionFailed(format!(
961 "Failed to create any connections to agent {}",
962 agent_id
963 )));
964 }
965
966 if let Some(conn) = connections.first() {
968 if let Some(caps) = conn.client.capabilities().await {
969 let supports_config_push = caps.features.config_push;
971 let agent_name = caps.name.clone();
972 self.config_pusher
973 .register_agent(&agent_id, &agent_name, supports_config_push);
974 debug!(
975 agent_id = %agent_id,
976 supports_config_push = supports_config_push,
977 "Registered agent with ConfigPusher"
978 );
979
980 *entry.capabilities.write().await = Some(caps);
981 }
982 }
983
984 *entry.connections.write().await = connections;
985 self.agents.insert(agent_id.clone(), entry);
986
987 info!(
988 agent_id = %agent_id,
989 connections = self.config.connections_per_agent,
990 "Agent added to pool"
991 );
992
993 Ok(())
994 }
995
996 pub async fn remove_agent(&self, agent_id: &str) -> Result<(), AgentProtocolError> {
1000 info!(agent_id = %agent_id, "Removing agent from pool");
1001
1002 self.config_pusher.unregister_agent(agent_id);
1004
1005 let (_, entry) = self.agents.remove(agent_id).ok_or_else(|| {
1006 AgentProtocolError::InvalidMessage(format!("Agent {} not found", agent_id))
1007 })?;
1008
1009 let connections = entry.connections.read().await;
1011 for conn in connections.iter() {
1012 let _ = conn.client.close().await;
1013 }
1014
1015 info!(agent_id = %agent_id, "Agent removed from pool");
1016 Ok(())
1017 }
1018
1019 pub async fn add_reverse_connection(
1025 &self,
1026 agent_id: &str,
1027 client: ReverseConnectionClient,
1028 capabilities: AgentCapabilities,
1029 ) -> Result<(), AgentProtocolError> {
1030 info!(
1031 agent_id = %agent_id,
1032 connection_id = %client.connection_id(),
1033 "Adding reverse connection to pool"
1034 );
1035
1036 let transport = V2Transport::Reverse(client);
1037 let conn = Arc::new(PooledConnection::new(
1038 transport,
1039 self.config.max_concurrent_per_connection,
1040 ));
1041
1042 if let Some(entry) = self.agents.get(agent_id) {
1044 let mut connections = entry.connections.write().await;
1046
1047 if connections.len() >= self.config.connections_per_agent {
1049 warn!(
1050 agent_id = %agent_id,
1051 current = connections.len(),
1052 max = self.config.connections_per_agent,
1053 "Reverse connection rejected: at connection limit"
1054 );
1055 return Err(AgentProtocolError::ConnectionFailed(format!(
1056 "Agent {} already has maximum connections ({})",
1057 agent_id, self.config.connections_per_agent
1058 )));
1059 }
1060
1061 connections.push(conn);
1062 info!(
1063 agent_id = %agent_id,
1064 total_connections = connections.len(),
1065 "Added reverse connection to existing agent"
1066 );
1067 } else {
1068 let entry = Arc::new(AgentEntry::new(
1070 agent_id.to_string(),
1071 format!("reverse://{}", agent_id),
1072 ));
1073
1074 let supports_config_push = capabilities.features.config_push;
1076 let agent_name = capabilities.name.clone();
1077 self.config_pusher
1078 .register_agent(agent_id, &agent_name, supports_config_push);
1079 debug!(
1080 agent_id = %agent_id,
1081 supports_config_push = supports_config_push,
1082 "Registered reverse connection agent with ConfigPusher"
1083 );
1084
1085 *entry.capabilities.write().await = Some(capabilities);
1086 *entry.connections.write().await = vec![conn];
1087 self.agents.insert(agent_id.to_string(), entry);
1088
1089 info!(
1090 agent_id = %agent_id,
1091 "Created new agent entry for reverse connection"
1092 );
1093 }
1094
1095 Ok(())
1096 }
1097
1098 async fn check_flow_control(
1104 &self,
1105 conn: &PooledConnection,
1106 agent_id: &str,
1107 ) -> Result<bool, AgentProtocolError> {
1108 if conn.client.can_accept_requests().await {
1109 return Ok(true);
1110 }
1111
1112 match self.config.flow_control_mode {
1113 FlowControlMode::FailClosed => {
1114 self.protocol_metrics.record_flow_rejection();
1115 Err(AgentProtocolError::FlowControlPaused {
1116 agent_id: agent_id.to_string(),
1117 })
1118 }
1119 FlowControlMode::FailOpen => {
1120 debug!(agent_id = %agent_id, "Flow control: agent paused, allowing request (fail-open mode)");
1122 self.protocol_metrics.record_flow_rejection();
1123 Ok(false) }
1125 FlowControlMode::WaitAndRetry => {
1126 let deadline = Instant::now() + self.config.flow_control_wait_timeout;
1128 while Instant::now() < deadline {
1129 tokio::time::sleep(Duration::from_millis(10)).await;
1130 if conn.client.can_accept_requests().await {
1131 trace!(agent_id = %agent_id, "Flow control: agent resumed after wait");
1132 return Ok(true);
1133 }
1134 }
1135 self.protocol_metrics.record_flow_rejection();
1137 Err(AgentProtocolError::FlowControlPaused {
1138 agent_id: agent_id.to_string(),
1139 })
1140 }
1141 }
1142 }
1143
1144 pub async fn send_request_headers(
1155 &self,
1156 agent_id: &str,
1157 correlation_id: &str,
1158 event: &RequestHeadersEvent,
1159 ) -> Result<AgentResponse, AgentProtocolError> {
1160 let start = Instant::now();
1161 self.total_requests.fetch_add(1, Ordering::Relaxed);
1162 self.protocol_metrics.inc_requests();
1163 self.protocol_metrics.inc_in_flight();
1164
1165 let conn = self.select_connection(agent_id)?;
1166
1167 match self.check_flow_control(&conn, agent_id).await {
1169 Ok(true) => {} Ok(false) => {
1171 self.protocol_metrics.dec_in_flight();
1173 return Ok(AgentResponse::default_allow());
1174 }
1175 Err(e) => {
1176 self.protocol_metrics.dec_in_flight();
1177 return Err(e);
1178 }
1179 }
1180
1181 let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
1183 self.protocol_metrics.dec_in_flight();
1184 self.protocol_metrics.inc_connection_errors();
1185 AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
1186 })?;
1187
1188 conn.in_flight.fetch_add(1, Ordering::Relaxed);
1189 conn.touch(); self.correlation_affinity
1193 .insert(correlation_id.to_string(), Arc::clone(&conn));
1194
1195 let result = conn
1196 .client
1197 .send_request_headers(correlation_id, event)
1198 .await;
1199
1200 conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1201 conn.request_count.fetch_add(1, Ordering::Relaxed);
1202 self.protocol_metrics.dec_in_flight();
1203 self.protocol_metrics
1204 .record_request_duration(start.elapsed());
1205
1206 match &result {
1207 Ok(_) => {
1208 conn.consecutive_errors.store(0, Ordering::Relaxed);
1209 self.protocol_metrics.inc_responses();
1210 }
1211 Err(e) => {
1212 conn.error_count.fetch_add(1, Ordering::Relaxed);
1213 let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1214 self.total_errors.fetch_add(1, Ordering::Relaxed);
1215
1216 match e {
1218 AgentProtocolError::Timeout(_) => self.protocol_metrics.inc_timeouts(),
1219 AgentProtocolError::ConnectionFailed(_)
1220 | AgentProtocolError::ConnectionClosed => {
1221 self.protocol_metrics.inc_connection_errors();
1222 }
1223 AgentProtocolError::Serialization(_) => {
1224 self.protocol_metrics.inc_serialization_errors();
1225 }
1226 _ => {}
1227 }
1228
1229 if consecutive >= 3 {
1231 conn.healthy_cached.store(false, Ordering::Release);
1232 trace!(agent_id = %agent_id, error = %e, "Connection marked unhealthy after consecutive errors");
1233 }
1234 }
1235 }
1236
1237 result
1238 }
1239
1240 pub async fn send_request_body_chunk(
1245 &self,
1246 agent_id: &str,
1247 correlation_id: &str,
1248 event: &RequestBodyChunkEvent,
1249 ) -> Result<AgentResponse, AgentProtocolError> {
1250 self.total_requests.fetch_add(1, Ordering::Relaxed);
1251
1252 let conn = if let Some(affinity_conn) = self.correlation_affinity.get(correlation_id) {
1254 Arc::clone(&affinity_conn)
1255 } else {
1256 trace!(correlation_id = %correlation_id, "No affinity found for body chunk, using selection");
1258 self.select_connection(agent_id)?
1259 };
1260
1261 match self.check_flow_control(&conn, agent_id).await {
1263 Ok(true) => {} Ok(false) => {
1265 return Ok(AgentResponse::default_allow());
1267 }
1268 Err(e) => return Err(e),
1269 }
1270
1271 let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
1272 AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
1273 })?;
1274
1275 conn.in_flight.fetch_add(1, Ordering::Relaxed);
1276 conn.touch();
1277
1278 let result = conn
1279 .client
1280 .send_request_body_chunk(correlation_id, event)
1281 .await;
1282
1283 conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1284 conn.request_count.fetch_add(1, Ordering::Relaxed);
1285
1286 match &result {
1287 Ok(_) => {
1288 conn.consecutive_errors.store(0, Ordering::Relaxed);
1289 }
1290 Err(_) => {
1291 conn.error_count.fetch_add(1, Ordering::Relaxed);
1292 let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1293 self.total_errors.fetch_add(1, Ordering::Relaxed);
1294 if consecutive >= 3 {
1295 conn.healthy_cached.store(false, Ordering::Release);
1296 }
1297 }
1298 }
1299
1300 result
1301 }
1302
1303 pub async fn send_response_headers(
1308 &self,
1309 agent_id: &str,
1310 correlation_id: &str,
1311 event: &ResponseHeadersEvent,
1312 ) -> Result<AgentResponse, AgentProtocolError> {
1313 self.total_requests.fetch_add(1, Ordering::Relaxed);
1314
1315 let conn = self.select_connection(agent_id)?;
1316
1317 let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
1318 AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
1319 })?;
1320
1321 conn.in_flight.fetch_add(1, Ordering::Relaxed);
1322 conn.touch();
1323
1324 let result = conn
1325 .client
1326 .send_response_headers(correlation_id, event)
1327 .await;
1328
1329 conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1330 conn.request_count.fetch_add(1, Ordering::Relaxed);
1331
1332 match &result {
1333 Ok(_) => {
1334 conn.consecutive_errors.store(0, Ordering::Relaxed);
1335 }
1336 Err(_) => {
1337 conn.error_count.fetch_add(1, Ordering::Relaxed);
1338 let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1339 self.total_errors.fetch_add(1, Ordering::Relaxed);
1340 if consecutive >= 3 {
1341 conn.healthy_cached.store(false, Ordering::Release);
1342 }
1343 }
1344 }
1345
1346 result
1347 }
1348
1349 pub async fn send_response_body_chunk(
1354 &self,
1355 agent_id: &str,
1356 correlation_id: &str,
1357 event: &ResponseBodyChunkEvent,
1358 ) -> Result<AgentResponse, AgentProtocolError> {
1359 self.total_requests.fetch_add(1, Ordering::Relaxed);
1360
1361 let conn = self.select_connection(agent_id)?;
1362
1363 match self.check_flow_control(&conn, agent_id).await {
1365 Ok(true) => {} Ok(false) => {
1367 return Ok(AgentResponse::default_allow());
1369 }
1370 Err(e) => return Err(e),
1371 }
1372
1373 let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
1374 AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
1375 })?;
1376
1377 conn.in_flight.fetch_add(1, Ordering::Relaxed);
1378 conn.touch();
1379
1380 let result = conn
1381 .client
1382 .send_response_body_chunk(correlation_id, event)
1383 .await;
1384
1385 conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1386 conn.request_count.fetch_add(1, Ordering::Relaxed);
1387
1388 match &result {
1389 Ok(_) => {
1390 conn.consecutive_errors.store(0, Ordering::Relaxed);
1391 }
1392 Err(_) => {
1393 conn.error_count.fetch_add(1, Ordering::Relaxed);
1394 let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1395 self.total_errors.fetch_add(1, Ordering::Relaxed);
1396 if consecutive >= 3 {
1397 conn.healthy_cached.store(false, Ordering::Release);
1398 }
1399 }
1400 }
1401
1402 result
1403 }
1404
1405 pub async fn send_guardrail_inspect(
1411 &self,
1412 agent_id: &str,
1413 correlation_id: &str,
1414 event: &GuardrailInspectEvent,
1415 ) -> Result<AgentResponse, AgentProtocolError> {
1416 self.total_requests.fetch_add(1, Ordering::Relaxed);
1417 self.protocol_metrics.inc_requests();
1418 self.protocol_metrics.inc_in_flight();
1419
1420 let conn = self.select_connection(agent_id)?;
1421
1422 match self.check_flow_control(&conn, agent_id).await {
1423 Ok(true) => {}
1424 Ok(false) => {
1425 self.protocol_metrics.dec_in_flight();
1426 return Ok(AgentResponse::default_allow());
1427 }
1428 Err(e) => {
1429 self.protocol_metrics.dec_in_flight();
1430 return Err(e);
1431 }
1432 }
1433
1434 let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
1435 self.protocol_metrics.dec_in_flight();
1436 self.protocol_metrics.inc_connection_errors();
1437 AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
1438 })?;
1439
1440 conn.in_flight.fetch_add(1, Ordering::Relaxed);
1441 conn.touch();
1442
1443 let result = conn
1444 .client
1445 .send_guardrail_inspect(correlation_id, event)
1446 .await;
1447
1448 conn.in_flight.fetch_sub(1, Ordering::Relaxed);
1449 conn.request_count.fetch_add(1, Ordering::Relaxed);
1450 self.protocol_metrics.dec_in_flight();
1451
1452 match &result {
1453 Ok(_) => {
1454 conn.consecutive_errors.store(0, Ordering::Relaxed);
1455 self.protocol_metrics.inc_responses();
1456 }
1457 Err(e) => {
1458 conn.error_count.fetch_add(1, Ordering::Relaxed);
1459 let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
1460 self.total_errors.fetch_add(1, Ordering::Relaxed);
1461
1462 match e {
1463 AgentProtocolError::Timeout(_) => self.protocol_metrics.inc_timeouts(),
1464 AgentProtocolError::ConnectionFailed(_)
1465 | AgentProtocolError::ConnectionClosed => {
1466 self.protocol_metrics.inc_connection_errors();
1467 }
1468 AgentProtocolError::Serialization(_) => {
1469 self.protocol_metrics.inc_serialization_errors();
1470 }
1471 _ => {}
1472 }
1473
1474 if consecutive >= 3 {
1475 conn.healthy_cached.store(false, Ordering::Release);
1476 }
1477 }
1478 }
1479
1480 result
1481 }
1482
1483 pub async fn cancel_request(
1485 &self,
1486 agent_id: &str,
1487 correlation_id: &str,
1488 reason: CancelReason,
1489 ) -> Result<(), AgentProtocolError> {
1490 let entry = self.agents.get(agent_id).ok_or_else(|| {
1491 AgentProtocolError::InvalidMessage(format!("Agent {} not found", agent_id))
1492 })?;
1493
1494 let connections = entry.connections.read().await;
1495 for conn in connections.iter() {
1496 let _ = conn.client.cancel_request(correlation_id, reason).await;
1497 }
1498
1499 Ok(())
1500 }
1501
1502 pub async fn stats(&self) -> Vec<AgentPoolStats> {
1504 let mut stats = Vec::with_capacity(self.agents.len());
1505
1506 for entry_ref in self.agents.iter() {
1507 let agent_id = entry_ref.key().clone();
1508 let entry = entry_ref.value();
1509
1510 let connections = entry.connections.read().await;
1511 let mut healthy_count = 0;
1512 let mut total_in_flight = 0;
1513 let mut total_requests = 0;
1514 let mut total_errors = 0;
1515
1516 for conn in connections.iter() {
1517 if conn.is_healthy_cached() {
1519 healthy_count += 1;
1520 }
1521 total_in_flight += conn.in_flight();
1522 total_requests += conn.request_count.load(Ordering::Relaxed);
1523 total_errors += conn.error_count.load(Ordering::Relaxed);
1524 }
1525
1526 let error_rate = if total_requests == 0 {
1527 0.0
1528 } else {
1529 total_errors as f64 / total_requests as f64
1530 };
1531
1532 stats.push(AgentPoolStats {
1533 agent_id,
1534 active_connections: connections.len(),
1535 healthy_connections: healthy_count,
1536 total_in_flight,
1537 total_requests,
1538 total_errors,
1539 error_rate,
1540 is_healthy: entry.healthy.load(Ordering::Acquire),
1541 });
1542 }
1543
1544 stats
1545 }
1546
1547 pub async fn agent_stats(&self, agent_id: &str) -> Option<AgentPoolStats> {
1549 self.stats()
1550 .await
1551 .into_iter()
1552 .find(|s| s.agent_id == agent_id)
1553 }
1554
1555 pub async fn agent_capabilities(&self, agent_id: &str) -> Option<AgentCapabilities> {
1557 let entry = match self.agents.get(agent_id) {
1559 Some(entry_ref) => Arc::clone(&*entry_ref),
1560 None => return None,
1561 };
1562 let result = entry.capabilities.read().await.clone();
1564 result
1565 }
1566
1567 pub fn is_agent_healthy(&self, agent_id: &str) -> bool {
1571 self.agents
1572 .get(agent_id)
1573 .map(|e| e.healthy.load(Ordering::Acquire))
1574 .unwrap_or(false)
1575 }
1576
1577 pub fn agent_ids(&self) -> Vec<String> {
1579 self.agents.iter().map(|e| e.key().clone()).collect()
1580 }
1581
1582 pub async fn shutdown(&self) -> Result<(), AgentProtocolError> {
1586 info!("Shutting down agent pool");
1587
1588 let agent_ids: Vec<String> = self.agents.iter().map(|e| e.key().clone()).collect();
1590
1591 for agent_id in agent_ids {
1592 if let Some((_, entry)) = self.agents.remove(&agent_id) {
1593 debug!(agent_id = %agent_id, "Draining agent connections");
1594
1595 let connections = entry.connections.read().await;
1596 for conn in connections.iter() {
1597 let _ = conn.client.cancel_all(CancelReason::ProxyShutdown).await;
1599 }
1600
1601 let drain_deadline = Instant::now() + self.config.drain_timeout;
1603 loop {
1604 let total_in_flight: u64 = connections.iter().map(|c| c.in_flight()).sum();
1605 if total_in_flight == 0 {
1606 break;
1607 }
1608 if Instant::now() > drain_deadline {
1609 warn!(
1610 agent_id = %agent_id,
1611 in_flight = total_in_flight,
1612 "Drain timeout, forcing close"
1613 );
1614 break;
1615 }
1616 tokio::time::sleep(Duration::from_millis(100)).await;
1617 }
1618
1619 for conn in connections.iter() {
1621 let _ = conn.client.close().await;
1622 }
1623 }
1624 }
1625
1626 info!("Agent pool shutdown complete");
1627 Ok(())
1628 }
1629
1630 pub async fn run_maintenance(&self) {
1642 let mut interval = tokio::time::interval(self.config.health_check_interval);
1643
1644 loop {
1645 interval.tick().await;
1646
1647 self.cleanup_expired_sessions();
1649
1650 let agent_ids: Vec<String> = self.agents.iter().map(|e| e.key().clone()).collect();
1652
1653 for agent_id in agent_ids {
1654 let Some(entry_ref) = self.agents.get(&agent_id) else {
1655 continue; };
1657 let entry = entry_ref.value().clone();
1658 drop(entry_ref); let connections = entry.connections.read().await;
1662 let mut healthy_count = 0;
1663
1664 for conn in connections.iter() {
1665 if conn.check_and_update_health().await {
1667 healthy_count += 1;
1668 }
1669 }
1670
1671 let was_healthy = entry.healthy.load(Ordering::Acquire);
1673 let is_healthy = healthy_count > 0;
1674 entry.healthy.store(is_healthy, Ordering::Release);
1675
1676 if was_healthy && !is_healthy {
1677 warn!(agent_id = %agent_id, "Agent marked unhealthy");
1678 } else if !was_healthy && is_healthy {
1679 info!(agent_id = %agent_id, "Agent recovered");
1680 }
1681
1682 if healthy_count < self.config.connections_per_agent
1684 && entry.should_reconnect(self.config.reconnect_interval)
1685 {
1686 drop(connections); if let Err(e) = self.reconnect_agent(&agent_id, &entry).await {
1688 trace!(agent_id = %agent_id, error = %e, "Reconnect failed");
1689 }
1690 }
1691 }
1692 }
1693 }
1694
1695 async fn create_connection(
1700 &self,
1701 agent_id: &str,
1702 endpoint: &str,
1703 ) -> Result<PooledConnection, AgentProtocolError> {
1704 let transport = if is_uds_endpoint(endpoint) {
1706 let socket_path = endpoint.strip_prefix("unix:").unwrap_or(endpoint);
1708
1709 let mut client =
1710 AgentClientV2Uds::new(agent_id, socket_path, self.config.request_timeout).await?;
1711
1712 client.set_metrics_callback(Arc::clone(&self.metrics_callback));
1714 client.set_config_update_callback(Arc::clone(&self.config_update_callback));
1715
1716 client.connect().await?;
1717 V2Transport::Uds(client)
1718 } else {
1719 let mut client =
1721 AgentClientV2::new(agent_id, endpoint, self.config.request_timeout).await?;
1722
1723 client.set_metrics_callback(Arc::clone(&self.metrics_callback));
1725 client.set_config_update_callback(Arc::clone(&self.config_update_callback));
1726
1727 client.connect().await?;
1728 V2Transport::Grpc(client)
1729 };
1730
1731 Ok(PooledConnection::new(
1732 transport,
1733 self.config.max_concurrent_per_connection,
1734 ))
1735 }
1736
1737 fn select_connection(
1751 &self,
1752 agent_id: &str,
1753 ) -> Result<Arc<PooledConnection>, AgentProtocolError> {
1754 let entry = self.agents.get(agent_id).ok_or_else(|| {
1755 AgentProtocolError::InvalidMessage(format!("Agent {} not found", agent_id))
1756 })?;
1757
1758 let connections_guard = match entry.connections.try_read() {
1760 Ok(guard) => guard,
1761 Err(_) => {
1762 trace!(agent_id = %agent_id, "select_connection: blocking on connections lock");
1764 futures::executor::block_on(entry.connections.read())
1765 }
1766 };
1767
1768 if connections_guard.is_empty() {
1769 return Err(AgentProtocolError::ConnectionFailed(format!(
1770 "No connections available for agent {}",
1771 agent_id
1772 )));
1773 }
1774
1775 let healthy: Vec<_> = connections_guard
1777 .iter()
1778 .filter(|c| c.is_healthy_cached())
1779 .cloned()
1780 .collect();
1781
1782 if healthy.is_empty() {
1783 return Err(AgentProtocolError::ConnectionFailed(format!(
1784 "No healthy connections for agent {}",
1785 agent_id
1786 )));
1787 }
1788
1789 let selected = match self.config.load_balance_strategy {
1790 LoadBalanceStrategy::RoundRobin => {
1791 let idx = entry.round_robin_index.fetch_add(1, Ordering::Relaxed);
1792 healthy[idx % healthy.len()].clone()
1793 }
1794 LoadBalanceStrategy::LeastConnections => healthy
1795 .iter()
1796 .min_by_key(|c| c.in_flight())
1797 .cloned()
1798 .unwrap(),
1799 LoadBalanceStrategy::HealthBased => {
1800 healthy
1802 .iter()
1803 .min_by(|a, b| {
1804 a.error_rate()
1805 .partial_cmp(&b.error_rate())
1806 .unwrap_or(std::cmp::Ordering::Equal)
1807 })
1808 .cloned()
1809 .unwrap()
1810 }
1811 LoadBalanceStrategy::Random => {
1812 use std::collections::hash_map::RandomState;
1813 use std::hash::{BuildHasher, Hasher};
1814 let idx = RandomState::new().build_hasher().finish() as usize % healthy.len();
1815 healthy[idx].clone()
1816 }
1817 };
1818
1819 Ok(selected)
1820 }
1821
1822 async fn reconnect_agent(
1823 &self,
1824 agent_id: &str,
1825 entry: &AgentEntry,
1826 ) -> Result<(), AgentProtocolError> {
1827 entry.mark_reconnect_attempt();
1828 let attempts = entry.reconnect_attempts.fetch_add(1, Ordering::Relaxed);
1829
1830 if attempts >= self.config.max_reconnect_attempts {
1831 debug!(
1832 agent_id = %agent_id,
1833 attempts = attempts,
1834 "Max reconnect attempts reached"
1835 );
1836 return Ok(());
1837 }
1838
1839 debug!(agent_id = %agent_id, attempt = attempts + 1, "Attempting reconnect");
1840
1841 match self.create_connection(agent_id, &entry.endpoint).await {
1842 Ok(conn) => {
1843 let mut connections = entry.connections.write().await;
1844 connections.push(Arc::new(conn));
1845 entry.reconnect_attempts.store(0, Ordering::Relaxed);
1846 info!(agent_id = %agent_id, "Reconnected successfully");
1847 Ok(())
1848 }
1849 Err(e) => {
1850 debug!(agent_id = %agent_id, error = %e, "Reconnect failed");
1851 Err(e)
1852 }
1853 }
1854 }
1855}
1856
1857impl Default for AgentPool {
1858 fn default() -> Self {
1859 Self::new()
1860 }
1861}
1862
1863impl std::fmt::Debug for AgentPool {
1864 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1865 f.debug_struct("AgentPool")
1866 .field("config", &self.config)
1867 .field(
1868 "total_requests",
1869 &self.total_requests.load(Ordering::Relaxed),
1870 )
1871 .field("total_errors", &self.total_errors.load(Ordering::Relaxed))
1872 .finish()
1873 }
1874}
1875
1876fn is_uds_endpoint(endpoint: &str) -> bool {
1883 endpoint.starts_with("unix:") || endpoint.starts_with('/') || endpoint.ends_with(".sock")
1884}
1885
1886#[cfg(test)]
1887mod tests {
1888 use super::*;
1889
1890 #[test]
1891 fn test_pool_config_default() {
1892 let config = AgentPoolConfig::default();
1893 assert_eq!(config.connections_per_agent, 4);
1894 assert_eq!(
1895 config.load_balance_strategy,
1896 LoadBalanceStrategy::RoundRobin
1897 );
1898 }
1899
1900 #[test]
1901 fn test_load_balance_strategy() {
1902 assert_eq!(
1903 LoadBalanceStrategy::default(),
1904 LoadBalanceStrategy::RoundRobin
1905 );
1906 }
1907
1908 #[test]
1909 fn test_pool_creation() {
1910 let pool = AgentPool::new();
1911 assert_eq!(pool.total_requests.load(Ordering::Relaxed), 0);
1912 assert_eq!(pool.total_errors.load(Ordering::Relaxed), 0);
1913 }
1914
1915 #[test]
1916 fn test_pool_with_config() {
1917 let config = AgentPoolConfig {
1918 connections_per_agent: 8,
1919 load_balance_strategy: LoadBalanceStrategy::LeastConnections,
1920 ..Default::default()
1921 };
1922 let pool = AgentPool::with_config(config.clone());
1923 assert_eq!(pool.config.connections_per_agent, 8);
1924 }
1925
1926 #[test]
1927 fn test_agent_ids_empty() {
1928 let pool = AgentPool::new();
1929 assert!(pool.agent_ids().is_empty());
1930 }
1931
1932 #[test]
1933 fn test_is_agent_healthy_not_found() {
1934 let pool = AgentPool::new();
1935 assert!(!pool.is_agent_healthy("nonexistent"));
1936 }
1937
1938 #[tokio::test]
1939 async fn test_stats_empty() {
1940 let pool = AgentPool::new();
1941 assert!(pool.stats().await.is_empty());
1942 }
1943
1944 #[test]
1945 fn test_is_uds_endpoint() {
1946 assert!(is_uds_endpoint("unix:/var/run/agent.sock"));
1948 assert!(is_uds_endpoint("unix:agent.sock"));
1949
1950 assert!(is_uds_endpoint("/var/run/agent.sock"));
1952 assert!(is_uds_endpoint("/tmp/test.sock"));
1953
1954 assert!(is_uds_endpoint("agent.sock"));
1956
1957 assert!(!is_uds_endpoint("http://localhost:8080"));
1959 assert!(!is_uds_endpoint("localhost:50051"));
1960 assert!(!is_uds_endpoint("127.0.0.1:8080"));
1961 }
1962
1963 #[test]
1964 fn test_flow_control_mode_default() {
1965 assert_eq!(FlowControlMode::default(), FlowControlMode::FailClosed);
1966 }
1967
1968 #[test]
1969 fn test_pool_config_flow_control_defaults() {
1970 let config = AgentPoolConfig::default();
1971 assert_eq!(config.channel_buffer_size, CHANNEL_BUFFER_SIZE);
1972 assert_eq!(config.flow_control_mode, FlowControlMode::FailClosed);
1973 assert_eq!(config.flow_control_wait_timeout, Duration::from_millis(100));
1974 }
1975
1976 #[test]
1977 fn test_pool_config_custom_flow_control() {
1978 let config = AgentPoolConfig {
1979 channel_buffer_size: 128,
1980 flow_control_mode: FlowControlMode::FailOpen,
1981 flow_control_wait_timeout: Duration::from_millis(500),
1982 ..Default::default()
1983 };
1984 assert_eq!(config.channel_buffer_size, 128);
1985 assert_eq!(config.flow_control_mode, FlowControlMode::FailOpen);
1986 assert_eq!(config.flow_control_wait_timeout, Duration::from_millis(500));
1987 }
1988
1989 #[test]
1990 fn test_pool_config_wait_and_retry() {
1991 let config = AgentPoolConfig {
1992 flow_control_mode: FlowControlMode::WaitAndRetry,
1993 flow_control_wait_timeout: Duration::from_millis(250),
1994 ..Default::default()
1995 };
1996 assert_eq!(config.flow_control_mode, FlowControlMode::WaitAndRetry);
1997 assert_eq!(config.flow_control_wait_timeout, Duration::from_millis(250));
1998 }
1999
2000 #[test]
2001 fn test_pool_config_sticky_session_default() {
2002 let config = AgentPoolConfig::default();
2003 assert_eq!(
2004 config.sticky_session_timeout,
2005 Some(Duration::from_secs(5 * 60))
2006 );
2007 }
2008
2009 #[test]
2010 fn test_pool_config_sticky_session_custom() {
2011 let config = AgentPoolConfig {
2012 sticky_session_timeout: Some(Duration::from_secs(60)),
2013 ..Default::default()
2014 };
2015 assert_eq!(config.sticky_session_timeout, Some(Duration::from_secs(60)));
2016 }
2017
2018 #[test]
2019 fn test_pool_config_sticky_session_disabled() {
2020 let config = AgentPoolConfig {
2021 sticky_session_timeout: None,
2022 ..Default::default()
2023 };
2024 assert!(config.sticky_session_timeout.is_none());
2025 }
2026
2027 #[test]
2028 fn test_sticky_session_count_empty() {
2029 let pool = AgentPool::new();
2030 assert_eq!(pool.sticky_session_count(), 0);
2031 }
2032
2033 #[test]
2034 fn test_sticky_session_has_nonexistent() {
2035 let pool = AgentPool::new();
2036 assert!(!pool.has_sticky_session("nonexistent"));
2037 }
2038
2039 #[test]
2040 fn test_sticky_session_clear_nonexistent() {
2041 let pool = AgentPool::new();
2042 pool.clear_sticky_session("nonexistent");
2044 }
2045
2046 #[test]
2047 fn test_cleanup_expired_sessions_empty() {
2048 let pool = AgentPool::new();
2049 let removed = pool.cleanup_expired_sessions();
2050 assert_eq!(removed, 0);
2051 }
2052}