binary_options_tools_core_pre/builder.rs
1// src/builder.rs
2
3use kanal::{bounded_async, AsyncSender};
4use std::any::type_name;
5use std::any::{Any, TypeId};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10use tokio::task::JoinSet;
11use tokio_tungstenite::tungstenite::Message;
12use tracing::{error, info, warn};
13
14use crate::callback::{ConnectionCallback, ReconnectCallbackStack};
15use crate::client::{Client, ClientRunner, LightweightHandler, Router};
16use crate::connector::Connector;
17use crate::error::{CoreError, CoreResult};
18use crate::middleware::{MiddlewareStack, WebSocketMiddleware};
19use crate::signals::Signals;
20use crate::traits::{ApiModule, AppState, LightweightModule, ReconnectCallback, RunnerCommand};
21
22type HandlerMap = Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>;
23type HandlersFn<S> = Box<
24 dyn FnOnce(
25 &mut Router<S>,
26 &mut JoinSet<()>,
27 HandlerMap,
28 AsyncSender<Message>,
29 AsyncSender<RunnerCommand>,
30 &mut ReconnectCallbackStack<S>,
31 ) + Send
32 + Sync,
33>;
34
35type LightweightHandlersFn<S> =
36 Box<dyn FnOnce(&mut Router<S>, AsyncSender<Message>, AsyncSender<RunnerCommand>) + Send + Sync>;
37
38pub struct ClientBuilder<S: AppState> {
39 state: Arc<S>,
40 connector: Arc<dyn Connector<S>>,
41 connection_callback: ConnectionCallback<S>,
42 lightweight_handlers: Vec<LightweightHandler<S>>,
43 // Stores functions that know how to create and register each module.
44 module_factories: Vec<HandlersFn<S>>,
45 lightweight_factories: Vec<LightweightHandlersFn<S>>,
46 // Middleware stack for WebSocket message processing
47 middleware_stack: MiddlewareStack<S>,
48
49 max_allowed_loops: u32,
50 reconnect_delay: Duration,
51}
52
53impl<S: AppState> ClientBuilder<S> {
54 /// Creates a new builder with the essential components.
55 pub fn new(connector: impl Connector<S> + 'static, state: S) -> Self {
56 Self {
57 state: Arc::new(state),
58 connector: Arc::new(connector),
59 // Provide empty default callbacks.
60 connection_callback: ConnectionCallback {
61 on_connect: Box::new(|_, _| Box::pin(async { Ok(()) })),
62 on_reconnect: ReconnectCallbackStack::default(),
63 },
64 lightweight_handlers: Vec::new(),
65 module_factories: Vec::new(),
66 lightweight_factories: Vec::new(),
67 middleware_stack: MiddlewareStack::new(),
68 max_allowed_loops: 0,
69 reconnect_delay: Duration::from_secs(5),
70 }
71 }
72
73 /// Sets the callback for the initial connection.
74 pub fn on_connect(
75 mut self,
76 callback: impl Fn(
77 Arc<S>,
78 &AsyncSender<Message>,
79 ) -> futures_util::future::BoxFuture<'static, CoreResult<()>>
80 + Send
81 + Sync
82 + 'static,
83 ) -> Self {
84 self.connection_callback.on_connect = Box::new(callback);
85 self
86 }
87
88 /// Sets the callback for subsequent reconnections.
89 pub fn on_reconnect(
90 mut self,
91 callback: Box<dyn ReconnectCallback<S> + Send + Sync + 'static>,
92 ) -> Self {
93 self.connection_callback.on_reconnect.add_layer(callback);
94 self
95 }
96
97 /// Adds a lightweight handler that receives all messages.
98 pub fn with_lightweight_handler(
99 mut self,
100 handler: impl Fn(
101 Arc<Message>,
102 Arc<S>,
103 &AsyncSender<Message>,
104 ) -> futures_util::future::BoxFuture<'static, CoreResult<()>>
105 + Send
106 + Sync
107 + 'static,
108 ) -> Self {
109 self.lightweight_handlers.push(Box::new(handler));
110 self
111 }
112
113 /// Registers a lightweight module
114 pub fn with_lightweight_module<M: LightweightModule<S>>(mut self) -> Self {
115 let factory = |router: &mut Router<S>,
116 to_ws_tx: AsyncSender<Message>,
117 runner_tx: AsyncSender<RunnerCommand>| {
118 let (msg_tx, msg_rx) = bounded_async(256);
119
120 let state = router.state.clone();
121 // Spawn the lightweight module task.
122 router.spawn_lightweight_module(async move {
123 let mut failures = 0;
124 // make the first timestamp far enough in the past
125 let mut last_fail = Instant::now()
126 .checked_sub(Duration::from_secs(3600))
127 .unwrap_or(Instant::now());
128
129 loop {
130 // create the module once
131 let mut module = M::new(
132 state.clone(),
133 to_ws_tx.clone(),
134 msg_rx.clone(),
135 runner_tx.clone(),
136 );
137 match module.run().await {
138 Ok(()) => {
139 info!(target: "LightweightModule", "[Lightweight {}] exited cleanly", type_name::<M>());
140 break;
141 }
142 Err(e) => {
143 let now = Instant::now();
144 if now.duration_since(last_fail) < Duration::from_secs(30) {
145 failures += 1;
146 } else {
147 failures = 1;
148 }
149 last_fail = now;
150
151 if failures >= 5 {
152 error!(target: "LightweightModule",
153 "[Lightweight {}] failing {}× rapidly: {:?}, backing off 60s",
154 type_name::<M>(),
155 failures,
156 e
157 );
158 tokio::time::sleep(Duration::from_secs(60)).await;
159 } else {
160 warn!(target: "LightweightModule", "[Lightweight {}] error: {:?}", type_name::<M>(), e);
161 tokio::time::sleep(Duration::from_secs(1)).await;
162 }
163 }
164 }
165 }
166 });
167 router.add_lightweight_rule(M::rule(), msg_tx);
168 };
169
170 self.lightweight_factories.push(Box::new(factory));
171 self
172 }
173
174 /// Registers a full API module with the client.
175 pub fn with_module<M: ApiModule<S>>(mut self) -> Self {
176 let factory =
177 |router: &mut Router<S>,
178 join_set: &mut JoinSet<()>,
179 handles: Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
180 to_ws_tx: AsyncSender<Message>,
181 runner_tx: AsyncSender<RunnerCommand>,
182 reconnect_callback_stack: &mut ReconnectCallbackStack<S>| {
183 let (cmd_tx, cmd_rx) = bounded_async(32);
184 let (cmd_ret_tx, cmd_ret_rx) = bounded_async(32);
185 let (msg_tx, msg_rx) = bounded_async(256);
186
187 let state = router.state.clone();
188 let handle = M::create_handle(cmd_tx, cmd_ret_rx);
189
190 // Must spawn this write to avoid blocking if called from an async context.
191 join_set.spawn(async move {
192 handles
193 .write()
194 .await
195 .insert(TypeId::of::<M>(), Box::new(handle));
196 });
197
198 match M::callback(
199 state.clone(),
200 cmd_rx.clone(),
201 cmd_ret_tx.clone(),
202 msg_rx.clone(),
203 to_ws_tx.clone(),
204 ) {
205 Ok(Some(callback)) => {
206 reconnect_callback_stack.add_layer(callback);
207 }
208 Ok(None) => {
209 // No callback needed, continue.
210 }
211 Err(e) => {
212 error!(target: "ApiModule", "Failed to get callback for module {}: {:?}", type_name::<M>(), e);
213 }
214 }
215 let state_clone = state.clone();
216 router.spawn_module(async move {
217 let mut failures = 0;
218 let mut last_fail = Instant::now()
219 .checked_sub(Duration::from_secs(3600))
220 .unwrap_or(Instant::now());
221 loop {
222 let mut module = M::new(
223 state.clone(),
224 cmd_rx.clone(),
225 cmd_ret_tx.clone(),
226 msg_rx.clone(),
227 to_ws_tx.clone(),
228 runner_tx.clone(),
229 );
230 match module.run().await {
231 Ok(_) => {
232 info!(target: "ApiModule", "[Module {}] exited cleanly", type_name::<M>());
233 break;
234 }
235 Err(e) => {
236 let now = Instant::now();
237 if now.duration_since(last_fail) < Duration::from_secs(30) {
238 failures += 1;
239 } else {
240 failures = 1;
241 }
242 last_fail = now;
243
244 let wait = if failures >= 5 {
245 error!(target: "ApiModule", "Module [{}] failed too many times, check module integrity: {:?}", type_name::<M>(), e);
246 60
247 } else {
248 warn!(target: "ApiModule", "[{}] err={:?}", type_name::<M>(), e);
249 1
250 };
251 tokio::time::sleep(Duration::from_secs(wait)).await;
252 }
253 }
254 }
255 });
256
257 router.add_module_rule(M::rule(state_clone), msg_tx);
258 };
259
260 self.module_factories.push(Box::new(factory));
261 self
262 }
263
264 /// Adds a middleware layer to the client.
265 ///
266 /// Middleware will be executed in the order they are added.
267 /// They will be called for all WebSocket messages sent and received.
268 ///
269 /// # Example
270 /// ```rust,no_run
271 /// # use binary_options_tools_core_pre::builder::ClientBuilder;
272 /// # use binary_options_tools_core_pre::middleware::WebSocketMiddleware;
273 /// # use binary_options_tools_core_pre::traits::AppState;
274 /// # use binary_options_tools_core_pre::connector::{Connector, ConnectorResult, WsStream};
275 /// # use async_trait::async_trait;
276 /// # use std::sync::Arc;
277 /// # #[derive(Debug)]
278 /// # struct MyState;
279 /// # impl AppState for MyState {
280 /// # fn clear_temporal_data(&self) {}
281 /// # }
282 /// # struct MyConnector;
283 /// # #[async_trait]
284 /// # impl Connector<MyState> for MyConnector {
285 /// # async fn connect(&self, _state: Arc<MyState>) -> ConnectorResult<WsStream> {
286 /// # unimplemented!()
287 /// # }
288 /// # async fn disconnect(&self) -> ConnectorResult<()> {
289 /// # unimplemented!()
290 /// # }
291 /// # }
292 /// # struct MyMiddleware;
293 /// # #[async_trait]
294 /// # impl WebSocketMiddleware<MyState> for MyMiddleware {}
295 /// let builder = ClientBuilder::new(MyConnector, MyState)
296 /// .with_middleware(Box::new(MyMiddleware));
297 /// ```
298 pub fn with_middleware(mut self, middleware: Box<dyn WebSocketMiddleware<S>>) -> Self {
299 self.middleware_stack.add_layer(middleware);
300 self
301 }
302
303 /// Adds multiple middleware layers at once.
304 ///
305 /// This is a convenience method for adding multiple middleware layers.
306 ///
307 /// # Example
308 /// ```rust,no_run
309 /// # use binary_options_tools_core_pre::builder::ClientBuilder;
310 /// # use binary_options_tools_core_pre::middleware::WebSocketMiddleware;
311 /// # use binary_options_tools_core_pre::traits::AppState;
312 /// # use binary_options_tools_core_pre::connector::{Connector, ConnectorResult, WsStream};
313 /// # use async_trait::async_trait;
314 /// # use std::sync::Arc;
315 /// # #[derive(Debug)]
316 /// # struct MyState;
317 /// # impl AppState for MyState {
318 /// # fn clear_temporal_data(&self) {}
319 /// # }
320 /// # struct MyConnector;
321 /// # #[async_trait]
322 /// # impl Connector<MyState> for MyConnector {
323 /// # async fn connect(&self, _state: Arc<MyState>) -> ConnectorResult<WsStream> {
324 /// # unimplemented!()
325 /// # }
326 /// # async fn disconnect(&self) -> ConnectorResult<()> {
327 /// # unimplemented!()
328 /// # }
329 /// # }
330 /// # struct MyMiddleware;
331 /// # #[async_trait]
332 /// # impl WebSocketMiddleware<MyState> for MyMiddleware {}
333 /// let builder = ClientBuilder::new(MyConnector, MyState)
334 /// .with_middleware_layers(vec![
335 /// Box::new(MyMiddleware),
336 /// Box::new(MyMiddleware),
337 /// ]);
338 /// ```
339 pub fn with_middleware_layers(
340 mut self,
341 middleware: Vec<Box<dyn WebSocketMiddleware<S>>>,
342 ) -> Self {
343 for layer in middleware {
344 self.middleware_stack.add_layer(layer);
345 }
346 self
347 }
348
349 /// Applies a middleware stack to the client.
350 ///
351 /// This replaces any existing middleware with the provided stack.
352 ///
353 /// # Example
354 /// ```rust,no_run
355 /// # use binary_options_tools_core_pre::builder::ClientBuilder;
356 /// # use binary_options_tools_core_pre::middleware::{MiddlewareStack, WebSocketMiddleware};
357 /// # use binary_options_tools_core_pre::traits::AppState;
358 /// # use binary_options_tools_core_pre::connector::{Connector, ConnectorResult, WsStream};
359 /// # use async_trait::async_trait;
360 /// # use std::sync::Arc;
361 /// # #[derive(Debug)]
362 /// # struct MyState;
363 /// # impl AppState for MyState {
364 /// # fn clear_temporal_data(&self) {}
365 /// # }
366 /// # struct MyConnector;
367 /// # #[async_trait]
368 /// # impl Connector<MyState> for MyConnector {
369 /// # async fn connect(&self, _state: Arc<MyState>) -> ConnectorResult<WsStream> {
370 /// # unimplemented!()
371 /// # }
372 /// # async fn disconnect(&self) -> ConnectorResult<()> {
373 /// # unimplemented!()
374 /// # }
375 /// # }
376 /// # struct MyMiddleware;
377 /// # #[async_trait]
378 /// # impl WebSocketMiddleware<MyState> for MyMiddleware {}
379 /// let mut stack = MiddlewareStack::new();
380 /// stack.add_layer(Box::new(MyMiddleware));
381 ///
382 /// let builder = ClientBuilder::new(MyConnector, MyState)
383 /// .with_middleware_stack(stack);
384 /// ```
385 pub fn with_middleware_stack(mut self, stack: MiddlewareStack<S>) -> Self {
386 self.middleware_stack = stack;
387 self
388 }
389
390 /// Sets the maximum number of reconnection attempts.
391 /// 0 means infinite attempts.
392 pub fn with_max_allowed_loops(mut self, max_allowed_loops: u32) -> Self {
393 self.max_allowed_loops = max_allowed_loops;
394 self
395 }
396
397 /// Sets the base delay for reconnection attempts.
398 pub fn with_reconnect_delay(mut self, reconnect_delay: Duration) -> Self {
399 self.reconnect_delay = reconnect_delay;
400 self
401 }
402
403 /// Assembles and returns the final `Client` handle and its `ClientRunner`.
404 pub async fn build(self) -> CoreResult<(Client<S>, ClientRunner<S>)> {
405 let (runner_cmd_tx, runner_cmd_rx) = bounded_async(8);
406 let (to_ws_tx, to_ws_rx) = bounded_async(256);
407 let signals = Signals::default();
408 let client = Client::new(
409 signals.clone(),
410 runner_cmd_tx.clone(),
411 self.state.clone(),
412 to_ws_tx.clone(),
413 );
414
415 let mut router = Router::new(self.state.clone());
416 router.lightweight_handlers = self.lightweight_handlers;
417 router.middleware_stack = self.middleware_stack;
418
419 let mut join_set = JoinSet::new();
420 // Execute all the deferred module setup functions.
421 let mut connection_callback = self.connection_callback;
422 for factory in self.module_factories {
423 factory(
424 &mut router,
425 &mut join_set,
426 client.module_handles.clone(),
427 to_ws_tx.clone(),
428 runner_cmd_tx.clone(),
429 &mut connection_callback.on_reconnect,
430 );
431 }
432
433 for factory in self.lightweight_factories {
434 factory(&mut router, to_ws_tx.clone(), runner_cmd_tx.clone());
435 }
436
437 // Wait for all the handles to be added to the handles hashmap.
438 while let Some(h) = join_set.join_next().await {
439 match h {
440 Ok(_) => {} // Successfully added the module handle.
441 Err(e) => {
442 error!("Failed to add module handle: {:?}", e);
443 return Err(CoreError::from(e));
444 }
445 }
446 }
447
448 let runner = ClientRunner {
449 signal: signals,
450 connector: self.connector,
451 state: self.state,
452 router: Arc::new(router),
453 is_hard_disconnect: true,
454 shutdown_requested: false,
455 to_ws_sender: to_ws_tx,
456 to_ws_receiver: to_ws_rx,
457 runner_command_rx: runner_cmd_rx,
458 connection_callback,
459 reconnect_attempts: 0,
460 max_allowed_loops: self.max_allowed_loops,
461 reconnect_delay: self.reconnect_delay,
462 };
463
464 Ok((client, runner))
465 }
466}
467
468// Add this test at the bottom of the file
469#[cfg(test)]
470mod tests {
471 use super::*;
472
473 fn assert_send_sync<T: Send + Sync>() {}
474
475 #[test]
476 fn test_client_builder_send_sync() {
477 // This will fail to compile if ClientBuilder is not Send + Sync
478 assert_send_sync::<ClientBuilder<()>>();
479 }
480}