1use std::ops::Deref;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_channel::{Receiver, RecvError};
6use futures_util::future::try_join3;
7use futures_util::stream::{select_all, SplitSink, SplitStream};
8use futures_util::{SinkExt, StreamExt};
9use tokio::net::TcpStream;
10use tokio::task::JoinHandle;
11use tokio::time::sleep;
12use tokio_tungstenite::tungstenite::Message;
13use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
14use tracing::{debug, error, info, warn};
15
16use crate::constants::MAX_CHANNEL_CAPACITY;
17use crate::error::{BinaryOptionsResult, BinaryOptionsToolsError};
18use crate::general::stream::RecieverStream;
19use crate::general::types::MessageType;
20
21use super::send::SenderMessage;
22use super::traits::{Callback, Connect, Credentials, DataHandler, MessageHandler, MessageTransfer};
23use super::types::Data;
24
25const MAX_ALLOWED_LOOPS: u32 = 8;
26const SLEEP_INTERVAL: u64 = 2;
27
28#[derive(Clone)]
29pub struct WebSocketClient<Transfer, Handler, Connector, Creds, T, C>
30where
31 Transfer: MessageTransfer,
32 Handler: MessageHandler,
33 Connector: Connect,
34 Creds: Credentials,
35 T: DataHandler,
36 C: Callback,
37{
38 inner: Arc<WebSocketInnerClient<Transfer, Handler, Connector, Creds, T, C>>,
39}
40
41pub struct WebSocketInnerClient<Transfer, Handler, Connector, Creds, T, C>
42where
43 Transfer: MessageTransfer,
44 Handler: MessageHandler,
45 Connector: Connect,
46 Creds: Credentials,
47 T: DataHandler,
48 C: Callback,
49{
50 pub credentials: Creds,
51 pub connector: Connector,
52 pub handler: Handler,
53 pub data: Data<T, Transfer>,
54 pub sender: SenderMessage,
55 pub reconnect_callback: Option<C>,
56 pub reconnect_time: u64,
57 _event_loop: JoinHandle<BinaryOptionsResult<()>>,
58}
59
60impl<Transfer, Handler, Connector, Creds, T, C> Deref
61 for WebSocketClient<Transfer, Handler, Connector, Creds, T, C>
62where
63 Transfer: MessageTransfer,
64 Handler: MessageHandler,
65 Connector: Connect,
66 Creds: Credentials,
67 T: DataHandler,
68 C: Callback,
69{
70 type Target = WebSocketInnerClient<Transfer, Handler, Connector, Creds, T, C>;
71
72 fn deref(&self) -> &Self::Target {
73 self.inner.as_ref()
74 }
75}
76
77impl<Transfer, Handler, Connector, Creds, T, C>
78 WebSocketClient<Transfer, Handler, Connector, Creds, T, C>
79where
80 Transfer: MessageTransfer + 'static,
81 Handler: MessageHandler<Transfer = Transfer> + 'static,
82 Creds: Credentials + 'static,
83 Connector: Connect<Creds = Creds> + 'static,
84 T: DataHandler<Transfer = Transfer> + 'static,
85 C: Callback<T = T, Transfer = Transfer> + 'static,
86{
87 pub async fn init(
88 credentials: Creds,
89 connector: Connector,
90 data: Data<T, Transfer>,
91 handler: Handler,
92 timeout: Duration,
93 reconnect_callback: Option<C>,
94 reconnect_time: Option<u64>,
95 ) -> BinaryOptionsResult<Self> {
96 let inner = WebSocketInnerClient::init(
97 credentials,
98 connector,
99 data,
100 handler,
101 timeout,
102 reconnect_callback,
103 reconnect_time.unwrap_or_default(),
104 )
105 .await?;
106 Ok(Self {
107 inner: Arc::new(inner),
108 })
109 }
110}
111
112impl<Transfer, Handler, Connector, Creds, T, C>
113 WebSocketInnerClient<Transfer, Handler, Connector, Creds, T, C>
114where
115 Transfer: MessageTransfer + 'static,
116 Handler: MessageHandler<Transfer = Transfer> + 'static,
117 Creds: Credentials + 'static,
118 Connector: Connect<Creds = Creds> + 'static,
119 T: DataHandler<Transfer = Transfer> + 'static,
120 C: Callback<T = T, Transfer = Transfer> + 'static,
121{
122 pub async fn init(
123 credentials: Creds,
124 connector: Connector,
125 data: Data<T, Transfer>,
126 handler: Handler,
127 timeout: Duration,
128 reconnect_callback: Option<C>,
129 reconnect_time: u64,
130 ) -> BinaryOptionsResult<Self> {
131 let _connection = connector.connect(credentials.clone()).await?;
132 let (_event_loop, sender) = Self::start_loops(
133 handler.clone(),
134 credentials.clone(),
135 data.clone(),
136 connector.clone(),
137 reconnect_callback.clone(),
138 reconnect_time,
139 )
140 .await?;
141 info!("Started WebSocketClient");
142 sleep(timeout).await;
143 Ok(Self {
144 credentials,
145 connector,
146 handler,
147 data,
148 sender,
149 reconnect_callback,
150 reconnect_time,
151 _event_loop,
152 })
153 }
154
155 async fn start_loops(
156 handler: Handler,
157 credentials: Creds,
158 data: Data<T, Transfer>,
159 connector: Connector,
160 reconnect_callback: Option<C>,
161 time: u64,
162 ) -> BinaryOptionsResult<(JoinHandle<BinaryOptionsResult<()>>, SenderMessage)> {
163 let (mut write, mut read) = connector.connect(credentials.clone()).await?.split();
164 let (sender, (reciever, reciever_priority)) = SenderMessage::new(MAX_CHANNEL_CAPACITY);
165 let loop_sender = sender.clone();
166 let task = tokio::task::spawn(async move {
167 let previous = None;
168 let mut loops = 0;
169 let mut reconnected = false;
170 loop {
171 let listener_future = WebSocketInnerClient::<
172 Transfer,
173 Handler,
174 Connector,
175 Creds,
176 T,
177 C,
178 >::listener_loop(
179 previous.clone(),
180 &data,
181 handler.clone(),
182 &loop_sender,
183 &mut read,
184 );
185 let sender_future =
186 WebSocketInnerClient::<Transfer, Handler, Connector, Creds, T, C>::sender_loop(
187 &mut write,
188 &reciever,
189 &reciever_priority,
190 time,
191 );
192
193 let callback = WebSocketInnerClient::<Transfer, Handler, Connector, Creds, T, C>::reconnect_callback(reconnect_callback.clone(), data.clone(), loop_sender.clone(), reconnected, time);
199
200 match try_join3(listener_future, sender_future, callback).await {
201 Ok(_) => {
202 if let Ok(websocket) = connector.connect(credentials.clone()).await {
203 (write, read) = websocket.split();
204 info!("Reconnected successfully!");
205 loops = 0;
206 reconnected = true;
207 } else {
208 loops += 1;
209 warn!("Error reconnecting... trying again in {SLEEP_INTERVAL} seconds (try {loops} of {MAX_ALLOWED_LOOPS}");
210 sleep(Duration::from_secs(SLEEP_INTERVAL)).await;
211 if loops >= MAX_ALLOWED_LOOPS {
212 panic!("Too many failed connections");
213 }
214 }
215 }
216 Err(e) => {
217 warn!("Error in event loop, {e}, reconnecting...");
218 if let Ok(websocket) = connector.connect(credentials.clone()).await {
220 (write, read) = websocket.split();
221 info!("Reconnected successfully!");
222 loops = 0;
224 reconnected = true;
225 } else {
226 loops += 1;
227 warn!("Error reconnecting... trying again in {SLEEP_INTERVAL} seconds (try {loops} of {MAX_ALLOWED_LOOPS}");
228 sleep(Duration::from_secs(SLEEP_INTERVAL)).await;
229 if loops >= MAX_ALLOWED_LOOPS {
230 error!("Too many failed connections");
231 break;
232 }
233 }
234 }
235 }
236 }
237 Ok(())
238 });
239 Ok((task, sender))
240 }
241
242 async fn listener_loop(
244 mut previous: Option<<<Handler as MessageHandler>::Transfer as MessageTransfer>::Info>,
245 data: &Data<T, Transfer>,
246 handler: Handler,
247 sender: &SenderMessage,
248 ws: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
249 ) -> BinaryOptionsResult<()> {
250 while let Some(msg) = &ws.next().await {
251 let msg = msg
252 .as_ref()
253 .inspect_err(|e| warn!("Error recieving websocket message, {e}"))
254 .map_err(|e| {
255 BinaryOptionsToolsError::WebsocketRecievingConnectionError(e.to_string())
256 })?;
257 match handler.process_message(msg, &previous, sender).await {
258 Ok((msg, close)) => {
259 if close {
260 info!("Recieved closing frame");
261 return Err(BinaryOptionsToolsError::WebsocketConnectionClosed(
262 "Recieved closing frame".into(),
263 ));
264 }
265 if let Some(msg) = msg {
266 match msg {
267 MessageType::Info(info) => {
268 debug!("Recieved info: {}", info);
269 previous = Some(info);
270 }
271 MessageType::Transfer(transfer) => {
272 debug!("Recieved data of type: {}", transfer.info());
273 if let Some(senders) = data.update_data(transfer.clone()).await? {
274 for sender in senders {
275 sender.send(transfer.clone()).await.map_err(|e| {
276 BinaryOptionsToolsError::ChannelRequestSendingError(
277 e.to_string(),
278 )
279 })?;
280 }
281 }
282 }
283 }
284 }
285 }
286 Err(e) => {
287 debug!("Error processing message, {e}");
288 }
289 }
290 }
291 todo!()
292 }
293
294 async fn sender_loop(
296 ws: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
297 reciever: &Receiver<Message>,
298 reciever_priority: &Receiver<Message>,
299 time: u64,
300 ) -> BinaryOptionsResult<()> {
301 async fn priority_mesages(
302 ws: &mut SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
303 reciever_priority: &Receiver<Message>,
304 ) -> BinaryOptionsResult<()> {
305 while let Ok(msg) = reciever_priority.recv().await {
306 ws.send(msg)
307 .await
308 .inspect_err(|e| warn!("Error sending message to websocket, {e}"))?;
309 ws.flush().await?;
310 debug!("Sent message to websocket!");
311 }
312 Err(BinaryOptionsToolsError::ChannelRequestRecievingError(
313 RecvError,
314 ))
315 }
316
317 tokio::select! {
318 res = priority_mesages(ws, reciever_priority) => res?,
319 _ = sleep(Duration::from_secs(time)) => {}
320 }
321 let stream1 = RecieverStream::new(reciever.to_owned());
322 let stream2 = RecieverStream::new(reciever_priority.to_owned());
323 let mut fused_streams = select_all([stream1.to_stream(), stream2.to_stream()]);
324
325 while let Some(Ok(msg)) = fused_streams.next().await {
326 ws.send(msg)
327 .await
328 .inspect_err(|e| warn!("Error sending message to websocket, {e}"))?;
329 ws.flush().await?;
330 debug!("Sent message to websocket!");
331 }
332 Err(BinaryOptionsToolsError::ChannelRequestRecievingError(
333 RecvError,
334 ))
335 }
336
337 async fn reconnect_callback(
348 reconnect_callback: Option<C>,
349 data: Data<T, Transfer>,
350 sender: SenderMessage,
351 reconnect: bool,
352 reconnect_time: u64,
353 ) -> BinaryOptionsResult<BinaryOptionsResult<()>> {
354 Ok(tokio::spawn(async move {
355 sleep(Duration::from_secs(reconnect_time)).await;
356 if reconnect {
357 if let Some(callback) = &reconnect_callback {
358 callback.call(data.clone(), &sender).await.inspect_err(
359 |e| error!(target: "EventLoop","Error calling callback, {e}"),
360 )?;
361 }
362 }
363 Ok(())
364 })
365 .await?)
366 }
367 pub async fn send_message(
368 &self,
369 msg: Transfer,
370 response_type: Transfer::Info,
371 validator: impl Fn(&Transfer) -> bool + Send + Sync,
372 ) -> BinaryOptionsResult<Transfer> {
373 self.sender
374 .send_message(&self.data, msg, response_type, validator)
375 .await
376 }
377
378 pub async fn send_message_with_timout(
379 &self,
380 timeout: Duration,
381 task: impl ToString,
382 msg: Transfer,
383 response_type: Transfer::Info,
384 validator: impl Fn(&Transfer) -> bool + Send + Sync,
385 ) -> BinaryOptionsResult<Transfer> {
386 self.sender
387 .send_message_with_timout(timeout, task, &self.data, msg, response_type, validator)
388 .await
389 }
390
391 pub async fn send_message_with_timeout_and_retry(
392 &self,
393 timeout: Duration,
394 task: impl ToString,
395 msg: Transfer,
396 response_type: Transfer::Info,
397 validator: impl Fn(&Transfer) -> bool + Send + Sync,
398 ) -> BinaryOptionsResult<Transfer> {
399 self.sender
400 .send_message_with_timeout_and_retry(
401 timeout,
402 task,
403 &self.data,
404 msg,
405 response_type,
406 validator,
407 )
408 .await
409 }
410}
411
412#[cfg(test)]
429mod tests {
430 use std::time::Duration;
431
432 use async_channel::{bounded, Receiver, Sender};
433 use futures_util::{
434 future::try_join,
435 stream::{select_all, unfold},
436 Stream, StreamExt,
437 };
438 use rand::{distributions::Alphanumeric, Rng};
439 use tokio::time::sleep;
440 use tracing::info;
441
442 use crate::utils::tracing::start_tracing;
443
444 struct RecieverStream<T> {
445 inner: Receiver<T>,
446 }
447
448 impl<T> RecieverStream<T> {
449 fn new(inner: Receiver<T>) -> Self {
450 Self { inner }
451 }
452
453 async fn receive(&self) -> anyhow::Result<T> {
454 Ok(self.inner.recv().await?)
455 }
456
457 fn to_stream(&self) -> impl Stream<Item = anyhow::Result<T>> + '_ {
458 Box::pin(unfold(self, |state| async move {
459 let item = state.receive().await;
460 Some((item, state))
461 }))
462 }
463 }
464
465 async fn recieve_dif(
466 reciever: Receiver<String>,
467 receiver_priority: Receiver<String>,
468 ) -> anyhow::Result<()> {
469 async fn receiv(r: &Receiver<String>) -> anyhow::Result<()> {
470 while let Ok(t) = r.recv().await {
471 info!(target: "High priority", "Recieved: {}", t);
472 }
473 Ok(())
474 }
475 tokio::select! {
476 err = receiv(&receiver_priority) => err?,
477 _ = tokio::time::sleep(Duration::from_secs(5)) => {}
478 }
479 let receiver = RecieverStream::new(reciever);
480 let receiver_priority = RecieverStream::new(receiver_priority);
481 let mut fused = select_all([receiver.to_stream(), receiver_priority.to_stream()]);
482 while let Some(value) = fused.next().await {
483 info!(target: "Fused", "Recieved: {}", value?);
484 }
485
486 Ok(())
487 }
488
489 async fn recieve_dif_err(
490 reciever: Receiver<String>,
491 receiver_priority: Receiver<String>,
492 ) -> anyhow::Result<()> {
493 async fn receiv(r: &Receiver<String>) -> anyhow::Result<()> {
494 let mut loops = 0;
495 while let Ok(t) = r.recv().await {
496 if loops == 2 {
497 return Err(anyhow::Error::msg("error receiving message"));
498 }
499 loops += 1;
500 info!(target: "High priority", "Recieved: {}", t);
501 }
502 Ok(())
503 }
504 tokio::select! {
505 err = receiv(&receiver_priority) => err?,
506 _ = tokio::time::sleep(Duration::from_secs(5)) => {}
507 }
508 let receiver = RecieverStream::new(reciever);
509 let receiver_priority = RecieverStream::new(receiver_priority);
510 let mut fused = select_all([receiver.to_stream(), receiver_priority.to_stream()]);
511 while let Some(value) = fused.next().await {
512 info!(target: "Fused", "Recieved: {}", value?);
513 }
514
515 Ok(())
516 }
517
518 async fn sender_dif(
519 sender: Sender<String>,
520 sender_priority: Sender<String>,
521 ) -> anyhow::Result<()> {
522 loop {
523 let s1: String = rand::thread_rng()
524 .sample_iter(&Alphanumeric)
525 .take(7)
526 .map(char::from)
527 .collect();
528 let s2: String = rand::thread_rng()
529 .sample_iter(&Alphanumeric)
530 .take(7)
531 .map(char::from)
532 .collect();
533 sender.send(s1).await?;
534 sender_priority.send(s2).await?;
535 sleep(Duration::from_secs(1)).await;
536 }
537 }
538
539 #[tokio::test]
540 async fn test_multi_priority_reciever_ok() -> anyhow::Result<()> {
541 start_tracing(true)?;
542 let (s, r) = bounded(8);
543 let (sp, rp) = bounded(8);
544 try_join(sender_dif(s, sp), recieve_dif(r, rp)).await?;
545 Ok(())
546 }
547
548 #[tokio::test]
549 #[should_panic(expected = "error receiving message")]
550 async fn test_multi_priority_reciever_err() {
551 start_tracing(true).unwrap();
552 let (s, r) = bounded(8);
553 let (sp, rp) = bounded(8);
554 try_join(sender_dif(s, sp), recieve_dif_err(r, rp))
555 .await
556 .unwrap();
557 }
558}