roblib-client 0.1.0

A client library for a dank engine
Documentation
use super::{SubscribableAsync, TransportAsync};
use anyhow::Result;
use async_trait::async_trait;
use futures::{executor::block_on, SinkExt, TryStreamExt};
use roblib::{
    cmd::{self, has_return},
    event::{ConcreteType, Event},
    text_format,
};
use serde::Deserialize;
use std::{collections::HashMap, io::Cursor, sync::Arc};
use tokio::{
    net::TcpStream,
    sync::{
        broadcast,
        mpsc::{self, unbounded_channel, UnboundedReceiver, UnboundedSender},
        Mutex,
    },
    task::JoinHandle,
};
use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};

type WsConn = WebSocketStream<MaybeTlsStream<TcpStream>>;
type D =
    bincode::Deserializer<bincode::de::read::IoReader<Cursor<Vec<u8>>>, bincode::DefaultOptions>;
type Handler = mpsc::Sender<D>;

#[derive(Default)]
struct WsInner {
    handlers: Mutex<HashMap<u32, Handler>>,
    events: Mutex<HashMap<roblib::event::ConcreteType, u32>>,
}
pub struct Ws {
    inner: Arc<WsInner>,
    handle: Option<JoinHandle<Result<()>>>,
    id: Mutex<u32>,
    sender: UnboundedSender<Message>,
}

enum Action {
    Recv(Message),
    Send(Message),
}
impl Ws {
    pub async fn connect(addr: &str) -> Result<Self> {
        let url = format!("ws://{addr}/ws");
        let (ws, _) = connect_async(url).await?;
        let inner = Arc::new(WsInner::default());

        let (tx, rx) = unbounded_channel();
        let handle = tokio::spawn(Self::worker(ws, rx, inner.clone()));
        Ok(Self {
            inner,
            handle: Some(handle),
            id: super::ID_START.into(),
            sender: tx,
        })
    }

    async fn worker(
        mut ws: WsConn,
        mut rx: UnboundedReceiver<Message>,
        inner: Arc<WsInner>,
    ) -> Result<()> {
        let bin = bincode::options();
        loop {
            let action = tokio::select! {
                Ok(Some(msg)) = ws.try_next() => Action::Recv(msg),
                Some(msg) = rx.recv() => Action::Send(msg)
            };
            match action {
                Action::Recv(msg) => match msg {
                    Message::Text(s) => {
                        let _de = text_format::de::new_deserializer(&s);
                        todo!()
                    }
                    Message::Binary(b) => {
                        let mut c = Cursor::new(b);
                        let id: u32 = bincode::Options::deserialize_from(bin, &mut c)?;

                        let mut handlers = inner.handlers.lock().await;
                        let Some(handler) = handlers.get_mut(&id) else {
                            return Err(anyhow::Error::msg("received response for unknown id"));
                        };

                        handler
                            .send(bincode::Deserializer::with_reader(c, bin))
                            .await?;
                    }
                    Message::Ping(p) => ws.send(Message::Pong(p)).await?,
                    Message::Close(close) => {
                        if let Some(close) = close {
                            log::debug!("ws close, reason: {}", close.reason);
                        }
                        continue;
                    }
                    _ => continue,
                },
                Action::Send(msg) => {
                    ws.send(msg).await?;
                }
            }
        }
    }

    async fn incr_id(&self) -> u32 {
        let mut id_handle = self.id.lock().await;
        let id = *id_handle;
        *id_handle = id + 1;
        id
    }

    async fn send<C: cmd::Command>(&self, id: u32, cmd: C) -> Result<C::Return> {
        let cmd: cmd::Concrete = cmd.into();
        let data = bincode::Options::serialize(bincode::options(), &(id, cmd))?;
        self.sender.send(Message::Binary(data))?;

        if has_return::<C>() {
            let (tx, mut rx) = mpsc::channel(1);
            self.inner.handlers.lock().await.insert(id, tx);
            let mut de = rx.recv().await.unwrap();
            let re = C::Return::deserialize(&mut de)?;
            Ok(re)
        } else {
            unsafe { std::mem::zeroed() }
        }
    }
}

#[async_trait]
impl TransportAsync for Ws {
    async fn cmd<C: cmd::Command>(&self, cmd: C) -> Result<C::Return> {
        let id = self.incr_id().await;
        self.send(id, cmd).await
    }
}

#[async_trait]
impl SubscribableAsync for Ws {
    async fn subscribe<E: Event>(&self, ev: E) -> Result<broadcast::Receiver<E::Item>> {
        let id = self.incr_id().await;
        let ev: ConcreteType = ev.into();

        let (tx, mut worker_rx) = mpsc::channel(1);
        self.inner.handlers.lock().await.insert(id, tx);
        self.inner.events.lock().await.insert(ev.clone(), id);
        self.send(id, cmd::Subscribe(ev)).await?;

        let (client_tx, client_rx) = broadcast::channel(128);
        tokio::spawn(async move {
            while let Some(mut de) = worker_rx.recv().await {
                let item = E::Item::deserialize(&mut de)?;
                if client_tx.send(item).is_err() {
                    log::error!("no receiver for active subscription");
                };
            }
            anyhow::Ok(())
        });

        Ok(client_rx)
    }

    async fn unsubscribe<E>(&self, ev: E) -> Result<()>
    where
        E: Event,
    {
        let ev = ev.into();

        let mut lock = self.inner.events.lock().await;
        match lock.entry(ev.clone()) {
            std::collections::hash_map::Entry::Occupied(v) => {
                let id = v.remove();
                self.send(id, cmd::Unsubscribe(ev)).await?;
                self.inner.handlers.lock().await.remove(&id);
            }
            std::collections::hash_map::Entry::Vacant(_) => anyhow::bail!("Subscription not found"),
        };
        Ok(())
    }
}

impl Drop for Ws {
    fn drop(&mut self) {
        let _ = self.sender.send(Message::Close(None));
        let _ = block_on(self.handle.take().unwrap());
    }
}