use super::messaging;
use crate::err;
use futures::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt, TryStreamExt,
};
use std::{collections::HashMap, sync::Arc, time::Duration};
use tokio::{
net::TcpStream,
sync::{oneshot, Mutex},
time,
};
use tokio_tungstenite::{
connect_async, tungstenite::protocol::Message, MaybeTlsStream, WebSocketStream,
};
type WebSocketConnection = WebSocketStream<MaybeTlsStream<TcpStream>>;
#[derive(Clone)]
pub struct Client {
sender: Arc<Mutex<SplitSink<WebSocketConnection, Message>>>,
receiver: Arc<Mutex<SplitStream<WebSocketConnection>>>,
pool: Arc<Mutex<HashMap<String, oneshot::Sender<messaging::Response>>>>,
}
impl Client {
pub async fn connect(host: &str) -> super::Result<Self> {
let (connection, _) = connect_async(format!("{host}/rpc"))
.await
.map_err(|_| err!(client, WebSocketConnectionFail))?;
let (sender, receiver) = connection.split();
Ok(Self {
sender: Arc::new(Mutex::new(sender)),
receiver: Arc::new(Mutex::new(receiver)),
pool: Arc::new(Mutex::new(HashMap::new())),
})
}
pub async fn start(&self) {
let clone = self.clone();
tokio::spawn(async move {
clone.listen().await;
});
}
pub async fn listen(&self) {
while let Ok(message) = self.receiver.lock().await.try_next().await {
if let Some(Message::Text(message)) = message {
let pool = Arc::clone(&self.pool);
tokio::spawn(async move {
let Ok(decoded) = serde_json::from_str::<messaging::Response>(&message) else {
return;
};
if let Some(sender) = pool.lock().await.remove(&decoded.id) {
sender.send(decoded).ok();
}
});
}
}
}
pub async fn cast(
&self,
request: messaging::Request,
) -> super::Result<oneshot::Receiver<messaging::Response>> {
let (sender, receiver) = oneshot::channel::<messaging::Response>();
self.pool.lock().await.insert(request.id.clone(), sender);
let raw_request = serde_json::to_string(&request).unwrap_or_default();
let result = self
.sender
.lock()
.await
.send(Message::Text(raw_request.into()))
.await;
if result.is_err() {
Err(err!(client, WebSocketSendFail))
} else {
Ok(receiver)
}
}
pub async fn call(&self, request: messaging::Request) -> super::Result<messaging::Response> {
self.cast(request)
.await?
.await
.map_err(|_| err!(client, WebSocketRecvFail))
}
pub async fn call_in(
&self,
request: messaging::Request,
timeout: Duration,
) -> super::Result<messaging::Response> {
time::timeout(timeout, self.call(request))
.await
.map_err(|_| err!(client, WebSocketRecvTimeout))?
}
}