Skip to main content

binary_options_tools_core_pre/
client.rs

1use crate::callback::ConnectionCallback;
2use crate::connector::Connector;
3use crate::error::CoreResult;
4use crate::middleware::{MiddlewareContext, MiddlewareStack};
5use crate::signals::Signals;
6use crate::traits::{ApiModule, AppState, ReconnectCallback, Rule, RunnerCommand};
7use futures_util::{stream::StreamExt, SinkExt};
8use kanal::{AsyncReceiver, AsyncSender};
9use rand::Rng;
10use std::any::{Any, TypeId};
11use std::collections::HashMap;
12use std::future::Future;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15use tokio::task::JoinSet;
16use tokio_tungstenite::tungstenite::Message;
17use tracing::{debug, error, info, warn};
18
19/// A lightweight handler is a function that can process messages without being tied to a specific module.
20/// It can be used for quick, non-blocking operations that don't require a full module lifecycle
21/// or state management.
22/// It takes a message, the shared application state, and a sender for outgoing messages.
23/// It returns a future that resolves to a `CoreResult<()>`, indicating success or failure.
24/// This is useful for handling messages that need to be processed quickly or in a lightweight manner,
25/// such as logging, simple transformations, or forwarding messages to other parts of the system.
26pub type LightweightHandler<S> = Box<
27    dyn Fn(
28            Arc<Message>,
29            Arc<S>,
30            &AsyncSender<Message>,
31        ) -> futures_util::future::BoxFuture<'static, CoreResult<()>>
32        + Send
33        + Sync,
34>;
35
36type RuleTp = (Box<dyn Rule + Send + Sync>, AsyncSender<Arc<Message>>);
37
38// --- Internal Router ---
39pub struct Router<S: AppState> {
40    pub(crate) state: Arc<S>,
41    pub(crate) module_rules: Vec<RuleTp>,
42    pub(crate) module_set: JoinSet<()>,
43    pub(crate) lightweight_rules: Vec<RuleTp>,
44    pub(crate) lightweight_handlers: Vec<LightweightHandler<S>>,
45    pub(crate) lightweight_set: JoinSet<()>,
46    pub(crate) middleware_stack: MiddlewareStack<S>,
47}
48
49impl<S: AppState> Router<S> {
50    pub fn new(state: Arc<S>) -> Self {
51        Self {
52            state,
53            module_rules: Vec::new(),
54            module_set: JoinSet::new(),
55            lightweight_rules: Vec::new(),
56            lightweight_handlers: Vec::new(),
57            lightweight_set: JoinSet::new(),
58            middleware_stack: MiddlewareStack::new(),
59        }
60    }
61
62    pub fn spawn_module<F: Future<Output = ()> + Send + 'static>(&mut self, task: F) {
63        self.module_set.spawn(task);
64    }
65
66    pub fn add_module_rule(
67        &mut self,
68        rule: Box<dyn Rule + Send + Sync>,
69        sender: AsyncSender<Arc<Message>>,
70    ) {
71        self.module_rules.push((rule, sender));
72    }
73
74    pub fn add_lightweight_rule(
75        &mut self,
76        rule: Box<dyn Rule + Send + Sync>,
77        sender: AsyncSender<Arc<Message>>,
78    ) {
79        self.lightweight_rules.push((rule, sender));
80    }
81
82    pub fn add_lightweight_handler(&mut self, handler: LightweightHandler<S>) {
83        self.lightweight_handlers.push(handler);
84    }
85
86    pub fn spawn_lightweight_module<F: Future<Output = ()> + Send + 'static>(&mut self, task: F) {
87        self.lightweight_set.spawn(task);
88    }
89
90    /// Routes incoming WebSocket messages to appropriate handlers and modules.
91    ///
92    /// This method implements the core message routing logic with middleware integration:
93    /// 1. **Middleware on_receive**: Called first for all incoming messages
94    /// 2. **Lightweight handlers**: Processed for quick operations
95    /// 3. **Lightweight modules**: Routed based on routing rules
96    /// 4. **API modules**: Routed to matching modules
97    ///
98    /// # Middleware Integration
99    /// The `on_receive` middleware hook is called at the beginning of message processing,
100    /// allowing middleware to observe, log, or transform incoming messages before they
101    /// reach the application logic.
102    ///
103    /// # Arguments
104    /// - `message`: The incoming WebSocket message wrapped in Arc for sharing
105    /// - `sender`: Channel for sending outgoing messages
106    async fn route(&self, message: Arc<Message>, sender: &AsyncSender<Message>) -> CoreResult<()> {
107        // Route to all lightweight handlers first
108        debug!(target: "Router", "Routing message: {message:?}");
109
110        // Create middleware context
111        let middleware_context = MiddlewareContext::new(Arc::clone(&self.state), sender.clone());
112
113        // 🎯 MIDDLEWARE HOOK: on_receive - called for ALL incoming messages
114        // This is where middleware can observe, log, or process incoming messages
115        self.middleware_stack
116            .on_receive(&message, &middleware_context)
117            .await;
118
119        for handler in &self.lightweight_handlers {
120            if let Err(err) = handler(Arc::clone(&message), Arc::clone(&self.state), sender).await {
121                error!(target: "Router",
122                     "Lightweight handler error: {err:#?}"
123                );
124            }
125        }
126        for (rule, sender) in &self.lightweight_rules {
127            // If the rule matches, send the message to the lightweight handler
128            if rule.call(&message) && sender.send(message.clone()).await.is_err() {
129                error!(target: "Router", "A lightweight module has shut down and its channel is closed.");
130            }
131        }
132
133        // Route to the first matching API module
134        for (rule, sender) in &self.module_rules {
135            if rule.call(&message) && sender.send(message.clone()).await.is_err() {
136                error!(target: "Router", "A module has shut down and its channel is closed.");
137            }
138        }
139        Ok(())
140    }
141}
142
143// --- The Public-Facing Handle ---
144#[derive(Debug)]
145pub struct Client<S: AppState> {
146    pub signal: Signals,
147    /// The shared application state, which can be used by modules and handlers.
148    pub state: Arc<S>,
149    pub module_handles: Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
150    pub to_ws_sender: AsyncSender<Message>,
151
152    runner_command_tx: AsyncSender<RunnerCommand>,
153}
154
155impl<S: AppState> Clone for Client<S> {
156    fn clone(&self) -> Self {
157        Self {
158            signal: self.signal.clone(),
159            state: Arc::clone(&self.state),
160            module_handles: Arc::clone(&self.module_handles),
161            runner_command_tx: self.runner_command_tx.clone(),
162            to_ws_sender: self.to_ws_sender.clone(),
163        }
164    }
165}
166
167impl<S: AppState> Client<S> {
168    // In a real implementation, this would be created by the builder.
169    pub fn new(
170        signal: Signals,
171        runner_command_tx: AsyncSender<RunnerCommand>,
172        state: Arc<S>,
173        sender: AsyncSender<Message>,
174    ) -> Self {
175        Self {
176            signal,
177            state,
178            module_handles: Arc::new(RwLock::new(HashMap::new())),
179            runner_command_tx,
180            to_ws_sender: sender,
181        }
182    }
183
184    /// Waits until the client is connected to the WebSocket server.
185    /// This method will block until the connection is established.
186    /// It is useful for ensuring that the client is ready to send and receive messages.
187    pub async fn wait_connected(&self) {
188        self.signal.wait_connected().await
189    }
190
191    /// Checks if the client is connected to the WebSocket server.
192    pub fn is_connected(&self) -> bool {
193        self.signal.is_connected()
194    }
195
196    /// Retrieves a clonable, typed handle to an already-registered module.
197    pub async fn get_handle<M: ApiModule<S>>(&self) -> Option<M::Handle> {
198        let handles = self.module_handles.read().await;
199        handles
200            .get(&TypeId::of::<M>())
201            .and_then(|boxed_handle| boxed_handle.downcast_ref::<M::Handle>())
202            .cloned()
203    }
204
205    /// Commands the runner to disconnect, clear state, and perform a "hard" reconnect.
206    pub async fn disconnect(&self) -> CoreResult<()> {
207        Ok(self
208            .runner_command_tx
209            .send(RunnerCommand::Disconnect)
210            .await?)
211    }
212
213    /// Commands the runner to disconnect, and perform a "soft" reconnect.
214    pub async fn reconnect(&self) -> CoreResult<()> {
215        Ok(self
216            .runner_command_tx
217            .send(RunnerCommand::Reconnect)
218            .await?)
219    }
220
221    /// Commands the runner to shutdown, this action is final as the runner and client will stop working and will be dropped.
222    pub async fn shutdown(self) -> CoreResult<()> {
223        self.runner_command_tx
224            .send(RunnerCommand::Shutdown)
225            .await
226            .inspect_err(|e| {
227                error!(target: "Client", "Failed to send shutdown command: {e}");
228            })?;
229        drop(self);
230        info!(target: "Client", "Runner shutdown command sent.");
231        Ok(())
232    }
233
234    /// Commands the runner to shutdown without consuming the client.
235    pub async fn shutdown_ref(&self) -> CoreResult<()> {
236        self.runner_command_tx
237            .send(RunnerCommand::Shutdown)
238            .await
239            .inspect_err(|e| {
240                error!(target: "Client", "Failed to send shutdown command: {e}");
241            })?;
242        info!(target: "Client", "Runner shutdown command sent (via ref).");
243        Ok(())
244    }
245
246    /// Send a message to the WebSocket
247    pub async fn send_message(&self, message: Message) -> CoreResult<()> {
248        self.to_ws_sender.send(message).await.inspect_err(|e| {
249            error!(target: "Client", "Failed to send message to WebSocket: {e}");
250        })?;
251        Ok(())
252    }
253
254    /// Send a text message to the WebSocket
255    pub async fn send_text(&self, text: String) -> CoreResult<()> {
256        self.send_message(Message::text(text)).await
257    }
258
259    /// Send a binary message to the WebSocket
260    pub async fn send_binary(&self, data: Vec<u8>) -> CoreResult<()> {
261        self.send_message(Message::binary(data)).await
262    }
263}
264
265// --- The Background Worker ---
266/// Implementation of the `ClientRunner` for managing WebSocket client connections and session lifecycle.
267pub struct ClientRunner<S: AppState> {
268    /// Notify the client of connection status changes.
269    pub(crate) signal: Signals,
270    pub(crate) connector: Arc<dyn Connector<S>>,
271    pub(crate) router: Arc<Router<S>>,
272    pub(crate) state: Arc<S>,
273    // Flag to determine if the next connection is a fresh one.
274    pub(crate) is_hard_disconnect: bool,
275    // Flag to terminate the main run loop.
276    pub(crate) shutdown_requested: bool,
277
278    pub(crate) connection_callback: ConnectionCallback<S>,
279    pub(crate) to_ws_sender: AsyncSender<Message>,
280    pub(crate) to_ws_receiver: AsyncReceiver<Message>,
281    pub(crate) runner_command_rx: AsyncReceiver<RunnerCommand>,
282
283    // Track reconnection attempts for exponential backoff
284    pub(crate) reconnect_attempts: u32,
285
286    pub(crate) max_allowed_loops: u32,
287    pub(crate) reconnect_delay: std::time::Duration,
288}
289
290impl<S: AppState> ClientRunner<S> {
291    /// Main client runner loop that manages WebSocket connections and message processing.
292    pub async fn run(&mut self) {
293        // TODO: Add a way to disconnect and keep the connection closed intill specified otherwhise
294        // The outermost loop runs until a shutdown is commanded.
295        while !self.shutdown_requested {
296            // Execute middleware on_connect hook
297            let middleware_context =
298                MiddlewareContext::new(Arc::clone(&self.state), self.to_ws_sender.clone());
299            debug!(target: "Runner", "Starting connection cycle...");
300
301            // Call middleware to record connection attempt
302            self.router
303                .middleware_stack
304                .record_connection_attempt(&middleware_context)
305                .await;
306
307            // Use the correct connection method based on the flag.
308            let stream_result = if self.is_hard_disconnect {
309                self.connector.connect(self.state.clone()).await
310            } else {
311                self.connector.reconnect(self.state.clone()).await
312            };
313
314            let ws_stream = match stream_result {
315                Ok(stream) => stream,
316                Err(e) => {
317                    self.reconnect_attempts += 1;
318
319                    if self.max_allowed_loops > 0
320                        && self.reconnect_attempts >= self.max_allowed_loops
321                    {
322                        error!(target: "Runner", "Maximum reconnection attempts ({}) reached. Shutting down.", self.max_allowed_loops);
323                        self.shutdown_requested = true;
324                        break;
325                    }
326
327                    // Use configured reconnect_delay with exponential backoff if it's > 0, else use a default
328                    let base_delay = if self.reconnect_delay.as_secs() > 0 {
329                        self.reconnect_delay.as_secs()
330                    } else {
331                        5
332                    };
333
334                    let delay_secs = std::cmp::min(
335                        base_delay
336                            .saturating_mul(2u64.saturating_pow(self.reconnect_attempts.min(10))),
337                        300,
338                    );
339                    // Add jitter
340                    let jitter = rand::rng().random_range(0.8..1.2);
341                    let delay = std::time::Duration::from_secs_f64(delay_secs as f64 * jitter);
342
343                    warn!(target: "Runner", "Connection failed (attempt {}/{}): {e}. Retrying in {:?}...",
344                        self.reconnect_attempts,
345                        if self.max_allowed_loops > 0 { self.max_allowed_loops.to_string() } else { "∞".to_string() },
346                        delay);
347                    tokio::time::sleep(delay).await;
348                    // On failure, the next attempt is a reconnect, not a hard connect.
349                    self.is_hard_disconnect = false;
350                    continue; // Restart the connection cycle.
351                }
352            };
353
354            // 🎯 MIDDLEWARE HOOK: on_connect - called after successful connection
355            // Location: After WebSocket connection is established
356            debug!(target: "Runner", "Connection successful.");
357            self.signal.set_connected();
358
359            // Track connection start time to reset attempts only if stable
360            let connection_start = std::time::Instant::now();
361            let mut attempts_reset = false;
362            self.router
363                .middleware_stack
364                .on_connect(&middleware_context)
365                .await;
366
367            // Execute the correct callback.
368            if self.is_hard_disconnect {
369                debug!(target: "Runner", "Executing on_connect callback.");
370                // Handle any error from on_connect
371                if let Err(err) =
372                    (self.connection_callback.on_connect)(self.state.clone(), &self.to_ws_sender)
373                        .await
374                {
375                    warn!(
376                        target: "Runner",
377                        "on_connect callback failed: {err:#?}"
378                    );
379                }
380            } else {
381                debug!(target: "Runner", "Executing on_reconnect callback.");
382                // Handle any error from on_reconnect
383                if let Err(err) = self
384                    .connection_callback
385                    .on_reconnect
386                    .call(self.state.clone(), &self.to_ws_sender)
387                    .await
388                {
389                    warn!(
390                        target: "Runner",
391                        "on_reconnect callback failed: {err:#?}"
392                    );
393                }
394            } // A successful connection means the next one is a "reconnect" unless told otherwise.
395            self.is_hard_disconnect = false;
396
397            let (mut ws_writer, mut ws_reader) = ws_stream.split();
398
399            // 🎯 MIDDLEWARE HOOK: on_send - called in writer task for outgoing messages
400            let writer_task = tokio::spawn({
401                let to_ws_rx = self.to_ws_receiver.clone();
402                let router = Arc::clone(&self.router);
403                let state = Arc::clone(&self.state);
404                let to_ws_sender = self.to_ws_sender.clone();
405                async move {
406                    let middleware_context = MiddlewareContext::new(state, to_ws_sender);
407                    while let Ok(msg) = to_ws_rx.recv().await {
408                        // Execute middleware on_send hook
409                        router
410                            .middleware_stack
411                            .on_send(&msg, &middleware_context)
412                            .await;
413                        if ws_writer.send(msg).await.is_err() {
414                            error!(target: "Runner", "WebSocket writer task failed to send message.");
415                            break;
416                        }
417                    }
418                }
419            });
420
421            let reader_task = tokio::spawn({
422                let to_ws_sender = self.to_ws_sender.clone();
423                let router = Arc::clone(&self.router); // Use Arc for sharing
424                async move {
425                    while let Some(Ok(msg)) = ws_reader.next().await {
426                        if let Err(e) = router.route(Arc::new(msg), &to_ws_sender).await {
427                            warn!(target: "Router", "Error routing message: {:?}", e);
428                        }
429                    }
430                }
431            });
432
433            // --- Active Session Loop ---
434            // This loop runs as long as the connection is stable or no commands are received.
435            let mut writer_task_opt = Some(writer_task);
436            let mut reader_task_opt: Option<tokio::task::JoinHandle<()>> = Some(reader_task);
437
438            let mut session_active = true;
439
440            // Temporal timer so we i can check the duration of a connection
441            // let temporal_timer = std::time::Instant::now();
442            while session_active {
443                // Reset reconnect attempts if connection has been stable for > 10s
444                if !attempts_reset
445                    && connection_start.elapsed() > std::time::Duration::from_secs(10)
446                {
447                    self.reconnect_attempts = 0;
448                    attempts_reset = true;
449                    debug!(target: "Runner", "Connection stable, resetting reconnect attempts.");
450                }
451
452                tokio::select! {
453                    biased;
454
455                    Ok(cmd) = self.runner_command_rx.recv() => {
456                        match cmd {
457                            RunnerCommand::Disconnect => {
458                                // 🎯 MIDDLEWARE HOOK: on_disconnect - manual disconnect
459
460                                debug!(target: "Runner", "Disconnect command received.");
461
462                                // Execute middleware on_disconnect hook
463                                let middleware_context = MiddlewareContext::new(Arc::clone(&self.state), self.to_ws_sender.clone());
464                                self.router.middleware_stack.on_disconnect(&middleware_context).await;
465
466                                // Call connector's disconnect method to properly close the connection
467                                if let Err(e) = self.connector.disconnect().await {
468                                    warn!(target: "Runner", "Connector disconnect failed: {e}");
469                                }
470
471
472                                self.state.clear_temporal_data().await;
473                                self.is_hard_disconnect = true;
474                                if let Some(writer_task) = writer_task_opt.take() {
475                                    writer_task.abort();
476                                }
477                                if let Some(reader_task) = reader_task_opt.take() {
478                                    reader_task.abort();
479                                }
480                                self.signal.set_disconnected();
481                                session_active = false;
482                            },
483                            RunnerCommand::Shutdown => {
484                                // 🎯 MIDDLEWARE HOOK: on_disconnect - shutdown
485
486                                debug!(target: "Runner", "Shutdown command received.");
487
488                                // Execute middleware on_disconnect hook
489                                let middleware_context = MiddlewareContext::new(Arc::clone(&self.state), self.to_ws_sender.clone());
490                                self.router.middleware_stack.on_disconnect(&middleware_context).await;
491
492                                // Call connector's disconnect method to properly close the connection
493                                if let Err(e) = self.connector.disconnect().await {
494                                    warn!(target: "Runner", "Connector disconnect failed: {e}");
495                                }
496
497                                self.shutdown_requested = true;
498                                if let Some(writer_task) = writer_task_opt.take() {
499                                    writer_task.abort();
500                                }
501                                if let Some(reader_task) = reader_task_opt.take() {
502                                    reader_task.abort();
503                                }
504                                self.signal.set_disconnected();
505                                session_active = false;
506                            }
507                            _ => {}
508                        }
509                    },
510                    _ = async {
511                        if let Some(reader_task) = &mut reader_task_opt {
512                            let _ = reader_task.await;
513                        }
514                    } => {
515                        // 🎯 MIDDLEWARE HOOK: on_disconnect - unexpected connection loss
516                        warn!(target: "Runner", "Connection lost unexpectedly.");
517
518                        // Execute middleware on_disconnect hook
519                        let middleware_context = MiddlewareContext::new(Arc::clone(&self.state), self.to_ws_sender.clone());
520                        self.router.middleware_stack.on_disconnect(&middleware_context).await;
521
522                        if let Some(writer_task) = writer_task_opt.take() {
523                            writer_task.abort();
524                        }
525                        if let Some(reader_task) = reader_task_opt.take() {
526                            // Already finished, but abort for completeness
527                            reader_task.abort();
528                        }
529                        self.signal.set_disconnected();
530                        session_active = false;
531                        // panic!("Connection lost unexpectedly, exiting session loop. Duration: {:?}", temporal_timer.elapsed());
532                    }
533                }
534            }
535        }
536
537        debug!(target: "Runner", "Shutdown complete.");
538    }
539}
540
541// A proper builder would be used here to configure and create the Client and ClientRunner