use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::{debug, error, info, trace, warn};
use zentinel_agent_protocol::v2::{
AgentCapabilities, AgentPool, AgentPoolConfig as ProtocolPoolConfig, AgentPoolStats,
CancelReason, ConfigPusher, ConfigUpdateType, LoadBalanceStrategy as ProtocolLBStrategy,
MetricsCollector,
};
use zentinel_agent_protocol::{
AgentResponse, EventType, GuardrailInspectEvent, RequestBodyChunkEvent, RequestHeadersEvent,
ResponseBodyChunkEvent, ResponseHeadersEvent,
};
use zentinel_common::{
errors::{ZentinelError, ZentinelResult},
CircuitBreaker,
};
use zentinel_config::{AgentConfig, AgentEvent, FailureMode, LoadBalanceStrategy};
use super::metrics::AgentMetrics;
const NO_TIMESTAMP: u64 = 0;
pub struct AgentV2 {
config: AgentConfig,
pool: Arc<AgentPool>,
circuit_breaker: Arc<CircuitBreaker>,
metrics: Arc<AgentMetrics>,
base_instant: Instant,
last_success_ns: AtomicU64,
consecutive_failures: AtomicU32,
}
impl AgentV2 {
pub fn new(config: AgentConfig, circuit_breaker: Arc<CircuitBreaker>) -> Self {
trace!(
agent_id = %config.id,
agent_type = ?config.agent_type,
timeout_ms = config.timeout_ms,
events = ?config.events,
"Creating v2 agent instance"
);
let pool_config = config
.pool
.as_ref()
.map(|p| ProtocolPoolConfig {
connections_per_agent: p.connections_per_agent,
load_balance_strategy: convert_lb_strategy(p.load_balance_strategy),
connect_timeout: Duration::from_millis(p.connect_timeout_ms),
request_timeout: Duration::from_millis(config.timeout_ms),
reconnect_interval: Duration::from_millis(p.reconnect_interval_ms),
max_reconnect_attempts: p.max_reconnect_attempts,
drain_timeout: Duration::from_millis(p.drain_timeout_ms),
max_concurrent_per_connection: p.max_concurrent_per_connection,
health_check_interval: Duration::from_millis(p.health_check_interval_ms),
..Default::default()
})
.unwrap_or_default();
let pool = Arc::new(AgentPool::with_config(pool_config));
Self {
config,
pool,
circuit_breaker,
metrics: Arc::new(AgentMetrics::default()),
base_instant: Instant::now(),
last_success_ns: AtomicU64::new(NO_TIMESTAMP),
consecutive_failures: AtomicU32::new(0),
}
}
pub fn id(&self) -> &str {
&self.config.id
}
pub fn circuit_breaker(&self) -> &CircuitBreaker {
&self.circuit_breaker
}
pub fn failure_mode(&self) -> FailureMode {
self.config.failure_mode
}
pub fn timeout_ms(&self) -> u64 {
self.config.timeout_ms
}
pub fn metrics(&self) -> &AgentMetrics {
&self.metrics
}
pub fn handles_event(&self, event_type: EventType) -> bool {
self.config.events.iter().any(|e| match (e, event_type) {
(AgentEvent::RequestHeaders, EventType::RequestHeaders) => true,
(AgentEvent::RequestBody, EventType::RequestBodyChunk) => true,
(AgentEvent::ResponseHeaders, EventType::ResponseHeaders) => true,
(AgentEvent::ResponseBody, EventType::ResponseBodyChunk) => true,
(AgentEvent::Log, EventType::RequestComplete) => true,
(AgentEvent::WebSocketFrame, EventType::WebSocketFrame) => true,
(AgentEvent::Guardrail, EventType::GuardrailInspect) => true,
_ => false,
})
}
pub async fn initialize(&self) -> ZentinelResult<()> {
let endpoint = self.get_endpoint()?;
debug!(
agent_id = %self.config.id,
endpoint = %endpoint,
"Initializing v2 agent pool"
);
let start = Instant::now();
self.pool
.add_agent(&self.config.id, &endpoint)
.await
.map_err(|e| {
error!(
agent_id = %self.config.id,
endpoint = %endpoint,
error = %e,
"Failed to add agent to v2 pool"
);
ZentinelError::Agent {
agent: self.config.id.clone(),
message: format!("Failed to initialize v2 agent: {}", e),
event: "initialize".to_string(),
source: None,
}
})?;
info!(
agent_id = %self.config.id,
endpoint = %endpoint,
connect_time_ms = start.elapsed().as_millis(),
"V2 agent pool initialized"
);
if let Some(config_value) = &self.config.config {
self.send_configure(config_value.clone()).await?;
}
Ok(())
}
fn get_endpoint(&self) -> ZentinelResult<String> {
use zentinel_config::AgentTransport;
match &self.config.transport {
AgentTransport::Grpc { address, .. } => Ok(address.clone()),
AgentTransport::UnixSocket { path } => {
Ok(format!("unix:{}", path.display()))
}
AgentTransport::Http { url, .. } => {
Err(ZentinelError::Agent {
agent: self.config.id.clone(),
message: "HTTP transport not supported for v2 protocol".to_string(),
event: "initialize".to_string(),
source: None,
})
}
}
}
async fn send_configure(&self, _config: serde_json::Value) -> ZentinelResult<()> {
use zentinel_agent_protocol::v2::ConfigUpdateType;
if let Some(push_id) = self
.pool
.push_config_to_agent(&self.config.id, ConfigUpdateType::RequestReload)
{
info!(
agent_id = %self.config.id,
push_id = %push_id,
"Configuration push sent to agent"
);
Ok(())
} else {
debug!(
agent_id = %self.config.id,
"Agent does not support config push, config will be sent on next connection"
);
Ok(())
}
}
pub async fn call_request_headers(
&self,
event: &RequestHeadersEvent,
) -> ZentinelResult<AgentResponse> {
let call_num = self.metrics.calls_total.fetch_add(1, Ordering::Relaxed) + 1;
let correlation_id = &event.metadata.correlation_id;
trace!(
agent_id = %self.config.id,
call_num = call_num,
correlation_id = %correlation_id,
"Sending request headers to v2 agent"
);
self.pool
.send_request_headers(&self.config.id, correlation_id, event)
.await
.map_err(|e| {
error!(
agent_id = %self.config.id,
correlation_id = %correlation_id,
error = %e,
"V2 agent request headers call failed"
);
ZentinelError::Agent {
agent: self.config.id.clone(),
message: e.to_string(),
event: "request_headers".to_string(),
source: None,
}
})
}
pub async fn call_request_body_chunk(
&self,
event: &RequestBodyChunkEvent,
) -> ZentinelResult<AgentResponse> {
let correlation_id = &event.correlation_id;
trace!(
agent_id = %self.config.id,
correlation_id = %correlation_id,
chunk_index = event.chunk_index,
is_last = event.is_last,
"Sending request body chunk to v2 agent"
);
self.pool
.send_request_body_chunk(&self.config.id, correlation_id, event)
.await
.map_err(|e| {
error!(
agent_id = %self.config.id,
correlation_id = %correlation_id,
error = %e,
"V2 agent request body chunk call failed"
);
ZentinelError::Agent {
agent: self.config.id.clone(),
message: e.to_string(),
event: "request_body_chunk".to_string(),
source: None,
}
})
}
pub async fn call_response_headers(
&self,
event: &ResponseHeadersEvent,
) -> ZentinelResult<AgentResponse> {
let correlation_id = &event.correlation_id;
trace!(
agent_id = %self.config.id,
correlation_id = %correlation_id,
status = event.status,
"Sending response headers to v2 agent"
);
self.pool
.send_response_headers(&self.config.id, correlation_id, event)
.await
.map_err(|e| {
error!(
agent_id = %self.config.id,
correlation_id = %correlation_id,
error = %e,
"V2 agent response headers call failed"
);
ZentinelError::Agent {
agent: self.config.id.clone(),
message: e.to_string(),
event: "response_headers".to_string(),
source: None,
}
})
}
pub async fn call_response_body_chunk(
&self,
event: &ResponseBodyChunkEvent,
) -> ZentinelResult<AgentResponse> {
let correlation_id = &event.correlation_id;
trace!(
agent_id = %self.config.id,
correlation_id = %correlation_id,
chunk_index = event.chunk_index,
is_last = event.is_last,
"Sending response body chunk to v2 agent"
);
self.pool
.send_response_body_chunk(&self.config.id, correlation_id, event)
.await
.map_err(|e| {
error!(
agent_id = %self.config.id,
correlation_id = %correlation_id,
error = %e,
"V2 agent response body chunk call failed"
);
ZentinelError::Agent {
agent: self.config.id.clone(),
message: e.to_string(),
event: "response_body_chunk".to_string(),
source: None,
}
})
}
pub async fn call_guardrail_inspect(
&self,
event: &GuardrailInspectEvent,
) -> ZentinelResult<AgentResponse> {
let call_num = self.metrics.calls_total.fetch_add(1, Ordering::Relaxed) + 1;
let correlation_id = &event.correlation_id;
trace!(
agent_id = %self.config.id,
call_num = call_num,
correlation_id = %correlation_id,
inspection_type = ?event.inspection_type,
"Sending guardrail inspect to v2 agent"
);
self.pool
.send_guardrail_inspect(&self.config.id, correlation_id, event)
.await
.map_err(|e| {
error!(
agent_id = %self.config.id,
correlation_id = %correlation_id,
error = %e,
"V2 agent guardrail inspect call failed"
);
ZentinelError::Agent {
agent: self.config.id.clone(),
message: e.to_string(),
event: "guardrail_inspect".to_string(),
source: None,
}
})
}
pub async fn call_event<T: serde::Serialize>(
&self,
event_type: EventType,
event: &T,
) -> ZentinelResult<AgentResponse> {
let json = serde_json::to_value(event).map_err(|e| ZentinelError::Agent {
agent: self.config.id.clone(),
message: format!("Failed to serialize event: {}", e),
event: format!("{:?}", event_type),
source: None,
})?;
match event_type {
EventType::RequestHeaders => {
let typed: RequestHeadersEvent =
serde_json::from_value(json).map_err(|e| ZentinelError::Agent {
agent: self.config.id.clone(),
message: format!("Failed to deserialize RequestHeadersEvent: {}", e),
event: format!("{:?}", event_type),
source: None,
})?;
self.call_request_headers(&typed).await
}
EventType::RequestBodyChunk => {
let typed: RequestBodyChunkEvent =
serde_json::from_value(json).map_err(|e| ZentinelError::Agent {
agent: self.config.id.clone(),
message: format!("Failed to deserialize RequestBodyChunkEvent: {}", e),
event: format!("{:?}", event_type),
source: None,
})?;
self.call_request_body_chunk(&typed).await
}
EventType::ResponseHeaders => {
let typed: ResponseHeadersEvent =
serde_json::from_value(json).map_err(|e| ZentinelError::Agent {
agent: self.config.id.clone(),
message: format!("Failed to deserialize ResponseHeadersEvent: {}", e),
event: format!("{:?}", event_type),
source: None,
})?;
self.call_response_headers(&typed).await
}
EventType::ResponseBodyChunk => {
let typed: ResponseBodyChunkEvent =
serde_json::from_value(json).map_err(|e| ZentinelError::Agent {
agent: self.config.id.clone(),
message: format!("Failed to deserialize ResponseBodyChunkEvent: {}", e),
event: format!("{:?}", event_type),
source: None,
})?;
self.call_response_body_chunk(&typed).await
}
EventType::GuardrailInspect => {
let typed: GuardrailInspectEvent =
serde_json::from_value(json).map_err(|e| ZentinelError::Agent {
agent: self.config.id.clone(),
message: format!("Failed to deserialize GuardrailInspectEvent: {}", e),
event: format!("{:?}", event_type),
source: None,
})?;
self.call_guardrail_inspect(&typed).await
}
_ => Err(ZentinelError::Agent {
agent: self.config.id.clone(),
message: format!("Unsupported event type {:?}", event_type),
event: format!("{:?}", event_type),
source: None,
}),
}
}
pub async fn cancel_request(
&self,
correlation_id: &str,
reason: CancelReason,
) -> ZentinelResult<()> {
trace!(
agent_id = %self.config.id,
correlation_id = %correlation_id,
reason = ?reason,
"Cancelling request on v2 agent"
);
self.pool
.cancel_request(&self.config.id, correlation_id, reason)
.await
.map_err(|e| {
warn!(
agent_id = %self.config.id,
correlation_id = %correlation_id,
error = %e,
"Failed to cancel request on v2 agent"
);
ZentinelError::Agent {
agent: self.config.id.clone(),
message: format!("Cancel failed: {}", e),
event: "cancel".to_string(),
source: None,
}
})
}
pub async fn capabilities(&self) -> Option<AgentCapabilities> {
self.pool.agent_capabilities(&self.config.id).await
}
pub async fn is_healthy(&self) -> bool {
self.pool.is_agent_healthy(&self.config.id)
}
pub fn record_success(&self, duration: Duration) {
let success_count = self.metrics.calls_success.fetch_add(1, Ordering::Relaxed) + 1;
self.metrics
.duration_total_us
.fetch_add(duration.as_micros() as u64, Ordering::Relaxed);
self.consecutive_failures.store(0, Ordering::Relaxed);
self.last_success_ns.store(
self.base_instant.elapsed().as_nanos() as u64,
Ordering::Relaxed,
);
trace!(
agent_id = %self.config.id,
duration_ms = duration.as_millis(),
total_successes = success_count,
"Recorded v2 agent call success"
);
self.circuit_breaker.record_success();
}
#[inline]
pub fn time_since_last_success(&self) -> Option<Duration> {
let last_ns = self.last_success_ns.load(Ordering::Relaxed);
if last_ns == NO_TIMESTAMP {
return None;
}
let current_ns = self.base_instant.elapsed().as_nanos() as u64;
Some(Duration::from_nanos(current_ns.saturating_sub(last_ns)))
}
pub fn record_failure(&self) {
let fail_count = self.metrics.calls_failed.fetch_add(1, Ordering::Relaxed) + 1;
let consecutive = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1;
debug!(
agent_id = %self.config.id,
total_failures = fail_count,
consecutive_failures = consecutive,
"Recorded v2 agent call failure"
);
self.circuit_breaker.record_failure();
}
pub fn record_timeout(&self) {
let timeout_count = self.metrics.calls_timeout.fetch_add(1, Ordering::Relaxed) + 1;
let consecutive = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1;
debug!(
agent_id = %self.config.id,
total_timeouts = timeout_count,
consecutive_failures = consecutive,
timeout_ms = self.config.timeout_ms,
"Recorded v2 agent call timeout"
);
self.circuit_breaker.record_failure();
}
pub async fn pool_stats(&self) -> Option<AgentPoolStats> {
self.pool.agent_stats(&self.config.id).await
}
pub fn pool_metrics_collector(&self) -> &MetricsCollector {
self.pool.metrics_collector()
}
pub fn pool_metrics_collector_arc(&self) -> Arc<MetricsCollector> {
self.pool.metrics_collector_arc()
}
pub fn export_prometheus(&self) -> String {
self.pool.export_prometheus()
}
pub fn config_pusher(&self) -> &ConfigPusher {
self.pool.config_pusher()
}
pub fn push_config(&self, update_type: ConfigUpdateType) -> Option<String> {
self.pool.push_config_to_agent(&self.config.id, update_type)
}
pub async fn send_configuration(&self, config: serde_json::Value) -> ZentinelResult<()> {
if let Some(push_id) = self.push_config(ConfigUpdateType::RequestReload) {
debug!(
agent_id = %self.config.id,
push_id = %push_id,
"Configuration push initiated"
);
Ok(())
} else {
warn!(
agent_id = %self.config.id,
"Agent does not support config push"
);
Err(ZentinelError::Agent {
agent: self.config.id.clone(),
message: "Agent does not support config push".to_string(),
event: "send_configuration".to_string(),
source: None,
})
}
}
pub async fn shutdown(&self) {
debug!(
agent_id = %self.config.id,
"Shutting down v2 agent"
);
if let Err(e) = self.pool.remove_agent(&self.config.id).await {
warn!(
agent_id = %self.config.id,
error = %e,
"Error removing agent from pool during shutdown"
);
}
let stats = (
self.metrics.calls_total.load(Ordering::Relaxed),
self.metrics.calls_success.load(Ordering::Relaxed),
self.metrics.calls_failed.load(Ordering::Relaxed),
self.metrics.calls_timeout.load(Ordering::Relaxed),
);
info!(
agent_id = %self.config.id,
total_calls = stats.0,
successes = stats.1,
failures = stats.2,
timeouts = stats.3,
"V2 agent shutdown complete"
);
}
}
fn convert_lb_strategy(strategy: LoadBalanceStrategy) -> ProtocolLBStrategy {
match strategy {
LoadBalanceStrategy::RoundRobin => ProtocolLBStrategy::RoundRobin,
LoadBalanceStrategy::LeastConnections => ProtocolLBStrategy::LeastConnections,
LoadBalanceStrategy::HealthBased => ProtocolLBStrategy::HealthBased,
LoadBalanceStrategy::Random => ProtocolLBStrategy::Random,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_convert_lb_strategy() {
assert_eq!(
convert_lb_strategy(LoadBalanceStrategy::RoundRobin),
ProtocolLBStrategy::RoundRobin
);
assert_eq!(
convert_lb_strategy(LoadBalanceStrategy::LeastConnections),
ProtocolLBStrategy::LeastConnections
);
assert_eq!(
convert_lb_strategy(LoadBalanceStrategy::HealthBased),
ProtocolLBStrategy::HealthBased
);
assert_eq!(
convert_lb_strategy(LoadBalanceStrategy::Random),
ProtocolLBStrategy::Random
);
}
}