price_adapter/services/
websocket.rs

1use crate::{
2    error::Error,
3    types::{PriceInfo, Service, Source, WebSocketSource, WebsocketMessage},
4};
5use std::{collections::HashMap, sync::Arc};
6use tokio::{select, sync::Mutex};
7use tokio_util::sync::CancellationToken;
8
9/// A caching object storing prices received from WebSocketSource.
10pub struct WebsocketService<S: WebSocketSource> {
11    socket: Arc<Mutex<S>>,
12    cached_prices: Arc<Mutex<HashMap<String, PriceInfo>>>,
13    cancellation_token: Option<CancellationToken>,
14}
15
16impl<S: WebSocketSource> WebsocketService<S> {
17    /// Creates a new `WebsocketService` with the provided WebSocketSource.
18    pub fn new(socket: S) -> Self {
19        Self {
20            socket: Arc::new(Mutex::new(socket)),
21            cached_prices: Arc::new(Mutex::new(HashMap::new())),
22            cancellation_token: None,
23        }
24    }
25}
26
27#[async_trait::async_trait]
28impl<S: WebSocketSource> Service for WebsocketService<S> {
29    /// Starts the service, connecting to the WebSocket and subscribing to symbols.
30    async fn start(&mut self, symbols: &[&str]) -> Result<(), Error> {
31        if self.is_started().await {
32            return Err(Error::AlreadyStarted);
33        }
34
35        let mut locked_socket = self.socket.lock().await;
36        if !locked_socket.is_connected().await {
37            locked_socket.connect().await?;
38            locked_socket.subscribe(symbols).await?;
39        }
40        drop(locked_socket);
41
42        let token = CancellationToken::new();
43        let cloned_token = token.clone();
44        let cloned_socket = Arc::clone(&self.socket);
45        let cloned_cached_prices = Arc::clone(&self.cached_prices);
46        self.cancellation_token = Some(token);
47
48        tokio::spawn(async move {
49            loop {
50                let mut locked_socket = cloned_socket.lock().await;
51                select! {
52                    _ = cloned_token.cancelled() => {
53                        break;
54                    }
55
56                    result = locked_socket.next() => {
57                        drop(locked_socket);
58
59                        match result {
60                            Some(Ok(WebsocketMessage::PriceInfo(price_info))) => {
61                                let mut locked_cached_prices = cloned_cached_prices.lock().await;
62                                locked_cached_prices.insert(price_info.symbol.to_string(), price_info);
63                            }
64                            Some(Ok(WebsocketMessage::SettingResponse(_response))) => {}
65                            Some(Err(err)) => {
66                                tracing::trace!("cannot get price: {}", err);
67                            }
68                            None => {
69                                tracing::trace!("cannot get price: stream ended");
70                                break;
71                            }
72                        }
73                    }
74                }
75            }
76        });
77
78        Ok(())
79    }
80
81    /// Stops the service, cancelling the WebSocket subscription.
82    async fn stop(&mut self) {
83        if let Some(token) = &self.cancellation_token {
84            token.cancel();
85        }
86        self.cancellation_token = None;
87    }
88
89    // To check if the service is started.
90    async fn is_started(&self) -> bool {
91        self.cancellation_token.is_some()
92    }
93}
94
95#[async_trait::async_trait]
96impl<S: WebSocketSource> Source for WebsocketService<S> {
97    /// Retrieves prices for the specified symbols from the cached prices.
98    async fn get_prices(&self, symbols: &[&str]) -> Vec<Result<PriceInfo, Error>> {
99        let locked_cached_prices = self.cached_prices.lock().await;
100        symbols
101            .iter()
102            .map(|&symbol| {
103                locked_cached_prices
104                    .get(&symbol.to_ascii_uppercase())
105                    .map_or_else(
106                        || Err(Error::NotFound(symbol.to_string())),
107                        |price| Ok(price.clone()),
108                    )
109            })
110            .collect()
111    }
112
113    // Asynchronous function to get price for a symbol.
114    async fn get_price(&self, symbol: &str) -> Result<PriceInfo, Error> {
115        self.get_prices(&[symbol])
116            .await
117            .pop()
118            .ok_or(Error::Unknown)?
119    }
120}