sentinel_proxy/websocket/
proxy.rs

1//! WebSocket proxy handler for frame-level inspection.
2//!
3//! This module provides WebSocket frame inspection by intercepting data
4//! in Pingora's body filters after a 101 upgrade. Frames are parsed from
5//! the byte stream, sent to agents for inspection, and filtered based on
6//! agent decisions.
7//!
8//! # Architecture
9//!
10//! After a 101 upgrade, Pingora treats the bidirectional data as "body" bytes.
11//! We intercept these in `request_body_filter` (client→server) and
12//! `response_body_filter` (server→client), parse WebSocket frames, and
13//! apply agent decisions.
14//!
15//! ```text
16//! Client → [body_filter] → Frame Parser → Agent → Forward/Drop/Close
17//!                                ↓
18//! Server ← [body_filter] ← Frame Parser ← Agent ← Forward/Drop/Close
19//! ```
20
21use bytes::{Bytes, BytesMut};
22use std::sync::Arc;
23use tokio::sync::Mutex;
24use tracing::{debug, trace, warn};
25
26use super::codec::{WebSocketCodec, WebSocketFrame};
27use super::inspector::{InspectionResult, WebSocketInspector};
28
29/// Trait for WebSocket frame inspection.
30///
31/// This trait abstracts the frame inspection logic, allowing for
32/// easy testing with mock implementations.
33#[async_trait::async_trait]
34pub trait FrameInspector: Send + Sync {
35    /// Inspect a frame from client to server
36    async fn inspect_client_frame(&self, frame: &WebSocketFrame) -> InspectionResult;
37
38    /// Inspect a frame from server to client
39    async fn inspect_server_frame(&self, frame: &WebSocketFrame) -> InspectionResult;
40
41    /// Get the correlation ID for logging
42    fn correlation_id(&self) -> &str;
43}
44
45/// Implementation of FrameInspector that delegates to WebSocketInspector
46#[async_trait::async_trait]
47impl FrameInspector for WebSocketInspector {
48    async fn inspect_client_frame(&self, frame: &WebSocketFrame) -> InspectionResult {
49        WebSocketInspector::inspect_client_frame(self, frame).await
50    }
51
52    async fn inspect_server_frame(&self, frame: &WebSocketFrame) -> InspectionResult {
53        WebSocketInspector::inspect_server_frame(self, frame).await
54    }
55
56    fn correlation_id(&self) -> &str {
57        WebSocketInspector::correlation_id(self)
58    }
59}
60
61/// WebSocket frame handler for body filter integration.
62///
63/// This handler accumulates bytes from Pingora's body filters, parses them
64/// into WebSocket frames, and applies agent inspection decisions.
65pub struct WebSocketHandler<I: FrameInspector = WebSocketInspector> {
66    /// Frame parser/codec
67    codec: WebSocketCodec,
68    /// Frame inspector for agent integration
69    inspector: Arc<I>,
70    /// Buffer for incomplete frames (client → server)
71    client_buffer: Mutex<BytesMut>,
72    /// Buffer for incomplete frames (server → client)
73    server_buffer: Mutex<BytesMut>,
74    /// Whether the connection should be closed
75    should_close: Mutex<Option<CloseReason>>,
76}
77
78/// Reason for closing the WebSocket connection
79#[derive(Debug, Clone)]
80pub struct CloseReason {
81    pub code: u16,
82    pub reason: String,
83}
84
85/// Result of processing WebSocket data
86#[derive(Debug)]
87pub enum ProcessResult {
88    /// Forward the (possibly modified) data
89    Forward(Option<Bytes>),
90    /// Close the connection with the given code and reason
91    Close(CloseReason),
92}
93
94impl<I: FrameInspector> WebSocketHandler<I> {
95    /// Create a new WebSocket handler with a custom inspector
96    pub fn with_inspector(inspector: Arc<I>, max_frame_size: usize) -> Self {
97        debug!(
98            correlation_id = %inspector.correlation_id(),
99            max_frame_size = max_frame_size,
100            "Creating WebSocket handler"
101        );
102
103        Self {
104            codec: WebSocketCodec::new(max_frame_size),
105            inspector,
106            client_buffer: Mutex::new(BytesMut::with_capacity(4096)),
107            server_buffer: Mutex::new(BytesMut::with_capacity(4096)),
108            should_close: Mutex::new(None),
109        }
110    }
111}
112
113impl WebSocketHandler<WebSocketInspector> {
114    /// Create a new WebSocket handler with the default WebSocketInspector
115    pub fn new(inspector: Arc<WebSocketInspector>, max_frame_size: usize) -> Self {
116        Self::with_inspector(inspector, max_frame_size)
117    }
118}
119
120impl<I: FrameInspector> WebSocketHandler<I> {
121    /// Process data from client to server (request body)
122    ///
123    /// Returns the data to forward (may be modified or None if all frames were dropped)
124    pub async fn process_client_data(&self, data: Option<Bytes>) -> ProcessResult {
125        // Check if we should close
126        if let Some(reason) = self.should_close.lock().await.clone() {
127            return ProcessResult::Close(reason);
128        }
129
130        let Some(data) = data else {
131            // End of stream
132            return ProcessResult::Forward(None);
133        };
134
135        self.process_data(data, true).await
136    }
137
138    /// Process data from server to client (response body)
139    ///
140    /// Returns the data to forward (may be modified or None if all frames were dropped)
141    pub async fn process_server_data(&self, data: Option<Bytes>) -> ProcessResult {
142        // Check if we should close
143        if let Some(reason) = self.should_close.lock().await.clone() {
144            return ProcessResult::Close(reason);
145        }
146
147        let Some(data) = data else {
148            // End of stream
149            return ProcessResult::Forward(None);
150        };
151
152        self.process_data(data, false).await
153    }
154
155    /// Internal data processing
156    async fn process_data(&self, data: Bytes, client_to_server: bool) -> ProcessResult {
157        let buffer = if client_to_server {
158            &self.client_buffer
159        } else {
160            &self.server_buffer
161        };
162
163        let mut buf = buffer.lock().await;
164        buf.extend_from_slice(&data);
165
166        let mut output = BytesMut::new();
167        let mut frames_processed = 0;
168        let mut frames_dropped = 0;
169
170        // Parse and process frames from the buffer
171        loop {
172            // Try to decode a frame
173            match self.codec.decode_frame(&buf) {
174                Ok(Some((frame, consumed))) => {
175                    frames_processed += 1;
176
177                    // Inspect the frame
178                    let result = if client_to_server {
179                        self.inspector.inspect_client_frame(&frame).await
180                    } else {
181                        self.inspector.inspect_server_frame(&frame).await
182                    };
183
184                    match result {
185                        InspectionResult::Allow => {
186                            // Forward the frame - copy the raw bytes
187                            output.extend_from_slice(&buf[..consumed]);
188                        }
189                        InspectionResult::Drop => {
190                            frames_dropped += 1;
191                            trace!(
192                                correlation_id = %self.inspector.correlation_id(),
193                                opcode = ?frame.opcode,
194                                direction = if client_to_server { "c2s" } else { "s2c" },
195                                "Dropping WebSocket frame"
196                            );
197                            // Don't forward this frame
198                        }
199                        InspectionResult::Close { code, reason } => {
200                            debug!(
201                                correlation_id = %self.inspector.correlation_id(),
202                                code = code,
203                                reason = %reason,
204                                "Agent requested WebSocket close"
205                            );
206
207                            // Store close reason
208                            *self.should_close.lock().await = Some(CloseReason {
209                                code,
210                                reason: reason.clone(),
211                            });
212
213                            // Create and forward a close frame
214                            let close_frame = WebSocketFrame::close(code, &reason);
215                            if let Ok(encoded) =
216                                self.codec.encode_frame(&close_frame, !client_to_server)
217                            {
218                                output.extend_from_slice(&encoded);
219                            }
220
221                            // Remove consumed bytes and return
222                            let _ = buf.split_to(consumed);
223                            return ProcessResult::Close(CloseReason { code, reason });
224                        }
225                    }
226
227                    // Remove consumed bytes from buffer
228                    let _ = buf.split_to(consumed);
229                }
230                Ok(None) => {
231                    // Need more data - incomplete frame
232                    break;
233                }
234                Err(e) => {
235                    warn!(
236                        correlation_id = %self.inspector.correlation_id(),
237                        error = %e,
238                        "WebSocket frame decode error"
239                    );
240                    // On decode error, forward the data as-is and clear buffer
241                    // This allows the connection to continue (fail-open)
242                    output.extend_from_slice(&buf);
243                    buf.clear();
244                    break;
245                }
246            }
247        }
248
249        if frames_processed > 0 {
250            trace!(
251                correlation_id = %self.inspector.correlation_id(),
252                frames_processed = frames_processed,
253                frames_dropped = frames_dropped,
254                output_len = output.len(),
255                buffer_remaining = buf.len(),
256                direction = if client_to_server { "c2s" } else { "s2c" },
257                "Processed WebSocket frames"
258            );
259        }
260
261        if output.is_empty() && frames_dropped > 0 {
262            // All frames were dropped, return empty
263            ProcessResult::Forward(Some(Bytes::new()))
264        } else if output.is_empty() {
265            // No complete frames yet, buffer more data
266            // Return empty bytes to signal "nothing to forward yet"
267            ProcessResult::Forward(Some(Bytes::new()))
268        } else {
269            ProcessResult::Forward(Some(output.freeze()))
270        }
271    }
272
273    /// Check if the connection should be closed
274    pub async fn should_close(&self) -> Option<CloseReason> {
275        self.should_close.lock().await.clone()
276    }
277
278    /// Get the correlation ID
279    pub fn correlation_id(&self) -> &str {
280        self.inspector.correlation_id()
281    }
282}
283
284/// Builder for WebSocketHandler
285pub struct WebSocketHandlerBuilder {
286    inspector: Option<Arc<WebSocketInspector>>,
287    max_frame_size: usize,
288}
289
290impl Default for WebSocketHandlerBuilder {
291    fn default() -> Self {
292        Self {
293            inspector: None,
294            max_frame_size: 1024 * 1024, // 1MB default
295        }
296    }
297}
298
299impl WebSocketHandlerBuilder {
300    /// Create a new builder
301    pub fn new() -> Self {
302        Self::default()
303    }
304
305    /// Set the inspector
306    pub fn inspector(mut self, inspector: Arc<WebSocketInspector>) -> Self {
307        self.inspector = Some(inspector);
308        self
309    }
310
311    /// Set the maximum frame size
312    pub fn max_frame_size(mut self, size: usize) -> Self {
313        self.max_frame_size = size;
314        self
315    }
316
317    /// Build the handler
318    pub fn build(self) -> Option<WebSocketHandler> {
319        Some(WebSocketHandler::new(self.inspector?, self.max_frame_size))
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use crate::websocket::codec::Opcode;
327    use std::sync::atomic::{AtomicUsize, Ordering};
328
329    /// Mock inspector for testing that returns configurable decisions
330    struct MockInspector {
331        /// Decision to return for client frames
332        client_decision: InspectionResult,
333        /// Decision to return for server frames
334        server_decision: InspectionResult,
335        /// Count of inspected client frames
336        client_frame_count: AtomicUsize,
337        /// Count of inspected server frames
338        server_frame_count: AtomicUsize,
339    }
340
341    impl MockInspector {
342        fn new(client_decision: InspectionResult, server_decision: InspectionResult) -> Self {
343            Self {
344                client_decision,
345                server_decision,
346                client_frame_count: AtomicUsize::new(0),
347                server_frame_count: AtomicUsize::new(0),
348            }
349        }
350
351        fn allowing() -> Self {
352            Self::new(InspectionResult::Allow, InspectionResult::Allow)
353        }
354
355        fn dropping_client() -> Self {
356            Self::new(InspectionResult::Drop, InspectionResult::Allow)
357        }
358
359        fn dropping_server() -> Self {
360            Self::new(InspectionResult::Allow, InspectionResult::Drop)
361        }
362
363        fn closing_client(code: u16, reason: &str) -> Self {
364            Self::new(
365                InspectionResult::Close {
366                    code,
367                    reason: reason.to_string(),
368                },
369                InspectionResult::Allow,
370            )
371        }
372
373        fn client_frames_inspected(&self) -> usize {
374            self.client_frame_count.load(Ordering::SeqCst)
375        }
376
377        fn server_frames_inspected(&self) -> usize {
378            self.server_frame_count.load(Ordering::SeqCst)
379        }
380    }
381
382    #[async_trait::async_trait]
383    impl FrameInspector for MockInspector {
384        async fn inspect_client_frame(&self, _frame: &WebSocketFrame) -> InspectionResult {
385            self.client_frame_count.fetch_add(1, Ordering::SeqCst);
386            self.client_decision.clone()
387        }
388
389        async fn inspect_server_frame(&self, _frame: &WebSocketFrame) -> InspectionResult {
390            self.server_frame_count.fetch_add(1, Ordering::SeqCst);
391            self.server_decision.clone()
392        }
393
394        fn correlation_id(&self) -> &str {
395            "test-correlation-id"
396        }
397    }
398
399    /// Helper to create a text frame as bytes
400    fn make_text_frame(text: &str, masked: bool) -> Bytes {
401        let codec = WebSocketCodec::new(1024 * 1024);
402        let frame = WebSocketFrame::new(Opcode::Text, text.as_bytes().to_vec());
403        Bytes::from(codec.encode_frame(&frame, masked).unwrap())
404    }
405
406    #[test]
407    fn test_close_reason() {
408        let reason = CloseReason {
409            code: 1000,
410            reason: "Normal closure".to_string(),
411        };
412        assert_eq!(reason.code, 1000);
413        assert_eq!(reason.reason, "Normal closure");
414    }
415
416    #[test]
417    fn test_builder_defaults() {
418        let builder = WebSocketHandlerBuilder::new();
419        assert_eq!(builder.max_frame_size, 1024 * 1024);
420    }
421
422    #[tokio::test]
423    async fn test_frame_allow() {
424        let inspector = Arc::new(MockInspector::allowing());
425        let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
426
427        // Send a text frame
428        let frame_data = make_text_frame("Hello", false);
429        let result = handler.process_client_data(Some(frame_data.clone())).await;
430
431        match result {
432            ProcessResult::Forward(Some(data)) => {
433                // Frame should be forwarded as-is
434                assert_eq!(data, frame_data);
435            }
436            _ => panic!("Expected Forward result"),
437        }
438
439        assert_eq!(inspector.client_frames_inspected(), 1);
440    }
441
442    #[tokio::test]
443    async fn test_frame_drop_client() {
444        let inspector = Arc::new(MockInspector::dropping_client());
445        let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
446
447        // Send a text frame
448        let frame_data = make_text_frame("Hello", false);
449        let result = handler.process_client_data(Some(frame_data)).await;
450
451        match result {
452            ProcessResult::Forward(Some(data)) => {
453                // Frame should be dropped (empty output)
454                assert!(data.is_empty(), "Dropped frame should produce empty output");
455            }
456            _ => panic!("Expected Forward with empty data"),
457        }
458
459        assert_eq!(inspector.client_frames_inspected(), 1);
460    }
461
462    #[tokio::test]
463    async fn test_frame_drop_server() {
464        let inspector = Arc::new(MockInspector::dropping_server());
465        let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
466
467        // Send a text frame from server
468        let frame_data = make_text_frame("Server message", false);
469        let result = handler.process_server_data(Some(frame_data)).await;
470
471        match result {
472            ProcessResult::Forward(Some(data)) => {
473                // Frame should be dropped (empty output)
474                assert!(data.is_empty(), "Dropped frame should produce empty output");
475            }
476            _ => panic!("Expected Forward with empty data"),
477        }
478
479        assert_eq!(inspector.server_frames_inspected(), 1);
480    }
481
482    #[tokio::test]
483    async fn test_frame_close() {
484        let inspector = Arc::new(MockInspector::closing_client(1008, "Policy violation"));
485        let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
486
487        // Send a text frame
488        let frame_data = make_text_frame("Malicious content", false);
489        let result = handler.process_client_data(Some(frame_data)).await;
490
491        match result {
492            ProcessResult::Close(reason) => {
493                assert_eq!(reason.code, 1008);
494                assert_eq!(reason.reason, "Policy violation");
495            }
496            _ => panic!("Expected Close result"),
497        }
498
499        assert_eq!(inspector.client_frames_inspected(), 1);
500
501        // Subsequent calls should also return Close
502        let result = handler
503            .process_client_data(Some(make_text_frame("More data", false)))
504            .await;
505        match result {
506            ProcessResult::Close(_) => {}
507            _ => panic!("Expected Close result on subsequent call"),
508        }
509    }
510
511    #[tokio::test]
512    async fn test_multiple_frames_mixed_decisions() {
513        // Use allowing inspector for multiple frames
514        let inspector = Arc::new(MockInspector::allowing());
515        let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
516
517        // Send first frame
518        let frame1 = make_text_frame("Frame 1", false);
519        let result = handler.process_client_data(Some(frame1.clone())).await;
520        assert!(matches!(result, ProcessResult::Forward(Some(_))));
521
522        // Send second frame
523        let frame2 = make_text_frame("Frame 2", false);
524        let result = handler.process_client_data(Some(frame2.clone())).await;
525        assert!(matches!(result, ProcessResult::Forward(Some(_))));
526
527        assert_eq!(inspector.client_frames_inspected(), 2);
528    }
529
530    #[tokio::test]
531    async fn test_end_of_stream() {
532        let inspector = Arc::new(MockInspector::allowing());
533        let handler = WebSocketHandler::with_inspector(inspector, 1024 * 1024);
534
535        // Send None to indicate end of stream
536        let result = handler.process_client_data(None).await;
537        match result {
538            ProcessResult::Forward(None) => {}
539            _ => panic!("Expected Forward(None) for end of stream"),
540        }
541    }
542
543    #[tokio::test]
544    async fn test_partial_frame_buffering() {
545        let inspector = Arc::new(MockInspector::allowing());
546        let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
547
548        // Create a frame and split it
549        let full_frame = make_text_frame("Hello World", false);
550        let (part1, part2) = full_frame.split_at(full_frame.len() / 2);
551
552        // Send first part - should return empty (buffering)
553        let result = handler
554            .process_client_data(Some(Bytes::from(part1.to_vec())))
555            .await;
556        match result {
557            ProcessResult::Forward(Some(data)) => {
558                assert!(data.is_empty(), "Partial frame should not produce output");
559            }
560            _ => panic!("Expected Forward with empty data for partial frame"),
561        }
562        assert_eq!(
563            inspector.client_frames_inspected(),
564            0,
565            "Partial frame should not be inspected"
566        );
567
568        // Send second part - should return complete frame
569        let result = handler
570            .process_client_data(Some(Bytes::from(part2.to_vec())))
571            .await;
572        match result {
573            ProcessResult::Forward(Some(data)) => {
574                assert_eq!(data, full_frame, "Complete frame should be forwarded");
575            }
576            _ => panic!("Expected Forward with complete frame"),
577        }
578        assert_eq!(
579            inspector.client_frames_inspected(),
580            1,
581            "Complete frame should be inspected"
582        );
583    }
584
585    #[tokio::test]
586    async fn test_bidirectional_independence() {
587        // Client drops, server allows
588        let inspector = Arc::new(MockInspector::new(
589            InspectionResult::Drop,
590            InspectionResult::Allow,
591        ));
592        let handler = WebSocketHandler::with_inspector(inspector.clone(), 1024 * 1024);
593
594        // Client frame should be dropped
595        let client_frame = make_text_frame("Client", false);
596        let result = handler.process_client_data(Some(client_frame)).await;
597        match result {
598            ProcessResult::Forward(Some(data)) => assert!(data.is_empty()),
599            _ => panic!("Expected empty forward for dropped client frame"),
600        }
601
602        // Server frame should be allowed
603        let server_frame = make_text_frame("Server", false);
604        let original_len = server_frame.len();
605        let result = handler.process_server_data(Some(server_frame)).await;
606        match result {
607            ProcessResult::Forward(Some(data)) => assert_eq!(data.len(), original_len),
608            _ => panic!("Expected forward for allowed server frame"),
609        }
610
611        assert_eq!(inspector.client_frames_inspected(), 1);
612        assert_eq!(inspector.server_frames_inspected(), 1);
613    }
614}