use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use tokio::sync::{RwLock, Semaphore};
use tracing::{debug, info, trace, warn};
use crate::v2::client::{AgentClientV2, CancelReason, ConfigUpdateCallback, MetricsCallback};
use crate::v2::control::ConfigUpdateType;
use crate::v2::observability::{ConfigPusher, ConfigUpdateHandler, MetricsCollector};
use crate::v2::protocol_metrics::ProtocolMetrics;
use crate::v2::reverse::ReverseConnectionClient;
use crate::v2::uds::AgentClientV2Uds;
use crate::v2::AgentCapabilities;
use crate::{
AgentProtocolError, AgentResponse, GuardrailInspectEvent, RequestBodyChunkEvent,
RequestHeadersEvent, ResponseBodyChunkEvent, ResponseHeadersEvent,
};
pub const CHANNEL_BUFFER_SIZE: usize = 64;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum LoadBalanceStrategy {
#[default]
RoundRobin,
LeastConnections,
HealthBased,
Random,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum FlowControlMode {
#[default]
FailClosed,
FailOpen,
WaitAndRetry,
}
struct StickySession {
connection: Arc<PooledConnection>,
agent_id: String,
created_at: Instant,
last_accessed: AtomicU64,
}
impl std::fmt::Debug for StickySession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StickySession")
.field("agent_id", &self.agent_id)
.field("created_at", &self.created_at)
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone)]
pub struct AgentPoolConfig {
pub connections_per_agent: usize,
pub load_balance_strategy: LoadBalanceStrategy,
pub connect_timeout: Duration,
pub request_timeout: Duration,
pub reconnect_interval: Duration,
pub max_reconnect_attempts: usize,
pub drain_timeout: Duration,
pub max_concurrent_per_connection: usize,
pub health_check_interval: Duration,
pub channel_buffer_size: usize,
pub flow_control_mode: FlowControlMode,
pub flow_control_wait_timeout: Duration,
pub sticky_session_timeout: Option<Duration>,
}
impl Default for AgentPoolConfig {
fn default() -> Self {
Self {
connections_per_agent: 4,
load_balance_strategy: LoadBalanceStrategy::RoundRobin,
connect_timeout: Duration::from_secs(5),
request_timeout: Duration::from_secs(30),
reconnect_interval: Duration::from_secs(5),
max_reconnect_attempts: 3,
drain_timeout: Duration::from_secs(30),
max_concurrent_per_connection: 100,
health_check_interval: Duration::from_secs(10),
channel_buffer_size: CHANNEL_BUFFER_SIZE,
flow_control_mode: FlowControlMode::FailClosed,
flow_control_wait_timeout: Duration::from_millis(100),
sticky_session_timeout: Some(Duration::from_secs(5 * 60)), }
}
}
impl StickySession {
fn new(agent_id: String, connection: Arc<PooledConnection>) -> Self {
Self {
connection,
agent_id,
created_at: Instant::now(),
last_accessed: AtomicU64::new(0),
}
}
fn touch(&self) {
let offset = self.created_at.elapsed().as_millis() as u64;
self.last_accessed.store(offset, Ordering::Relaxed);
}
fn last_accessed(&self) -> Instant {
let offset_ms = self.last_accessed.load(Ordering::Relaxed);
self.created_at + Duration::from_millis(offset_ms)
}
fn is_expired(&self, timeout: Duration) -> bool {
self.last_accessed().elapsed() > timeout
}
}
pub enum V2Transport {
Grpc(AgentClientV2),
Uds(AgentClientV2Uds),
Reverse(ReverseConnectionClient),
}
impl V2Transport {
pub async fn is_connected(&self) -> bool {
match self {
V2Transport::Grpc(client) => client.is_connected().await,
V2Transport::Uds(client) => client.is_connected().await,
V2Transport::Reverse(client) => client.is_connected().await,
}
}
pub async fn can_accept_requests(&self) -> bool {
match self {
V2Transport::Grpc(client) => client.can_accept_requests().await,
V2Transport::Uds(client) => client.can_accept_requests().await,
V2Transport::Reverse(client) => client.can_accept_requests().await,
}
}
pub async fn capabilities(&self) -> Option<AgentCapabilities> {
match self {
V2Transport::Grpc(client) => client.capabilities().await,
V2Transport::Uds(client) => client.capabilities().await,
V2Transport::Reverse(client) => client.capabilities().await,
}
}
pub async fn send_request_headers(
&self,
correlation_id: &str,
event: &RequestHeadersEvent,
) -> Result<AgentResponse, AgentProtocolError> {
match self {
V2Transport::Grpc(client) => client.send_request_headers(correlation_id, event).await,
V2Transport::Uds(client) => client.send_request_headers(correlation_id, event).await,
V2Transport::Reverse(client) => {
client.send_request_headers(correlation_id, event).await
}
}
}
pub async fn send_request_body_chunk(
&self,
correlation_id: &str,
event: &RequestBodyChunkEvent,
) -> Result<AgentResponse, AgentProtocolError> {
match self {
V2Transport::Grpc(client) => {
client.send_request_body_chunk(correlation_id, event).await
}
V2Transport::Uds(client) => client.send_request_body_chunk(correlation_id, event).await,
V2Transport::Reverse(client) => {
client.send_request_body_chunk(correlation_id, event).await
}
}
}
pub async fn send_response_headers(
&self,
correlation_id: &str,
event: &ResponseHeadersEvent,
) -> Result<AgentResponse, AgentProtocolError> {
match self {
V2Transport::Grpc(client) => client.send_response_headers(correlation_id, event).await,
V2Transport::Uds(client) => client.send_response_headers(correlation_id, event).await,
V2Transport::Reverse(client) => {
client.send_response_headers(correlation_id, event).await
}
}
}
pub async fn send_response_body_chunk(
&self,
correlation_id: &str,
event: &ResponseBodyChunkEvent,
) -> Result<AgentResponse, AgentProtocolError> {
match self {
V2Transport::Grpc(client) => {
client.send_response_body_chunk(correlation_id, event).await
}
V2Transport::Uds(client) => {
client.send_response_body_chunk(correlation_id, event).await
}
V2Transport::Reverse(client) => {
client.send_response_body_chunk(correlation_id, event).await
}
}
}
pub async fn send_guardrail_inspect(
&self,
correlation_id: &str,
event: &GuardrailInspectEvent,
) -> Result<AgentResponse, AgentProtocolError> {
match self {
V2Transport::Grpc(_client) => Err(AgentProtocolError::InvalidMessage(
"GuardrailInspect events are not yet supported via gRPC".to_string(),
)),
V2Transport::Uds(client) => client.send_guardrail_inspect(correlation_id, event).await,
V2Transport::Reverse(_client) => Err(AgentProtocolError::InvalidMessage(
"GuardrailInspect events are not yet supported via reverse connections".to_string(),
)),
}
}
pub async fn cancel_request(
&self,
correlation_id: &str,
reason: CancelReason,
) -> Result<(), AgentProtocolError> {
match self {
V2Transport::Grpc(client) => client.cancel_request(correlation_id, reason).await,
V2Transport::Uds(client) => client.cancel_request(correlation_id, reason).await,
V2Transport::Reverse(client) => client.cancel_request(correlation_id, reason).await,
}
}
pub async fn cancel_all(&self, reason: CancelReason) -> Result<usize, AgentProtocolError> {
match self {
V2Transport::Grpc(client) => client.cancel_all(reason).await,
V2Transport::Uds(client) => client.cancel_all(reason).await,
V2Transport::Reverse(client) => client.cancel_all(reason).await,
}
}
pub async fn close(&self) -> Result<(), AgentProtocolError> {
match self {
V2Transport::Grpc(client) => client.close().await,
V2Transport::Uds(client) => client.close().await,
V2Transport::Reverse(client) => client.close().await,
}
}
pub fn agent_id(&self) -> &str {
match self {
V2Transport::Grpc(client) => client.agent_id(),
V2Transport::Uds(client) => client.agent_id(),
V2Transport::Reverse(client) => client.agent_id(),
}
}
}
struct PooledConnection {
client: V2Transport,
created_at: Instant,
last_used_offset_ms: AtomicU64,
in_flight: AtomicU64,
request_count: AtomicU64,
error_count: AtomicU64,
consecutive_errors: AtomicU64,
concurrency_limiter: Semaphore,
healthy_cached: AtomicBool,
}
impl PooledConnection {
fn new(client: V2Transport, max_concurrent: usize) -> Self {
Self {
client,
created_at: Instant::now(),
last_used_offset_ms: AtomicU64::new(0),
in_flight: AtomicU64::new(0),
request_count: AtomicU64::new(0),
error_count: AtomicU64::new(0),
consecutive_errors: AtomicU64::new(0),
concurrency_limiter: Semaphore::new(max_concurrent),
healthy_cached: AtomicBool::new(true), }
}
fn in_flight(&self) -> u64 {
self.in_flight.load(Ordering::Relaxed)
}
fn error_rate(&self) -> f64 {
let requests = self.request_count.load(Ordering::Relaxed);
let errors = self.error_count.load(Ordering::Relaxed);
if requests == 0 {
0.0
} else {
errors as f64 / requests as f64
}
}
#[inline]
fn is_healthy_cached(&self) -> bool {
self.healthy_cached.load(Ordering::Acquire)
}
async fn check_and_update_health(&self) -> bool {
let connected = self.client.is_connected().await;
let low_errors = self.consecutive_errors.load(Ordering::Relaxed) < 3;
let can_accept = self.client.can_accept_requests().await;
let healthy = connected && low_errors && can_accept;
self.healthy_cached.store(healthy, Ordering::Release);
healthy
}
#[inline]
fn touch(&self) {
let offset = self.created_at.elapsed().as_millis() as u64;
self.last_used_offset_ms.store(offset, Ordering::Relaxed);
}
fn last_used(&self) -> Instant {
let offset_ms = self.last_used_offset_ms.load(Ordering::Relaxed);
self.created_at + Duration::from_millis(offset_ms)
}
}
#[derive(Debug, Clone)]
pub struct AgentPoolStats {
pub agent_id: String,
pub active_connections: usize,
pub healthy_connections: usize,
pub total_in_flight: u64,
pub total_requests: u64,
pub total_errors: u64,
pub error_rate: f64,
pub is_healthy: bool,
}
struct AgentEntry {
agent_id: String,
endpoint: String,
connections: RwLock<Vec<Arc<PooledConnection>>>,
capabilities: RwLock<Option<AgentCapabilities>>,
round_robin_index: AtomicUsize,
reconnect_attempts: AtomicUsize,
last_reconnect_attempt_ms: AtomicU64,
healthy: AtomicBool,
}
impl AgentEntry {
fn new(agent_id: String, endpoint: String) -> Self {
Self {
agent_id,
endpoint,
connections: RwLock::new(Vec::new()),
capabilities: RwLock::new(None),
round_robin_index: AtomicUsize::new(0),
reconnect_attempts: AtomicUsize::new(0),
last_reconnect_attempt_ms: AtomicU64::new(0),
healthy: AtomicBool::new(true),
}
}
fn should_reconnect(&self, interval: Duration) -> bool {
let last_ms = self.last_reconnect_attempt_ms.load(Ordering::Relaxed);
if last_ms == 0 {
return true;
}
let now_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
now_ms.saturating_sub(last_ms) > interval.as_millis() as u64
}
fn mark_reconnect_attempt(&self) {
let now_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
self.last_reconnect_attempt_ms
.store(now_ms, Ordering::Relaxed);
}
}
pub struct AgentPool {
config: AgentPoolConfig,
agents: DashMap<String, Arc<AgentEntry>>,
total_requests: AtomicU64,
total_errors: AtomicU64,
metrics_collector: Arc<MetricsCollector>,
metrics_callback: MetricsCallback,
config_pusher: Arc<ConfigPusher>,
config_update_handler: Arc<ConfigUpdateHandler>,
config_update_callback: ConfigUpdateCallback,
protocol_metrics: Arc<ProtocolMetrics>,
correlation_affinity: DashMap<String, Arc<PooledConnection>>,
sticky_sessions: DashMap<String, StickySession>,
}
impl AgentPool {
pub fn new() -> Self {
Self::with_config(AgentPoolConfig::default())
}
pub fn with_config(config: AgentPoolConfig) -> Self {
let metrics_collector = Arc::new(MetricsCollector::new());
let collector_clone = Arc::clone(&metrics_collector);
let metrics_callback: MetricsCallback = Arc::new(move |report| {
collector_clone.record(&report);
});
let config_pusher = Arc::new(ConfigPusher::new());
let config_update_handler = Arc::new(ConfigUpdateHandler::new());
let handler_clone = Arc::clone(&config_update_handler);
let config_update_callback: ConfigUpdateCallback = Arc::new(move |agent_id, request| {
debug!(
agent_id = %agent_id,
request_id = %request.request_id,
"Processing config update request from agent"
);
handler_clone.handle(request)
});
Self {
config,
agents: DashMap::new(),
total_requests: AtomicU64::new(0),
total_errors: AtomicU64::new(0),
metrics_collector,
metrics_callback,
config_pusher,
config_update_handler,
config_update_callback,
protocol_metrics: Arc::new(ProtocolMetrics::new()),
correlation_affinity: DashMap::new(),
sticky_sessions: DashMap::new(),
}
}
pub fn protocol_metrics(&self) -> &ProtocolMetrics {
&self.protocol_metrics
}
pub fn protocol_metrics_arc(&self) -> Arc<ProtocolMetrics> {
Arc::clone(&self.protocol_metrics)
}
pub fn metrics_collector(&self) -> &MetricsCollector {
&self.metrics_collector
}
pub fn metrics_collector_arc(&self) -> Arc<MetricsCollector> {
Arc::clone(&self.metrics_collector)
}
pub fn export_prometheus(&self) -> String {
self.metrics_collector.export_prometheus()
}
pub fn clear_correlation_affinity(&self, correlation_id: &str) {
self.correlation_affinity.remove(correlation_id);
}
pub fn correlation_affinity_count(&self) -> usize {
self.correlation_affinity.len()
}
pub fn create_sticky_session(
&self,
session_id: impl Into<String>,
agent_id: &str,
) -> Result<(), AgentProtocolError> {
let session_id = session_id.into();
let conn = self.select_connection(agent_id)?;
let session = StickySession::new(agent_id.to_string(), conn);
session.touch();
self.sticky_sessions.insert(session_id.clone(), session);
debug!(
session_id = %session_id,
agent_id = %agent_id,
"Created sticky session"
);
Ok(())
}
fn get_sticky_session_conn(&self, session_id: &str) -> Option<Arc<PooledConnection>> {
let entry = self.sticky_sessions.get(session_id)?;
if let Some(timeout) = self.config.sticky_session_timeout {
if entry.is_expired(timeout) {
drop(entry); self.sticky_sessions.remove(session_id);
debug!(session_id = %session_id, "Sticky session expired");
return None;
}
}
entry.touch();
Some(Arc::clone(&entry.connection))
}
pub fn refresh_sticky_session(&self, session_id: &str) -> bool {
self.get_sticky_session_conn(session_id).is_some()
}
pub fn has_sticky_session(&self, session_id: &str) -> bool {
self.get_sticky_session_conn(session_id).is_some()
}
pub fn clear_sticky_session(&self, session_id: &str) {
if self.sticky_sessions.remove(session_id).is_some() {
debug!(session_id = %session_id, "Cleared sticky session");
}
}
pub fn sticky_session_count(&self) -> usize {
self.sticky_sessions.len()
}
pub fn sticky_session_agent(&self, session_id: &str) -> Option<String> {
self.sticky_sessions
.get(session_id)
.map(|s| s.agent_id.clone())
}
pub async fn send_request_headers_with_sticky_session(
&self,
session_id: &str,
agent_id: &str,
correlation_id: &str,
event: &RequestHeadersEvent,
) -> Result<(AgentResponse, bool), AgentProtocolError> {
let start = Instant::now();
self.total_requests.fetch_add(1, Ordering::Relaxed);
self.protocol_metrics.inc_requests();
self.protocol_metrics.inc_in_flight();
let (conn, used_sticky) =
if let Some(sticky_conn) = self.get_sticky_session_conn(session_id) {
(sticky_conn, true)
} else {
(self.select_connection(agent_id)?, false)
};
match self.check_flow_control(&conn, agent_id).await {
Ok(true) => {}
Ok(false) => {
self.protocol_metrics.dec_in_flight();
return Ok((AgentResponse::default_allow(), used_sticky));
}
Err(e) => {
self.protocol_metrics.dec_in_flight();
return Err(e);
}
}
let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
self.protocol_metrics.dec_in_flight();
self.protocol_metrics.inc_connection_errors();
AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
})?;
conn.in_flight.fetch_add(1, Ordering::Relaxed);
conn.touch();
self.correlation_affinity
.insert(correlation_id.to_string(), Arc::clone(&conn));
let result = conn
.client
.send_request_headers(correlation_id, event)
.await;
conn.in_flight.fetch_sub(1, Ordering::Relaxed);
conn.request_count.fetch_add(1, Ordering::Relaxed);
self.protocol_metrics.dec_in_flight();
self.protocol_metrics
.record_request_duration(start.elapsed());
match &result {
Ok(_) => {
conn.consecutive_errors.store(0, Ordering::Relaxed);
self.protocol_metrics.inc_responses();
}
Err(e) => {
conn.error_count.fetch_add(1, Ordering::Relaxed);
let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
self.total_errors.fetch_add(1, Ordering::Relaxed);
match e {
AgentProtocolError::Timeout(_) => self.protocol_metrics.inc_timeouts(),
AgentProtocolError::ConnectionFailed(_)
| AgentProtocolError::ConnectionClosed => {
self.protocol_metrics.inc_connection_errors();
}
AgentProtocolError::Serialization(_) => {
self.protocol_metrics.inc_serialization_errors();
}
_ => {}
}
if consecutive >= 3 {
conn.healthy_cached.store(false, Ordering::Release);
}
}
}
result.map(|r| (r, used_sticky))
}
pub fn cleanup_expired_sessions(&self) -> usize {
let Some(timeout) = self.config.sticky_session_timeout else {
return 0;
};
let mut removed = 0;
self.sticky_sessions.retain(|session_id, session| {
if session.is_expired(timeout) {
debug!(session_id = %session_id, "Removing expired sticky session");
removed += 1;
false
} else {
true
}
});
if removed > 0 {
trace!(removed = removed, "Cleaned up expired sticky sessions");
}
removed
}
pub fn config_pusher(&self) -> &ConfigPusher {
&self.config_pusher
}
pub fn config_update_handler(&self) -> &ConfigUpdateHandler {
&self.config_update_handler
}
pub fn push_config_to_agent(
&self,
agent_id: &str,
update_type: ConfigUpdateType,
) -> Option<String> {
self.config_pusher.push_to_agent(agent_id, update_type)
}
pub fn push_config_to_all(&self, update_type: ConfigUpdateType) -> Vec<String> {
self.config_pusher.push_to_all(update_type)
}
pub fn acknowledge_config_push(&self, push_id: &str, accepted: bool, error: Option<String>) {
self.config_pusher.acknowledge(push_id, accepted, error);
}
pub async fn add_agent(
&self,
agent_id: impl Into<String>,
endpoint: impl Into<String>,
) -> Result<(), AgentProtocolError> {
let agent_id = agent_id.into();
let endpoint = endpoint.into();
info!(agent_id = %agent_id, endpoint = %endpoint, "Adding agent to pool");
let entry = Arc::new(AgentEntry::new(agent_id.clone(), endpoint.clone()));
let mut connections = Vec::with_capacity(self.config.connections_per_agent);
for i in 0..self.config.connections_per_agent {
match self.create_connection(&agent_id, &endpoint).await {
Ok(conn) => {
connections.push(Arc::new(conn));
debug!(
agent_id = %agent_id,
connection = i,
"Created connection"
);
}
Err(e) => {
warn!(
agent_id = %agent_id,
connection = i,
error = %e,
"Failed to create connection"
);
}
}
}
if connections.is_empty() {
return Err(AgentProtocolError::ConnectionFailed(format!(
"Failed to create any connections to agent {}",
agent_id
)));
}
if let Some(conn) = connections.first() {
if let Some(caps) = conn.client.capabilities().await {
let supports_config_push = caps.features.config_push;
let agent_name = caps.name.clone();
self.config_pusher
.register_agent(&agent_id, &agent_name, supports_config_push);
debug!(
agent_id = %agent_id,
supports_config_push = supports_config_push,
"Registered agent with ConfigPusher"
);
*entry.capabilities.write().await = Some(caps);
}
}
*entry.connections.write().await = connections;
self.agents.insert(agent_id.clone(), entry);
info!(
agent_id = %agent_id,
connections = self.config.connections_per_agent,
"Agent added to pool"
);
Ok(())
}
pub async fn remove_agent(&self, agent_id: &str) -> Result<(), AgentProtocolError> {
info!(agent_id = %agent_id, "Removing agent from pool");
self.config_pusher.unregister_agent(agent_id);
let (_, entry) = self.agents.remove(agent_id).ok_or_else(|| {
AgentProtocolError::InvalidMessage(format!("Agent {} not found", agent_id))
})?;
let connections = entry.connections.read().await;
for conn in connections.iter() {
let _ = conn.client.close().await;
}
info!(agent_id = %agent_id, "Agent removed from pool");
Ok(())
}
pub async fn add_reverse_connection(
&self,
agent_id: &str,
client: ReverseConnectionClient,
capabilities: AgentCapabilities,
) -> Result<(), AgentProtocolError> {
info!(
agent_id = %agent_id,
connection_id = %client.connection_id(),
"Adding reverse connection to pool"
);
let transport = V2Transport::Reverse(client);
let conn = Arc::new(PooledConnection::new(
transport,
self.config.max_concurrent_per_connection,
));
if let Some(entry) = self.agents.get(agent_id) {
let mut connections = entry.connections.write().await;
if connections.len() >= self.config.connections_per_agent {
warn!(
agent_id = %agent_id,
current = connections.len(),
max = self.config.connections_per_agent,
"Reverse connection rejected: at connection limit"
);
return Err(AgentProtocolError::ConnectionFailed(format!(
"Agent {} already has maximum connections ({})",
agent_id, self.config.connections_per_agent
)));
}
connections.push(conn);
info!(
agent_id = %agent_id,
total_connections = connections.len(),
"Added reverse connection to existing agent"
);
} else {
let entry = Arc::new(AgentEntry::new(
agent_id.to_string(),
format!("reverse://{}", agent_id),
));
let supports_config_push = capabilities.features.config_push;
let agent_name = capabilities.name.clone();
self.config_pusher
.register_agent(agent_id, &agent_name, supports_config_push);
debug!(
agent_id = %agent_id,
supports_config_push = supports_config_push,
"Registered reverse connection agent with ConfigPusher"
);
*entry.capabilities.write().await = Some(capabilities);
*entry.connections.write().await = vec![conn];
self.agents.insert(agent_id.to_string(), entry);
info!(
agent_id = %agent_id,
"Created new agent entry for reverse connection"
);
}
Ok(())
}
async fn check_flow_control(
&self,
conn: &PooledConnection,
agent_id: &str,
) -> Result<bool, AgentProtocolError> {
if conn.client.can_accept_requests().await {
return Ok(true);
}
match self.config.flow_control_mode {
FlowControlMode::FailClosed => {
self.protocol_metrics.record_flow_rejection();
Err(AgentProtocolError::FlowControlPaused {
agent_id: agent_id.to_string(),
})
}
FlowControlMode::FailOpen => {
debug!(agent_id = %agent_id, "Flow control: agent paused, allowing request (fail-open mode)");
self.protocol_metrics.record_flow_rejection();
Ok(false) }
FlowControlMode::WaitAndRetry => {
let deadline = Instant::now() + self.config.flow_control_wait_timeout;
while Instant::now() < deadline {
tokio::time::sleep(Duration::from_millis(10)).await;
if conn.client.can_accept_requests().await {
trace!(agent_id = %agent_id, "Flow control: agent resumed after wait");
return Ok(true);
}
}
self.protocol_metrics.record_flow_rejection();
Err(AgentProtocolError::FlowControlPaused {
agent_id: agent_id.to_string(),
})
}
}
}
pub async fn send_request_headers(
&self,
agent_id: &str,
correlation_id: &str,
event: &RequestHeadersEvent,
) -> Result<AgentResponse, AgentProtocolError> {
let start = Instant::now();
self.total_requests.fetch_add(1, Ordering::Relaxed);
self.protocol_metrics.inc_requests();
self.protocol_metrics.inc_in_flight();
let conn = self.select_connection(agent_id)?;
match self.check_flow_control(&conn, agent_id).await {
Ok(true) => {} Ok(false) => {
self.protocol_metrics.dec_in_flight();
return Ok(AgentResponse::default_allow());
}
Err(e) => {
self.protocol_metrics.dec_in_flight();
return Err(e);
}
}
let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
self.protocol_metrics.dec_in_flight();
self.protocol_metrics.inc_connection_errors();
AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
})?;
conn.in_flight.fetch_add(1, Ordering::Relaxed);
conn.touch();
self.correlation_affinity
.insert(correlation_id.to_string(), Arc::clone(&conn));
let result = conn
.client
.send_request_headers(correlation_id, event)
.await;
conn.in_flight.fetch_sub(1, Ordering::Relaxed);
conn.request_count.fetch_add(1, Ordering::Relaxed);
self.protocol_metrics.dec_in_flight();
self.protocol_metrics
.record_request_duration(start.elapsed());
match &result {
Ok(_) => {
conn.consecutive_errors.store(0, Ordering::Relaxed);
self.protocol_metrics.inc_responses();
}
Err(e) => {
conn.error_count.fetch_add(1, Ordering::Relaxed);
let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
self.total_errors.fetch_add(1, Ordering::Relaxed);
match e {
AgentProtocolError::Timeout(_) => self.protocol_metrics.inc_timeouts(),
AgentProtocolError::ConnectionFailed(_)
| AgentProtocolError::ConnectionClosed => {
self.protocol_metrics.inc_connection_errors();
}
AgentProtocolError::Serialization(_) => {
self.protocol_metrics.inc_serialization_errors();
}
_ => {}
}
if consecutive >= 3 {
conn.healthy_cached.store(false, Ordering::Release);
trace!(agent_id = %agent_id, error = %e, "Connection marked unhealthy after consecutive errors");
}
}
}
result
}
pub async fn send_request_body_chunk(
&self,
agent_id: &str,
correlation_id: &str,
event: &RequestBodyChunkEvent,
) -> Result<AgentResponse, AgentProtocolError> {
self.total_requests.fetch_add(1, Ordering::Relaxed);
let conn = if let Some(affinity_conn) = self.correlation_affinity.get(correlation_id) {
Arc::clone(&affinity_conn)
} else {
trace!(correlation_id = %correlation_id, "No affinity found for body chunk, using selection");
self.select_connection(agent_id)?
};
match self.check_flow_control(&conn, agent_id).await {
Ok(true) => {} Ok(false) => {
return Ok(AgentResponse::default_allow());
}
Err(e) => return Err(e),
}
let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
})?;
conn.in_flight.fetch_add(1, Ordering::Relaxed);
conn.touch();
let result = conn
.client
.send_request_body_chunk(correlation_id, event)
.await;
conn.in_flight.fetch_sub(1, Ordering::Relaxed);
conn.request_count.fetch_add(1, Ordering::Relaxed);
match &result {
Ok(_) => {
conn.consecutive_errors.store(0, Ordering::Relaxed);
}
Err(_) => {
conn.error_count.fetch_add(1, Ordering::Relaxed);
let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
self.total_errors.fetch_add(1, Ordering::Relaxed);
if consecutive >= 3 {
conn.healthy_cached.store(false, Ordering::Release);
}
}
}
result
}
pub async fn send_response_headers(
&self,
agent_id: &str,
correlation_id: &str,
event: &ResponseHeadersEvent,
) -> Result<AgentResponse, AgentProtocolError> {
self.total_requests.fetch_add(1, Ordering::Relaxed);
let conn = self.select_connection(agent_id)?;
let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
})?;
conn.in_flight.fetch_add(1, Ordering::Relaxed);
conn.touch();
let result = conn
.client
.send_response_headers(correlation_id, event)
.await;
conn.in_flight.fetch_sub(1, Ordering::Relaxed);
conn.request_count.fetch_add(1, Ordering::Relaxed);
match &result {
Ok(_) => {
conn.consecutive_errors.store(0, Ordering::Relaxed);
}
Err(_) => {
conn.error_count.fetch_add(1, Ordering::Relaxed);
let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
self.total_errors.fetch_add(1, Ordering::Relaxed);
if consecutive >= 3 {
conn.healthy_cached.store(false, Ordering::Release);
}
}
}
result
}
pub async fn send_response_body_chunk(
&self,
agent_id: &str,
correlation_id: &str,
event: &ResponseBodyChunkEvent,
) -> Result<AgentResponse, AgentProtocolError> {
self.total_requests.fetch_add(1, Ordering::Relaxed);
let conn = self.select_connection(agent_id)?;
match self.check_flow_control(&conn, agent_id).await {
Ok(true) => {} Ok(false) => {
return Ok(AgentResponse::default_allow());
}
Err(e) => return Err(e),
}
let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
})?;
conn.in_flight.fetch_add(1, Ordering::Relaxed);
conn.touch();
let result = conn
.client
.send_response_body_chunk(correlation_id, event)
.await;
conn.in_flight.fetch_sub(1, Ordering::Relaxed);
conn.request_count.fetch_add(1, Ordering::Relaxed);
match &result {
Ok(_) => {
conn.consecutive_errors.store(0, Ordering::Relaxed);
}
Err(_) => {
conn.error_count.fetch_add(1, Ordering::Relaxed);
let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
self.total_errors.fetch_add(1, Ordering::Relaxed);
if consecutive >= 3 {
conn.healthy_cached.store(false, Ordering::Release);
}
}
}
result
}
pub async fn send_guardrail_inspect(
&self,
agent_id: &str,
correlation_id: &str,
event: &GuardrailInspectEvent,
) -> Result<AgentResponse, AgentProtocolError> {
self.total_requests.fetch_add(1, Ordering::Relaxed);
self.protocol_metrics.inc_requests();
self.protocol_metrics.inc_in_flight();
let conn = self.select_connection(agent_id)?;
match self.check_flow_control(&conn, agent_id).await {
Ok(true) => {}
Ok(false) => {
self.protocol_metrics.dec_in_flight();
return Ok(AgentResponse::default_allow());
}
Err(e) => {
self.protocol_metrics.dec_in_flight();
return Err(e);
}
}
let _permit = conn.concurrency_limiter.acquire().await.map_err(|_| {
self.protocol_metrics.dec_in_flight();
self.protocol_metrics.inc_connection_errors();
AgentProtocolError::ConnectionFailed("Concurrency limit reached".to_string())
})?;
conn.in_flight.fetch_add(1, Ordering::Relaxed);
conn.touch();
let result = conn
.client
.send_guardrail_inspect(correlation_id, event)
.await;
conn.in_flight.fetch_sub(1, Ordering::Relaxed);
conn.request_count.fetch_add(1, Ordering::Relaxed);
self.protocol_metrics.dec_in_flight();
match &result {
Ok(_) => {
conn.consecutive_errors.store(0, Ordering::Relaxed);
self.protocol_metrics.inc_responses();
}
Err(e) => {
conn.error_count.fetch_add(1, Ordering::Relaxed);
let consecutive = conn.consecutive_errors.fetch_add(1, Ordering::Relaxed) + 1;
self.total_errors.fetch_add(1, Ordering::Relaxed);
match e {
AgentProtocolError::Timeout(_) => self.protocol_metrics.inc_timeouts(),
AgentProtocolError::ConnectionFailed(_)
| AgentProtocolError::ConnectionClosed => {
self.protocol_metrics.inc_connection_errors();
}
AgentProtocolError::Serialization(_) => {
self.protocol_metrics.inc_serialization_errors();
}
_ => {}
}
if consecutive >= 3 {
conn.healthy_cached.store(false, Ordering::Release);
}
}
}
result
}
pub async fn cancel_request(
&self,
agent_id: &str,
correlation_id: &str,
reason: CancelReason,
) -> Result<(), AgentProtocolError> {
let entry = self.agents.get(agent_id).ok_or_else(|| {
AgentProtocolError::InvalidMessage(format!("Agent {} not found", agent_id))
})?;
let connections = entry.connections.read().await;
for conn in connections.iter() {
let _ = conn.client.cancel_request(correlation_id, reason).await;
}
Ok(())
}
pub async fn stats(&self) -> Vec<AgentPoolStats> {
let mut stats = Vec::with_capacity(self.agents.len());
for entry_ref in self.agents.iter() {
let agent_id = entry_ref.key().clone();
let entry = entry_ref.value();
let connections = entry.connections.read().await;
let mut healthy_count = 0;
let mut total_in_flight = 0;
let mut total_requests = 0;
let mut total_errors = 0;
for conn in connections.iter() {
if conn.is_healthy_cached() {
healthy_count += 1;
}
total_in_flight += conn.in_flight();
total_requests += conn.request_count.load(Ordering::Relaxed);
total_errors += conn.error_count.load(Ordering::Relaxed);
}
let error_rate = if total_requests == 0 {
0.0
} else {
total_errors as f64 / total_requests as f64
};
stats.push(AgentPoolStats {
agent_id,
active_connections: connections.len(),
healthy_connections: healthy_count,
total_in_flight,
total_requests,
total_errors,
error_rate,
is_healthy: entry.healthy.load(Ordering::Acquire),
});
}
stats
}
pub async fn agent_stats(&self, agent_id: &str) -> Option<AgentPoolStats> {
self.stats()
.await
.into_iter()
.find(|s| s.agent_id == agent_id)
}
pub async fn agent_capabilities(&self, agent_id: &str) -> Option<AgentCapabilities> {
let entry = match self.agents.get(agent_id) {
Some(entry_ref) => Arc::clone(&*entry_ref),
None => return None,
};
let result = entry.capabilities.read().await.clone();
result
}
pub fn is_agent_healthy(&self, agent_id: &str) -> bool {
self.agents
.get(agent_id)
.map(|e| e.healthy.load(Ordering::Acquire))
.unwrap_or(false)
}
pub fn agent_ids(&self) -> Vec<String> {
self.agents.iter().map(|e| e.key().clone()).collect()
}
pub async fn shutdown(&self) -> Result<(), AgentProtocolError> {
info!("Shutting down agent pool");
let agent_ids: Vec<String> = self.agents.iter().map(|e| e.key().clone()).collect();
for agent_id in agent_ids {
if let Some((_, entry)) = self.agents.remove(&agent_id) {
debug!(agent_id = %agent_id, "Draining agent connections");
let connections = entry.connections.read().await;
for conn in connections.iter() {
let _ = conn.client.cancel_all(CancelReason::ProxyShutdown).await;
}
let drain_deadline = Instant::now() + self.config.drain_timeout;
loop {
let total_in_flight: u64 = connections.iter().map(|c| c.in_flight()).sum();
if total_in_flight == 0 {
break;
}
if Instant::now() > drain_deadline {
warn!(
agent_id = %agent_id,
in_flight = total_in_flight,
"Drain timeout, forcing close"
);
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
for conn in connections.iter() {
let _ = conn.client.close().await;
}
}
}
info!("Agent pool shutdown complete");
Ok(())
}
pub async fn run_maintenance(&self) {
let mut interval = tokio::time::interval(self.config.health_check_interval);
loop {
interval.tick().await;
self.cleanup_expired_sessions();
let agent_ids: Vec<String> = self.agents.iter().map(|e| e.key().clone()).collect();
for agent_id in agent_ids {
let Some(entry_ref) = self.agents.get(&agent_id) else {
continue; };
let entry = entry_ref.value().clone();
drop(entry_ref);
let connections = entry.connections.read().await;
let mut healthy_count = 0;
for conn in connections.iter() {
if conn.check_and_update_health().await {
healthy_count += 1;
}
}
let was_healthy = entry.healthy.load(Ordering::Acquire);
let is_healthy = healthy_count > 0;
entry.healthy.store(is_healthy, Ordering::Release);
if was_healthy && !is_healthy {
warn!(agent_id = %agent_id, "Agent marked unhealthy");
} else if !was_healthy && is_healthy {
info!(agent_id = %agent_id, "Agent recovered");
}
if healthy_count < self.config.connections_per_agent
&& entry.should_reconnect(self.config.reconnect_interval)
{
drop(connections); if let Err(e) = self.reconnect_agent(&agent_id, &entry).await {
trace!(agent_id = %agent_id, error = %e, "Reconnect failed");
}
}
}
}
}
async fn create_connection(
&self,
agent_id: &str,
endpoint: &str,
) -> Result<PooledConnection, AgentProtocolError> {
let transport = if is_uds_endpoint(endpoint) {
let socket_path = endpoint.strip_prefix("unix:").unwrap_or(endpoint);
let mut client =
AgentClientV2Uds::new(agent_id, socket_path, self.config.request_timeout).await?;
client.set_metrics_callback(Arc::clone(&self.metrics_callback));
client.set_config_update_callback(Arc::clone(&self.config_update_callback));
client.connect().await?;
V2Transport::Uds(client)
} else {
let mut client =
AgentClientV2::new(agent_id, endpoint, self.config.request_timeout).await?;
client.set_metrics_callback(Arc::clone(&self.metrics_callback));
client.set_config_update_callback(Arc::clone(&self.config_update_callback));
client.connect().await?;
V2Transport::Grpc(client)
};
Ok(PooledConnection::new(
transport,
self.config.max_concurrent_per_connection,
))
}
fn select_connection(
&self,
agent_id: &str,
) -> Result<Arc<PooledConnection>, AgentProtocolError> {
let entry = self.agents.get(agent_id).ok_or_else(|| {
AgentProtocolError::InvalidMessage(format!("Agent {} not found", agent_id))
})?;
let connections_guard = match entry.connections.try_read() {
Ok(guard) => guard,
Err(_) => {
trace!(agent_id = %agent_id, "select_connection: blocking on connections lock");
futures::executor::block_on(entry.connections.read())
}
};
if connections_guard.is_empty() {
return Err(AgentProtocolError::ConnectionFailed(format!(
"No connections available for agent {}",
agent_id
)));
}
let healthy: Vec<_> = connections_guard
.iter()
.filter(|c| c.is_healthy_cached())
.cloned()
.collect();
if healthy.is_empty() {
return Err(AgentProtocolError::ConnectionFailed(format!(
"No healthy connections for agent {}",
agent_id
)));
}
let selected = match self.config.load_balance_strategy {
LoadBalanceStrategy::RoundRobin => {
let idx = entry.round_robin_index.fetch_add(1, Ordering::Relaxed);
healthy[idx % healthy.len()].clone()
}
LoadBalanceStrategy::LeastConnections => healthy
.iter()
.min_by_key(|c| c.in_flight())
.cloned()
.unwrap(),
LoadBalanceStrategy::HealthBased => {
healthy
.iter()
.min_by(|a, b| {
a.error_rate()
.partial_cmp(&b.error_rate())
.unwrap_or(std::cmp::Ordering::Equal)
})
.cloned()
.unwrap()
}
LoadBalanceStrategy::Random => {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hasher};
let idx = RandomState::new().build_hasher().finish() as usize % healthy.len();
healthy[idx].clone()
}
};
Ok(selected)
}
async fn reconnect_agent(
&self,
agent_id: &str,
entry: &AgentEntry,
) -> Result<(), AgentProtocolError> {
entry.mark_reconnect_attempt();
let attempts = entry.reconnect_attempts.fetch_add(1, Ordering::Relaxed);
if attempts >= self.config.max_reconnect_attempts {
debug!(
agent_id = %agent_id,
attempts = attempts,
"Max reconnect attempts reached"
);
return Ok(());
}
debug!(agent_id = %agent_id, attempt = attempts + 1, "Attempting reconnect");
match self.create_connection(agent_id, &entry.endpoint).await {
Ok(conn) => {
let mut connections = entry.connections.write().await;
connections.push(Arc::new(conn));
entry.reconnect_attempts.store(0, Ordering::Relaxed);
info!(agent_id = %agent_id, "Reconnected successfully");
Ok(())
}
Err(e) => {
debug!(agent_id = %agent_id, error = %e, "Reconnect failed");
Err(e)
}
}
}
}
impl Default for AgentPool {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for AgentPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AgentPool")
.field("config", &self.config)
.field(
"total_requests",
&self.total_requests.load(Ordering::Relaxed),
)
.field("total_errors", &self.total_errors.load(Ordering::Relaxed))
.finish()
}
}
fn is_uds_endpoint(endpoint: &str) -> bool {
endpoint.starts_with("unix:") || endpoint.starts_with('/') || endpoint.ends_with(".sock")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_config_default() {
let config = AgentPoolConfig::default();
assert_eq!(config.connections_per_agent, 4);
assert_eq!(
config.load_balance_strategy,
LoadBalanceStrategy::RoundRobin
);
}
#[test]
fn test_load_balance_strategy() {
assert_eq!(
LoadBalanceStrategy::default(),
LoadBalanceStrategy::RoundRobin
);
}
#[test]
fn test_pool_creation() {
let pool = AgentPool::new();
assert_eq!(pool.total_requests.load(Ordering::Relaxed), 0);
assert_eq!(pool.total_errors.load(Ordering::Relaxed), 0);
}
#[test]
fn test_pool_with_config() {
let config = AgentPoolConfig {
connections_per_agent: 8,
load_balance_strategy: LoadBalanceStrategy::LeastConnections,
..Default::default()
};
let pool = AgentPool::with_config(config.clone());
assert_eq!(pool.config.connections_per_agent, 8);
}
#[test]
fn test_agent_ids_empty() {
let pool = AgentPool::new();
assert!(pool.agent_ids().is_empty());
}
#[test]
fn test_is_agent_healthy_not_found() {
let pool = AgentPool::new();
assert!(!pool.is_agent_healthy("nonexistent"));
}
#[tokio::test]
async fn test_stats_empty() {
let pool = AgentPool::new();
assert!(pool.stats().await.is_empty());
}
#[test]
fn test_is_uds_endpoint() {
assert!(is_uds_endpoint("unix:/var/run/agent.sock"));
assert!(is_uds_endpoint("unix:agent.sock"));
assert!(is_uds_endpoint("/var/run/agent.sock"));
assert!(is_uds_endpoint("/tmp/test.sock"));
assert!(is_uds_endpoint("agent.sock"));
assert!(!is_uds_endpoint("http://localhost:8080"));
assert!(!is_uds_endpoint("localhost:50051"));
assert!(!is_uds_endpoint("127.0.0.1:8080"));
}
#[test]
fn test_flow_control_mode_default() {
assert_eq!(FlowControlMode::default(), FlowControlMode::FailClosed);
}
#[test]
fn test_pool_config_flow_control_defaults() {
let config = AgentPoolConfig::default();
assert_eq!(config.channel_buffer_size, CHANNEL_BUFFER_SIZE);
assert_eq!(config.flow_control_mode, FlowControlMode::FailClosed);
assert_eq!(config.flow_control_wait_timeout, Duration::from_millis(100));
}
#[test]
fn test_pool_config_custom_flow_control() {
let config = AgentPoolConfig {
channel_buffer_size: 128,
flow_control_mode: FlowControlMode::FailOpen,
flow_control_wait_timeout: Duration::from_millis(500),
..Default::default()
};
assert_eq!(config.channel_buffer_size, 128);
assert_eq!(config.flow_control_mode, FlowControlMode::FailOpen);
assert_eq!(config.flow_control_wait_timeout, Duration::from_millis(500));
}
#[test]
fn test_pool_config_wait_and_retry() {
let config = AgentPoolConfig {
flow_control_mode: FlowControlMode::WaitAndRetry,
flow_control_wait_timeout: Duration::from_millis(250),
..Default::default()
};
assert_eq!(config.flow_control_mode, FlowControlMode::WaitAndRetry);
assert_eq!(config.flow_control_wait_timeout, Duration::from_millis(250));
}
#[test]
fn test_pool_config_sticky_session_default() {
let config = AgentPoolConfig::default();
assert_eq!(
config.sticky_session_timeout,
Some(Duration::from_secs(5 * 60))
);
}
#[test]
fn test_pool_config_sticky_session_custom() {
let config = AgentPoolConfig {
sticky_session_timeout: Some(Duration::from_secs(60)),
..Default::default()
};
assert_eq!(config.sticky_session_timeout, Some(Duration::from_secs(60)));
}
#[test]
fn test_pool_config_sticky_session_disabled() {
let config = AgentPoolConfig {
sticky_session_timeout: None,
..Default::default()
};
assert!(config.sticky_session_timeout.is_none());
}
#[test]
fn test_sticky_session_count_empty() {
let pool = AgentPool::new();
assert_eq!(pool.sticky_session_count(), 0);
}
#[test]
fn test_sticky_session_has_nonexistent() {
let pool = AgentPool::new();
assert!(!pool.has_sticky_session("nonexistent"));
}
#[test]
fn test_sticky_session_clear_nonexistent() {
let pool = AgentPool::new();
pool.clear_sticky_session("nonexistent");
}
#[test]
fn test_cleanup_expired_sessions_empty() {
let pool = AgentPool::new();
let removed = pool.cleanup_expired_sessions();
assert_eq!(removed, 0);
}
}