Skip to main content

binary_options_tools_core_pre/
builder.rs

1// src/builder.rs
2
3use kanal::{bounded_async, AsyncSender};
4use std::any::type_name;
5use std::any::{Any, TypeId};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10use tokio::task::JoinSet;
11use tokio_tungstenite::tungstenite::Message;
12use tracing::{error, info, warn};
13
14use crate::callback::{ConnectionCallback, ReconnectCallbackStack};
15use crate::client::{Client, ClientRunner, LightweightHandler, Router};
16use crate::connector::Connector;
17use crate::error::{CoreError, CoreResult};
18use crate::middleware::{MiddlewareStack, WebSocketMiddleware};
19use crate::signals::Signals;
20use crate::traits::{ApiModule, AppState, LightweightModule, ReconnectCallback, RunnerCommand};
21
22type HandlerMap = Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>;
23type HandlersFn<S> = Box<
24    dyn FnOnce(
25            &mut Router<S>,
26            &mut JoinSet<()>,
27            HandlerMap,
28            AsyncSender<Message>,
29            AsyncSender<RunnerCommand>,
30            &mut ReconnectCallbackStack<S>,
31        ) + Send
32        + Sync,
33>;
34
35type LightweightHandlersFn<S> =
36    Box<dyn FnOnce(&mut Router<S>, AsyncSender<Message>, AsyncSender<RunnerCommand>) + Send + Sync>;
37
38pub struct ClientBuilder<S: AppState> {
39    state: Arc<S>,
40    connector: Arc<dyn Connector<S>>,
41    connection_callback: ConnectionCallback<S>,
42    lightweight_handlers: Vec<LightweightHandler<S>>,
43    // Stores functions that know how to create and register each module.
44    module_factories: Vec<HandlersFn<S>>,
45    lightweight_factories: Vec<LightweightHandlersFn<S>>,
46    // Middleware stack for WebSocket message processing
47    middleware_stack: MiddlewareStack<S>,
48
49    max_allowed_loops: u32,
50    reconnect_delay: Duration,
51}
52
53impl<S: AppState> ClientBuilder<S> {
54    /// Creates a new builder with the essential components.
55    pub fn new(connector: impl Connector<S> + 'static, state: S) -> Self {
56        Self {
57            state: Arc::new(state),
58            connector: Arc::new(connector),
59            // Provide empty default callbacks.
60            connection_callback: ConnectionCallback {
61                on_connect: Box::new(|_, _| Box::pin(async { Ok(()) })),
62                on_reconnect: ReconnectCallbackStack::default(),
63            },
64            lightweight_handlers: Vec::new(),
65            module_factories: Vec::new(),
66            lightweight_factories: Vec::new(),
67            middleware_stack: MiddlewareStack::new(),
68            max_allowed_loops: 0,
69            reconnect_delay: Duration::from_secs(5),
70        }
71    }
72
73    /// Sets the callback for the initial connection.
74    pub fn on_connect(
75        mut self,
76        callback: impl Fn(
77                Arc<S>,
78                &AsyncSender<Message>,
79            ) -> futures_util::future::BoxFuture<'static, CoreResult<()>>
80            + Send
81            + Sync
82            + 'static,
83    ) -> Self {
84        self.connection_callback.on_connect = Box::new(callback);
85        self
86    }
87
88    /// Sets the callback for subsequent reconnections.
89    pub fn on_reconnect(
90        mut self,
91        callback: Box<dyn ReconnectCallback<S> + Send + Sync + 'static>,
92    ) -> Self {
93        self.connection_callback.on_reconnect.add_layer(callback);
94        self
95    }
96
97    /// Adds a lightweight handler that receives all messages.
98    pub fn with_lightweight_handler(
99        mut self,
100        handler: impl Fn(
101                Arc<Message>,
102                Arc<S>,
103                &AsyncSender<Message>,
104            ) -> futures_util::future::BoxFuture<'static, CoreResult<()>>
105            + Send
106            + Sync
107            + 'static,
108    ) -> Self {
109        self.lightweight_handlers.push(Box::new(handler));
110        self
111    }
112
113    /// Registers a lightweight module
114    pub fn with_lightweight_module<M: LightweightModule<S>>(mut self) -> Self {
115        let factory = |router: &mut Router<S>,
116                       to_ws_tx: AsyncSender<Message>,
117                       runner_tx: AsyncSender<RunnerCommand>| {
118            let (msg_tx, msg_rx) = bounded_async(256);
119
120            let state = router.state.clone();
121            // Spawn the lightweight module task.
122            router.spawn_lightweight_module(async move {
123                    let mut failures = 0;
124                    // make the first timestamp far enough in the past
125                    let mut last_fail = Instant::now()
126                        .checked_sub(Duration::from_secs(3600))
127                        .unwrap_or(Instant::now());
128
129                    loop {
130                        // create the module once
131                        let mut module = M::new(
132                            state.clone(),
133                            to_ws_tx.clone(),
134                            msg_rx.clone(),
135                            runner_tx.clone(),
136                        );
137                        match module.run().await {
138                            Ok(()) => {
139                                info!(target: "LightweightModule", "[Lightweight {}] exited cleanly", type_name::<M>());
140                                break;
141                            }
142                            Err(e) => {
143                                let now = Instant::now();
144                                if now.duration_since(last_fail) < Duration::from_secs(30) {
145                                    failures += 1;
146                                } else {
147                                    failures = 1;
148                                }
149                                last_fail = now;
150
151                                if failures >= 5 {
152                                    error!(target: "LightweightModule",
153                                    "[Lightweight {}] failing {}× rapidly: {:?}, backing off 60s",
154                                    type_name::<M>(),
155                                    failures,
156                                    e
157                                );
158                                    tokio::time::sleep(Duration::from_secs(60)).await;
159                                } else {
160                                    warn!(target: "LightweightModule", "[Lightweight {}] error: {:?}", type_name::<M>(), e);
161                                    tokio::time::sleep(Duration::from_secs(1)).await;
162                                }
163                            }
164                        }
165                    }
166                });
167            router.add_lightweight_rule(M::rule(), msg_tx);
168        };
169
170        self.lightweight_factories.push(Box::new(factory));
171        self
172    }
173
174    /// Registers a full API module with the client.
175    pub fn with_module<M: ApiModule<S>>(mut self) -> Self {
176        let factory =
177            |router: &mut Router<S>,
178             join_set: &mut JoinSet<()>,
179             handles: Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
180             to_ws_tx: AsyncSender<Message>,
181             runner_tx: AsyncSender<RunnerCommand>,
182             reconnect_callback_stack: &mut ReconnectCallbackStack<S>| {
183                let (cmd_tx, cmd_rx) = bounded_async(32);
184                let (cmd_ret_tx, cmd_ret_rx) = bounded_async(32);
185                let (msg_tx, msg_rx) = bounded_async(256);
186
187                let state = router.state.clone();
188                let handle = M::create_handle(cmd_tx, cmd_ret_rx);
189
190                // Must spawn this write to avoid blocking if called from an async context.
191                join_set.spawn(async move {
192                    handles
193                        .write()
194                        .await
195                        .insert(TypeId::of::<M>(), Box::new(handle));
196                });
197
198                match M::callback(
199                    state.clone(),
200                    cmd_rx.clone(),
201                    cmd_ret_tx.clone(),
202                    msg_rx.clone(),
203                    to_ws_tx.clone(),
204                ) {
205                    Ok(Some(callback)) => {
206                        reconnect_callback_stack.add_layer(callback);
207                    }
208                    Ok(None) => {
209                        // No callback needed, continue.
210                    }
211                    Err(e) => {
212                        error!(target: "ApiModule", "Failed to get callback for module {}: {:?}", type_name::<M>(), e);
213                    }
214                }
215                let state_clone = state.clone();
216                router.spawn_module(async move {
217                let mut failures = 0;
218                let mut last_fail = Instant::now()
219                    .checked_sub(Duration::from_secs(3600))
220                    .unwrap_or(Instant::now());
221                loop {
222                    let mut module = M::new(
223                        state.clone(),
224                        cmd_rx.clone(),
225                        cmd_ret_tx.clone(),
226                        msg_rx.clone(),
227                        to_ws_tx.clone(),
228                        runner_tx.clone(),
229                    );
230                    match module.run().await {
231                        Ok(_) => {
232                            info!(target: "ApiModule", "[Module {}] exited cleanly", type_name::<M>());
233                            break;
234                        }
235                        Err(e) => {
236                            let now = Instant::now();
237                            if now.duration_since(last_fail) < Duration::from_secs(30) {
238                                failures += 1;
239                            } else {
240                                failures = 1;
241                            }
242                            last_fail = now;
243
244                            let wait = if failures >= 5 {
245                                error!(target: "ApiModule", "Module [{}] failed too many times, check module integrity: {:?}", type_name::<M>(), e);
246                                60
247                            } else {
248                                warn!(target: "ApiModule", "[{}] err={:?}", type_name::<M>(), e);
249                                1
250                            };
251                            tokio::time::sleep(Duration::from_secs(wait)).await;
252                        }
253                    }
254                }
255            });
256
257                router.add_module_rule(M::rule(state_clone), msg_tx);
258            };
259
260        self.module_factories.push(Box::new(factory));
261        self
262    }
263
264    /// Adds a middleware layer to the client.
265    ///
266    /// Middleware will be executed in the order they are added.
267    /// They will be called for all WebSocket messages sent and received.
268    ///
269    /// # Example
270    /// ```rust,no_run
271    /// # use binary_options_tools_core_pre::builder::ClientBuilder;
272    /// # use binary_options_tools_core_pre::middleware::WebSocketMiddleware;
273    /// # use binary_options_tools_core_pre::traits::AppState;
274    /// # use binary_options_tools_core_pre::connector::{Connector, ConnectorResult, WsStream};
275    /// # use async_trait::async_trait;
276    /// # use std::sync::Arc;
277    /// # #[derive(Debug)]
278    /// # struct MyState;
279    /// # impl AppState for MyState {
280    /// #     fn clear_temporal_data(&self) {}
281    /// # }
282    /// # struct MyConnector;
283    /// # #[async_trait]
284    /// # impl Connector<MyState> for MyConnector {
285    /// #     async fn connect(&self, _state: Arc<MyState>) -> ConnectorResult<WsStream> {
286    /// #         unimplemented!()
287    /// #     }
288    /// #     async fn disconnect(&self) -> ConnectorResult<()> {
289    /// #         unimplemented!()
290    /// #     }
291    /// # }
292    /// # struct MyMiddleware;
293    /// # #[async_trait]
294    /// # impl WebSocketMiddleware<MyState> for MyMiddleware {}
295    /// let builder = ClientBuilder::new(MyConnector, MyState)
296    ///     .with_middleware(Box::new(MyMiddleware));
297    /// ```
298    pub fn with_middleware(mut self, middleware: Box<dyn WebSocketMiddleware<S>>) -> Self {
299        self.middleware_stack.add_layer(middleware);
300        self
301    }
302
303    /// Adds multiple middleware layers at once.
304    ///
305    /// This is a convenience method for adding multiple middleware layers.
306    ///
307    /// # Example
308    /// ```rust,no_run
309    /// # use binary_options_tools_core_pre::builder::ClientBuilder;
310    /// # use binary_options_tools_core_pre::middleware::WebSocketMiddleware;
311    /// # use binary_options_tools_core_pre::traits::AppState;
312    /// # use binary_options_tools_core_pre::connector::{Connector, ConnectorResult, WsStream};
313    /// # use async_trait::async_trait;
314    /// # use std::sync::Arc;
315    /// # #[derive(Debug)]
316    /// # struct MyState;
317    /// # impl AppState for MyState {
318    /// #     fn clear_temporal_data(&self) {}
319    /// # }
320    /// # struct MyConnector;
321    /// # #[async_trait]
322    /// # impl Connector<MyState> for MyConnector {
323    /// #     async fn connect(&self, _state: Arc<MyState>) -> ConnectorResult<WsStream> {
324    /// #         unimplemented!()
325    /// #     }
326    /// #     async fn disconnect(&self) -> ConnectorResult<()> {
327    /// #         unimplemented!()
328    /// #     }
329    /// # }
330    /// # struct MyMiddleware;
331    /// # #[async_trait]
332    /// # impl WebSocketMiddleware<MyState> for MyMiddleware {}
333    /// let builder = ClientBuilder::new(MyConnector, MyState)
334    ///     .with_middleware_layers(vec![
335    ///         Box::new(MyMiddleware),
336    ///         Box::new(MyMiddleware),
337    ///     ]);
338    /// ```
339    pub fn with_middleware_layers(
340        mut self,
341        middleware: Vec<Box<dyn WebSocketMiddleware<S>>>,
342    ) -> Self {
343        for layer in middleware {
344            self.middleware_stack.add_layer(layer);
345        }
346        self
347    }
348
349    /// Applies a middleware stack to the client.
350    ///
351    /// This replaces any existing middleware with the provided stack.
352    ///
353    /// # Example
354    /// ```rust,no_run
355    /// # use binary_options_tools_core_pre::builder::ClientBuilder;
356    /// # use binary_options_tools_core_pre::middleware::{MiddlewareStack, WebSocketMiddleware};
357    /// # use binary_options_tools_core_pre::traits::AppState;
358    /// # use binary_options_tools_core_pre::connector::{Connector, ConnectorResult, WsStream};
359    /// # use async_trait::async_trait;
360    /// # use std::sync::Arc;
361    /// # #[derive(Debug)]
362    /// # struct MyState;
363    /// # impl AppState for MyState {
364    /// #     fn clear_temporal_data(&self) {}
365    /// # }
366    /// # struct MyConnector;
367    /// # #[async_trait]
368    /// # impl Connector<MyState> for MyConnector {
369    /// #     async fn connect(&self, _state: Arc<MyState>) -> ConnectorResult<WsStream> {
370    /// #         unimplemented!()
371    /// #     }
372    /// #     async fn disconnect(&self) -> ConnectorResult<()> {
373    /// #         unimplemented!()
374    /// #     }
375    /// # }
376    /// # struct MyMiddleware;
377    /// # #[async_trait]
378    /// # impl WebSocketMiddleware<MyState> for MyMiddleware {}
379    /// let mut stack = MiddlewareStack::new();
380    /// stack.add_layer(Box::new(MyMiddleware));
381    ///
382    /// let builder = ClientBuilder::new(MyConnector, MyState)
383    ///     .with_middleware_stack(stack);
384    /// ```
385    pub fn with_middleware_stack(mut self, stack: MiddlewareStack<S>) -> Self {
386        self.middleware_stack = stack;
387        self
388    }
389
390    /// Sets the maximum number of reconnection attempts.
391    /// 0 means infinite attempts.
392    pub fn with_max_allowed_loops(mut self, max_allowed_loops: u32) -> Self {
393        self.max_allowed_loops = max_allowed_loops;
394        self
395    }
396
397    /// Sets the base delay for reconnection attempts.
398    pub fn with_reconnect_delay(mut self, reconnect_delay: Duration) -> Self {
399        self.reconnect_delay = reconnect_delay;
400        self
401    }
402
403    /// Assembles and returns the final `Client` handle and its `ClientRunner`.
404    pub async fn build(self) -> CoreResult<(Client<S>, ClientRunner<S>)> {
405        let (runner_cmd_tx, runner_cmd_rx) = bounded_async(8);
406        let (to_ws_tx, to_ws_rx) = bounded_async(256);
407        let signals = Signals::default();
408        let client = Client::new(
409            signals.clone(),
410            runner_cmd_tx.clone(),
411            self.state.clone(),
412            to_ws_tx.clone(),
413        );
414
415        let mut router = Router::new(self.state.clone());
416        router.lightweight_handlers = self.lightweight_handlers;
417        router.middleware_stack = self.middleware_stack;
418
419        let mut join_set = JoinSet::new();
420        // Execute all the deferred module setup functions.
421        let mut connection_callback = self.connection_callback;
422        for factory in self.module_factories {
423            factory(
424                &mut router,
425                &mut join_set,
426                client.module_handles.clone(),
427                to_ws_tx.clone(),
428                runner_cmd_tx.clone(),
429                &mut connection_callback.on_reconnect,
430            );
431        }
432
433        for factory in self.lightweight_factories {
434            factory(&mut router, to_ws_tx.clone(), runner_cmd_tx.clone());
435        }
436
437        // Wait for all the handles to be added to the handles hashmap.
438        while let Some(h) = join_set.join_next().await {
439            match h {
440                Ok(_) => {} // Successfully added the module handle.
441                Err(e) => {
442                    error!("Failed to add module handle: {:?}", e);
443                    return Err(CoreError::from(e));
444                }
445            }
446        }
447
448        let runner = ClientRunner {
449            signal: signals,
450            connector: self.connector,
451            state: self.state,
452            router: Arc::new(router),
453            is_hard_disconnect: true,
454            shutdown_requested: false,
455            to_ws_sender: to_ws_tx,
456            to_ws_receiver: to_ws_rx,
457            runner_command_rx: runner_cmd_rx,
458            connection_callback,
459            reconnect_attempts: 0,
460            max_allowed_loops: self.max_allowed_loops,
461            reconnect_delay: self.reconnect_delay,
462        };
463
464        Ok((client, runner))
465    }
466}
467
468// Add this test at the bottom of the file
469#[cfg(test)]
470mod tests {
471    use super::*;
472
473    fn assert_send_sync<T: Send + Sync>() {}
474
475    #[test]
476    fn test_client_builder_send_sync() {
477        // This will fail to compile if ClientBuilder is not Send + Sync
478        assert_send_sync::<ClientBuilder<()>>();
479    }
480}