Skip to main content

binary_options_tools_core_pre/
middleware.rs

1//! Middleware system for WebSocket client operations.
2//!
3//! This module provides a composable middleware system inspired by Axum's middleware/layer system.
4//! Middleware can be used to observe, modify, or control the flow of WebSocket messages
5//! being sent and received by the client.
6//!
7//! # Key Components
8//!
9//! - [`WebSocketMiddleware`]: The core trait for implementing middleware
10//! - [`MiddlewareStack`]: A composable stack of middleware layers
11//! - [`MiddlewareContext`]: Context passed to middleware with message and client information
12//!
13//! # Example Usage
14//!
15//! ```rust,no_run
16//! use binary_options_tools_core_pre::middleware::{WebSocketMiddleware, MiddlewareContext};
17//! use binary_options_tools_core_pre::traits::AppState;
18//! use binary_options_tools_core_pre::error::CoreResult;
19//! use async_trait::async_trait;
20//! use tokio_tungstenite::tungstenite::Message;
21//! use std::sync::Arc;
22//!
23//! #[derive(Debug)]
24//! struct MyState;
25//! impl AppState for MyState {
26//!     fn clear_temporal_data(&self) {}
27//! }
28//!
29//! // Example statistics middleware
30//! struct StatisticsMiddleware {
31//!     sent_count: Arc<std::sync::atomic::AtomicU64>,
32//!     received_count: Arc<std::sync::atomic::AtomicU64>,
33//! }
34//!
35//! #[async_trait]
36//! impl WebSocketMiddleware<MyState> for StatisticsMiddleware {
37//!     async fn on_send(&self, message: &Message, context: &MiddlewareContext<MyState>) -> CoreResult<()> {
38//!         self.sent_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
39//!         println!("Sending message: {:?}", message);
40//!         Ok(())
41//!     }
42//!
43//!     async fn on_receive(&self, message: &Message, context: &MiddlewareContext<MyState>) -> CoreResult<()> {
44//!         self.received_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
45//!         println!("Received message: {:?}", message);
46//!         Ok(())
47//!     }
48//! }
49//! ```
50
51use crate::error::CoreResult;
52use crate::traits::AppState;
53use async_trait::async_trait;
54use kanal::AsyncSender;
55use std::sync::Arc;
56use tokio_tungstenite::tungstenite::Message;
57use tracing::{error, warn};
58
59/// Context information passed to middleware during message processing.
60///
61/// This struct provides middleware with access to the application state
62/// and the WebSocket sender channel for sending messages.
63#[derive(Debug)]
64pub struct MiddlewareContext<S: AppState> {
65    /// The shared application state
66    pub state: Arc<S>,
67    /// The WebSocket sender for outgoing messages
68    pub ws_sender: AsyncSender<Message>,
69}
70
71impl<S: AppState> MiddlewareContext<S> {
72    /// Creates a new middleware context.
73    pub fn new(state: Arc<S>, ws_sender: AsyncSender<Message>) -> Self {
74        Self { state, ws_sender }
75    }
76}
77
78/// Trait for implementing WebSocket middleware.
79///
80/// Middleware can observe and react to WebSocket messages being sent and received.
81/// This trait provides hooks for both outgoing and incoming messages.
82///
83/// # Type Parameters
84/// - `S`: The application state type that implements [`AppState`]
85///
86/// # Methods
87/// - [`on_send`]: Called before a message is sent to the WebSocket
88/// - [`on_receive`]: Called after a message is received from the WebSocket
89/// - [`on_connect`]: Called when a WebSocket connection is established
90/// - [`on_disconnect`]: Called when a WebSocket connection is lost
91///
92/// # Error Handling
93/// Middleware should be designed to be resilient. If middleware returns an error,
94/// it will be logged but will not prevent the message from being processed or
95/// other middleware from running.
96#[async_trait]
97pub trait WebSocketMiddleware<S: AppState>: Send + Sync + 'static {
98    /// Called before a message is sent to the WebSocket.
99    ///
100    /// # Arguments
101    /// - `message`: The message that will be sent
102    /// - `context`: Context information including state and sender
103    ///
104    /// # Returns
105    /// - `Ok(())` if the middleware processed successfully
106    /// - `Err(_)` if an error occurred (will be logged but not block processing)
107    async fn on_send(&self, message: &Message, context: &MiddlewareContext<S>) -> CoreResult<()> {
108        // Default implementation does nothing
109        let _ = (message, context);
110        Ok(())
111    }
112
113    /// Called after a message is received from the WebSocket.
114    ///
115    /// # Arguments
116    /// - `message`: The message that was received
117    /// - `context`: Context information including state and sender
118    ///
119    /// # Returns
120    /// - `Ok(())` if the middleware processed successfully
121    /// - `Err(_)` if an error occurred (will be logged but not block processing)
122    async fn on_receive(
123        &self,
124        message: &Message,
125        context: &MiddlewareContext<S>,
126    ) -> CoreResult<()> {
127        // Default implementation does nothing
128        let _ = (message, context);
129        Ok(())
130    }
131
132    /// Called when a WebSocket connection is established.
133    ///
134    /// # Arguments
135    /// - `context`: Context information including state and sender
136    ///
137    /// # Returns
138    /// - `Ok(())` if the middleware processed successfully
139    /// - `Err(_)` if an error occurred (will be logged but not block processing)
140    async fn on_connect(&self, context: &MiddlewareContext<S>) -> CoreResult<()> {
141        // Default implementation does nothing
142        let _ = context;
143        Ok(())
144    }
145
146    /// Called when a WebSocket connection is lost.
147    ///
148    /// # Arguments
149    /// - `context`: Context information including state and sender
150    ///
151    /// # Returns
152    /// - `Ok(())` if the middleware processed successfully
153    /// - `Err(_)` if an error occurred (will be logged but not block processing)
154    async fn on_disconnect(&self, context: &MiddlewareContext<S>) -> CoreResult<()> {
155        // Default implementation does nothing
156        let _ = context;
157        Ok(())
158    }
159
160    /// Called when a connection attempt is made (before actual connection)
161    async fn on_connection_attempt(&self, _context: &MiddlewareContext<S>) -> CoreResult<()> {
162        Ok(())
163    }
164
165    /// Called when a connection attempt fails
166    async fn on_connection_failure(
167        &self,
168        _context: &MiddlewareContext<S>,
169        _reason: Option<String>,
170    ) -> CoreResult<()> {
171        Ok(())
172    }
173}
174
175/// A composable stack of middleware layers.
176///
177/// This struct holds a collection of middleware that will be executed in order.
178/// Middleware are executed in the order they are added to the stack.
179///
180/// # Example
181/// ```rust,no_run
182/// use binary_options_tools_core_pre::middleware::MiddlewareStack;
183/// # use binary_options_tools_core_pre::middleware::WebSocketMiddleware;
184/// # use binary_options_tools_core_pre::traits::AppState;
185/// # use async_trait::async_trait;
186/// # #[derive(Debug)]
187/// # struct MyState;
188/// # impl AppState for MyState {
189/// #     fn clear_temporal_data(&self) {}
190/// # }
191/// # struct LoggingMiddleware;
192/// # #[async_trait]
193/// # impl WebSocketMiddleware<MyState> for LoggingMiddleware {}
194/// # struct StatisticsMiddleware;
195/// # impl StatisticsMiddleware {
196/// #     fn new() -> Self { Self }
197/// # }
198/// # #[async_trait]
199/// # impl WebSocketMiddleware<MyState> for StatisticsMiddleware {}
200///
201/// let mut stack = MiddlewareStack::new();
202/// stack.add_layer(Box::new(LoggingMiddleware));
203/// stack.add_layer(Box::new(StatisticsMiddleware::new()));
204/// ```
205pub struct MiddlewareStack<S: AppState> {
206    layers: Vec<Box<dyn WebSocketMiddleware<S> + Send + Sync>>,
207}
208
209impl<S: AppState> MiddlewareStack<S> {
210    /// Creates a new empty middleware stack.
211    pub fn new() -> Self {
212        Self { layers: Vec::new() }
213    }
214
215    /// Adds a middleware layer to the stack.
216    ///
217    /// Middleware will be executed in the order they are added.
218    pub fn add_layer(&mut self, middleware: Box<dyn WebSocketMiddleware<S> + Send + Sync>) {
219        self.layers.push(middleware);
220    }
221
222    /// Executes all middleware for an outgoing message.
223    ///
224    /// # Arguments
225    /// - `message`: The message being sent
226    /// - `context`: Context information
227    ///
228    /// # Behavior
229    /// All middleware will be executed even if some fail. Errors are logged but
230    /// do not prevent other middleware from running.
231    pub async fn on_send(&self, message: &Message, context: &MiddlewareContext<S>) {
232        for (index, middleware) in self.layers.iter().enumerate() {
233            if let Err(e) = middleware.on_send(message, context).await {
234                error!(
235                    target: "Middleware",
236                    "Error in middleware layer {} on_send: {:?}",
237                    index, e
238                );
239            }
240        }
241    }
242
243    /// Executes all middleware for an incoming message.
244    ///
245    /// # Arguments
246    /// - `message`: The message that was received
247    /// - `context`: Context information
248    ///
249    /// # Behavior
250    /// All middleware will be executed even if some fail. Errors are logged but
251    /// do not prevent other middleware from running.
252    pub async fn on_receive(&self, message: &Message, context: &MiddlewareContext<S>) {
253        for (index, middleware) in self.layers.iter().enumerate() {
254            if let Err(e) = middleware.on_receive(message, context).await {
255                error!(
256                    target: "Middleware",
257                    "Error in middleware layer {} on_receive: {:?}",
258                    index, e
259                );
260            }
261        }
262    }
263
264    /// Executes all middleware for connection establishment.
265    ///
266    /// # Arguments
267    /// - `context`: Context information
268    ///
269    /// # Behavior
270    /// All middleware will be executed even if some fail. Errors are logged but
271    /// do not prevent other middleware from running.
272    pub async fn on_connect(&self, context: &MiddlewareContext<S>) {
273        for (index, middleware) in self.layers.iter().enumerate() {
274            if let Err(e) = middleware.on_connect(context).await {
275                error!(
276                    target: "Middleware",
277                    "Error in middleware layer {} on_connect: {:?}",
278                    index, e
279                );
280            }
281        }
282    }
283
284    /// Executes all middleware for connection loss.
285    ///
286    /// # Arguments
287    /// - `context`: Context information
288    ///
289    /// # Behavior
290    /// All middleware will be executed even if some fail. Errors are logged but
291    /// do not prevent other middleware from running.
292    pub async fn on_disconnect(&self, context: &MiddlewareContext<S>) {
293        for (index, middleware) in self.layers.iter().enumerate() {
294            if let Err(e) = middleware.on_disconnect(context).await {
295                warn!(
296                    target: "Middleware",
297                    "Error in middleware layer {} on_disconnect: {:?}",
298                    index, e
299                );
300            }
301        }
302    }
303
304    /// Record a connection attempt across all middleware
305    pub async fn record_connection_attempt(&self, context: &MiddlewareContext<S>) {
306        for (index, middleware) in self.layers.iter().enumerate() {
307            if let Err(e) = middleware.on_connection_attempt(context).await {
308                warn!(
309                    target: "Middleware",
310                    "Error in middleware layer {} on_connection_attempt: {:?}",
311                    index, e
312                );
313            }
314        }
315    }
316
317    /// Record a connection failure across all middleware
318    pub async fn record_connection_failure(
319        &self,
320        context: &MiddlewareContext<S>,
321        reason: Option<String>,
322    ) {
323        for (index, middleware) in self.layers.iter().enumerate() {
324            if let Err(e) = middleware
325                .on_connection_failure(context, reason.clone())
326                .await
327            {
328                warn!(
329                    target: "Middleware",
330                    "Error in middleware layer {} on_connection_failure: {:?}",
331                    index, e
332                );
333            }
334        }
335    }
336
337    /// Returns the number of middleware layers in the stack.
338    pub fn len(&self) -> usize {
339        self.layers.len()
340    }
341
342    /// Returns true if the stack is empty.
343    pub fn is_empty(&self) -> bool {
344        self.layers.is_empty()
345    }
346}
347
348impl<S: AppState> Default for MiddlewareStack<S> {
349    fn default() -> Self {
350        Self::new()
351    }
352}
353
354/// A builder for creating middleware stacks in a fluent manner.
355///
356/// This provides a convenient way to chain middleware additions.
357///
358/// # Example
359/// ```rust,no_run
360/// use binary_options_tools_core_pre::middleware::MiddlewareStackBuilder;
361/// # use binary_options_tools_core_pre::middleware::WebSocketMiddleware;
362/// # use binary_options_tools_core_pre::traits::AppState;
363/// # use async_trait::async_trait;
364/// # #[derive(Debug)]
365/// # struct MyState;
366/// # impl AppState for MyState {
367/// #     fn clear_temporal_data(&self) {}
368/// # }
369/// # struct LoggingMiddleware;
370/// # #[async_trait]
371/// # impl WebSocketMiddleware<MyState> for LoggingMiddleware {}
372/// # struct StatisticsMiddleware;
373/// # impl StatisticsMiddleware {
374/// #     fn new() -> Self { Self }
375/// # }
376/// # #[async_trait]
377/// # impl WebSocketMiddleware<MyState> for StatisticsMiddleware {}
378///
379/// let stack = MiddlewareStackBuilder::new()
380///     .layer(Box::new(LoggingMiddleware))
381///     .layer(Box::new(StatisticsMiddleware::new()))
382///     .build();
383/// ```
384pub struct MiddlewareStackBuilder<S: AppState> {
385    stack: MiddlewareStack<S>,
386}
387
388impl<S: AppState> MiddlewareStackBuilder<S> {
389    /// Creates a new middleware stack builder.
390    pub fn new() -> Self {
391        Self {
392            stack: MiddlewareStack::new(),
393        }
394    }
395
396    /// Adds a middleware layer to the stack.
397    pub fn layer(mut self, middleware: Box<dyn WebSocketMiddleware<S>>) -> Self {
398        self.stack.add_layer(middleware);
399        self
400    }
401
402    /// Builds and returns the middleware stack.
403    pub fn build(self) -> MiddlewareStack<S> {
404        self.stack
405    }
406}
407
408impl<S: AppState> Default for MiddlewareStackBuilder<S> {
409    fn default() -> Self {
410        Self::new()
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use std::sync::atomic::{AtomicU64, Ordering};
418
419    #[derive(Debug)]
420    struct TestState;
421
422    #[async_trait]
423    impl AppState for TestState {
424        async fn clear_temporal_data(&self) {}
425    }
426
427    struct TestMiddleware {
428        #[allow(dead_code)]
429        name: String,
430        send_count: AtomicU64,
431        receive_count: AtomicU64,
432    }
433
434    impl TestMiddleware {
435        fn new(name: impl Into<String>) -> Self {
436            Self {
437                name: name.into(),
438                send_count: AtomicU64::new(0),
439                receive_count: AtomicU64::new(0),
440            }
441        }
442    }
443
444    #[async_trait]
445    impl WebSocketMiddleware<TestState> for TestMiddleware {
446        async fn on_send(
447            &self,
448            _message: &Message,
449            _context: &MiddlewareContext<TestState>,
450        ) -> CoreResult<()> {
451            self.send_count.fetch_add(1, Ordering::Relaxed);
452            Ok(())
453        }
454
455        async fn on_receive(
456            &self,
457            _message: &Message,
458            _context: &MiddlewareContext<TestState>,
459        ) -> CoreResult<()> {
460            self.receive_count.fetch_add(1, Ordering::Relaxed);
461            Ok(())
462        }
463    }
464
465    #[tokio::test]
466    async fn test_middleware_stack() {
467        let (sender, _receiver) = kanal::bounded_async(10);
468        let state = Arc::new(TestState);
469        let context = MiddlewareContext::new(state, sender);
470
471        let middleware1 = TestMiddleware::new("test1");
472        let middleware2 = TestMiddleware::new("test2");
473
474        let mut stack = MiddlewareStack::new();
475        stack.add_layer(Box::new(middleware1));
476        stack.add_layer(Box::new(middleware2));
477
478        let message = Message::text("test");
479
480        // Test on_send
481        stack.on_send(&message, &context).await;
482
483        // Test on_receive
484        stack.on_receive(&message, &context).await;
485
486        assert_eq!(stack.len(), 2);
487        assert!(!stack.is_empty());
488    }
489
490    #[tokio::test]
491    async fn test_middleware_stack_builder() {
492        let stack = MiddlewareStackBuilder::new()
493            .layer(Box::new(TestMiddleware::new("test1")))
494            .layer(Box::new(TestMiddleware::new("test2")))
495            .build();
496
497        assert_eq!(stack.len(), 2);
498    }
499}