use std::{sync::Arc, task::Poll};
use crate::{WsIn, router::ResponseRouter};
use futures::{Stream, StreamExt};
use serde::Deserialize;
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::Message;
use tokio_util::sync::CancellationToken;
use super::{Event, RequestId, Result};
type EventSender = mpsc::UnboundedSender<Result<Event>>;
pub type EventReceiver = mpsc::UnboundedReceiver<Result<Event>>;
pub fn init(ws_in: WsIn, router: ResponseRouter, token: CancellationToken) -> EventQueue {
let (events_tx, receiver) = mpsc::unbounded_channel::<Result<Event>>();
tokio::spawn(event_dispatcher_task(ws_in, events_tx, router, token));
EventQueue { receiver }
}
pub struct EventQueue {
receiver: EventReceiver,
}
impl EventQueue {
pub async fn next_event(&mut self) -> Option<Result<Event>> {
self.receiver.recv().await
}
pub fn into_receiver(self) -> EventReceiver {
self.receiver
}
}
async fn event_dispatcher_task(
mut ws_in: WsIn,
mut event_queue: EventSender,
router: ResponseRouter,
token: CancellationToken,
) {
loop {
tokio::select! {
biased;
_ = token.cancelled() => {
tokio::task::yield_now().await;
let mut ws_in = Closed(ws_in);
while let Some(ev) = ws_in.next().await {
match ev {
Ok(msg) => {
process_raw_event(None, &mut event_queue, msg);
}
Err(e) => {
let _ = event_queue.send(Err(Arc::new(e)));
break;
}
}
}
break;
}
ev = ws_in.next() => {
match ev {
Some(Ok(msg)) => {
process_raw_event(Some(&router), &mut event_queue, msg);
}
Some(Err(e)) => {
let e = Arc::new(e);
let _ = event_queue.send(Err(Arc::clone(&e)));
router.shutdown(e);
break;
}
None => unreachable!("Must receive an error before connection drops")
}
}
}
}
log::debug!("Dispatcher task finished");
}
fn process_raw_event(router: Option<&ResponseRouter>, event_queue: &mut EventSender, msg: Message) {
let event = match msg {
Message::Text(utf8bytes) => utf8bytes.to_string(),
unexpected => {
log::warn!("Ignoring event in unexpecetd format: {unexpected:#?}");
return;
}
};
let header: EventHeader = match serde_json::from_str(&event) {
Ok(header) => header,
Err(e) => {
log::error!("Got invalid JSON form the server\n{event:?}\n{e}");
return;
}
};
if let Some(corr_id) = header.corr_id {
let id: RequestId = match corr_id.parse() {
Ok(id) => id,
Err(e) => {
log::error!("Failed to parse corr_id: {corr_id}\n{e}");
return;
}
};
match router {
Some(router) => router.deliver(id, event),
None => {
log::warn!("Dropping response because router task already finished\n{event}");
}
}
} else {
let _ = event_queue.send(Ok(event));
}
}
#[derive(Deserialize)]
struct EventHeader<'a> {
#[serde(rename = "corrId")]
#[serde(borrow)]
corr_id: Option<&'a str>,
}
struct Closed<S>(S);
impl<S> Stream for Closed<S>
where
S: Stream + Unpin,
{
type Item = S::Item;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.0.poll_next_unpin(cx) {
Poll::Ready(v) => Poll::Ready(v),
Poll::Pending => Poll::Ready(None),
}
}
}