use std::{collections::HashMap, sync::Arc, fmt::Debug};
use async_tungstenite::tungstenite::{Message, self};
use futures::{prelude::*, channel::mpsc::{Sender, self, Receiver}, stream::{SplitSink, SplitStream}, lock::Mutex};
use tracing::{warn, error, debug, info};
use crate::{Authentication, Result, Frame, ClientMessage, Payload, Error, ServerMessage, Spawner};
pub struct Lighthouse<S> {
ws_sink: SplitSink<S, Message>,
txs: Arc<Mutex<HashMap<i32, Sender<ServerMessage>>>>,
authentication: Authentication,
request_id: i32,
}
impl<S> Lighthouse<S>
where S: Stream<Item = tungstenite::Result<Message>>
+ Sink<Message, Error = tungstenite::Error>
+ Unpin
+ Send
+ 'static {
pub fn new<W>(web_socket: S, authentication: Authentication) -> Result<Self> where W: Spawner {
let (ws_sink, ws_stream) = web_socket.split();
let txs = Arc::new(Mutex::new(HashMap::new()));
let lh = Self {
ws_sink,
txs: txs.clone(),
authentication,
request_id: 0,
};
W::spawn(Self::run_receive_loop(ws_stream, txs));
Ok(lh)
}
#[tracing::instrument(skip(ws_stream, txs))]
async fn run_receive_loop(mut ws_stream: SplitStream<S>, txs: Arc<Mutex<HashMap<i32, Sender<ServerMessage>>>>) {
loop {
match Self::receive_message_from(&mut ws_stream).await {
Ok(msg) => {
let mut txs = txs.lock().await;
if let Some(request_id) = msg.request_id {
if let Some(tx) = txs.get_mut(&request_id) {
if let Err(e) = tx.send(msg).await {
if e.is_disconnected() {
info!("Receiver for request id {} disconnected, removing the sender...", request_id);
txs.remove(&request_id);
} else {
warn!("Could not send message for request id {} via channel: {:?}", request_id, e);
}
}
} else {
warn!("No channel registered for request id {}", request_id);
}
} else {
warn!("Got message without request id from server: {:?}", msg);
}
},
Err(e) => error!("Bad message: {:?}", e),
}
}
}
#[tracing::instrument(skip(ws_stream))]
async fn receive_message_from(ws_stream: &mut SplitStream<S>) -> Result<ServerMessage> {
let bytes = Self::receive_raw_from(ws_stream).await?;
let message = rmp_serde::from_slice(&bytes)?;
Ok(message)
}
#[tracing::instrument(skip(ws_stream))]
async fn receive_raw_from(ws_stream: &mut SplitStream<S>) -> Result<Vec<u8>> {
loop {
let message = ws_stream.next().await.ok_or_else(|| Error::custom("Got no message"))??;
match message {
Message::Binary(bytes) => break Ok(bytes),
Message::Ping(_) => {},
_ => warn!("Got non-binary message: {:?}", message),
}
}
}
pub async fn put_model(&mut self, frame: Frame) -> Result<()> {
let username = self.authentication.username.clone();
self.put(["user", username.as_str(), "model"], Payload::Frame(frame)).await
}
pub async fn stream_model(&mut self) -> Result<Receiver<ServerMessage>> {
let username = self.authentication.username.clone();
self.stream(["user", username.as_str(), "model"], Payload::Empty).await
}
pub async fn put(&mut self, path: impl IntoIterator<Item=&str> + Debug, payload: Payload) -> Result<()> {
self.perform("PUT", path, payload).await
}
#[tracing::instrument(skip(self, payload))]
pub async fn perform(&mut self, verb: &str, path: impl IntoIterator<Item=&str> + Debug, payload: Payload) -> Result<()> {
assert_ne!(verb, "STREAM", "Lighthouse::perform may only be used for one-off requests, use Lighthouse::stream for streaming.");
let request_id = self.send_request(verb, path, payload).await?;
let response = self.receive_single(request_id).await?;
response.check()?;
Ok(())
}
#[tracing::instrument(skip(self, payload))]
pub async fn stream(&mut self, path: impl IntoIterator<Item=&str> + Debug, payload: Payload) -> Result<Receiver<ServerMessage>> {
let request_id = self.send_request("STREAM", path, payload).await?;
let stream = self.receive_streaming(request_id).await?;
Ok(stream)
}
async fn send_request(&mut self, verb: &str, path: impl IntoIterator<Item=&str> + Debug, payload: Payload) -> Result<i32> {
let path = path.into_iter().map(|s| s.to_owned()).collect();
let request_id = self.request_id;
debug! { %request_id, "Sending request" };
self.request_id += 1;
self.send_message(&ClientMessage {
request_id,
authentication: self.authentication.clone(),
path,
meta: HashMap::new(),
verb: verb.to_owned(),
payload
}).await?;
Ok(request_id)
}
async fn send_message(&mut self, message: &ClientMessage) -> Result<()> {
self.send_raw(rmp_serde::to_vec_named(message)?).await
}
#[tracing::instrument(skip(self))]
async fn receive_single(&self, request_id: i32) -> Result<ServerMessage> {
let mut rx = self.receive(request_id).await?;
rx.next().await.ok_or_else(|| Error::Custom(format!("No response for {}", request_id)))
}
#[tracing::instrument(skip(self))]
async fn receive_streaming(&self, request_id: i32) -> Result<Receiver<ServerMessage>> {
self.receive(request_id).await
}
async fn receive(&self, request_id: i32) -> Result<Receiver<ServerMessage>> {
let mut txs = self.txs.lock().await;
let (tx, rx) = mpsc::channel(4);
txs.insert(request_id, tx);
Ok(rx)
}
async fn send_raw(&mut self, bytes: impl Into<Vec<u8>> + Debug) -> Result<()> {
Ok(self.ws_sink.send(Message::Binary(bytes.into())).await?)
}
}