use std::collections::HashMap;
use futures_util::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
use crate::{
error::{Error, Result},
types::websocket::{
request::{Channel, Method, Request},
response::Response,
},
};
pub struct Websocket {
pub stream: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
pub channels: HashMap<u64, Channel>,
pub url: String,
}
impl Websocket {
pub async fn is_connected(&self) -> bool {
self.stream.is_some()
}
pub async fn connect(&mut self) -> Result<()> {
let (stream, _) = connect_async(&self.url).await?;
self.stream = Some(stream);
Ok(())
}
pub async fn disconnect(&mut self) -> Result<()> {
self.unsubscribe_all().await?;
self.stream = None;
Ok(())
}
pub async fn subscribe(&mut self, channels: &[Channel]) -> Result<()> {
self.send(channels, true).await?;
channels.iter().for_each(|channel| {
self.channels.insert(channel.id, channel.clone());
});
Ok(())
}
pub async fn unsubscribe(&mut self, ids: &[u64]) -> Result<()> {
let channels = ids
.iter()
.map(|id| {
self.channels
.get(id)
.ok_or_else(|| Error::NotSubscribed(*id))
.map(|channel| channel.clone())
})
.collect::<Result<Vec<Channel>>>()?;
self.send(&channels, false).await?;
channels.iter().for_each(|channel| {
self.channels.remove(&channel.id);
});
Ok(())
}
pub async fn unsubscribe_all(&mut self) -> Result<()> {
let channels: Vec<Channel> = self.channels.values().cloned().collect();
self.send(&channels, false).await
}
pub async fn next<Callback>(&mut self, handler: Callback) -> Result<Option<bool>>
where
Callback: Fn(Response) -> Result<()>,
{
if let Some(stream) = &mut self.stream {
while let Some(message) = stream.next().await {
let message = message?;
if let Message::Text(text) = message {
if !text.starts_with('{') {
continue;
}
let response = serde_json::from_str(&text)?;
(handler)(response)?;
}
}
}
Ok(None)
}
async fn send(&mut self, channels: &[Channel], subscribe: bool) -> Result<()> {
if let Some(stream) = &mut self.stream {
for channel in channels {
let method = if subscribe {
Method::Subscribe
} else {
Method::Unsubscribe
};
let request = Request {
method,
subscription: channel.sub.clone(),
};
let message = Message::Text(serde_json::to_string(&request)?);
stream.send(message).await?;
}
return Ok(());
}
Err(Error::NotConnected)
}
}