use bytes::{Bytes, BytesMut};
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::{debug, trace, warn};
use super::codec::{WebSocketCodec, WebSocketFrame};
use super::inspector::{InspectionResult, WebSocketInspector};
#[async_trait::async_trait]
pub trait FrameInspector: Send + Sync {
async fn inspect_client_frame(&self, frame: &WebSocketFrame) -> InspectionResult;
async fn inspect_server_frame(&self, frame: &WebSocketFrame) -> InspectionResult;
fn correlation_id(&self) -> &str;
}
#[async_trait::async_trait]
impl FrameInspector for WebSocketInspector {
async fn inspect_client_frame(&self, frame: &WebSocketFrame) -> InspectionResult {
WebSocketInspector::inspect_client_frame(self, frame).await
}
async fn inspect_server_frame(&self, frame: &WebSocketFrame) -> InspectionResult {
WebSocketInspector::inspect_server_frame(self, frame).await
}
fn correlation_id(&self) -> &str {
WebSocketInspector::correlation_id(self)
}
}
pub struct WebSocketHandler<I: FrameInspector = WebSocketInspector> {
codec: WebSocketCodec,
inspector: Arc<I>,
client_buffer: Mutex<BytesMut>,
server_buffer: Mutex<BytesMut>,
should_close: Mutex<Option<CloseReason>>,
}
#[derive(Debug, Clone)]
pub struct CloseReason {
pub code: u16,
pub reason: String,
}
#[derive(Debug)]
pub enum ProcessResult {
Forward(Option<Bytes>),
Close(CloseReason),
}
impl<I: FrameInspector> WebSocketHandler<I> {
pub fn with_inspector(inspector: Arc<I>, max_frame_size: usize) -> Self {
debug!(
correlation_id = %inspector.correlation_id(),
max_frame_size = max_frame_size,
"Creating WebSocket handler"
);
Self {
codec: WebSocketCodec::new(max_frame_size),
inspector,
client_buffer: Mutex::new(BytesMut::with_capacity(4096)),
server_buffer: Mutex::new(BytesMut::with_capacity(4096)),
should_close: Mutex::new(None),
}
}
}
impl WebSocketHandler<WebSocketInspector> {
pub fn new(inspector: Arc<WebSocketInspector>, max_frame_size: usize) -> Self {
Self::with_inspector(inspector, max_frame_size)
}
}
impl<I: FrameInspector> WebSocketHandler<I> {
pub async fn process_client_data(&self, data: Option<Bytes>) -> ProcessResult {
if let Some(reason) = self.should_close.lock().await.clone() {
return ProcessResult::Close(reason);
}
let Some(data) = data else {
return ProcessResult::Forward(None);
};
self.process_data(data, true).await
}
pub async fn process_server_data(&self, data: Option<Bytes>) -> ProcessResult {
if let Some(reason) = self.should_close.lock().await.clone() {
return ProcessResult::Close(reason);
}
let Some(data) = data else {
return ProcessResult::Forward(None);
};
self.process_data(data, false).await
}
async fn process_data(&self, data: Bytes, client_to_server: bool) -> ProcessResult {
let buffer = if client_to_server {
&self.client_buffer
} else {
&self.server_buffer
};
let mut buf = buffer.lock().await;
buf.extend_from_slice(&data);
let mut output = BytesMut::new();
let mut frames_processed = 0;
let mut frames_dropped = 0;
loop {
match self.codec.decode_frame(&buf) {
Ok(Some((frame, consumed))) => {
frames_processed += 1;
let result = if client_to_server {
self.inspector.inspect_client_frame(&frame).await
} else {
self.inspector.inspect_server_frame(&frame).await
};
match result {
InspectionResult::Allow => {
output.extend_from_slice(&buf[..consumed]);
}
InspectionResult::Drop => {
frames_dropped += 1;
trace!(
correlation_id = %self.inspector.correlation_id(),
opcode = ?frame.opcode,
direction = if client_to_server { "c2s" } else { "s2c" },
"Dropping WebSocket frame"
);
}
InspectionResult::Close { code, reason } => {
debug!(
correlation_id = %self.inspector.correlation_id(),
code = code,
reason = %reason,
"Agent requested WebSocket close"
);
*self.should_close.lock().await = Some(CloseReason {
code,
reason: reason.clone(),
});
let close_frame = WebSocketFrame::close(code, &reason);
if let Ok(encoded) =
self.codec.encode_frame(&close_frame, !client_to_server)
{
output.extend_from_slice(&encoded);
}
let _ = buf.split_to(consumed);
return ProcessResult::Close(CloseReason { code, reason });
}
}
let _ = buf.split_to(consumed);
}
Ok(None) => {
break;
}
Err(e) => {
warn!(
correlation_id = %self.inspector.correlation_id(),
error = %e,
"WebSocket frame decode error"
);
output.extend_from_slice(&buf);
buf.clear();
break;
}
}
}
if frames_processed > 0 {
trace!(
correlation_id = %self.inspector.correlation_id(),
frames_processed = frames_processed,
frames_dropped = frames_dropped,
output_len = output.len(),
buffer_remaining = buf.len(),
direction = if client_to_server { "c2s" } else { "s2c" },
"Processed WebSocket frames"
);
}
if output.is_empty() && frames_dropped > 0 {
ProcessResult::Forward(Some(Bytes::new()))
} else if output.is_empty() {
ProcessResult::Forward(Some(Bytes::new()))
} else {
ProcessResult::Forward(Some(output.freeze()))
}
}
pub async fn should_close(&self) -> Option<CloseReason> {
self.should_close.lock().await.clone()
}
pub fn correlation_id(&self) -> &str {
self.inspector.correlation_id()
}
}
pub struct WebSocketHandlerBuilder {
inspector: Option<Arc<WebSocketInspector>>,
max_frame_size: usize,
}
impl Default for WebSocketHandlerBuilder {
fn default() -> Self {
Self {
inspector: None,
max_frame_size: 1024 * 1024, }
}
}
impl WebSocketHandlerBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn inspector(mut self, inspector: Arc<WebSocketInspector>) -> Self {
self.inspector = Some(inspector);
self
}
pub fn max_frame_size(mut self, size: usize) -> Self {
self.max_frame_size = size;
self
}
pub fn build(self) -> Option<WebSocketHandler> {
Some(WebSocketHandler::new(self.inspector?, self.max_frame_size))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::websocket::codec::Opcode;
use std::sync::atomic::{AtomicUsize, Ordering};
struct MockInspector {
client_decision: InspectionResult,
server_decision: InspectionResult,
client_frame_count: AtomicUsize,
server_frame_count: AtomicUsize,
}
impl MockInspector {
fn new(client_decision: InspectionResult, server_decision: InspectionResult) -> Self {
Self {
client_decision,
server_decision,
client_frame_count: AtomicUsize::new(0),
server_frame_count: AtomicUsize::new(0),
}
}
fn allowing() -> Self {
Self::new(InspectionResult::Allow, InspectionResult::Allow)
}
fn dropping_client() -> Self {
Self::new(InspectionResult::Drop, InspectionResult::Allow)
}
fn dropping_server() -> Self {
Self::new(InspectionResult::Allow, InspectionResult::Drop)
}
fn closing_client(code: u16, reason: &str) -> Self {
Self::new(
InspectionResult::Close {
code,
reason: reason.to_string(),
},
InspectionResult::Allow,
)
}
fn client_frames_inspected(&self) -> usize {
self.client_frame_count.load(Ordering::SeqCst)
}
fn server_frames_inspected(&self) -> usize {
self.server_frame_count.load(Ordering::SeqCst)
}
}
#[async_trait::async_trait]
impl FrameInspector for MockInspector {
async fn inspect_client_frame(&self, _frame: &WebSocketFrame) -> InspectionResult {
self.client_frame_count.fetch_add(1, Ordering::SeqCst);
self.client_decision.clone()
}
async fn inspect_server_frame(&self, _frame: &WebSocketFrame) -> InspectionResult {
self.server_frame_count.fetch_add(1, Ordering::SeqCst);
self.server_decision.clone()
}
fn correlation_id(&self) -> &str {
"test-correlation-id"
}
}
fn make_text_frame(text: &str, masked: bool) -> Bytes {
let codec = WebSocketCodec::new(1024 * 1024);
let frame = WebSocketFrame::new(Opcode::Text, text.as_bytes().to_vec());
Bytes::from(codec.encode_frame(&frame, masked).unwrap())
}
#[test]
fn test_close_reason() {
let reason = CloseReason {
code: 1000,
reason: "Normal closure".to_string(),
};
assert_eq!(reason.code, 1000);
assert_eq!(reason.reason, "Normal closure");
}
#[test]
fn test_builder_defaults() {
let builder = WebSocketHandlerBuilder::new();
assert_eq!(builder.max_frame_size, 1024 * 1024);
}
#[tokio::test]
async fn test_frame_allow() {
let inspector = Arc::new(MockInspector::allowing());
let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
let frame_data = make_text_frame("Hello", false);
let result = handler.process_client_data(Some(frame_data.clone())).await;
match result {
ProcessResult::Forward(Some(data)) => {
assert_eq!(data, frame_data);
}
_ => panic!("Expected Forward result"),
}
assert_eq!(inspector.client_frames_inspected(), 1);
}
#[tokio::test]
async fn test_frame_drop_client() {
let inspector = Arc::new(MockInspector::dropping_client());
let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
let frame_data = make_text_frame("Hello", false);
let result = handler.process_client_data(Some(frame_data)).await;
match result {
ProcessResult::Forward(Some(data)) => {
assert!(data.is_empty(), "Dropped frame should produce empty output");
}
_ => panic!("Expected Forward with empty data"),
}
assert_eq!(inspector.client_frames_inspected(), 1);
}
#[tokio::test]
async fn test_frame_drop_server() {
let inspector = Arc::new(MockInspector::dropping_server());
let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
let frame_data = make_text_frame("Server message", false);
let result = handler.process_server_data(Some(frame_data)).await;
match result {
ProcessResult::Forward(Some(data)) => {
assert!(data.is_empty(), "Dropped frame should produce empty output");
}
_ => panic!("Expected Forward with empty data"),
}
assert_eq!(inspector.server_frames_inspected(), 1);
}
#[tokio::test]
async fn test_frame_close() {
let inspector = Arc::new(MockInspector::closing_client(1008, "Policy violation"));
let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
let frame_data = make_text_frame("Malicious content", false);
let result = handler.process_client_data(Some(frame_data)).await;
match result {
ProcessResult::Close(reason) => {
assert_eq!(reason.code, 1008);
assert_eq!(reason.reason, "Policy violation");
}
_ => panic!("Expected Close result"),
}
assert_eq!(inspector.client_frames_inspected(), 1);
let result = handler
.process_client_data(Some(make_text_frame("More data", false)))
.await;
match result {
ProcessResult::Close(_) => {}
_ => panic!("Expected Close result on subsequent call"),
}
}
#[tokio::test]
async fn test_multiple_frames_mixed_decisions() {
let inspector = Arc::new(MockInspector::allowing());
let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
let frame1 = make_text_frame("Frame 1", false);
let result = handler.process_client_data(Some(frame1.clone())).await;
assert!(matches!(result, ProcessResult::Forward(Some(_))));
let frame2 = make_text_frame("Frame 2", false);
let result = handler.process_client_data(Some(frame2.clone())).await;
assert!(matches!(result, ProcessResult::Forward(Some(_))));
assert_eq!(inspector.client_frames_inspected(), 2);
}
#[tokio::test]
async fn test_end_of_stream() {
let inspector = Arc::new(MockInspector::allowing());
let handler = WebSocketHandler::with_inspector(inspector, 1024 * 1024);
let result = handler.process_client_data(None).await;
match result {
ProcessResult::Forward(None) => {}
_ => panic!("Expected Forward(None) for end of stream"),
}
}
#[tokio::test]
async fn test_partial_frame_buffering() {
let inspector = Arc::new(MockInspector::allowing());
let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
let full_frame = make_text_frame("Hello World", false);
let (part1, part2) = full_frame.split_at(full_frame.len() / 2);
let result = handler
.process_client_data(Some(Bytes::from(part1.to_vec())))
.await;
match result {
ProcessResult::Forward(Some(data)) => {
assert!(data.is_empty(), "Partial frame should not produce output");
}
_ => panic!("Expected Forward with empty data for partial frame"),
}
assert_eq!(
inspector.client_frames_inspected(),
0,
"Partial frame should not be inspected"
);
let result = handler
.process_client_data(Some(Bytes::from(part2.to_vec())))
.await;
match result {
ProcessResult::Forward(Some(data)) => {
assert_eq!(data, full_frame, "Complete frame should be forwarded");
}
_ => panic!("Expected Forward with complete frame"),
}
assert_eq!(
inspector.client_frames_inspected(),
1,
"Complete frame should be inspected"
);
}
#[tokio::test]
async fn test_bidirectional_independence() {
let inspector = Arc::new(MockInspector::new(
InspectionResult::Drop,
InspectionResult::Allow,
));
let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
let client_frame = make_text_frame("Client", false);
let result = handler.process_client_data(Some(client_frame)).await;
match result {
ProcessResult::Forward(Some(data)) => assert!(data.is_empty()),
_ => panic!("Expected empty forward for dropped client frame"),
}
let server_frame = make_text_frame("Server", false);
let original_len = server_frame.len();
let result = handler.process_server_data(Some(server_frame)).await;
match result {
ProcessResult::Forward(Some(data)) => assert_eq!(data.len(), original_len),
_ => panic!("Expected forward for allowed server frame"),
}
assert_eq!(inspector.client_frames_inspected(), 1);
assert_eq!(inspector.server_frames_inspected(), 1);
}
}