binary_options_tools_core_pre/builder.rs
1// src/builder.rs
2
3use kanal::{AsyncSender, bounded_async};
4use std::any::type_name;
5use std::any::{Any, TypeId};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10use tokio::task::JoinSet;
11use tokio_tungstenite::tungstenite::Message;
12use tracing::{error, info, warn};
13
14use crate::callback::{ConnectionCallback, ReconnectCallbackStack};
15use crate::client::{Client, ClientRunner, LightweightHandler, Router};
16use crate::connector::Connector;
17use crate::error::{CoreError, CoreResult};
18use crate::middleware::{MiddlewareStack, WebSocketMiddleware};
19use crate::signals::Signals;
20use crate::traits::{ApiModule, AppState, LightweightModule, ReconnectCallback};
21
22type HandlerMap = Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>;
23type HandlersFn<S> = Box<
24 dyn FnOnce(
25 &mut Router<S>,
26 &mut JoinSet<()>,
27 HandlerMap,
28 AsyncSender<Message>,
29 &mut ReconnectCallbackStack<S>,
30 ) + Send
31 + Sync,
32>;
33
34type LightweightHandlersFn<S> = Box<dyn FnOnce(&mut Router<S>, AsyncSender<Message>) + Send + Sync>;
35
36pub struct ClientBuilder<S: AppState> {
37 state: Arc<S>,
38 connector: Arc<dyn Connector<S>>,
39 connection_callback: ConnectionCallback<S>,
40 lightweight_handlers: Vec<LightweightHandler<S>>,
41 // Stores functions that know how to create and register each module.
42 module_factories: Vec<HandlersFn<S>>,
43 lightweight_factories: Vec<LightweightHandlersFn<S>>,
44 // Middleware stack for WebSocket message processing
45 middleware_stack: MiddlewareStack<S>,
46}
47
48impl<S: AppState> ClientBuilder<S> {
49 /// Creates a new builder with the essential components.
50 pub fn new(connector: impl Connector<S> + 'static, state: S) -> Self {
51 Self {
52 state: Arc::new(state),
53 connector: Arc::new(connector),
54 // Provide empty default callbacks.
55 connection_callback: ConnectionCallback {
56 on_connect: Box::new(|_, _| Box::pin(async { Ok(()) })),
57 on_reconnect: ReconnectCallbackStack::default(),
58 },
59 lightweight_handlers: Vec::new(),
60 module_factories: Vec::new(),
61 lightweight_factories: Vec::new(),
62 middleware_stack: MiddlewareStack::new(),
63 }
64 }
65
66 /// Sets the callback for the initial connection.
67 pub fn on_connect(
68 mut self,
69 callback: impl Fn(
70 Arc<S>,
71 &AsyncSender<Message>,
72 ) -> futures_util::future::BoxFuture<'static, CoreResult<()>>
73 + Send
74 + Sync
75 + 'static,
76 ) -> Self {
77 self.connection_callback.on_connect = Box::new(callback);
78 self
79 }
80
81 /// Sets the callback for subsequent reconnections.
82 pub fn on_reconnect(
83 mut self,
84 callback: Box<dyn ReconnectCallback<S> + Send + Sync + 'static>,
85 ) -> Self {
86 self.connection_callback.on_reconnect.add_layer(callback);
87 self
88 }
89
90 /// Adds a lightweight handler that receives all messages.
91 pub fn with_lightweight_handler(
92 mut self,
93 handler: impl Fn(
94 Arc<Message>,
95 Arc<S>,
96 &AsyncSender<Message>,
97 ) -> futures_util::future::BoxFuture<'static, CoreResult<()>>
98 + Send
99 + Sync
100 + 'static,
101 ) -> Self {
102 self.lightweight_handlers.push(Box::new(handler));
103 self
104 }
105
106 /// Registers a lightweight module
107 pub fn with_lightweight_module<M: LightweightModule<S>>(mut self) -> Self {
108 let factory = |router: &mut Router<S>, to_ws_tx: AsyncSender<Message>| {
109 let (msg_tx, msg_rx) = bounded_async(256);
110
111 let state = router.state.clone();
112 // Spawn the lightweight module task.
113 router.spawn_lightweight_module(async move {
114 let mut failures = 0;
115 // make the first timestamp far enough in the past
116 let mut last_fail = Instant::now().checked_sub(Duration::from_secs(3600)).unwrap_or(Instant::now());
117
118 loop {
119 // create the module once
120 let mut module = M::new(state.clone(), to_ws_tx.clone(), msg_rx.clone());
121 match module.run().await {
122 Ok(()) => {
123 info!(target: "LightweightModule", "[Lightweight {}] exited cleanly", type_name::<M>());
124 break;
125 }
126 Err(e) => {
127 let now = Instant::now();
128 if now.duration_since(last_fail) < Duration::from_secs(30) {
129 failures += 1;
130 } else {
131 failures = 1;
132 }
133 last_fail = now;
134
135 if failures >= 5 {
136 error!(target: "LightweightModule",
137 "[Lightweight {}] failing {}× rapidly: {:?}, backing off 60s",
138 type_name::<M>(),
139 failures,
140 e
141 );
142 tokio::time::sleep(Duration::from_secs(60)).await;
143 } else {
144 warn!(target: "LightweightModule", "[Lightweight {}] error: {:?}", type_name::<M>(), e);
145 tokio::time::sleep(Duration::from_secs(1)).await;
146 }
147 }
148 }
149 }
150 });
151 router.add_lightweight_rule(M::rule(), msg_tx);
152 };
153
154 self.lightweight_factories.push(Box::new(factory));
155 self
156 }
157
158 /// Registers a full API module with the client.
159 pub fn with_module<M: ApiModule<S>>(mut self) -> Self {
160 let factory =
161 |router: &mut Router<S>,
162 join_set: &mut JoinSet<()>,
163 handles: Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>,
164 to_ws_tx: AsyncSender<Message>,
165 reconnect_callback_stack: &mut ReconnectCallbackStack<S>| {
166 let (cmd_tx, cmd_rx) = bounded_async(32);
167 let (cmd_ret_tx, cmd_ret_rx) = bounded_async(32);
168 let (msg_tx, msg_rx) = bounded_async(256);
169
170 let state = router.state.clone();
171 let handle = M::create_handle(cmd_tx, cmd_ret_rx);
172
173 // Must spawn this write to avoid blocking if called from an async context.
174 join_set.spawn(async move {
175 handles
176 .write()
177 .await
178 .insert(TypeId::of::<M>(), Box::new(handle));
179 });
180
181 let m_temp = M::new(
182 state.clone(),
183 cmd_rx.clone(),
184 cmd_ret_tx.clone(),
185 msg_rx.clone(),
186 to_ws_tx.clone(),
187 );
188 match m_temp.callback() {
189 Ok(Some(callback)) => {
190 reconnect_callback_stack.add_layer(callback);
191 }
192 Ok(None) => {
193 // No callback needed, continue.
194 }
195 Err(e) => {
196 error!(target: "ApiModule", "Failed to get callback for module {}: {:?}", type_name::<M>(), e);
197 }
198 }
199 let state_clone = state.clone();
200 router.spawn_module(async move {
201 let mut failures = 0;
202 let mut last_fail = Instant::now().checked_sub(Duration::from_secs(3600)).unwrap_or(Instant::now());
203 loop {
204 let mut module = M::new(
205 state.clone(),
206 cmd_rx.clone(),
207 cmd_ret_tx.clone(),
208 msg_rx.clone(),
209 to_ws_tx.clone(),
210 );
211 match module.run().await {
212 Ok(_) => {
213 info!(target: "ApiModule", "[Module {}] exited cleanly", type_name::<M>());
214 break;
215 },
216 Err(e) => {
217 let now = Instant::now();
218 if now.duration_since(last_fail) < Duration::from_secs(30) {
219 failures += 1;
220 } else {
221 failures = 1;
222 }
223 last_fail = now;
224
225 let wait = if failures >= 5 {
226 error!(target: "ApiModule", "Module [{}] failed too many times, check module integrity: {:?}", type_name::<M>(), e);
227 60
228 } else {
229 warn!(target: "ApiModule", "[{}] err={:?}", type_name::<M>(), e);
230 1
231 };
232 tokio::time::sleep(Duration::from_secs(wait)).await;
233 }
234 }
235 }
236 });
237
238 router.add_module_rule(M::rule(state_clone), msg_tx);
239 };
240
241 self.module_factories.push(Box::new(factory));
242 self
243 }
244
245 /// Adds a middleware layer to the client.
246 ///
247 /// Middleware will be executed in the order they are added.
248 /// They will be called for all WebSocket messages sent and received.
249 ///
250 /// # Example
251 /// ```rust,no_run
252 /// # use binary_options_tools_core_pre::builder::ClientBuilder;
253 /// # use binary_options_tools_core_pre::middleware::WebSocketMiddleware;
254 /// # use binary_options_tools_core_pre::traits::AppState;
255 /// # use binary_options_tools_core_pre::connector::{Connector, ConnectorResult, WsStream};
256 /// # use async_trait::async_trait;
257 /// # use std::sync::Arc;
258 /// # #[derive(Debug)]
259 /// # struct MyState;
260 /// # impl AppState for MyState {
261 /// # fn clear_temporal_data(&self) {}
262 /// # }
263 /// # struct MyConnector;
264 /// # #[async_trait]
265 /// # impl Connector<MyState> for MyConnector {
266 /// # async fn connect(&self, _state: Arc<MyState>) -> ConnectorResult<WsStream> {
267 /// # unimplemented!()
268 /// # }
269 /// # async fn disconnect(&self) -> ConnectorResult<()> {
270 /// # unimplemented!()
271 /// # }
272 /// # }
273 /// # struct MyMiddleware;
274 /// # #[async_trait]
275 /// # impl WebSocketMiddleware<MyState> for MyMiddleware {}
276 /// let builder = ClientBuilder::new(MyConnector, MyState)
277 /// .with_middleware(Box::new(MyMiddleware));
278 /// ```
279 pub fn with_middleware(mut self, middleware: Box<dyn WebSocketMiddleware<S>>) -> Self {
280 self.middleware_stack.add_layer(middleware);
281 self
282 }
283
284 /// Adds multiple middleware layers at once.
285 ///
286 /// This is a convenience method for adding multiple middleware layers.
287 ///
288 /// # Example
289 /// ```rust,no_run
290 /// # use binary_options_tools_core_pre::builder::ClientBuilder;
291 /// # use binary_options_tools_core_pre::middleware::WebSocketMiddleware;
292 /// # use binary_options_tools_core_pre::traits::AppState;
293 /// # use binary_options_tools_core_pre::connector::{Connector, ConnectorResult, WsStream};
294 /// # use async_trait::async_trait;
295 /// # use std::sync::Arc;
296 /// # #[derive(Debug)]
297 /// # struct MyState;
298 /// # impl AppState for MyState {
299 /// # fn clear_temporal_data(&self) {}
300 /// # }
301 /// # struct MyConnector;
302 /// # #[async_trait]
303 /// # impl Connector<MyState> for MyConnector {
304 /// # async fn connect(&self, _state: Arc<MyState>) -> ConnectorResult<WsStream> {
305 /// # unimplemented!()
306 /// # }
307 /// # async fn disconnect(&self) -> ConnectorResult<()> {
308 /// # unimplemented!()
309 /// # }
310 /// # }
311 /// # struct MyMiddleware;
312 /// # #[async_trait]
313 /// # impl WebSocketMiddleware<MyState> for MyMiddleware {}
314 /// let builder = ClientBuilder::new(MyConnector, MyState)
315 /// .with_middleware_layers(vec![
316 /// Box::new(MyMiddleware),
317 /// Box::new(MyMiddleware),
318 /// ]);
319 /// ```
320 pub fn with_middleware_layers(
321 mut self,
322 middleware: Vec<Box<dyn WebSocketMiddleware<S>>>,
323 ) -> Self {
324 for layer in middleware {
325 self.middleware_stack.add_layer(layer);
326 }
327 self
328 }
329
330 /// Applies a middleware stack to the client.
331 ///
332 /// This replaces any existing middleware with the provided stack.
333 ///
334 /// # Example
335 /// ```rust,no_run
336 /// # use binary_options_tools_core_pre::builder::ClientBuilder;
337 /// # use binary_options_tools_core_pre::middleware::{MiddlewareStack, WebSocketMiddleware};
338 /// # use binary_options_tools_core_pre::traits::AppState;
339 /// # use binary_options_tools_core_pre::connector::{Connector, ConnectorResult, WsStream};
340 /// # use async_trait::async_trait;
341 /// # use std::sync::Arc;
342 /// # #[derive(Debug)]
343 /// # struct MyState;
344 /// # impl AppState for MyState {
345 /// # fn clear_temporal_data(&self) {}
346 /// # }
347 /// # struct MyConnector;
348 /// # #[async_trait]
349 /// # impl Connector<MyState> for MyConnector {
350 /// # async fn connect(&self, _state: Arc<MyState>) -> ConnectorResult<WsStream> {
351 /// # unimplemented!()
352 /// # }
353 /// # async fn disconnect(&self) -> ConnectorResult<()> {
354 /// # unimplemented!()
355 /// # }
356 /// # }
357 /// # struct MyMiddleware;
358 /// # #[async_trait]
359 /// # impl WebSocketMiddleware<MyState> for MyMiddleware {}
360 /// let mut stack = MiddlewareStack::new();
361 /// stack.add_layer(Box::new(MyMiddleware));
362 ///
363 /// let builder = ClientBuilder::new(MyConnector, MyState)
364 /// .with_middleware_stack(stack);
365 /// ```
366 pub fn with_middleware_stack(mut self, stack: MiddlewareStack<S>) -> Self {
367 self.middleware_stack = stack;
368 self
369 }
370
371 /// Assembles and returns the final `Client` handle and its `ClientRunner`.
372 pub async fn build(self) -> CoreResult<(Client<S>, ClientRunner<S>)> {
373 let (runner_cmd_tx, runner_cmd_rx) = bounded_async(8);
374 let (to_ws_tx, to_ws_rx) = bounded_async(256);
375 let signals = Signals::default();
376 let client = Client::new(
377 signals.clone(),
378 runner_cmd_tx,
379 self.state.clone(),
380 to_ws_tx.clone(),
381 );
382
383 let mut router = Router::new(self.state.clone());
384 router.lightweight_handlers = self.lightweight_handlers;
385 router.middleware_stack = self.middleware_stack;
386
387 let mut join_set = JoinSet::new();
388 // Execute all the deferred module setup functions.
389 let mut connection_callback = self.connection_callback;
390 for factory in self.module_factories {
391 factory(
392 &mut router,
393 &mut join_set,
394 client.module_handles.clone(),
395 to_ws_tx.clone(),
396 &mut connection_callback.on_reconnect,
397 );
398 }
399
400 for factory in self.lightweight_factories {
401 factory(&mut router, to_ws_tx.clone());
402 }
403
404 // Wait for all the handles to be added to the handles hashmap.
405 while let Some(h) = join_set.join_next().await {
406 match h {
407 Ok(_) => {} // Successfully added the module handle.
408 Err(e) => {
409 error!("Failed to add module handle: {:?}", e);
410 return Err(CoreError::from(e));
411 }
412 }
413 }
414
415 let runner = ClientRunner {
416 signal: signals,
417 connector: self.connector,
418 state: self.state,
419 router: Arc::new(router),
420 is_hard_disconnect: true,
421 shutdown_requested: false,
422 to_ws_sender: to_ws_tx,
423 to_ws_receiver: to_ws_rx,
424 runner_command_rx: runner_cmd_rx,
425 connection_callback,
426 };
427
428 Ok((client, runner))
429 }
430}
431
432// Add this test at the bottom of the file
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 fn assert_send_sync<T: Send + Sync>() {}
438
439 #[test]
440 fn test_client_builder_send_sync() {
441 // This will fail to compile if ClientBuilder is not Send + Sync
442 assert_send_sync::<ClientBuilder<()>>();
443 }
444}