1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
use std::collections::HashMap;

use futures_util::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};

use crate::{
    error::{Error, Result},
    types::websocket::{
        request::{Channel, Method, Request},
        response::Response,
    },
};

pub struct Websocket {
    pub stream: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
    pub channels: HashMap<u64, Channel>,
    pub url: String,
}

impl Websocket {
    /// Returns `true` if the websocket is connected
    pub async fn is_connected(&self) -> bool {
        self.stream.is_some()
    }

    /// Connect to the websocket
    pub async fn connect(&mut self) -> Result<()> {
        let (stream, _) = connect_async(&self.url).await?;
        self.stream = Some(stream);

        Ok(())
    }

    /// Disconnect from the websocket
    pub async fn disconnect(&mut self) -> Result<()> {
        self.unsubscribe_all().await?;

        self.stream = None;
        Ok(())
    }

    /// Subscribe to the given channels
    /// - `channels` - The channels to subscribe to
    pub async fn subscribe(&mut self, channels: &[Channel]) -> Result<()> {
        self.send(channels, true).await?;

        channels.iter().for_each(|channel| {
            self.channels.insert(channel.id, channel.clone());
        });

        Ok(())
    }

    /// Unsubscribe from the given channels
    /// - `channels` - The channels to unsubscribe from
    pub async fn unsubscribe(&mut self, ids: &[u64]) -> Result<()> {
        let channels = ids
            .iter()
            .map(|id| {
                self.channels
                    .get(id)
                    .ok_or_else(|| Error::NotSubscribed(*id))
                    .map(|channel| channel.clone())
            })
            .collect::<Result<Vec<Channel>>>()?;

        self.send(&channels, false).await?;

        channels.iter().for_each(|channel| {
            self.channels.remove(&channel.id);
        });

        Ok(())
    }

    /// Unsubscribe from all channels
    pub async fn unsubscribe_all(&mut self) -> Result<()> {
        let channels: Vec<Channel> = self.channels.values().cloned().collect();

        self.send(&channels, false).await
    }

    pub async fn next<Callback>(&mut self, handler: Callback) -> Result<Option<bool>>
    where
        Callback: Fn(Response) -> Result<()>,
    {
        if let Some(stream) = &mut self.stream {
            while let Some(message) = stream.next().await {
                let message = message?;

                if let Message::Text(text) = message {
                    if !text.starts_with('{') {
                        continue;
                    }
                    let response = serde_json::from_str(&text)?;

                    (handler)(response)?;
                }
            }
        }

        Ok(None)
    }

    /// Send a message request
    /// - `channels` is a list of subscriptions to send
    /// - `subscribe` is a boolean indicating whether to subscribe or unsubscribe
    async fn send(&mut self, channels: &[Channel], subscribe: bool) -> Result<()> {
        if let Some(stream) = &mut self.stream {
            for channel in channels {
                let method = if subscribe {
                    Method::Subscribe
                } else {
                    Method::Unsubscribe
                };

                let request = Request {
                    method,
                    subscription: channel.sub.clone(),
                };

                let message = Message::Text(serde_json::to_string(&request)?);

                stream.send(message).await?;
            }

            return Ok(());
        }

        Err(Error::NotConnected)
    }
}