simploxide-ws-core 0.1.0

SimpleX-Chat raw websocket client
Documentation
//! Event dispatcher task.

use std::{sync::Arc, task::Poll};

use crate::{WsIn, router::ResponseRouter};
use futures::{Stream, StreamExt};
use serde::Deserialize;
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::Message;
use tokio_util::sync::CancellationToken;

use super::{Event, RequestId, Result};

type EventSender = mpsc::UnboundedSender<Result<Event>>;
pub type EventReceiver = mpsc::UnboundedReceiver<Result<Event>>;

pub fn init(ws_in: WsIn, router: ResponseRouter, token: CancellationToken) -> EventQueue {
    let (events_tx, receiver) = mpsc::unbounded_channel::<Result<Event>>();
    tokio::spawn(event_dispatcher_task(ws_in, events_tx, router, token));

    EventQueue { receiver }
}

/// An event queue buffers events if you're not actively processing them so it's recommended to
/// drop it as soon as it no longer needed.
pub struct EventQueue {
    receiver: EventReceiver,
}

impl EventQueue {
    /// Can return a SimpleX event or a [`tokio_tungstenite::tungstenite::Error`] if a connection is dropped due to a
    /// web socket failure. SimpleX events can themselves represent SimpleX errors but recognizing
    /// and handling them them is a task of the upstream code.
    pub async fn next_event(&mut self) -> Option<Result<Event>> {
        self.receiver.recv().await
    }

    /// Get the underlying tokio unbounded receiver that enables more complicated use cases.
    pub fn into_receiver(self) -> EventReceiver {
        self.receiver
    }
}

async fn event_dispatcher_task(
    mut ws_in: WsIn,
    mut event_queue: EventSender,
    router: ResponseRouter,
    token: CancellationToken,
) {
    loop {
        tokio::select! {
            // Biased so the cancellation token is always checked before the next frame. Without
            // this, a frame arriving in the same poll cycle as the cancellation, could be
            // processed via `process_raw_event` after the routing task has already exited and
            // dropped its response receiver(race condition between tasks running on different
            // threads), causing `ResponseRouter::deliver` to panic.
            biased;

            _ = token.cancelled() => {
                // Yielding here to give tokio the last chance to move any OS-buffered frames into
                // the stream. This minimizes the chance of silently dropping events
                tokio::task::yield_now().await;

                let mut ws_in = Closed(ws_in);
                while let Some(ev) = ws_in.next().await {
                    match ev {
                        Ok(msg) => {
                            process_raw_event(None, &mut event_queue, msg);
                        }
                        Err(e) => {
                            let _ = event_queue.send(Err(Arc::new(e)));
                            break;
                        }
                    }
                }

                break;
            }

            ev = ws_in.next() => {
                match ev {
                    Some(Ok(msg)) => {
                        process_raw_event(Some(&router), &mut event_queue, msg);
                    }
                    Some(Err(e)) => {
                        let e = Arc::new(e);
                        let _ = event_queue.send(Err(Arc::clone(&e)));
                        router.shutdown(e);

                        break;
                    }
                    None => unreachable!("Must receive an error before connection drops")

                }
            }
        }
    }

    log::debug!("Dispatcher task finished");
}

/// Parse the top level JSON and either route event to the `event_queue` or deliver a response by
/// `corrId` via the `router`.
///
/// TODO: `Option<&Router>` was added to reuse code in a branch that handles the interruption
/// event. In this case all buffered events can only be sent to the `event_queue`. This could be
/// refactored to look less hacky.
fn process_raw_event(router: Option<&ResponseRouter>, event_queue: &mut EventSender, msg: Message) {
    let event = match msg {
        Message::Text(utf8bytes) => utf8bytes.to_string(),
        unexpected => {
            log::warn!("Ignoring event in unexpecetd format: {unexpected:#?}");
            return;
        }
    };

    let header: EventHeader = match serde_json::from_str(&event) {
        Ok(header) => header,
        Err(e) => {
            log::error!("Got invalid JSON form the server\n{event:?}\n{e}");
            return;
        }
    };

    if let Some(corr_id) = header.corr_id {
        let id: RequestId = match corr_id.parse() {
            Ok(id) => id,
            Err(e) => {
                log::error!("Failed to parse corr_id: {corr_id}\n{e}");
                return;
            }
        };

        match router {
            Some(router) => router.deliver(id, event),
            None => {
                log::warn!("Dropping response because router task already finished\n{event}");
            }
        }
    } else {
        let _ = event_queue.send(Ok(event));
    }
}

/// A helper that detects corr IDs in incoming events
#[derive(Deserialize)]
struct EventHeader<'a> {
    #[serde(rename = "corrId")]
    #[serde(borrow)]
    corr_id: Option<&'a str>,
}

/// A helper that allows to process buffered items. Returns `None` when internal stream buffer
/// becomes empty.
struct Closed<S>(S);

impl<S> Stream for Closed<S>
where
    S: Stream + Unpin,
{
    type Item = S::Item;

    fn poll_next(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<Option<Self::Item>> {
        match self.0.poll_next_unpin(cx) {
            Poll::Ready(v) => Poll::Ready(v),
            Poll::Pending => Poll::Ready(None),
        }
    }
}