Skip to main content

reinhardt_websockets/
middleware.rs

1//! WebSocket middleware integration
2//!
3//! This module provides middleware support for WebSocket connections,
4//! allowing pre-processing and post-processing of connections and messages.
5
6use crate::connection::{Message, WebSocketConnection};
7use async_trait::async_trait;
8use std::sync::Arc;
9
10/// WebSocket middleware result
11pub type MiddlewareResult<T> = Result<T, MiddlewareError>;
12
13/// Middleware errors
14#[derive(Debug, thiserror::Error)]
15pub enum MiddlewareError {
16	#[error("Connection rejected")]
17	ConnectionRejected(String),
18	#[error("Message rejected")]
19	MessageRejected(String),
20	#[error("Middleware error")]
21	Error(String),
22}
23
24/// WebSocket connection context for middleware
25pub struct ConnectionContext {
26	/// Client IP address
27	pub ip: String,
28	/// Connection headers (if available)
29	pub headers: std::collections::HashMap<String, String>,
30	/// Custom metadata
31	pub metadata: std::collections::HashMap<String, String>,
32}
33
34impl ConnectionContext {
35	/// Create a new connection context
36	///
37	/// # Examples
38	///
39	/// ```
40	/// use reinhardt_websockets::middleware::ConnectionContext;
41	///
42	/// let context = ConnectionContext::new("192.168.1.1".to_string());
43	/// assert_eq!(context.ip, "192.168.1.1");
44	/// ```
45	pub fn new(ip: String) -> Self {
46		Self {
47			ip,
48			headers: std::collections::HashMap::new(),
49			metadata: std::collections::HashMap::new(),
50		}
51	}
52
53	/// Add a header to the context
54	pub fn with_header(mut self, key: String, value: String) -> Self {
55		self.headers.insert(key, value);
56		self
57	}
58
59	/// Add metadata to the context
60	pub fn with_metadata(mut self, key: String, value: String) -> Self {
61		self.metadata.insert(key, value);
62		self
63	}
64}
65
66/// WebSocket connection middleware trait
67///
68/// Implementors can intercept WebSocket connections before they are established.
69#[async_trait]
70pub trait ConnectionMiddleware: Send + Sync {
71	/// Process a connection before it is established
72	///
73	/// # Arguments
74	///
75	/// * `context` - Connection context with client information
76	///
77	/// # Returns
78	///
79	/// Returns `Ok(())` to allow the connection, or an error to reject it.
80	async fn on_connect(&self, context: &mut ConnectionContext) -> MiddlewareResult<()>;
81
82	/// Process a connection after it is closed
83	///
84	/// # Arguments
85	///
86	/// * `connection` - The closed connection
87	async fn on_disconnect(&self, connection: &Arc<WebSocketConnection>) -> MiddlewareResult<()>;
88}
89
90/// WebSocket message middleware trait
91///
92/// Implementors can intercept and modify messages before they are processed.
93#[async_trait]
94pub trait MessageMiddleware: Send + Sync {
95	/// Process a message before it is handled
96	///
97	/// # Arguments
98	///
99	/// * `connection` - The connection that sent the message
100	/// * `message` - The message to process
101	///
102	/// # Returns
103	///
104	/// Returns the processed message, or an error to reject it.
105	async fn on_message(
106		&self,
107		connection: &Arc<WebSocketConnection>,
108		message: Message,
109	) -> MiddlewareResult<Message>;
110}
111
112/// Logging middleware for WebSocket connections
113///
114/// # Examples
115///
116/// ```
117/// use reinhardt_websockets::middleware::{LoggingMiddleware, ConnectionMiddleware, ConnectionContext};
118///
119/// # tokio_test::block_on(async {
120/// let middleware = LoggingMiddleware::new("WebSocket".to_string());
121/// let mut context = ConnectionContext::new("192.168.1.1".to_string());
122///
123/// assert!(middleware.on_connect(&mut context).await.is_ok());
124/// # });
125/// ```
126pub struct LoggingMiddleware {
127	prefix: String,
128}
129
130impl LoggingMiddleware {
131	/// Create a new logging middleware
132	pub fn new(prefix: String) -> Self {
133		Self { prefix }
134	}
135}
136
137#[async_trait]
138impl ConnectionMiddleware for LoggingMiddleware {
139	async fn on_connect(&self, context: &mut ConnectionContext) -> MiddlewareResult<()> {
140		println!(
141			"[{}] Connection established from {}",
142			self.prefix, context.ip
143		);
144		Ok(())
145	}
146
147	async fn on_disconnect(&self, connection: &Arc<WebSocketConnection>) -> MiddlewareResult<()> {
148		println!("[{}] Connection closed: {}", self.prefix, connection.id());
149		Ok(())
150	}
151}
152
153#[async_trait]
154impl MessageMiddleware for LoggingMiddleware {
155	async fn on_message(
156		&self,
157		connection: &Arc<WebSocketConnection>,
158		message: Message,
159	) -> MiddlewareResult<Message> {
160		match &message {
161			Message::Text { data } => {
162				println!(
163					"[{}] Text message from {}: {}",
164					self.prefix,
165					connection.id(),
166					data
167				);
168			}
169			Message::Binary { data } => {
170				println!(
171					"[{}] Binary message from {}: {} bytes",
172					self.prefix,
173					connection.id(),
174					data.len()
175				);
176			}
177			Message::Ping => {
178				println!("[{}] Ping from {}", self.prefix, connection.id());
179			}
180			Message::Pong => {
181				println!("[{}] Pong from {}", self.prefix, connection.id());
182			}
183			Message::Close { .. } => {
184				println!("[{}] Close from {}", self.prefix, connection.id());
185			}
186		}
187		Ok(message)
188	}
189}
190
191/// IP filtering middleware
192///
193/// # Examples
194///
195/// ```
196/// use reinhardt_websockets::middleware::{IpFilterMiddleware, ConnectionMiddleware, ConnectionContext};
197///
198/// # tokio_test::block_on(async {
199/// let middleware = IpFilterMiddleware::whitelist(vec!["192.168.1.1".to_string()]);
200/// let mut context = ConnectionContext::new("192.168.1.1".to_string());
201///
202/// assert!(middleware.on_connect(&mut context).await.is_ok());
203///
204/// let mut blocked_context = ConnectionContext::new("10.0.0.1".to_string());
205/// assert!(middleware.on_connect(&mut blocked_context).await.is_err());
206/// # });
207/// ```
208pub struct IpFilterMiddleware {
209	allowed_ips: Vec<String>,
210	blocked_ips: Vec<String>,
211	mode: IpFilterMode,
212}
213
214#[derive(Debug, Clone, Copy)]
215enum IpFilterMode {
216	Whitelist,
217	Blacklist,
218}
219
220impl IpFilterMiddleware {
221	/// Create a whitelist-based filter
222	pub fn whitelist(allowed_ips: Vec<String>) -> Self {
223		Self {
224			allowed_ips,
225			blocked_ips: Vec::new(),
226			mode: IpFilterMode::Whitelist,
227		}
228	}
229
230	/// Create a blacklist-based filter
231	pub fn blacklist(blocked_ips: Vec<String>) -> Self {
232		Self {
233			allowed_ips: Vec::new(),
234			blocked_ips,
235			mode: IpFilterMode::Blacklist,
236		}
237	}
238}
239
240#[async_trait]
241impl ConnectionMiddleware for IpFilterMiddleware {
242	async fn on_connect(&self, context: &mut ConnectionContext) -> MiddlewareResult<()> {
243		match self.mode {
244			IpFilterMode::Whitelist => {
245				if self.allowed_ips.contains(&context.ip) {
246					Ok(())
247				} else {
248					Err(MiddlewareError::ConnectionRejected(format!(
249						"IP not in whitelist: {}",
250						context.ip
251					)))
252				}
253			}
254			IpFilterMode::Blacklist => {
255				if self.blocked_ips.contains(&context.ip) {
256					Err(MiddlewareError::ConnectionRejected(format!(
257						"IP is blacklisted: {}",
258						context.ip
259					)))
260				} else {
261					Ok(())
262				}
263			}
264		}
265	}
266
267	async fn on_disconnect(&self, _connection: &Arc<WebSocketConnection>) -> MiddlewareResult<()> {
268		Ok(())
269	}
270}
271
272/// WebSocket close code for "Message Too Big" as defined in RFC 6455 Section 7.4.1
273const CLOSE_CODE_MESSAGE_TOO_BIG: u16 = 1009;
274
275/// Message size limit middleware
276///
277/// Enforces maximum message size to prevent memory exhaustion attacks.
278/// By default, uses a 1 MB limit matching the protocol-level default.
279/// When an oversized message is detected, the connection is closed with
280/// status code 1009 (Message Too Big) as per RFC 6455.
281///
282/// # Examples
283///
284/// ```
285/// use reinhardt_websockets::middleware::{MessageSizeLimitMiddleware, MessageMiddleware};
286/// use reinhardt_websockets::{Message, WebSocketConnection};
287/// use tokio::sync::mpsc;
288/// use std::sync::Arc;
289///
290/// # tokio_test::block_on(async {
291/// // Use default 1 MB limit
292/// let middleware = MessageSizeLimitMiddleware::default();
293///
294/// let (tx, _rx) = mpsc::unbounded_channel();
295/// let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
296///
297/// let small_msg = Message::text("Small".to_string());
298/// assert!(middleware.on_message(&conn, small_msg).await.is_ok());
299///
300/// // Custom limit
301/// let strict = MessageSizeLimitMiddleware::new(100);
302/// let large_msg = Message::text("x".repeat(200));
303/// assert!(strict.on_message(&conn, large_msg).await.is_err());
304/// # });
305/// ```
306pub struct MessageSizeLimitMiddleware {
307	max_size: usize,
308}
309
310impl MessageSizeLimitMiddleware {
311	/// Create a new message size limit middleware with a custom limit
312	pub fn new(max_size: usize) -> Self {
313		Self { max_size }
314	}
315
316	/// Get the configured maximum message size
317	pub fn max_size(&self) -> usize {
318		self.max_size
319	}
320}
321
322impl Default for MessageSizeLimitMiddleware {
323	/// Create a message size limit middleware with the default 1 MB limit
324	fn default() -> Self {
325		Self {
326			max_size: crate::protocol::DEFAULT_MAX_MESSAGE_SIZE,
327		}
328	}
329}
330
331#[async_trait]
332impl MessageMiddleware for MessageSizeLimitMiddleware {
333	async fn on_message(
334		&self,
335		connection: &Arc<WebSocketConnection>,
336		message: Message,
337	) -> MiddlewareResult<Message> {
338		let size = match &message {
339			Message::Text { data } => data.len(),
340			Message::Binary { data } => data.len(),
341			_ => 0,
342		};
343
344		if size > self.max_size {
345			// Send close frame with 1009 (Message Too Big) before rejecting
346			let reason = format!(
347				"Message size {} bytes exceeds limit of {} bytes",
348				size, self.max_size
349			);
350			let _ = connection
351				.close_with_reason(CLOSE_CODE_MESSAGE_TOO_BIG, reason.clone())
352				.await;
353
354			Err(MiddlewareError::MessageRejected(format!(
355				"Message size {} exceeds limit {}",
356				size, self.max_size
357			)))
358		} else {
359			Ok(message)
360		}
361	}
362}
363
364/// Middleware chain for composing multiple middlewares
365///
366/// # Examples
367///
368/// ```
369/// use reinhardt_websockets::middleware::{
370///     MiddlewareChain, LoggingMiddleware, ConnectionContext, ConnectionMiddleware
371/// };
372///
373/// # tokio_test::block_on(async {
374/// let mut chain = MiddlewareChain::new();
375/// chain.add_connection_middleware(Box::new(LoggingMiddleware::new("WS".to_string())));
376///
377/// let mut context = ConnectionContext::new("192.168.1.1".to_string());
378/// assert!(chain.process_connect(&mut context).await.is_ok());
379/// # });
380/// ```
381pub struct MiddlewareChain {
382	connection_middlewares: Vec<Box<dyn ConnectionMiddleware>>,
383	message_middlewares: Vec<Box<dyn MessageMiddleware>>,
384}
385
386impl MiddlewareChain {
387	/// Create a new middleware chain
388	pub fn new() -> Self {
389		Self {
390			connection_middlewares: Vec::new(),
391			message_middlewares: Vec::new(),
392		}
393	}
394
395	/// Add a connection middleware to the chain
396	pub fn add_connection_middleware(&mut self, middleware: Box<dyn ConnectionMiddleware>) {
397		self.connection_middlewares.push(middleware);
398	}
399
400	/// Add a message middleware to the chain
401	pub fn add_message_middleware(&mut self, middleware: Box<dyn MessageMiddleware>) {
402		self.message_middlewares.push(middleware);
403	}
404
405	/// Process connection through all middlewares
406	pub async fn process_connect(&self, context: &mut ConnectionContext) -> MiddlewareResult<()> {
407		for middleware in &self.connection_middlewares {
408			middleware.on_connect(context).await?;
409		}
410		Ok(())
411	}
412
413	/// Process disconnection through all middlewares
414	pub async fn process_disconnect(
415		&self,
416		connection: &Arc<WebSocketConnection>,
417	) -> MiddlewareResult<()> {
418		for middleware in &self.connection_middlewares {
419			middleware.on_disconnect(connection).await?;
420		}
421		Ok(())
422	}
423
424	/// Process message through all middlewares
425	pub async fn process_message(
426		&self,
427		connection: &Arc<WebSocketConnection>,
428		mut message: Message,
429	) -> MiddlewareResult<Message> {
430		for middleware in &self.message_middlewares {
431			message = middleware.on_message(connection, message).await?;
432		}
433		Ok(message)
434	}
435}
436
437impl Default for MiddlewareChain {
438	fn default() -> Self {
439		Self::new()
440	}
441}
442
443#[cfg(test)]
444mod tests {
445	use super::*;
446	use rstest::rstest;
447	use tokio::sync::mpsc;
448
449	#[rstest]
450	#[tokio::test]
451	async fn test_connection_context() {
452		// Arrange & Act
453		let context = ConnectionContext::new("192.168.1.1".to_string())
454			.with_header("User-Agent".to_string(), "Test".to_string())
455			.with_metadata("session_id".to_string(), "abc123".to_string());
456
457		// Assert
458		assert_eq!(context.ip, "192.168.1.1");
459		assert_eq!(context.headers.get("User-Agent").unwrap(), "Test");
460		assert_eq!(context.metadata.get("session_id").unwrap(), "abc123");
461	}
462
463	#[rstest]
464	#[tokio::test]
465	async fn test_logging_middleware_connect() {
466		// Arrange
467		let middleware = LoggingMiddleware::new("Test".to_string());
468		let mut context = ConnectionContext::new("192.168.1.1".to_string());
469
470		// Act
471		let result = middleware.on_connect(&mut context).await;
472
473		// Assert
474		assert!(result.is_ok());
475	}
476
477	#[rstest]
478	#[tokio::test]
479	async fn test_logging_middleware_message() {
480		// Arrange
481		let middleware = LoggingMiddleware::new("Test".to_string());
482		let (tx, _rx) = mpsc::unbounded_channel();
483		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
484		let msg = Message::text("Hello".to_string());
485
486		// Act
487		let result = middleware.on_message(&conn, msg).await;
488
489		// Assert
490		assert!(result.is_ok());
491	}
492
493	#[rstest]
494	#[tokio::test]
495	async fn test_ip_filter_whitelist_allowed() {
496		// Arrange
497		let middleware = IpFilterMiddleware::whitelist(vec!["192.168.1.1".to_string()]);
498		let mut context = ConnectionContext::new("192.168.1.1".to_string());
499
500		// Act & Assert
501		assert!(middleware.on_connect(&mut context).await.is_ok());
502	}
503
504	#[rstest]
505	#[tokio::test]
506	async fn test_ip_filter_whitelist_blocked() {
507		// Arrange
508		let middleware = IpFilterMiddleware::whitelist(vec!["192.168.1.1".to_string()]);
509		let mut context = ConnectionContext::new("10.0.0.1".to_string());
510
511		// Act
512		let result = middleware.on_connect(&mut context).await;
513
514		// Assert
515		assert!(result.is_err());
516		assert!(matches!(
517			result.unwrap_err(),
518			MiddlewareError::ConnectionRejected(_)
519		));
520	}
521
522	#[rstest]
523	#[tokio::test]
524	async fn test_ip_filter_blacklist_allowed() {
525		// Arrange
526		let middleware = IpFilterMiddleware::blacklist(vec!["10.0.0.1".to_string()]);
527		let mut context = ConnectionContext::new("192.168.1.1".to_string());
528
529		// Act & Assert
530		assert!(middleware.on_connect(&mut context).await.is_ok());
531	}
532
533	#[rstest]
534	#[tokio::test]
535	async fn test_ip_filter_blacklist_blocked() {
536		// Arrange
537		let middleware = IpFilterMiddleware::blacklist(vec!["10.0.0.1".to_string()]);
538		let mut context = ConnectionContext::new("10.0.0.1".to_string());
539
540		// Act
541		let result = middleware.on_connect(&mut context).await;
542
543		// Assert
544		assert!(result.is_err());
545	}
546
547	#[rstest]
548	#[tokio::test]
549	async fn test_message_size_limit_within_limit() {
550		// Arrange
551		let middleware = MessageSizeLimitMiddleware::new(100);
552		let (tx, _rx) = mpsc::unbounded_channel();
553		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
554		let msg = Message::text("Small message".to_string());
555
556		// Act & Assert
557		assert!(middleware.on_message(&conn, msg).await.is_ok());
558	}
559
560	#[rstest]
561	#[tokio::test]
562	async fn test_message_size_limit_exceeds_limit() {
563		// Arrange
564		let middleware = MessageSizeLimitMiddleware::new(10);
565		let (tx, _rx) = mpsc::unbounded_channel();
566		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
567		let msg = Message::text("This is a very long message".to_string());
568
569		// Act
570		let result = middleware.on_message(&conn, msg).await;
571
572		// Assert
573		assert!(result.is_err());
574		assert!(matches!(
575			result.unwrap_err(),
576			MiddlewareError::MessageRejected(_)
577		));
578	}
579
580	#[rstest]
581	fn test_message_size_limit_default_is_1mb() {
582		// Arrange & Act
583		let middleware = MessageSizeLimitMiddleware::default();
584
585		// Assert
586		assert_eq!(middleware.max_size(), 1_048_576);
587	}
588
589	#[rstest]
590	#[tokio::test]
591	async fn test_message_size_limit_default_accepts_normal_messages() {
592		// Arrange
593		let middleware = MessageSizeLimitMiddleware::default();
594		let (tx, _rx) = mpsc::unbounded_channel();
595		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
596		// 10 KB message - well within 1 MB limit
597		let msg = Message::text("x".repeat(10_000));
598
599		// Act
600		let result = middleware.on_message(&conn, msg).await;
601
602		// Assert
603		assert!(result.is_ok());
604	}
605
606	#[rstest]
607	#[tokio::test]
608	async fn test_message_size_limit_default_rejects_oversized_messages() {
609		// Arrange
610		let middleware = MessageSizeLimitMiddleware::default();
611		let (tx, _rx) = mpsc::unbounded_channel();
612		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
613		// 2 MB message - exceeds 1 MB limit
614		let msg = Message::text("x".repeat(2 * 1024 * 1024));
615
616		// Act
617		let result = middleware.on_message(&conn, msg).await;
618
619		// Assert
620		assert!(result.is_err());
621	}
622
623	#[rstest]
624	#[tokio::test]
625	async fn test_message_size_limit_sends_close_frame_on_rejection() {
626		// Arrange
627		let middleware = MessageSizeLimitMiddleware::new(10);
628		let (tx, mut rx) = mpsc::unbounded_channel();
629		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
630		let msg = Message::text("This exceeds the limit".to_string());
631
632		// Act
633		let result = middleware.on_message(&conn, msg).await;
634
635		// Assert
636		assert!(result.is_err());
637
638		// Verify close frame was sent with code 1009 (Message Too Big)
639		let close_msg = rx.recv().await.unwrap();
640		match close_msg {
641			Message::Close { code, reason } => {
642				assert_eq!(code, 1009);
643				assert!(reason.contains("exceeds limit"));
644			}
645			_ => panic!("Expected close message with code 1009"),
646		}
647	}
648
649	#[rstest]
650	#[tokio::test]
651	async fn test_message_size_limit_binary_messages() {
652		// Arrange
653		let middleware = MessageSizeLimitMiddleware::new(100);
654		let (tx, _rx) = mpsc::unbounded_channel();
655		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
656
657		// Act - within limit
658		let small_binary = Message::binary(vec![0u8; 50]);
659		assert!(middleware.on_message(&conn, small_binary).await.is_ok());
660
661		// Act - exceeds limit
662		let large_binary = Message::binary(vec![0u8; 200]);
663		let result = middleware.on_message(&conn, large_binary).await;
664
665		// Assert
666		assert!(result.is_err());
667	}
668
669	#[rstest]
670	#[tokio::test]
671	async fn test_message_size_limit_control_frames_always_pass() {
672		// Arrange
673		let middleware = MessageSizeLimitMiddleware::new(1);
674		let (tx, _rx) = mpsc::unbounded_channel();
675		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
676
677		// Act & Assert - control frames (Ping, Pong) should always pass
678		assert!(middleware.on_message(&conn, Message::Ping).await.is_ok());
679		assert!(middleware.on_message(&conn, Message::Pong).await.is_ok());
680	}
681
682	#[rstest]
683	fn test_message_size_limit_custom_configuration() {
684		// Arrange & Act
685		let middleware = MessageSizeLimitMiddleware::new(512 * 1024); // 512 KB
686
687		// Assert
688		assert_eq!(middleware.max_size(), 512 * 1024);
689	}
690
691	#[rstest]
692	#[tokio::test]
693	async fn test_middleware_chain_connect() {
694		// Arrange
695		let mut chain = MiddlewareChain::new();
696		chain.add_connection_middleware(Box::new(LoggingMiddleware::new("WS".to_string())));
697		let mut context = ConnectionContext::new("192.168.1.1".to_string());
698
699		// Act & Assert
700		assert!(chain.process_connect(&mut context).await.is_ok());
701	}
702
703	#[rstest]
704	#[tokio::test]
705	async fn test_middleware_chain_message() {
706		// Arrange
707		let mut chain = MiddlewareChain::new();
708		chain.add_message_middleware(Box::new(MessageSizeLimitMiddleware::new(100)));
709		let (tx, _rx) = mpsc::unbounded_channel();
710		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
711		let msg = Message::text("Test".to_string());
712
713		// Act & Assert
714		assert!(chain.process_message(&conn, msg).await.is_ok());
715	}
716
717	#[rstest]
718	#[tokio::test]
719	async fn test_middleware_chain_rejection() {
720		// Arrange
721		let mut chain = MiddlewareChain::new();
722		chain.add_connection_middleware(Box::new(IpFilterMiddleware::whitelist(vec![
723			"192.168.1.1".to_string(),
724		])));
725		let mut context = ConnectionContext::new("10.0.0.1".to_string());
726
727		// Act & Assert
728		assert!(chain.process_connect(&mut context).await.is_err());
729	}
730
731	#[rstest]
732	#[tokio::test]
733	async fn test_middleware_chain_with_default_size_limit() {
734		// Arrange
735		let mut chain = MiddlewareChain::new();
736		chain.add_message_middleware(Box::new(MessageSizeLimitMiddleware::default()));
737		let (tx, _rx) = mpsc::unbounded_channel();
738		let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
739
740		// Act - normal message should pass through default chain
741		let msg = Message::text("Normal message".to_string());
742		let result = chain.process_message(&conn, msg).await;
743
744		// Assert
745		assert!(result.is_ok());
746	}
747}