use std::{collections::HashMap, sync::Arc, fmt::Debug};
use async_tungstenite::tungstenite::{Message, self};
use futures::{prelude::*, channel::mpsc::{Sender, self}, stream::{SplitSink, SplitStream}, lock::Mutex};
use lighthouse_protocol::{Authentication, ClientMessage, DirectoryTree, Frame, LaserMetrics, Model, ServerMessage, Value, Verb};
use serde::{Deserialize, Serialize};
use stream_guard::GuardStreamExt;
use tracing::{warn, error, debug, info};
use crate::{Check, Error, Result, Spawner};
pub struct Lighthouse<S> {
ws_sink: SplitSink<S, Message>,
slots: Arc<Mutex<HashMap<i32, Slot<ServerMessage<Value>>>>>,
authentication: Authentication,
request_id: i32,
}
enum Slot<M> {
EarlyMessages(Vec<M>),
WaitForMessages(Sender<M>),
}
impl<S> Lighthouse<S>
where S: Stream<Item = tungstenite::Result<Message>>
+ Sink<Message, Error = tungstenite::Error>
+ Send
+ 'static {
pub fn new<W>(web_socket: S, authentication: Authentication) -> Result<Self> where W: Spawner {
let (ws_sink, ws_stream) = web_socket.split();
let slots = Arc::new(Mutex::new(HashMap::new()));
let lh = Self {
ws_sink,
slots: slots.clone(),
authentication,
request_id: 0,
};
W::spawn(Self::run_receive_loop(ws_stream, slots));
Ok(lh)
}
#[tracing::instrument(skip(ws_stream, slots))]
async fn run_receive_loop(mut ws_stream: SplitStream<S>, slots: Arc<Mutex<HashMap<i32, Slot<ServerMessage<Value>>>>>) {
loop {
match Self::receive_message_from(&mut ws_stream).await {
Ok(msg) => {
let mut slots = slots.lock().await;
if let Some(request_id) = msg.request_id {
if let Some(slot) = slots.get_mut(&request_id) {
match slot {
Slot::EarlyMessages(msgs) => msgs.push(msg),
Slot::WaitForMessages(tx) => {
if let Err(e) = tx.send(msg).await {
if e.is_disconnected() {
info!("Receiver for request id {} disconnected, removing the sender...", request_id);
slots.remove(&request_id);
} else {
warn!("Could not send message for request id {} via channel: {:?}", request_id, e);
}
}
}
}
} else {
slots.insert(request_id, Slot::EarlyMessages(vec![msg]));
}
} else {
warn!("Got message without request id from server: {:?}", msg);
}
},
Err(Error::NoNextMessage) => {
info!("No next message available, closing receive loop");
break
},
Err(e) => error!("Bad message: {:?}", e),
}
}
}
#[tracing::instrument(skip(ws_stream))]
async fn receive_message_from<P>(ws_stream: &mut SplitStream<S>) -> Result<ServerMessage<P>>
where
P: for<'de> Deserialize<'de> {
let bytes = Self::receive_raw_from(ws_stream).await?;
let message = rmp_serde::from_slice(&bytes)?;
Ok(message)
}
#[tracing::instrument(skip(ws_stream))]
async fn receive_raw_from(ws_stream: &mut SplitStream<S>) -> Result<Vec<u8>> {
loop {
let message = ws_stream.next().await.ok_or_else(|| Error::NoNextMessage)??;
match message {
Message::Binary(bytes) => break Ok(bytes),
Message::Ping(_) => {}, Message::Close(_) => break Err(Error::ConnectionClosed),
_ => warn!("Got non-binary message: {:?}", message),
}
}
}
pub async fn put_model(&mut self, frame: Frame) -> Result<ServerMessage<()>> {
let username = self.authentication.username.clone();
self.put(&["user", username.as_str(), "model"], Model::Frame(frame)).await
}
pub async fn stream_model(&mut self) -> Result<impl Stream<Item = Result<ServerMessage<Model>>>> {
let username = self.authentication.username.clone();
self.stream(&["user", username.as_str(), "model"], ()).await
}
pub async fn get_laser_metrics(&mut self) -> Result<ServerMessage<LaserMetrics>> {
self.get(&["metrics", "laser"]).await
}
pub async fn post<P>(&mut self, path: &[&str], payload: P) -> Result<ServerMessage<()>>
where
P: Serialize {
self.perform(&Verb::Post, path, payload).await
}
pub async fn put<P>(&mut self, path: &[&str], payload: P) -> Result<ServerMessage<()>>
where
P: Serialize {
self.perform(&Verb::Put, path, payload).await
}
pub async fn create(&mut self, path: &[&str]) -> Result<ServerMessage<()>> {
self.perform(&Verb::Create, path, ()).await
}
pub async fn delete(&mut self, path: &[&str]) -> Result<ServerMessage<()>> {
self.perform(&Verb::Delete, path, ()).await
}
pub async fn mkdir(&mut self, path: &[&str]) -> Result<ServerMessage<()>> {
self.perform(&Verb::Mkdir, path, ()).await
}
pub async fn list(&mut self, path: &[&str]) -> Result<ServerMessage<DirectoryTree>> {
self.perform(&Verb::List, path, ()).await
}
pub async fn get<R>(&mut self, path: &[&str]) -> Result<ServerMessage<R>>
where
R: for<'de> Deserialize<'de> {
self.perform(&Verb::Get, path, ()).await
}
pub async fn link(&mut self, src_path: &[&str], dest_path: &[&str]) -> Result<ServerMessage<()>> {
self.perform(&Verb::Link, dest_path, src_path).await
}
pub async fn unlink(&mut self, src_path: &[&str], dest_path: &[&str]) -> Result<ServerMessage<()>> {
self.perform(&Verb::Unlink, dest_path, src_path).await
}
pub async fn stop(&mut self, path: &[&str]) -> Result<ServerMessage<()>> {
self.perform(&Verb::Stop, path, ()).await
}
#[tracing::instrument(skip(self, payload))]
pub async fn perform<P, R>(&mut self, verb: &Verb, path: &[&str], payload: P) -> Result<ServerMessage<R>>
where
P: Serialize,
R: for<'de> Deserialize<'de> {
assert_ne!(verb, &Verb::Stream, "Lighthouse::perform may only be used for one-off requests, use Lighthouse::stream for streaming.");
let request_id = self.send_request(verb, path, payload).await?;
let response = self.receive_single(request_id).await?.check()?.decode_payload()?;
Ok(response)
}
#[tracing::instrument(skip(self, payload))]
pub async fn stream<P, R>(&mut self, path: &[&str], payload: P) -> Result<impl Stream<Item = Result<ServerMessage<R>>>>
where
P: Serialize,
R: for<'de> Deserialize<'de> {
let request_id = self.send_request(&Verb::Stream, path, payload).await?;
let stream = self.receive_streaming(request_id).await?;
Ok(stream)
}
async fn send_request<P>(&mut self, verb: &Verb, path: &[&str], payload: P) -> Result<i32>
where
P: Serialize {
let path = path.into_iter().map(|s| s.to_string()).collect();
let request_id = self.request_id;
debug! { %request_id, "Sending request" };
self.request_id += 1;
self.send_message(&ClientMessage {
request_id,
authentication: self.authentication.clone(),
path,
meta: HashMap::new(),
verb: verb.clone(),
payload
}).await?;
Ok(request_id)
}
async fn send_message<P>(&mut self, message: &ClientMessage<P>) -> Result<()>
where
P: Serialize {
self.send_raw(rmp_serde::to_vec_named(message)?).await
}
#[tracing::instrument(skip(self))]
async fn receive_single<R>(&self, request_id: i32) -> Result<ServerMessage<R>>
where
R: for<'de> Deserialize<'de> {
let mut rx = self.receive(request_id).await?;
rx.next().await.ok_or_else(|| Error::Custom(format!("No response for {}", request_id)))?
}
#[tracing::instrument(skip(self))]
async fn receive_streaming<R>(&self, request_id: i32) -> Result<impl Stream<Item = Result<ServerMessage<R>>>>
where
R: for<'de> Deserialize<'de> {
self.receive(request_id).await
}
async fn receive<R>(&self, request_id: i32) -> Result<impl Stream<Item = Result<ServerMessage<R>>>>
where
R: for<'de> Deserialize<'de> {
let rx = {
let capacity = 4;
let (tx, rx) = {
let mut slots = self.slots.lock().await;
if let Some(Slot::EarlyMessages(msgs)) = slots.get_mut(&request_id) {
let (mut tx, rx) = mpsc::channel(capacity.min(msgs.len()));
for msg in msgs.drain(..) {
tx.feed(msg).await.map_err(|e| Error::Custom(format!("Could not feed tx with early message: {}", e)))?;
}
tx.flush().await.map_err(|e| Error::Custom(format!("Could not flush tx with early messages: {}", e)))?;
(tx, rx)
} else {
mpsc::channel(capacity)
}
};
self.slots.lock().await.insert(request_id, Slot::WaitForMessages(tx));
rx
};
Ok(rx.map(|s| Ok(s.decode_payload()?)).guard({
let slots = self.slots.clone();
move || {
tokio::spawn(async move {
let mut slots = slots.lock().await;
slots.remove(&request_id);
});
}
}))
}
async fn send_raw(&mut self, bytes: impl Into<Vec<u8>> + Debug) -> Result<()> {
Ok(self.ws_sink.send(Message::Binary(bytes.into())).await?)
}
pub fn authentication(&self) -> &Authentication {
&self.authentication
}
pub async fn close(&mut self) -> Result<()> {
Ok(self.ws_sink.close().await?)
}
}