binary_options_tools_core_pre/middleware.rs
1//! Middleware system for WebSocket client operations.
2//!
3//! This module provides a composable middleware system inspired by Axum's middleware/layer system.
4//! Middleware can be used to observe, modify, or control the flow of WebSocket messages
5//! being sent and received by the client.
6//!
7//! # Key Components
8//!
9//! - [`WebSocketMiddleware`]: The core trait for implementing middleware
10//! - [`MiddlewareStack`]: A composable stack of middleware layers
11//! - [`MiddlewareContext`]: Context passed to middleware with message and client information
12//!
13//! # Example Usage
14//!
15//! ```rust,no_run
16//! use binary_options_tools_core_pre::middleware::{WebSocketMiddleware, MiddlewareContext};
17//! use binary_options_tools_core_pre::traits::AppState;
18//! use binary_options_tools_core_pre::error::CoreResult;
19//! use async_trait::async_trait;
20//! use tokio_tungstenite::tungstenite::Message;
21//! use std::sync::Arc;
22//!
23//! #[derive(Debug)]
24//! struct MyState;
25//! impl AppState for MyState {
26//! fn clear_temporal_data(&self) {}
27//! }
28//!
29//! // Example statistics middleware
30//! struct StatisticsMiddleware {
31//! sent_count: Arc<std::sync::atomic::AtomicU64>,
32//! received_count: Arc<std::sync::atomic::AtomicU64>,
33//! }
34//!
35//! #[async_trait]
36//! impl WebSocketMiddleware<MyState> for StatisticsMiddleware {
37//! async fn on_send(&self, message: &Message, context: &MiddlewareContext<MyState>) -> CoreResult<()> {
38//! self.sent_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
39//! println!("Sending message: {:?}", message);
40//! Ok(())
41//! }
42//!
43//! async fn on_receive(&self, message: &Message, context: &MiddlewareContext<MyState>) -> CoreResult<()> {
44//! self.received_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
45//! println!("Received message: {:?}", message);
46//! Ok(())
47//! }
48//! }
49//! ```
50
51use crate::error::CoreResult;
52use crate::traits::AppState;
53use async_trait::async_trait;
54use kanal::AsyncSender;
55use std::sync::Arc;
56use tokio_tungstenite::tungstenite::Message;
57use tracing::{error, warn};
58
59/// Context information passed to middleware during message processing.
60///
61/// This struct provides middleware with access to the application state
62/// and the WebSocket sender channel for sending messages.
63#[derive(Debug)]
64pub struct MiddlewareContext<S: AppState> {
65 /// The shared application state
66 pub state: Arc<S>,
67 /// The WebSocket sender for outgoing messages
68 pub ws_sender: AsyncSender<Message>,
69}
70
71impl<S: AppState> MiddlewareContext<S> {
72 /// Creates a new middleware context.
73 pub fn new(state: Arc<S>, ws_sender: AsyncSender<Message>) -> Self {
74 Self { state, ws_sender }
75 }
76}
77
78/// Trait for implementing WebSocket middleware.
79///
80/// Middleware can observe and react to WebSocket messages being sent and received.
81/// This trait provides hooks for both outgoing and incoming messages.
82///
83/// # Type Parameters
84/// - `S`: The application state type that implements [`AppState`]
85///
86/// # Methods
87/// - [`on_send`]: Called before a message is sent to the WebSocket
88/// - [`on_receive`]: Called after a message is received from the WebSocket
89/// - [`on_connect`]: Called when a WebSocket connection is established
90/// - [`on_disconnect`]: Called when a WebSocket connection is lost
91///
92/// # Error Handling
93/// Middleware should be designed to be resilient. If middleware returns an error,
94/// it will be logged but will not prevent the message from being processed or
95/// other middleware from running.
96#[async_trait]
97pub trait WebSocketMiddleware<S: AppState>: Send + Sync + 'static {
98 /// Called before a message is sent to the WebSocket.
99 ///
100 /// # Arguments
101 /// - `message`: The message that will be sent
102 /// - `context`: Context information including state and sender
103 ///
104 /// # Returns
105 /// - `Ok(())` if the middleware processed successfully
106 /// - `Err(_)` if an error occurred (will be logged but not block processing)
107 async fn on_send(&self, message: &Message, context: &MiddlewareContext<S>) -> CoreResult<()> {
108 // Default implementation does nothing
109 let _ = (message, context);
110 Ok(())
111 }
112
113 /// Called after a message is received from the WebSocket.
114 ///
115 /// # Arguments
116 /// - `message`: The message that was received
117 /// - `context`: Context information including state and sender
118 ///
119 /// # Returns
120 /// - `Ok(())` if the middleware processed successfully
121 /// - `Err(_)` if an error occurred (will be logged but not block processing)
122 async fn on_receive(
123 &self,
124 message: &Message,
125 context: &MiddlewareContext<S>,
126 ) -> CoreResult<()> {
127 // Default implementation does nothing
128 let _ = (message, context);
129 Ok(())
130 }
131
132 /// Called when a WebSocket connection is established.
133 ///
134 /// # Arguments
135 /// - `context`: Context information including state and sender
136 ///
137 /// # Returns
138 /// - `Ok(())` if the middleware processed successfully
139 /// - `Err(_)` if an error occurred (will be logged but not block processing)
140 async fn on_connect(&self, context: &MiddlewareContext<S>) -> CoreResult<()> {
141 // Default implementation does nothing
142 let _ = context;
143 Ok(())
144 }
145
146 /// Called when a WebSocket connection is lost.
147 ///
148 /// # Arguments
149 /// - `context`: Context information including state and sender
150 ///
151 /// # Returns
152 /// - `Ok(())` if the middleware processed successfully
153 /// - `Err(_)` if an error occurred (will be logged but not block processing)
154 async fn on_disconnect(&self, context: &MiddlewareContext<S>) -> CoreResult<()> {
155 // Default implementation does nothing
156 let _ = context;
157 Ok(())
158 }
159
160 /// Called when a connection attempt is made (before actual connection)
161 async fn on_connection_attempt(&self, _context: &MiddlewareContext<S>) -> CoreResult<()> {
162 Ok(())
163 }
164
165 /// Called when a connection attempt fails
166 async fn on_connection_failure(
167 &self,
168 _context: &MiddlewareContext<S>,
169 _reason: Option<String>,
170 ) -> CoreResult<()> {
171 Ok(())
172 }
173}
174
175/// A composable stack of middleware layers.
176///
177/// This struct holds a collection of middleware that will be executed in order.
178/// Middleware are executed in the order they are added to the stack.
179///
180/// # Example
181/// ```rust,no_run
182/// use binary_options_tools_core_pre::middleware::MiddlewareStack;
183/// # use binary_options_tools_core_pre::middleware::WebSocketMiddleware;
184/// # use binary_options_tools_core_pre::traits::AppState;
185/// # use async_trait::async_trait;
186/// # #[derive(Debug)]
187/// # struct MyState;
188/// # impl AppState for MyState {
189/// # fn clear_temporal_data(&self) {}
190/// # }
191/// # struct LoggingMiddleware;
192/// # #[async_trait]
193/// # impl WebSocketMiddleware<MyState> for LoggingMiddleware {}
194/// # struct StatisticsMiddleware;
195/// # impl StatisticsMiddleware {
196/// # fn new() -> Self { Self }
197/// # }
198/// # #[async_trait]
199/// # impl WebSocketMiddleware<MyState> for StatisticsMiddleware {}
200///
201/// let mut stack = MiddlewareStack::new();
202/// stack.add_layer(Box::new(LoggingMiddleware));
203/// stack.add_layer(Box::new(StatisticsMiddleware::new()));
204/// ```
205pub struct MiddlewareStack<S: AppState> {
206 layers: Vec<Box<dyn WebSocketMiddleware<S> + Send + Sync>>,
207}
208
209impl<S: AppState> MiddlewareStack<S> {
210 /// Creates a new empty middleware stack.
211 pub fn new() -> Self {
212 Self { layers: Vec::new() }
213 }
214
215 /// Adds a middleware layer to the stack.
216 ///
217 /// Middleware will be executed in the order they are added.
218 pub fn add_layer(&mut self, middleware: Box<dyn WebSocketMiddleware<S> + Send + Sync>) {
219 self.layers.push(middleware);
220 }
221
222 /// Executes all middleware for an outgoing message.
223 ///
224 /// # Arguments
225 /// - `message`: The message being sent
226 /// - `context`: Context information
227 ///
228 /// # Behavior
229 /// All middleware will be executed even if some fail. Errors are logged but
230 /// do not prevent other middleware from running.
231 pub async fn on_send(&self, message: &Message, context: &MiddlewareContext<S>) {
232 for (index, middleware) in self.layers.iter().enumerate() {
233 if let Err(e) = middleware.on_send(message, context).await {
234 error!(
235 target: "Middleware",
236 "Error in middleware layer {} on_send: {:?}",
237 index, e
238 );
239 }
240 }
241 }
242
243 /// Executes all middleware for an incoming message.
244 ///
245 /// # Arguments
246 /// - `message`: The message that was received
247 /// - `context`: Context information
248 ///
249 /// # Behavior
250 /// All middleware will be executed even if some fail. Errors are logged but
251 /// do not prevent other middleware from running.
252 pub async fn on_receive(&self, message: &Message, context: &MiddlewareContext<S>) {
253 for (index, middleware) in self.layers.iter().enumerate() {
254 if let Err(e) = middleware.on_receive(message, context).await {
255 error!(
256 target: "Middleware",
257 "Error in middleware layer {} on_receive: {:?}",
258 index, e
259 );
260 }
261 }
262 }
263
264 /// Executes all middleware for connection establishment.
265 ///
266 /// # Arguments
267 /// - `context`: Context information
268 ///
269 /// # Behavior
270 /// All middleware will be executed even if some fail. Errors are logged but
271 /// do not prevent other middleware from running.
272 pub async fn on_connect(&self, context: &MiddlewareContext<S>) {
273 for (index, middleware) in self.layers.iter().enumerate() {
274 if let Err(e) = middleware.on_connect(context).await {
275 error!(
276 target: "Middleware",
277 "Error in middleware layer {} on_connect: {:?}",
278 index, e
279 );
280 }
281 }
282 }
283
284 /// Executes all middleware for connection loss.
285 ///
286 /// # Arguments
287 /// - `context`: Context information
288 ///
289 /// # Behavior
290 /// All middleware will be executed even if some fail. Errors are logged but
291 /// do not prevent other middleware from running.
292 pub async fn on_disconnect(&self, context: &MiddlewareContext<S>) {
293 for (index, middleware) in self.layers.iter().enumerate() {
294 if let Err(e) = middleware.on_disconnect(context).await {
295 warn!(
296 target: "Middleware",
297 "Error in middleware layer {} on_disconnect: {:?}",
298 index, e
299 );
300 }
301 }
302 }
303
304 /// Record a connection attempt across all middleware
305 pub async fn record_connection_attempt(&self, context: &MiddlewareContext<S>) {
306 for (index, middleware) in self.layers.iter().enumerate() {
307 if let Err(e) = middleware.on_connection_attempt(context).await {
308 warn!(
309 target: "Middleware",
310 "Error in middleware layer {} on_connection_attempt: {:?}",
311 index, e
312 );
313 }
314 }
315 }
316
317 /// Record a connection failure across all middleware
318 pub async fn record_connection_failure(
319 &self,
320 context: &MiddlewareContext<S>,
321 reason: Option<String>,
322 ) {
323 for (index, middleware) in self.layers.iter().enumerate() {
324 if let Err(e) = middleware
325 .on_connection_failure(context, reason.clone())
326 .await
327 {
328 warn!(
329 target: "Middleware",
330 "Error in middleware layer {} on_connection_failure: {:?}",
331 index, e
332 );
333 }
334 }
335 }
336
337 /// Returns the number of middleware layers in the stack.
338 pub fn len(&self) -> usize {
339 self.layers.len()
340 }
341
342 /// Returns true if the stack is empty.
343 pub fn is_empty(&self) -> bool {
344 self.layers.is_empty()
345 }
346}
347
348impl<S: AppState> Default for MiddlewareStack<S> {
349 fn default() -> Self {
350 Self::new()
351 }
352}
353
354/// A builder for creating middleware stacks in a fluent manner.
355///
356/// This provides a convenient way to chain middleware additions.
357///
358/// # Example
359/// ```rust,no_run
360/// use binary_options_tools_core_pre::middleware::MiddlewareStackBuilder;
361/// # use binary_options_tools_core_pre::middleware::WebSocketMiddleware;
362/// # use binary_options_tools_core_pre::traits::AppState;
363/// # use async_trait::async_trait;
364/// # #[derive(Debug)]
365/// # struct MyState;
366/// # impl AppState for MyState {
367/// # fn clear_temporal_data(&self) {}
368/// # }
369/// # struct LoggingMiddleware;
370/// # #[async_trait]
371/// # impl WebSocketMiddleware<MyState> for LoggingMiddleware {}
372/// # struct StatisticsMiddleware;
373/// # impl StatisticsMiddleware {
374/// # fn new() -> Self { Self }
375/// # }
376/// # #[async_trait]
377/// # impl WebSocketMiddleware<MyState> for StatisticsMiddleware {}
378///
379/// let stack = MiddlewareStackBuilder::new()
380/// .layer(Box::new(LoggingMiddleware))
381/// .layer(Box::new(StatisticsMiddleware::new()))
382/// .build();
383/// ```
384pub struct MiddlewareStackBuilder<S: AppState> {
385 stack: MiddlewareStack<S>,
386}
387
388impl<S: AppState> MiddlewareStackBuilder<S> {
389 /// Creates a new middleware stack builder.
390 pub fn new() -> Self {
391 Self {
392 stack: MiddlewareStack::new(),
393 }
394 }
395
396 /// Adds a middleware layer to the stack.
397 pub fn layer(mut self, middleware: Box<dyn WebSocketMiddleware<S>>) -> Self {
398 self.stack.add_layer(middleware);
399 self
400 }
401
402 /// Builds and returns the middleware stack.
403 pub fn build(self) -> MiddlewareStack<S> {
404 self.stack
405 }
406}
407
408impl<S: AppState> Default for MiddlewareStackBuilder<S> {
409 fn default() -> Self {
410 Self::new()
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use std::sync::atomic::{AtomicU64, Ordering};
418
419 #[derive(Debug)]
420 struct TestState;
421
422 #[async_trait]
423 impl AppState for TestState {
424 async fn clear_temporal_data(&self) {}
425 }
426
427 struct TestMiddleware {
428 #[allow(dead_code)]
429 name: String,
430 send_count: AtomicU64,
431 receive_count: AtomicU64,
432 }
433
434 impl TestMiddleware {
435 fn new(name: impl Into<String>) -> Self {
436 Self {
437 name: name.into(),
438 send_count: AtomicU64::new(0),
439 receive_count: AtomicU64::new(0),
440 }
441 }
442 }
443
444 #[async_trait]
445 impl WebSocketMiddleware<TestState> for TestMiddleware {
446 async fn on_send(
447 &self,
448 _message: &Message,
449 _context: &MiddlewareContext<TestState>,
450 ) -> CoreResult<()> {
451 self.send_count.fetch_add(1, Ordering::Relaxed);
452 Ok(())
453 }
454
455 async fn on_receive(
456 &self,
457 _message: &Message,
458 _context: &MiddlewareContext<TestState>,
459 ) -> CoreResult<()> {
460 self.receive_count.fetch_add(1, Ordering::Relaxed);
461 Ok(())
462 }
463 }
464
465 #[tokio::test]
466 async fn test_middleware_stack() {
467 let (sender, _receiver) = kanal::bounded_async(10);
468 let state = Arc::new(TestState);
469 let context = MiddlewareContext::new(state, sender);
470
471 let middleware1 = TestMiddleware::new("test1");
472 let middleware2 = TestMiddleware::new("test2");
473
474 let mut stack = MiddlewareStack::new();
475 stack.add_layer(Box::new(middleware1));
476 stack.add_layer(Box::new(middleware2));
477
478 let message = Message::text("test");
479
480 // Test on_send
481 stack.on_send(&message, &context).await;
482
483 // Test on_receive
484 stack.on_receive(&message, &context).await;
485
486 assert_eq!(stack.len(), 2);
487 assert!(!stack.is_empty());
488 }
489
490 #[tokio::test]
491 async fn test_middleware_stack_builder() {
492 let stack = MiddlewareStackBuilder::new()
493 .layer(Box::new(TestMiddleware::new("test1")))
494 .layer(Box::new(TestMiddleware::new("test2")))
495 .build();
496
497 assert_eq!(stack.len(), 2);
498 }
499}