use base64::{engine::general_purpose::STANDARD, Engine as _};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tracing::{debug, trace, warn};
use zentinel_agent_protocol::{WebSocketDecision, WebSocketFrameEvent};
use zentinel_common::observability::RequestMetrics;
use super::codec::WebSocketFrame;
use crate::agents::AgentManager;
#[derive(Debug, Clone)]
pub enum InspectionResult {
Allow,
Drop,
Close { code: u16, reason: String },
}
impl From<WebSocketDecision> for InspectionResult {
fn from(decision: WebSocketDecision) -> Self {
match decision {
WebSocketDecision::Allow => InspectionResult::Allow,
WebSocketDecision::Drop => InspectionResult::Drop,
WebSocketDecision::Close { code, reason } => InspectionResult::Close { code, reason },
}
}
}
pub struct WebSocketInspector {
agent_manager: Arc<AgentManager>,
route_id: String,
correlation_id: String,
client_ip: String,
client_frame_index: AtomicU64,
server_frame_index: AtomicU64,
timeout_ms: u64,
metrics: Option<Arc<RequestMetrics>>,
}
impl WebSocketInspector {
pub fn new(
agent_manager: Arc<AgentManager>,
route_id: String,
correlation_id: String,
client_ip: String,
timeout_ms: u64,
) -> Self {
Self::with_metrics(
agent_manager,
route_id,
correlation_id,
client_ip,
timeout_ms,
None,
)
}
pub fn with_metrics(
agent_manager: Arc<AgentManager>,
route_id: String,
correlation_id: String,
client_ip: String,
timeout_ms: u64,
metrics: Option<Arc<RequestMetrics>>,
) -> Self {
debug!(
route_id = %route_id,
correlation_id = %correlation_id,
"Creating WebSocket inspector"
);
if let Some(ref m) = metrics {
m.record_websocket_connection(&route_id);
}
Self {
agent_manager,
route_id,
correlation_id,
client_ip,
client_frame_index: AtomicU64::new(0),
server_frame_index: AtomicU64::new(0),
timeout_ms,
metrics,
}
}
pub async fn inspect_client_frame(&self, frame: &WebSocketFrame) -> InspectionResult {
let frame_index = self.client_frame_index.fetch_add(1, Ordering::SeqCst);
trace!(
correlation_id = %self.correlation_id,
frame_index = frame_index,
opcode = ?frame.opcode,
"Inspecting client frame"
);
self.inspect_frame(frame, true, frame_index).await
}
pub async fn inspect_server_frame(&self, frame: &WebSocketFrame) -> InspectionResult {
let frame_index = self.server_frame_index.fetch_add(1, Ordering::SeqCst);
trace!(
correlation_id = %self.correlation_id,
frame_index = frame_index,
opcode = ?frame.opcode,
"Inspecting server frame"
);
self.inspect_frame(frame, false, frame_index).await
}
async fn inspect_frame(
&self,
frame: &WebSocketFrame,
client_to_server: bool,
frame_index: u64,
) -> InspectionResult {
let start = Instant::now();
let direction = if client_to_server { "c2s" } else { "s2c" };
let opcode = frame.opcode.as_str();
if let Some(ref metrics) = self.metrics {
metrics.record_websocket_frame_size(
&self.route_id,
direction,
opcode,
frame.payload.len(),
);
}
let event = WebSocketFrameEvent {
correlation_id: self.correlation_id.clone(),
opcode: opcode.to_string(),
data: STANDARD.encode(&frame.payload),
client_to_server,
frame_index,
fin: frame.fin,
route_id: Some(self.route_id.clone()),
client_ip: self.client_ip.clone(),
};
let result = match tokio::time::timeout(
std::time::Duration::from_millis(self.timeout_ms),
self.agent_manager
.process_websocket_frame(&self.route_id, event),
)
.await
{
Ok(Ok(response)) => {
if let Some(ws_decision) = response.websocket_decision {
let result = InspectionResult::from(ws_decision);
trace!(
correlation_id = %self.correlation_id,
frame_index = frame_index,
decision = ?result,
"Frame inspection complete"
);
result
} else {
InspectionResult::Allow
}
}
Ok(Err(e)) => {
warn!(
correlation_id = %self.correlation_id,
error = %e,
"Agent error during frame inspection, allowing frame"
);
InspectionResult::Allow
}
Err(_) => {
warn!(
correlation_id = %self.correlation_id,
timeout_ms = self.timeout_ms,
"Agent timeout during frame inspection, allowing frame"
);
InspectionResult::Allow
}
};
if let Some(ref metrics) = self.metrics {
let duration = start.elapsed();
metrics.record_websocket_inspection_duration(&self.route_id, duration);
let decision_str = match &result {
InspectionResult::Allow => "allow",
InspectionResult::Drop => "drop",
InspectionResult::Close { .. } => "close",
};
metrics.record_websocket_frame(&self.route_id, direction, opcode, decision_str);
}
result
}
pub fn correlation_id(&self) -> &str {
&self.correlation_id
}
pub fn route_id(&self) -> &str {
&self.route_id
}
}
pub struct WebSocketInspectorBuilder {
agent_manager: Option<Arc<AgentManager>>,
route_id: Option<String>,
correlation_id: Option<String>,
client_ip: Option<String>,
timeout_ms: u64,
metrics: Option<Arc<RequestMetrics>>,
}
impl Default for WebSocketInspectorBuilder {
fn default() -> Self {
Self {
agent_manager: None,
route_id: None,
correlation_id: None,
client_ip: None,
timeout_ms: 100, metrics: None,
}
}
}
impl WebSocketInspectorBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn agent_manager(mut self, manager: Arc<AgentManager>) -> Self {
self.agent_manager = Some(manager);
self
}
pub fn route_id(mut self, id: impl Into<String>) -> Self {
self.route_id = Some(id.into());
self
}
pub fn correlation_id(mut self, id: impl Into<String>) -> Self {
self.correlation_id = Some(id.into());
self
}
pub fn client_ip(mut self, ip: impl Into<String>) -> Self {
self.client_ip = Some(ip.into());
self
}
pub fn timeout_ms(mut self, ms: u64) -> Self {
self.timeout_ms = ms;
self
}
pub fn metrics(mut self, metrics: Arc<RequestMetrics>) -> Self {
self.metrics = Some(metrics);
self
}
pub fn build(self) -> Option<WebSocketInspector> {
Some(WebSocketInspector::with_metrics(
self.agent_manager?,
self.route_id?,
self.correlation_id?,
self.client_ip?,
self.timeout_ms,
self.metrics,
))
}
}