binary_options_tools_core_pre/
builder.rs

1// src/builder.rs
2
3use kanal::{AsyncSender, bounded_async};
4use std::any::type_name;
5use std::any::{Any, TypeId};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10use tokio::task::JoinSet;
11use tokio_tungstenite::tungstenite::Message;
12use tracing::{error, info, warn};
13
14use crate::callback::{ConnectionCallback, ReconnectCallbackStack};
15use crate::client::{Client, ClientRunner, LightweightHandler, Router};
16use crate::connector::Connector;
17use crate::error::{CoreError, CoreResult};
18use crate::middleware::{MiddlewareStack, WebSocketMiddleware};
19use crate::signals::Signals;
20use crate::traits::{ApiModule, AppState, LightweightModule, ReconnectCallback};
21
22type HandlerMap = Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>;
23type HandlersFn<S> = Box<
24    dyn FnOnce(
25            &mut Router<S>,
26            &mut JoinSet<()>,
27            HandlerMap,
28            AsyncSender<Message>,
29            &mut ReconnectCallbackStack<S>,
30        ) + Send
31        + Sync,
32>;
33
34type LightweightHandlersFn<S> = Box<dyn FnOnce(&mut Router<S>, AsyncSender<Message>) + Send + Sync>;
35
36pub struct ClientBuilder<S: AppState> {
37    state: Arc<S>,
38    connector: Arc<dyn Connector<S>>,
39    connection_callback: ConnectionCallback<S>,
40    lightweight_handlers: Vec<LightweightHandler<S>>,
41    // Stores functions that know how to create and register each module.
42    module_factories: Vec<HandlersFn<S>>,
43    lightweight_factories: Vec<LightweightHandlersFn<S>>,
44    // Middleware stack for WebSocket message processing
45    middleware_stack: MiddlewareStack<S>,
46}
47
48impl<S: AppState> ClientBuilder<S> {
49    /// Creates a new builder with the essential components.
50    pub fn new(connector: impl Connector<S> + 'static, state: S) -> Self {
51        Self {
52            state: Arc::new(state),
53            connector: Arc::new(connector),
54            // Provide empty default callbacks.
55            connection_callback: ConnectionCallback {
56                on_connect: Box::new(|_, _| Box::pin(async { Ok(()) })),
57                on_reconnect: ReconnectCallbackStack::default(),
58            },
59            lightweight_handlers: Vec::new(),
60            module_factories: Vec::new(),
61            lightweight_factories: Vec::new(),
62            middleware_stack: MiddlewareStack::new(),
63        }
64    }
65
66    /// Sets the callback for the initial connection.
67    pub fn on_connect(
68        mut self,
69        callback: impl Fn(
70            Arc<S>,
71            &AsyncSender<Message>,
72        ) -> futures_util::future::BoxFuture<'static, CoreResult<()>>
73        + Send
74        + Sync
75        + 'static,
76    ) -> Self {
77        self.connection_callback.on_connect = Box::new(callback);
78        self
79    }
80
81    /// Sets the callback for subsequent reconnections.
82    pub fn on_reconnect(
83        mut self,
84        callback: Box<dyn ReconnectCallback<S> + Send + Sync + 'static>,
85    ) -> Self {
86        self.connection_callback.on_reconnect.add_layer(callback);
87        self
88    }
89
90    /// Adds a lightweight handler that receives all messages.
91    pub fn with_lightweight_handler(
92        mut self,
93        handler: impl Fn(
94            Arc<Message>,
95            Arc<S>,
96            &AsyncSender<Message>,
97        ) -> futures_util::future::BoxFuture<'static, CoreResult<()>>
98        + Send
99        + Sync
100        + 'static,
101    ) -> Self {
102        self.lightweight_handlers.push(Box::new(handler));
103        self
104    }
105
106    /// Registers a lightweight module
107    pub fn with_lightweight_module<M: LightweightModule<S>>(mut self) -> Self {
108        let factory = |router: &mut Router<S>, to_ws_tx: AsyncSender<Message>| {
109            let (msg_tx, msg_rx) = bounded_async(256);
110
111            let state = router.state.clone();
112            // Spawn the lightweight module task.
113            router.spawn_lightweight_module(async move {
114                let mut failures = 0;
115                // make the first timestamp far enough in the past
116                let mut last_fail = Instant::now().checked_sub(Duration::from_secs(3600)).unwrap_or(Instant::now());
117
118                loop {
119                    // create the module once
120                    let mut module = M::new(state.clone(), to_ws_tx.clone(), msg_rx.clone());
121                    match module.run().await {
122                        Ok(()) => {
123                            info!(target: "LightweightModule", "[Lightweight {}] exited cleanly", type_name::<M>());
124                            break;
125                        }
126                        Err(e) => {
127                            let now = Instant::now();
128                            if now.duration_since(last_fail) < Duration::from_secs(30) {
129                                failures += 1;
130                            } else {
131                                failures = 1;
132                            }
133                            last_fail = now;
134
135                            if failures >= 5 {
136                                error!(target: "LightweightModule",
137                                    "[Lightweight {}] failing {}× rapidly: {:?}, backing off 60s",
138                                    type_name::<M>(),
139                                    failures,
140                                    e
141                                );
142                                tokio::time::sleep(Duration::from_secs(60)).await;
143                            } else {
144                                warn!(target: "LightweightModule", "[Lightweight {}] error: {:?}", type_name::<M>(), e);
145                                tokio::time::sleep(Duration::from_secs(1)).await;
146                            }
147                        }
148                    }
149                }
150            });
151            router.add_lightweight_rule(M::rule(), msg_tx);
152        };
153
154        self.lightweight_factories.push(Box::new(factory));
155        self
156    }
157
158    /// Registers a full API module with the client.
159    pub fn with_module<M: ApiModule<S>>(mut self) -> Self {
160        let factory =
161            |router: &mut Router<S>,
162             join_set: &mut JoinSet<()>,
163             handles: Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
164             to_ws_tx: AsyncSender<Message>,
165             reconnect_callback_stack: &mut ReconnectCallbackStack<S>| {
166                let (cmd_tx, cmd_rx) = bounded_async(32);
167                let (cmd_ret_tx, cmd_ret_rx) = bounded_async(32);
168                let (msg_tx, msg_rx) = bounded_async(256);
169
170                let state = router.state.clone();
171                let handle = M::create_handle(cmd_tx, cmd_ret_rx);
172
173                // Must spawn this write to avoid blocking if called from an async context.
174                join_set.spawn(async move {
175                    handles
176                        .write()
177                        .await
178                        .insert(TypeId::of::<M>(), Box::new(handle));
179                });
180
181                let m_temp = M::new(
182                    state.clone(),
183                    cmd_rx.clone(),
184                    cmd_ret_tx.clone(),
185                    msg_rx.clone(),
186                    to_ws_tx.clone(),
187                );
188                match m_temp.callback() {
189                    Ok(Some(callback)) => {
190                        reconnect_callback_stack.add_layer(callback);
191                    }
192                    Ok(None) => {
193                        // No callback needed, continue.
194                    }
195                    Err(e) => {
196                        error!(target: "ApiModule", "Failed to get callback for module {}: {:?}", type_name::<M>(), e);
197                    }
198                }
199                let state_clone = state.clone();
200                router.spawn_module(async move {
201                let mut failures = 0;
202                let mut last_fail = Instant::now().checked_sub(Duration::from_secs(3600)).unwrap_or(Instant::now());
203                loop {
204                    let mut module = M::new(
205                        state.clone(),
206                        cmd_rx.clone(),
207                        cmd_ret_tx.clone(),
208                        msg_rx.clone(),
209                        to_ws_tx.clone(),
210                    );
211                    match module.run().await {
212                        Ok(_) => {
213                          info!(target: "ApiModule", "[Module {}] exited cleanly", type_name::<M>());
214                          break;
215                      },
216                        Err(e) => {
217                            let now = Instant::now();
218                            if now.duration_since(last_fail) < Duration::from_secs(30) {
219                                failures += 1;
220                            } else {
221                                failures = 1;
222                            }
223                            last_fail = now;
224
225                            let wait = if failures >= 5 {
226                                error!(target: "ApiModule", "Module [{}] failed too many times, check module integrity: {:?}", type_name::<M>(), e);
227                                60
228                            } else {
229                                warn!(target: "ApiModule", "[{}] err={:?}", type_name::<M>(), e);
230                                1
231                            };
232                            tokio::time::sleep(Duration::from_secs(wait)).await;
233                        }
234                    }
235                }
236            });
237
238                router.add_module_rule(M::rule(state_clone), msg_tx);
239            };
240
241        self.module_factories.push(Box::new(factory));
242        self
243    }
244
245    /// Adds a middleware layer to the client.
246    ///
247    /// Middleware will be executed in the order they are added.
248    /// They will be called for all WebSocket messages sent and received.
249    ///
250    /// # Example
251    /// ```rust,no_run
252    /// # use binary_options_tools_core_pre::builder::ClientBuilder;
253    /// # use binary_options_tools_core_pre::middleware::WebSocketMiddleware;
254    /// # use binary_options_tools_core_pre::traits::AppState;
255    /// # use binary_options_tools_core_pre::connector::{Connector, ConnectorResult, WsStream};
256    /// # use async_trait::async_trait;
257    /// # use std::sync::Arc;
258    /// # #[derive(Debug)]
259    /// # struct MyState;
260    /// # impl AppState for MyState {
261    /// #     fn clear_temporal_data(&self) {}
262    /// # }
263    /// # struct MyConnector;
264    /// # #[async_trait]
265    /// # impl Connector<MyState> for MyConnector {
266    /// #     async fn connect(&self, _state: Arc<MyState>) -> ConnectorResult<WsStream> {
267    /// #         unimplemented!()
268    /// #     }
269    /// #     async fn disconnect(&self) -> ConnectorResult<()> {
270    /// #         unimplemented!()
271    /// #     }
272    /// # }
273    /// # struct MyMiddleware;
274    /// # #[async_trait]
275    /// # impl WebSocketMiddleware<MyState> for MyMiddleware {}
276    /// let builder = ClientBuilder::new(MyConnector, MyState)
277    ///     .with_middleware(Box::new(MyMiddleware));
278    /// ```
279    pub fn with_middleware(mut self, middleware: Box<dyn WebSocketMiddleware<S>>) -> Self {
280        self.middleware_stack.add_layer(middleware);
281        self
282    }
283
284    /// Adds multiple middleware layers at once.
285    ///
286    /// This is a convenience method for adding multiple middleware layers.
287    ///
288    /// # Example
289    /// ```rust,no_run
290    /// # use binary_options_tools_core_pre::builder::ClientBuilder;
291    /// # use binary_options_tools_core_pre::middleware::WebSocketMiddleware;
292    /// # use binary_options_tools_core_pre::traits::AppState;
293    /// # use binary_options_tools_core_pre::connector::{Connector, ConnectorResult, WsStream};
294    /// # use async_trait::async_trait;
295    /// # use std::sync::Arc;
296    /// # #[derive(Debug)]
297    /// # struct MyState;
298    /// # impl AppState for MyState {
299    /// #     fn clear_temporal_data(&self) {}
300    /// # }
301    /// # struct MyConnector;
302    /// # #[async_trait]
303    /// # impl Connector<MyState> for MyConnector {
304    /// #     async fn connect(&self, _state: Arc<MyState>) -> ConnectorResult<WsStream> {
305    /// #         unimplemented!()
306    /// #     }
307    /// #     async fn disconnect(&self) -> ConnectorResult<()> {
308    /// #         unimplemented!()
309    /// #     }
310    /// # }
311    /// # struct MyMiddleware;
312    /// # #[async_trait]
313    /// # impl WebSocketMiddleware<MyState> for MyMiddleware {}
314    /// let builder = ClientBuilder::new(MyConnector, MyState)
315    ///     .with_middleware_layers(vec![
316    ///         Box::new(MyMiddleware),
317    ///         Box::new(MyMiddleware),
318    ///     ]);
319    /// ```
320    pub fn with_middleware_layers(
321        mut self,
322        middleware: Vec<Box<dyn WebSocketMiddleware<S>>>,
323    ) -> Self {
324        for layer in middleware {
325            self.middleware_stack.add_layer(layer);
326        }
327        self
328    }
329
330    /// Applies a middleware stack to the client.
331    ///
332    /// This replaces any existing middleware with the provided stack.
333    ///
334    /// # Example
335    /// ```rust,no_run
336    /// # use binary_options_tools_core_pre::builder::ClientBuilder;
337    /// # use binary_options_tools_core_pre::middleware::{MiddlewareStack, WebSocketMiddleware};
338    /// # use binary_options_tools_core_pre::traits::AppState;
339    /// # use binary_options_tools_core_pre::connector::{Connector, ConnectorResult, WsStream};
340    /// # use async_trait::async_trait;
341    /// # use std::sync::Arc;
342    /// # #[derive(Debug)]
343    /// # struct MyState;
344    /// # impl AppState for MyState {
345    /// #     fn clear_temporal_data(&self) {}
346    /// # }
347    /// # struct MyConnector;
348    /// # #[async_trait]
349    /// # impl Connector<MyState> for MyConnector {
350    /// #     async fn connect(&self, _state: Arc<MyState>) -> ConnectorResult<WsStream> {
351    /// #         unimplemented!()
352    /// #     }
353    /// #     async fn disconnect(&self) -> ConnectorResult<()> {
354    /// #         unimplemented!()
355    /// #     }
356    /// # }
357    /// # struct MyMiddleware;
358    /// # #[async_trait]
359    /// # impl WebSocketMiddleware<MyState> for MyMiddleware {}
360    /// let mut stack = MiddlewareStack::new();
361    /// stack.add_layer(Box::new(MyMiddleware));
362    ///
363    /// let builder = ClientBuilder::new(MyConnector, MyState)
364    ///     .with_middleware_stack(stack);
365    /// ```
366    pub fn with_middleware_stack(mut self, stack: MiddlewareStack<S>) -> Self {
367        self.middleware_stack = stack;
368        self
369    }
370
371    /// Assembles and returns the final `Client` handle and its `ClientRunner`.
372    pub async fn build(self) -> CoreResult<(Client<S>, ClientRunner<S>)> {
373        let (runner_cmd_tx, runner_cmd_rx) = bounded_async(8);
374        let (to_ws_tx, to_ws_rx) = bounded_async(256);
375        let signals = Signals::default();
376        let client = Client::new(
377            signals.clone(),
378            runner_cmd_tx,
379            self.state.clone(),
380            to_ws_tx.clone(),
381        );
382
383        let mut router = Router::new(self.state.clone());
384        router.lightweight_handlers = self.lightweight_handlers;
385        router.middleware_stack = self.middleware_stack;
386
387        let mut join_set = JoinSet::new();
388        // Execute all the deferred module setup functions.
389        let mut connection_callback = self.connection_callback;
390        for factory in self.module_factories {
391            factory(
392                &mut router,
393                &mut join_set,
394                client.module_handles.clone(),
395                to_ws_tx.clone(),
396                &mut connection_callback.on_reconnect,
397            );
398        }
399
400        for factory in self.lightweight_factories {
401            factory(&mut router, to_ws_tx.clone());
402        }
403
404        // Wait for all the handles to be added to the handles hashmap.
405        while let Some(h) = join_set.join_next().await {
406            match h {
407                Ok(_) => {} // Successfully added the module handle.
408                Err(e) => {
409                    error!("Failed to add module handle: {:?}", e);
410                    return Err(CoreError::from(e));
411                }
412            }
413        }
414
415        let runner = ClientRunner {
416            signal: signals,
417            connector: self.connector,
418            state: self.state,
419            router: Arc::new(router),
420            is_hard_disconnect: true,
421            shutdown_requested: false,
422            to_ws_sender: to_ws_tx,
423            to_ws_receiver: to_ws_rx,
424            runner_command_rx: runner_cmd_rx,
425            connection_callback,
426        };
427
428        Ok((client, runner))
429    }
430}
431
432// Add this test at the bottom of the file
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    fn assert_send_sync<T: Send + Sync>() {}
438
439    #[test]
440    fn test_client_builder_send_sync() {
441        // This will fail to compile if ClientBuilder is not Send + Sync
442        assert_send_sync::<ClientBuilder<()>>();
443    }
444}