price_adapter/services/
websocket.rs1use crate::{
2 error::Error,
3 types::{PriceInfo, Service, Source, WebSocketSource, WebsocketMessage},
4};
5use std::{collections::HashMap, sync::Arc};
6use tokio::{select, sync::Mutex};
7use tokio_util::sync::CancellationToken;
8
9pub struct WebsocketService<S: WebSocketSource> {
11 socket: Arc<Mutex<S>>,
12 cached_prices: Arc<Mutex<HashMap<String, PriceInfo>>>,
13 cancellation_token: Option<CancellationToken>,
14}
15
16impl<S: WebSocketSource> WebsocketService<S> {
17 pub fn new(socket: S) -> Self {
19 Self {
20 socket: Arc::new(Mutex::new(socket)),
21 cached_prices: Arc::new(Mutex::new(HashMap::new())),
22 cancellation_token: None,
23 }
24 }
25}
26
27#[async_trait::async_trait]
28impl<S: WebSocketSource> Service for WebsocketService<S> {
29 async fn start(&mut self, symbols: &[&str]) -> Result<(), Error> {
31 if self.is_started().await {
32 return Err(Error::AlreadyStarted);
33 }
34
35 let mut locked_socket = self.socket.lock().await;
36 if !locked_socket.is_connected().await {
37 locked_socket.connect().await?;
38 locked_socket.subscribe(symbols).await?;
39 }
40 drop(locked_socket);
41
42 let token = CancellationToken::new();
43 let cloned_token = token.clone();
44 let cloned_socket = Arc::clone(&self.socket);
45 let cloned_cached_prices = Arc::clone(&self.cached_prices);
46 self.cancellation_token = Some(token);
47
48 tokio::spawn(async move {
49 loop {
50 let mut locked_socket = cloned_socket.lock().await;
51 select! {
52 _ = cloned_token.cancelled() => {
53 break;
54 }
55
56 result = locked_socket.next() => {
57 drop(locked_socket);
58
59 match result {
60 Some(Ok(WebsocketMessage::PriceInfo(price_info))) => {
61 let mut locked_cached_prices = cloned_cached_prices.lock().await;
62 locked_cached_prices.insert(price_info.symbol.to_string(), price_info);
63 }
64 Some(Ok(WebsocketMessage::SettingResponse(_response))) => {}
65 Some(Err(err)) => {
66 tracing::trace!("cannot get price: {}", err);
67 }
68 None => {
69 tracing::trace!("cannot get price: stream ended");
70 break;
71 }
72 }
73 }
74 }
75 }
76 });
77
78 Ok(())
79 }
80
81 async fn stop(&mut self) {
83 if let Some(token) = &self.cancellation_token {
84 token.cancel();
85 }
86 self.cancellation_token = None;
87 }
88
89 async fn is_started(&self) -> bool {
91 self.cancellation_token.is_some()
92 }
93}
94
95#[async_trait::async_trait]
96impl<S: WebSocketSource> Source for WebsocketService<S> {
97 async fn get_prices(&self, symbols: &[&str]) -> Vec<Result<PriceInfo, Error>> {
99 let locked_cached_prices = self.cached_prices.lock().await;
100 symbols
101 .iter()
102 .map(|&symbol| {
103 locked_cached_prices
104 .get(&symbol.to_ascii_uppercase())
105 .map_or_else(
106 || Err(Error::NotFound(symbol.to_string())),
107 |price| Ok(price.clone()),
108 )
109 })
110 .collect()
111 }
112
113 async fn get_price(&self, symbol: &str) -> Result<PriceInfo, Error> {
115 self.get_prices(&[symbol])
116 .await
117 .pop()
118 .ok_or(Error::Unknown)?
119 }
120}