1use crate::connection::{Message, WebSocketConnection};
7use async_trait::async_trait;
8use std::sync::Arc;
9
10pub type MiddlewareResult<T> = Result<T, MiddlewareError>;
12
13#[derive(Debug, thiserror::Error)]
15pub enum MiddlewareError {
16 #[error("Connection rejected")]
17 ConnectionRejected(String),
18 #[error("Message rejected")]
19 MessageRejected(String),
20 #[error("Middleware error")]
21 Error(String),
22}
23
24pub struct ConnectionContext {
26 pub ip: String,
28 pub headers: std::collections::HashMap<String, String>,
30 pub metadata: std::collections::HashMap<String, String>,
32}
33
34impl ConnectionContext {
35 pub fn new(ip: String) -> Self {
46 Self {
47 ip,
48 headers: std::collections::HashMap::new(),
49 metadata: std::collections::HashMap::new(),
50 }
51 }
52
53 pub fn with_header(mut self, key: String, value: String) -> Self {
55 self.headers.insert(key, value);
56 self
57 }
58
59 pub fn with_metadata(mut self, key: String, value: String) -> Self {
61 self.metadata.insert(key, value);
62 self
63 }
64}
65
66#[async_trait]
70pub trait ConnectionMiddleware: Send + Sync {
71 async fn on_connect(&self, context: &mut ConnectionContext) -> MiddlewareResult<()>;
81
82 async fn on_disconnect(&self, connection: &Arc<WebSocketConnection>) -> MiddlewareResult<()>;
88}
89
90#[async_trait]
94pub trait MessageMiddleware: Send + Sync {
95 async fn on_message(
106 &self,
107 connection: &Arc<WebSocketConnection>,
108 message: Message,
109 ) -> MiddlewareResult<Message>;
110}
111
112pub struct LoggingMiddleware {
127 prefix: String,
128}
129
130impl LoggingMiddleware {
131 pub fn new(prefix: String) -> Self {
133 Self { prefix }
134 }
135}
136
137#[async_trait]
138impl ConnectionMiddleware for LoggingMiddleware {
139 async fn on_connect(&self, context: &mut ConnectionContext) -> MiddlewareResult<()> {
140 println!(
141 "[{}] Connection established from {}",
142 self.prefix, context.ip
143 );
144 Ok(())
145 }
146
147 async fn on_disconnect(&self, connection: &Arc<WebSocketConnection>) -> MiddlewareResult<()> {
148 println!("[{}] Connection closed: {}", self.prefix, connection.id());
149 Ok(())
150 }
151}
152
153#[async_trait]
154impl MessageMiddleware for LoggingMiddleware {
155 async fn on_message(
156 &self,
157 connection: &Arc<WebSocketConnection>,
158 message: Message,
159 ) -> MiddlewareResult<Message> {
160 match &message {
161 Message::Text { data } => {
162 println!(
163 "[{}] Text message from {}: {}",
164 self.prefix,
165 connection.id(),
166 data
167 );
168 }
169 Message::Binary { data } => {
170 println!(
171 "[{}] Binary message from {}: {} bytes",
172 self.prefix,
173 connection.id(),
174 data.len()
175 );
176 }
177 Message::Ping => {
178 println!("[{}] Ping from {}", self.prefix, connection.id());
179 }
180 Message::Pong => {
181 println!("[{}] Pong from {}", self.prefix, connection.id());
182 }
183 Message::Close { .. } => {
184 println!("[{}] Close from {}", self.prefix, connection.id());
185 }
186 }
187 Ok(message)
188 }
189}
190
191pub struct IpFilterMiddleware {
209 allowed_ips: Vec<String>,
210 blocked_ips: Vec<String>,
211 mode: IpFilterMode,
212}
213
214#[derive(Debug, Clone, Copy)]
215enum IpFilterMode {
216 Whitelist,
217 Blacklist,
218}
219
220impl IpFilterMiddleware {
221 pub fn whitelist(allowed_ips: Vec<String>) -> Self {
223 Self {
224 allowed_ips,
225 blocked_ips: Vec::new(),
226 mode: IpFilterMode::Whitelist,
227 }
228 }
229
230 pub fn blacklist(blocked_ips: Vec<String>) -> Self {
232 Self {
233 allowed_ips: Vec::new(),
234 blocked_ips,
235 mode: IpFilterMode::Blacklist,
236 }
237 }
238}
239
240#[async_trait]
241impl ConnectionMiddleware for IpFilterMiddleware {
242 async fn on_connect(&self, context: &mut ConnectionContext) -> MiddlewareResult<()> {
243 match self.mode {
244 IpFilterMode::Whitelist => {
245 if self.allowed_ips.contains(&context.ip) {
246 Ok(())
247 } else {
248 Err(MiddlewareError::ConnectionRejected(format!(
249 "IP not in whitelist: {}",
250 context.ip
251 )))
252 }
253 }
254 IpFilterMode::Blacklist => {
255 if self.blocked_ips.contains(&context.ip) {
256 Err(MiddlewareError::ConnectionRejected(format!(
257 "IP is blacklisted: {}",
258 context.ip
259 )))
260 } else {
261 Ok(())
262 }
263 }
264 }
265 }
266
267 async fn on_disconnect(&self, _connection: &Arc<WebSocketConnection>) -> MiddlewareResult<()> {
268 Ok(())
269 }
270}
271
272const CLOSE_CODE_MESSAGE_TOO_BIG: u16 = 1009;
274
275pub struct MessageSizeLimitMiddleware {
307 max_size: usize,
308}
309
310impl MessageSizeLimitMiddleware {
311 pub fn new(max_size: usize) -> Self {
313 Self { max_size }
314 }
315
316 pub fn max_size(&self) -> usize {
318 self.max_size
319 }
320}
321
322impl Default for MessageSizeLimitMiddleware {
323 fn default() -> Self {
325 Self {
326 max_size: crate::protocol::DEFAULT_MAX_MESSAGE_SIZE,
327 }
328 }
329}
330
331#[async_trait]
332impl MessageMiddleware for MessageSizeLimitMiddleware {
333 async fn on_message(
334 &self,
335 connection: &Arc<WebSocketConnection>,
336 message: Message,
337 ) -> MiddlewareResult<Message> {
338 let size = match &message {
339 Message::Text { data } => data.len(),
340 Message::Binary { data } => data.len(),
341 _ => 0,
342 };
343
344 if size > self.max_size {
345 let reason = format!(
347 "Message size {} bytes exceeds limit of {} bytes",
348 size, self.max_size
349 );
350 let _ = connection
351 .close_with_reason(CLOSE_CODE_MESSAGE_TOO_BIG, reason.clone())
352 .await;
353
354 Err(MiddlewareError::MessageRejected(format!(
355 "Message size {} exceeds limit {}",
356 size, self.max_size
357 )))
358 } else {
359 Ok(message)
360 }
361 }
362}
363
364pub struct MiddlewareChain {
382 connection_middlewares: Vec<Box<dyn ConnectionMiddleware>>,
383 message_middlewares: Vec<Box<dyn MessageMiddleware>>,
384}
385
386impl MiddlewareChain {
387 pub fn new() -> Self {
389 Self {
390 connection_middlewares: Vec::new(),
391 message_middlewares: Vec::new(),
392 }
393 }
394
395 pub fn add_connection_middleware(&mut self, middleware: Box<dyn ConnectionMiddleware>) {
397 self.connection_middlewares.push(middleware);
398 }
399
400 pub fn add_message_middleware(&mut self, middleware: Box<dyn MessageMiddleware>) {
402 self.message_middlewares.push(middleware);
403 }
404
405 pub async fn process_connect(&self, context: &mut ConnectionContext) -> MiddlewareResult<()> {
407 for middleware in &self.connection_middlewares {
408 middleware.on_connect(context).await?;
409 }
410 Ok(())
411 }
412
413 pub async fn process_disconnect(
415 &self,
416 connection: &Arc<WebSocketConnection>,
417 ) -> MiddlewareResult<()> {
418 for middleware in &self.connection_middlewares {
419 middleware.on_disconnect(connection).await?;
420 }
421 Ok(())
422 }
423
424 pub async fn process_message(
426 &self,
427 connection: &Arc<WebSocketConnection>,
428 mut message: Message,
429 ) -> MiddlewareResult<Message> {
430 for middleware in &self.message_middlewares {
431 message = middleware.on_message(connection, message).await?;
432 }
433 Ok(message)
434 }
435}
436
437impl Default for MiddlewareChain {
438 fn default() -> Self {
439 Self::new()
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use rstest::rstest;
447 use tokio::sync::mpsc;
448
449 #[rstest]
450 #[tokio::test]
451 async fn test_connection_context() {
452 let context = ConnectionContext::new("192.168.1.1".to_string())
454 .with_header("User-Agent".to_string(), "Test".to_string())
455 .with_metadata("session_id".to_string(), "abc123".to_string());
456
457 assert_eq!(context.ip, "192.168.1.1");
459 assert_eq!(context.headers.get("User-Agent").unwrap(), "Test");
460 assert_eq!(context.metadata.get("session_id").unwrap(), "abc123");
461 }
462
463 #[rstest]
464 #[tokio::test]
465 async fn test_logging_middleware_connect() {
466 let middleware = LoggingMiddleware::new("Test".to_string());
468 let mut context = ConnectionContext::new("192.168.1.1".to_string());
469
470 let result = middleware.on_connect(&mut context).await;
472
473 assert!(result.is_ok());
475 }
476
477 #[rstest]
478 #[tokio::test]
479 async fn test_logging_middleware_message() {
480 let middleware = LoggingMiddleware::new("Test".to_string());
482 let (tx, _rx) = mpsc::unbounded_channel();
483 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
484 let msg = Message::text("Hello".to_string());
485
486 let result = middleware.on_message(&conn, msg).await;
488
489 assert!(result.is_ok());
491 }
492
493 #[rstest]
494 #[tokio::test]
495 async fn test_ip_filter_whitelist_allowed() {
496 let middleware = IpFilterMiddleware::whitelist(vec!["192.168.1.1".to_string()]);
498 let mut context = ConnectionContext::new("192.168.1.1".to_string());
499
500 assert!(middleware.on_connect(&mut context).await.is_ok());
502 }
503
504 #[rstest]
505 #[tokio::test]
506 async fn test_ip_filter_whitelist_blocked() {
507 let middleware = IpFilterMiddleware::whitelist(vec!["192.168.1.1".to_string()]);
509 let mut context = ConnectionContext::new("10.0.0.1".to_string());
510
511 let result = middleware.on_connect(&mut context).await;
513
514 assert!(result.is_err());
516 assert!(matches!(
517 result.unwrap_err(),
518 MiddlewareError::ConnectionRejected(_)
519 ));
520 }
521
522 #[rstest]
523 #[tokio::test]
524 async fn test_ip_filter_blacklist_allowed() {
525 let middleware = IpFilterMiddleware::blacklist(vec!["10.0.0.1".to_string()]);
527 let mut context = ConnectionContext::new("192.168.1.1".to_string());
528
529 assert!(middleware.on_connect(&mut context).await.is_ok());
531 }
532
533 #[rstest]
534 #[tokio::test]
535 async fn test_ip_filter_blacklist_blocked() {
536 let middleware = IpFilterMiddleware::blacklist(vec!["10.0.0.1".to_string()]);
538 let mut context = ConnectionContext::new("10.0.0.1".to_string());
539
540 let result = middleware.on_connect(&mut context).await;
542
543 assert!(result.is_err());
545 }
546
547 #[rstest]
548 #[tokio::test]
549 async fn test_message_size_limit_within_limit() {
550 let middleware = MessageSizeLimitMiddleware::new(100);
552 let (tx, _rx) = mpsc::unbounded_channel();
553 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
554 let msg = Message::text("Small message".to_string());
555
556 assert!(middleware.on_message(&conn, msg).await.is_ok());
558 }
559
560 #[rstest]
561 #[tokio::test]
562 async fn test_message_size_limit_exceeds_limit() {
563 let middleware = MessageSizeLimitMiddleware::new(10);
565 let (tx, _rx) = mpsc::unbounded_channel();
566 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
567 let msg = Message::text("This is a very long message".to_string());
568
569 let result = middleware.on_message(&conn, msg).await;
571
572 assert!(result.is_err());
574 assert!(matches!(
575 result.unwrap_err(),
576 MiddlewareError::MessageRejected(_)
577 ));
578 }
579
580 #[rstest]
581 fn test_message_size_limit_default_is_1mb() {
582 let middleware = MessageSizeLimitMiddleware::default();
584
585 assert_eq!(middleware.max_size(), 1_048_576);
587 }
588
589 #[rstest]
590 #[tokio::test]
591 async fn test_message_size_limit_default_accepts_normal_messages() {
592 let middleware = MessageSizeLimitMiddleware::default();
594 let (tx, _rx) = mpsc::unbounded_channel();
595 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
596 let msg = Message::text("x".repeat(10_000));
598
599 let result = middleware.on_message(&conn, msg).await;
601
602 assert!(result.is_ok());
604 }
605
606 #[rstest]
607 #[tokio::test]
608 async fn test_message_size_limit_default_rejects_oversized_messages() {
609 let middleware = MessageSizeLimitMiddleware::default();
611 let (tx, _rx) = mpsc::unbounded_channel();
612 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
613 let msg = Message::text("x".repeat(2 * 1024 * 1024));
615
616 let result = middleware.on_message(&conn, msg).await;
618
619 assert!(result.is_err());
621 }
622
623 #[rstest]
624 #[tokio::test]
625 async fn test_message_size_limit_sends_close_frame_on_rejection() {
626 let middleware = MessageSizeLimitMiddleware::new(10);
628 let (tx, mut rx) = mpsc::unbounded_channel();
629 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
630 let msg = Message::text("This exceeds the limit".to_string());
631
632 let result = middleware.on_message(&conn, msg).await;
634
635 assert!(result.is_err());
637
638 let close_msg = rx.recv().await.unwrap();
640 match close_msg {
641 Message::Close { code, reason } => {
642 assert_eq!(code, 1009);
643 assert!(reason.contains("exceeds limit"));
644 }
645 _ => panic!("Expected close message with code 1009"),
646 }
647 }
648
649 #[rstest]
650 #[tokio::test]
651 async fn test_message_size_limit_binary_messages() {
652 let middleware = MessageSizeLimitMiddleware::new(100);
654 let (tx, _rx) = mpsc::unbounded_channel();
655 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
656
657 let small_binary = Message::binary(vec![0u8; 50]);
659 assert!(middleware.on_message(&conn, small_binary).await.is_ok());
660
661 let large_binary = Message::binary(vec![0u8; 200]);
663 let result = middleware.on_message(&conn, large_binary).await;
664
665 assert!(result.is_err());
667 }
668
669 #[rstest]
670 #[tokio::test]
671 async fn test_message_size_limit_control_frames_always_pass() {
672 let middleware = MessageSizeLimitMiddleware::new(1);
674 let (tx, _rx) = mpsc::unbounded_channel();
675 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
676
677 assert!(middleware.on_message(&conn, Message::Ping).await.is_ok());
679 assert!(middleware.on_message(&conn, Message::Pong).await.is_ok());
680 }
681
682 #[rstest]
683 fn test_message_size_limit_custom_configuration() {
684 let middleware = MessageSizeLimitMiddleware::new(512 * 1024); assert_eq!(middleware.max_size(), 512 * 1024);
689 }
690
691 #[rstest]
692 #[tokio::test]
693 async fn test_middleware_chain_connect() {
694 let mut chain = MiddlewareChain::new();
696 chain.add_connection_middleware(Box::new(LoggingMiddleware::new("WS".to_string())));
697 let mut context = ConnectionContext::new("192.168.1.1".to_string());
698
699 assert!(chain.process_connect(&mut context).await.is_ok());
701 }
702
703 #[rstest]
704 #[tokio::test]
705 async fn test_middleware_chain_message() {
706 let mut chain = MiddlewareChain::new();
708 chain.add_message_middleware(Box::new(MessageSizeLimitMiddleware::new(100)));
709 let (tx, _rx) = mpsc::unbounded_channel();
710 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
711 let msg = Message::text("Test".to_string());
712
713 assert!(chain.process_message(&conn, msg).await.is_ok());
715 }
716
717 #[rstest]
718 #[tokio::test]
719 async fn test_middleware_chain_rejection() {
720 let mut chain = MiddlewareChain::new();
722 chain.add_connection_middleware(Box::new(IpFilterMiddleware::whitelist(vec![
723 "192.168.1.1".to_string(),
724 ])));
725 let mut context = ConnectionContext::new("10.0.0.1".to_string());
726
727 assert!(chain.process_connect(&mut context).await.is_err());
729 }
730
731 #[rstest]
732 #[tokio::test]
733 async fn test_middleware_chain_with_default_size_limit() {
734 let mut chain = MiddlewareChain::new();
736 chain.add_message_middleware(Box::new(MessageSizeLimitMiddleware::default()));
737 let (tx, _rx) = mpsc::unbounded_channel();
738 let conn = Arc::new(WebSocketConnection::new("test".to_string(), tx));
739
740 let msg = Message::text("Normal message".to_string());
742 let result = chain.process_message(&conn, msg).await;
743
744 assert!(result.is_ok());
746 }
747}