use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
use tonic::transport::Channel;
use tracing::{debug, info, trace, warn};
use crate::grpc_v2::{self, agent_service_v2_client::AgentServiceV2Client, ProxyToAgent};
use crate::headers::iter_flat;
use crate::v2::pool::CHANNEL_BUFFER_SIZE;
use crate::v2::{AgentCapabilities, PROTOCOL_VERSION_2};
use crate::{AgentProtocolError, AgentResponse, Decision, EventType, HeaderOp};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CancelReason {
ClientDisconnect,
Timeout,
BlockedByAgent,
UpstreamError,
ProxyShutdown,
Manual,
}
impl CancelReason {
fn to_grpc(self) -> i32 {
match self {
CancelReason::ClientDisconnect => 1,
CancelReason::Timeout => 2,
CancelReason::BlockedByAgent => 3,
CancelReason::UpstreamError => 4,
CancelReason::ProxyShutdown => 5,
CancelReason::Manual => 6,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum FlowState {
#[default]
Normal,
Paused,
Draining,
}
pub type MetricsCallback = Arc<dyn Fn(crate::v2::MetricsReport) + Send + Sync>;
pub type ConfigUpdateCallback = Arc<
dyn Fn(String, crate::v2::ConfigUpdateRequest) -> crate::v2::ConfigUpdateResponse + Send + Sync,
>;
pub struct AgentClientV2 {
agent_id: String,
channel: Channel,
timeout: Duration,
capabilities: RwLock<Option<AgentCapabilities>>,
protocol_version: AtomicU64,
pending: Arc<Mutex<HashMap<String, oneshot::Sender<AgentResponse>>>>,
outbound_tx: Mutex<Option<mpsc::Sender<ProxyToAgent>>>,
ping_sequence: AtomicU64,
connected: RwLock<bool>,
flow_state: RwLock<FlowState>,
health_state: RwLock<i32>,
in_flight: AtomicU64,
metrics_callback: Option<MetricsCallback>,
config_update_callback: Option<ConfigUpdateCallback>,
}
impl AgentClientV2 {
pub async fn new(
agent_id: impl Into<String>,
endpoint: impl Into<String>,
timeout: Duration,
) -> Result<Self, AgentProtocolError> {
let agent_id = agent_id.into();
let endpoint = endpoint.into();
debug!(agent_id = %agent_id, endpoint = %endpoint, "Creating v2 client");
let channel = Channel::from_shared(endpoint.clone())
.map_err(|e| AgentProtocolError::ConnectionFailed(format!("Invalid endpoint: {}", e)))?
.connect_timeout(timeout)
.timeout(timeout)
.connect()
.await
.map_err(|e| {
AgentProtocolError::ConnectionFailed(format!("Failed to connect: {}", e))
})?;
Ok(Self {
agent_id,
channel,
timeout,
capabilities: RwLock::new(None),
protocol_version: AtomicU64::new(1), pending: Arc::new(Mutex::new(HashMap::new())),
outbound_tx: Mutex::new(None),
ping_sequence: AtomicU64::new(0),
connected: RwLock::new(false),
flow_state: RwLock::new(FlowState::Normal),
health_state: RwLock::new(1), in_flight: AtomicU64::new(0),
metrics_callback: None,
config_update_callback: None,
})
}
pub fn set_metrics_callback(&mut self, callback: MetricsCallback) {
self.metrics_callback = Some(callback);
}
pub fn set_config_update_callback(&mut self, callback: ConfigUpdateCallback) {
self.config_update_callback = Some(callback);
}
pub async fn connect(&self) -> Result<(), AgentProtocolError> {
let mut client = AgentServiceV2Client::new(self.channel.clone());
let (tx, rx) = mpsc::channel::<ProxyToAgent>(CHANNEL_BUFFER_SIZE);
let rx_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let response_stream = client
.process_stream(rx_stream)
.await
.map_err(|e| AgentProtocolError::ConnectionFailed(format!("Stream failed: {}", e)))?;
let mut inbound = response_stream.into_inner();
let handshake = ProxyToAgent {
message: Some(grpc_v2::proxy_to_agent::Message::Handshake(
grpc_v2::HandshakeRequest {
supported_versions: vec![PROTOCOL_VERSION_2, 1],
proxy_id: "zentinel-proxy".to_string(),
proxy_version: env!("CARGO_PKG_VERSION").to_string(),
config_json: "{}".to_string(),
},
)),
};
tx.send(handshake).await.map_err(|e| {
AgentProtocolError::ConnectionFailed(format!("Failed to send handshake: {}", e))
})?;
let handshake_resp = tokio::time::timeout(self.timeout, inbound.message())
.await
.map_err(|_| AgentProtocolError::Timeout(self.timeout))?
.map_err(|e| AgentProtocolError::ConnectionFailed(format!("Stream error: {}", e)))?
.ok_or_else(|| {
AgentProtocolError::ConnectionFailed("Empty handshake response".to_string())
})?;
if let Some(grpc_v2::agent_to_proxy::Message::Handshake(resp)) = handshake_resp.message {
if !resp.success {
return Err(AgentProtocolError::ConnectionFailed(format!(
"Handshake failed: {}",
resp.error.unwrap_or_default()
)));
}
self.protocol_version
.store(resp.protocol_version as u64, Ordering::SeqCst);
if let Some(caps) = resp.capabilities {
let capabilities = convert_capabilities_from_grpc(caps);
*self.capabilities.write().await = Some(capabilities);
}
info!(
agent_id = %self.agent_id,
protocol_version = resp.protocol_version,
"v2 handshake successful"
);
} else {
return Err(AgentProtocolError::ConnectionFailed(
"Invalid handshake response".to_string(),
));
}
*self.outbound_tx.lock().await = Some(tx);
*self.connected.write().await = true;
let pending = Arc::clone(&self.pending);
let agent_id = self.agent_id.clone();
let flow_state = Arc::new(RwLock::new(FlowState::Normal));
let health_state = Arc::new(RwLock::new(1i32));
let _in_flight = Arc::new(AtomicU64::new(0));
let flow_state_clone = Arc::clone(&flow_state);
let health_state_clone = Arc::clone(&health_state);
let metrics_callback = self.metrics_callback.clone();
let config_update_callback = self.config_update_callback.clone();
tokio::spawn(async move {
while let Ok(Some(msg)) = inbound.message().await {
match msg.message {
Some(grpc_v2::agent_to_proxy::Message::Response(resp)) => {
let correlation_id = resp.correlation_id.clone();
if let Some(sender) = pending.lock().await.remove(&correlation_id) {
let response = convert_response_from_grpc(resp);
let _ = sender.send(response);
} else {
warn!(
agent_id = %agent_id,
correlation_id = %correlation_id,
"Received response for unknown correlation ID"
);
}
}
Some(grpc_v2::agent_to_proxy::Message::Health(health)) => {
trace!(
agent_id = %agent_id,
state = health.state,
"Received health status"
);
*health_state_clone.write().await = health.state;
}
Some(grpc_v2::agent_to_proxy::Message::Metrics(metrics)) => {
trace!(
agent_id = %agent_id,
counters = metrics.counters.len(),
gauges = metrics.gauges.len(),
histograms = metrics.histograms.len(),
"Received metrics report"
);
if let Some(ref callback) = metrics_callback {
let report = convert_metrics_from_grpc(metrics, &agent_id);
callback(report);
}
}
Some(grpc_v2::agent_to_proxy::Message::FlowControl(fc)) => {
let new_state = match fc.action {
1 => FlowState::Paused, 2 => FlowState::Normal, _ => FlowState::Normal,
};
debug!(
agent_id = %agent_id,
action = fc.action,
correlation_id = ?fc.correlation_id,
"Received flow control signal"
);
*flow_state_clone.write().await = new_state;
}
Some(grpc_v2::agent_to_proxy::Message::Pong(pong)) => {
trace!(
agent_id = %agent_id,
sequence = pong.sequence,
latency_ms = pong.timestamp_ms.saturating_sub(pong.ping_timestamp_ms),
"Received pong"
);
}
Some(grpc_v2::agent_to_proxy::Message::ConfigUpdate(update)) => {
debug!(
agent_id = %agent_id,
request_id = %update.request_id,
"Received config update request from agent"
);
if let Some(ref callback) = config_update_callback {
let request = convert_config_update_from_grpc(update);
let _response = callback(agent_id.clone(), request);
}
}
Some(grpc_v2::agent_to_proxy::Message::Log(log_msg)) => {
match log_msg.level {
1 => {
trace!(agent_id = %agent_id, msg = %log_msg.message, "Agent debug log")
}
2 => {
debug!(agent_id = %agent_id, msg = %log_msg.message, "Agent info log")
}
3 => {
warn!(agent_id = %agent_id, msg = %log_msg.message, "Agent warning")
}
4 => warn!(agent_id = %agent_id, msg = %log_msg.message, "Agent error"),
_ => trace!(agent_id = %agent_id, msg = %log_msg.message, "Agent log"),
}
}
_ => {}
}
}
debug!(agent_id = %agent_id, "Response handler ended");
});
Ok(())
}
pub async fn send_request_headers(
&self,
correlation_id: &str,
event: &crate::RequestHeadersEvent,
) -> Result<AgentResponse, AgentProtocolError> {
let msg = ProxyToAgent {
message: Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(
convert_request_headers_to_grpc(event),
)),
};
self.send_and_wait(correlation_id, msg).await
}
pub async fn send_request_body_chunk(
&self,
correlation_id: &str,
event: &crate::RequestBodyChunkEvent,
) -> Result<AgentResponse, AgentProtocolError> {
let msg = ProxyToAgent {
message: Some(grpc_v2::proxy_to_agent::Message::RequestBodyChunk(
convert_body_chunk_to_grpc(event),
)),
};
self.send_and_wait(correlation_id, msg).await
}
pub async fn send_response_headers(
&self,
correlation_id: &str,
event: &crate::ResponseHeadersEvent,
) -> Result<AgentResponse, AgentProtocolError> {
let msg = ProxyToAgent {
message: Some(grpc_v2::proxy_to_agent::Message::ResponseHeaders(
convert_response_headers_to_grpc(event),
)),
};
self.send_and_wait(correlation_id, msg).await
}
pub async fn send_response_body_chunk(
&self,
correlation_id: &str,
event: &crate::ResponseBodyChunkEvent,
) -> Result<AgentResponse, AgentProtocolError> {
let msg = ProxyToAgent {
message: Some(grpc_v2::proxy_to_agent::Message::ResponseBodyChunk(
convert_response_body_chunk_to_grpc(event),
)),
};
self.send_and_wait(correlation_id, msg).await
}
pub async fn send_event<T: serde::Serialize>(
&self,
event_type: EventType,
event: &T,
) -> Result<AgentResponse, AgentProtocolError> {
let correlation_id = extract_correlation_id(event);
let msg = match event_type {
EventType::RequestHeaders => {
if let Ok(e) = serde_json::from_value::<crate::RequestHeadersEvent>(
serde_json::to_value(event).unwrap_or_default(),
) {
ProxyToAgent {
message: Some(grpc_v2::proxy_to_agent::Message::RequestHeaders(
convert_request_headers_to_grpc(&e),
)),
}
} else {
return Err(AgentProtocolError::InvalidMessage(
"Failed to convert event".to_string(),
));
}
}
_ => {
return Err(AgentProtocolError::InvalidMessage(format!(
"Event type {:?} not yet supported in v2 streaming mode",
event_type
)));
}
};
self.send_and_wait(&correlation_id, msg).await
}
async fn send_and_wait(
&self,
correlation_id: &str,
msg: ProxyToAgent,
) -> Result<AgentResponse, AgentProtocolError> {
let (tx, rx) = oneshot::channel();
self.pending
.lock()
.await
.insert(correlation_id.to_string(), tx);
{
let outbound = self.outbound_tx.lock().await;
if let Some(sender) = outbound.as_ref() {
sender.send(msg).await.map_err(|e| {
AgentProtocolError::ConnectionFailed(format!("Send failed: {}", e))
})?;
} else {
return Err(AgentProtocolError::ConnectionFailed(
"Not connected".to_string(),
));
}
}
match tokio::time::timeout(self.timeout, rx).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(_)) => {
self.pending.lock().await.remove(correlation_id);
Err(AgentProtocolError::ConnectionFailed(
"Response channel closed".to_string(),
))
}
Err(_) => {
self.pending.lock().await.remove(correlation_id);
Err(AgentProtocolError::Timeout(self.timeout))
}
}
}
pub async fn ping(&self) -> Result<Duration, AgentProtocolError> {
let sequence = self.ping_sequence.fetch_add(1, Ordering::SeqCst);
let timestamp_ms = now_ms();
let msg = ProxyToAgent {
message: Some(grpc_v2::proxy_to_agent::Message::Ping(grpc_v2::Ping {
sequence,
timestamp_ms,
})),
};
let outbound = self.outbound_tx.lock().await;
if let Some(sender) = outbound.as_ref() {
sender
.send(msg)
.await
.map_err(|e| AgentProtocolError::ConnectionFailed(format!("Ping failed: {}", e)))?;
}
Ok(Duration::from_millis(0))
}
pub fn protocol_version(&self) -> u32 {
self.protocol_version.load(Ordering::SeqCst) as u32
}
pub async fn capabilities(&self) -> Option<AgentCapabilities> {
self.capabilities.read().await.clone()
}
pub async fn is_connected(&self) -> bool {
*self.connected.read().await
}
pub async fn close(&self) -> Result<(), AgentProtocolError> {
*self.outbound_tx.lock().await = None;
*self.connected.write().await = false;
Ok(())
}
pub async fn cancel_request(
&self,
correlation_id: &str,
reason: CancelReason,
) -> Result<(), AgentProtocolError> {
self.pending.lock().await.remove(correlation_id);
let msg = ProxyToAgent {
message: Some(grpc_v2::proxy_to_agent::Message::Cancel(
grpc_v2::CancelRequest {
correlation_id: correlation_id.to_string(),
reason: reason.to_grpc(),
timestamp_ms: now_ms(),
blocking_agent_id: None,
manual_reason: None,
},
)),
};
let outbound = self.outbound_tx.lock().await;
if let Some(sender) = outbound.as_ref() {
sender.send(msg).await.map_err(|e| {
AgentProtocolError::ConnectionFailed(format!("Cancel send failed: {}", e))
})?;
}
debug!(
agent_id = %self.agent_id,
correlation_id = %correlation_id,
reason = ?reason,
"Cancelled request"
);
Ok(())
}
pub async fn cancel_all(&self, reason: CancelReason) -> Result<usize, AgentProtocolError> {
let correlation_ids: Vec<String> = {
let pending = self.pending.lock().await;
pending.keys().cloned().collect()
};
let count = correlation_ids.len();
for cid in correlation_ids {
let _ = self.cancel_request(&cid, reason).await;
}
debug!(
agent_id = %self.agent_id,
count = count,
reason = ?reason,
"Cancelled all requests"
);
Ok(count)
}
pub async fn flow_state(&self) -> FlowState {
*self.flow_state.read().await
}
pub async fn can_accept_requests(&self) -> bool {
matches!(*self.flow_state.read().await, FlowState::Normal)
}
pub async fn wait_for_flow_control(&self, timeout: Duration) -> Result<(), AgentProtocolError> {
let deadline = tokio::time::Instant::now() + timeout;
loop {
if self.can_accept_requests().await {
return Ok(());
}
if tokio::time::Instant::now() >= deadline {
return Err(AgentProtocolError::Timeout(timeout));
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
pub async fn health_state(&self) -> i32 {
*self.health_state.read().await
}
pub async fn is_healthy(&self) -> bool {
*self.health_state.read().await == 1
}
pub fn in_flight_count(&self) -> u64 {
self.in_flight.load(Ordering::Relaxed)
}
pub async fn send_configure(
&self,
config: serde_json::Value,
version: Option<String>,
) -> Result<(), AgentProtocolError> {
let msg = ProxyToAgent {
message: Some(grpc_v2::proxy_to_agent::Message::Configure(
grpc_v2::ConfigureEvent {
config_json: serde_json::to_string(&config).unwrap_or_default(),
config_version: version,
is_initial: false,
timestamp_ms: now_ms(),
},
)),
};
let outbound = self.outbound_tx.lock().await;
if let Some(sender) = outbound.as_ref() {
sender.send(msg).await.map_err(|e| {
AgentProtocolError::ConnectionFailed(format!("Configure send failed: {}", e))
})?;
} else {
return Err(AgentProtocolError::ConnectionFailed(
"Not connected".to_string(),
));
}
debug!(agent_id = %self.agent_id, "Sent configuration update");
Ok(())
}
pub async fn send_shutdown(
&self,
reason: ShutdownReason,
grace_period_ms: u64,
) -> Result<(), AgentProtocolError> {
info!(
agent_id = %self.agent_id,
reason = ?reason,
grace_period_ms = grace_period_ms,
"Requesting agent shutdown"
);
let _ = self.cancel_all(CancelReason::ProxyShutdown).await;
self.close().await
}
pub async fn send_drain(
&self,
duration_ms: u64,
reason: DrainReason,
) -> Result<(), AgentProtocolError> {
info!(
agent_id = %self.agent_id,
duration_ms = duration_ms,
reason = ?reason,
"Requesting agent drain"
);
*self.flow_state.write().await = FlowState::Draining;
Ok(())
}
pub fn agent_id(&self) -> &str {
&self.agent_id
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShutdownReason {
Graceful,
Immediate,
ConfigReload,
Upgrade,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DrainReason {
ConfigReload,
Maintenance,
HealthCheckFailed,
Manual,
}
fn convert_capabilities_from_grpc(caps: grpc_v2::AgentCapabilities) -> AgentCapabilities {
use crate::v2::{AgentFeatures, AgentLimits, HealthConfig};
let features = caps
.features
.map(|f| AgentFeatures {
streaming_body: f.streaming_body,
websocket: f.websocket,
guardrails: f.guardrails,
config_push: f.config_push,
metrics_export: f.metrics_export,
concurrent_requests: f.concurrent_requests,
cancellation: f.cancellation,
flow_control: f.flow_control,
health_reporting: f.health_reporting,
})
.unwrap_or_default();
let limits = caps
.limits
.map(|l| AgentLimits {
max_body_size: l.max_body_size as usize,
max_concurrency: l.max_concurrency,
preferred_chunk_size: l.preferred_chunk_size as usize,
max_memory: l.max_memory.map(|m| m as usize),
max_processing_time_ms: l.max_processing_time_ms,
})
.unwrap_or_default();
let health = caps
.health_config
.map(|h| HealthConfig {
report_interval_ms: h.report_interval_ms,
include_load_metrics: h.include_load_metrics,
include_resource_metrics: h.include_resource_metrics,
})
.unwrap_or_default();
AgentCapabilities {
protocol_version: caps.protocol_version,
agent_id: caps.agent_id,
name: caps.name,
version: caps.version,
supported_events: caps
.supported_events
.into_iter()
.filter_map(i32_to_event_type)
.collect(),
features,
limits,
health,
}
}
fn i32_to_event_type(i: i32) -> Option<EventType> {
match i {
1 => Some(EventType::RequestHeaders),
2 => Some(EventType::RequestBodyChunk),
3 => Some(EventType::ResponseHeaders),
4 => Some(EventType::ResponseBodyChunk),
5 => Some(EventType::RequestComplete),
6 => Some(EventType::WebSocketFrame),
7 => Some(EventType::GuardrailInspect),
8 => Some(EventType::Configure),
_ => None,
}
}
fn convert_request_headers_to_grpc(
event: &crate::RequestHeadersEvent,
) -> grpc_v2::RequestHeadersEvent {
let metadata = Some(grpc_v2::RequestMetadata {
correlation_id: event.metadata.correlation_id.clone(),
request_id: event.metadata.request_id.clone(),
client_ip: event.metadata.client_ip.clone(),
client_port: event.metadata.client_port as u32,
server_name: event.metadata.server_name.clone(),
protocol: event.metadata.protocol.clone(),
tls_version: event.metadata.tls_version.clone(),
route_id: event.metadata.route_id.clone(),
upstream_id: event.metadata.upstream_id.clone(),
timestamp_ms: now_ms(),
traceparent: event.metadata.traceparent.clone(),
});
let headers: Vec<grpc_v2::Header> = iter_flat(&event.headers)
.map(|(name, value)| grpc_v2::Header {
name: name.to_string(),
value: value.to_string(),
})
.collect();
grpc_v2::RequestHeadersEvent {
metadata,
method: event.method.clone(),
uri: event.uri.clone(),
http_version: "HTTP/1.1".to_string(),
headers,
}
}
fn convert_body_chunk_to_grpc(event: &crate::RequestBodyChunkEvent) -> grpc_v2::BodyChunkEvent {
let binary: crate::BinaryRequestBodyChunkEvent = event.into();
convert_binary_body_chunk_to_grpc(&binary)
}
fn convert_binary_body_chunk_to_grpc(
event: &crate::BinaryRequestBodyChunkEvent,
) -> grpc_v2::BodyChunkEvent {
grpc_v2::BodyChunkEvent {
correlation_id: event.correlation_id.clone(),
chunk_index: event.chunk_index,
data: event.data.to_vec(), is_last: event.is_last,
total_size: event.total_size.map(|s| s as u64),
bytes_transferred: event.bytes_received as u64,
proxy_buffer_available: 0, timestamp_ms: now_ms(),
}
}
fn convert_response_headers_to_grpc(
event: &crate::ResponseHeadersEvent,
) -> grpc_v2::ResponseHeadersEvent {
let headers: Vec<grpc_v2::Header> = iter_flat(&event.headers)
.map(|(name, value)| grpc_v2::Header {
name: name.to_string(),
value: value.to_string(),
})
.collect();
grpc_v2::ResponseHeadersEvent {
correlation_id: event.correlation_id.clone(),
status_code: event.status as u32,
headers,
}
}
fn convert_response_body_chunk_to_grpc(
event: &crate::ResponseBodyChunkEvent,
) -> grpc_v2::BodyChunkEvent {
let binary: crate::BinaryResponseBodyChunkEvent = event.into();
convert_binary_response_body_chunk_to_grpc(&binary)
}
fn convert_binary_response_body_chunk_to_grpc(
event: &crate::BinaryResponseBodyChunkEvent,
) -> grpc_v2::BodyChunkEvent {
grpc_v2::BodyChunkEvent {
correlation_id: event.correlation_id.clone(),
chunk_index: event.chunk_index,
data: event.data.to_vec(), is_last: event.is_last,
total_size: event.total_size.map(|s| s as u64),
bytes_transferred: event.bytes_sent as u64,
proxy_buffer_available: 0,
timestamp_ms: now_ms(),
}
}
fn convert_response_from_grpc(resp: grpc_v2::AgentResponse) -> AgentResponse {
let decision = match resp.decision {
Some(grpc_v2::agent_response::Decision::Allow(_)) => Decision::Allow,
Some(grpc_v2::agent_response::Decision::Block(b)) => Decision::Block {
status: b.status as u16,
body: b.body,
headers: if b.headers.is_empty() {
None
} else {
Some(b.headers.into_iter().map(|h| (h.name, h.value)).collect())
},
},
Some(grpc_v2::agent_response::Decision::Redirect(r)) => Decision::Redirect {
url: r.url,
status: r.status as u16,
},
Some(grpc_v2::agent_response::Decision::Challenge(c)) => Decision::Challenge {
challenge_type: c.challenge_type,
params: c.params,
},
None => Decision::Allow,
};
let request_headers: Vec<HeaderOp> = resp
.request_headers
.into_iter()
.filter_map(convert_header_op_from_grpc)
.collect();
let response_headers: Vec<HeaderOp> = resp
.response_headers
.into_iter()
.filter_map(convert_header_op_from_grpc)
.collect();
let audit = resp
.audit
.map(|a| crate::AuditMetadata {
tags: a.tags,
rule_ids: a.rule_ids,
confidence: a.confidence,
reason_codes: a.reason_codes,
custom: a
.custom
.into_iter()
.map(|(k, v)| (k, serde_json::Value::String(v)))
.collect(),
})
.unwrap_or_default();
AgentResponse {
version: PROTOCOL_VERSION_2,
decision,
request_headers,
response_headers,
routing_metadata: HashMap::new(),
audit,
needs_more: resp.needs_more,
request_body_mutation: None,
response_body_mutation: None,
websocket_decision: None,
}
}
fn convert_header_op_from_grpc(op: grpc_v2::HeaderOp) -> Option<HeaderOp> {
match op.operation {
Some(grpc_v2::header_op::Operation::Set(h)) => Some(HeaderOp::Set {
name: h.name,
value: h.value,
}),
Some(grpc_v2::header_op::Operation::Add(h)) => Some(HeaderOp::Add {
name: h.name,
value: h.value,
}),
Some(grpc_v2::header_op::Operation::Remove(name)) => Some(HeaderOp::Remove { name }),
None => None,
}
}
fn convert_metrics_from_grpc(
report: grpc_v2::MetricsReport,
agent_id: &str,
) -> crate::v2::MetricsReport {
use crate::v2::metrics::{CounterMetric, GaugeMetric, HistogramBucket, HistogramMetric};
let counters = report
.counters
.into_iter()
.map(|c| CounterMetric {
name: c.name,
help: c.help.filter(|s| !s.is_empty()),
labels: c.labels,
value: c.value,
})
.collect();
let gauges = report
.gauges
.into_iter()
.map(|g| GaugeMetric {
name: g.name,
help: g.help.filter(|s| !s.is_empty()),
labels: g.labels,
value: g.value,
})
.collect();
let histograms = report
.histograms
.into_iter()
.map(|h| HistogramMetric {
name: h.name,
help: h.help.filter(|s| !s.is_empty()),
labels: h.labels,
sum: h.sum,
count: h.count,
buckets: h
.buckets
.into_iter()
.map(|b| HistogramBucket {
le: b.le,
count: b.count,
})
.collect(),
})
.collect();
crate::v2::MetricsReport {
agent_id: agent_id.to_string(),
timestamp_ms: report.timestamp_ms,
interval_ms: report.interval_ms,
counters,
gauges,
histograms,
}
}
fn convert_config_update_from_grpc(
update: grpc_v2::ConfigUpdateRequest,
) -> crate::v2::ConfigUpdateRequest {
use crate::v2::control::{ConfigUpdateType, RuleDefinition};
let update_type = match update.update_type {
Some(grpc_v2::config_update_request::UpdateType::RequestReload(_)) => {
ConfigUpdateType::RequestReload
}
Some(grpc_v2::config_update_request::UpdateType::RuleUpdate(ru)) => {
ConfigUpdateType::RuleUpdate {
rule_set: ru.rule_set,
rules: ru
.rules
.into_iter()
.map(|r| RuleDefinition {
id: r.id,
priority: r.priority,
definition: serde_json::from_str(&r.definition_json).unwrap_or_default(),
enabled: r.enabled,
description: r.description,
tags: r.tags,
})
.collect(),
remove_rules: ru.remove_rules,
}
}
Some(grpc_v2::config_update_request::UpdateType::ListUpdate(lu)) => {
ConfigUpdateType::ListUpdate {
list_id: lu.list_id,
add: lu.add,
remove: lu.remove,
}
}
Some(grpc_v2::config_update_request::UpdateType::RestartRequired(rr)) => {
ConfigUpdateType::RestartRequired {
reason: rr.reason,
grace_period_ms: rr.grace_period_ms,
}
}
Some(grpc_v2::config_update_request::UpdateType::ConfigError(ce)) => {
ConfigUpdateType::ConfigError {
error: ce.error,
field: ce.field,
}
}
None => ConfigUpdateType::RequestReload, };
crate::v2::ConfigUpdateRequest {
update_type,
request_id: update.request_id,
timestamp_ms: update.timestamp_ms,
}
}
fn extract_correlation_id<T: serde::Serialize>(event: &T) -> String {
if let Ok(value) = serde_json::to_value(event) {
if let Some(metadata) = value.get("metadata") {
if let Some(cid) = metadata.get("correlation_id").and_then(|v| v.as_str()) {
return cid.to_string();
}
}
if let Some(cid) = value.get("correlation_id").and_then(|v| v.as_str()) {
return cid.to_string();
}
}
uuid::Uuid::new_v4().to_string()
}
fn now_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_event_type_conversion() {
assert_eq!(i32_to_event_type(1), Some(EventType::RequestHeaders));
assert_eq!(i32_to_event_type(2), Some(EventType::RequestBodyChunk));
assert_eq!(i32_to_event_type(99), None);
}
#[test]
fn test_extract_correlation_id() {
#[derive(serde::Serialize)]
struct TestEvent {
correlation_id: String,
}
let event = TestEvent {
correlation_id: "test-123".to_string(),
};
assert_eq!(extract_correlation_id(&event), "test-123");
}
}