use std::pin::Pin;
use std::task::{Context, Poll};
use futures_util::Stream;
use tokio::sync::broadcast::error::TryRecvError;
use tokio::sync::broadcast::{Receiver, error::RecvError};
use super::BidiEvent;
use super::command::RawEvent;
use super::transport::ws::BidiTransport;
#[derive(Debug)]
pub struct EventStream<T> {
transport: BidiTransport,
rx: Receiver<RawEvent>,
method: &'static str,
_marker: std::marker::PhantomData<fn() -> T>,
}
impl<T> EventStream<T> {
pub(crate) fn new(
transport: BidiTransport,
rx: Receiver<RawEvent>,
method: &'static str,
) -> Self {
Self {
transport,
rx,
method,
_marker: std::marker::PhantomData,
}
}
fn matches(&self, raw: &RawEvent) -> bool {
raw.method == self.method
}
}
impl<T> Drop for EventStream<T> {
fn drop(&mut self) {
if let Ok(handle) = tokio::runtime::Handle::try_current() {
let transport = self.transport.clone();
let method = self.method;
handle.spawn(async move {
transport.release_subscription(method).await;
});
}
}
}
impl<T: BidiEvent> Stream for EventStream<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
match this.rx.try_recv() {
Ok(raw) => {
if this.matches(&raw) {
match serde_json::from_value::<T>(raw.params.clone()) {
Ok(parsed) => return Poll::Ready(Some(parsed)),
Err(e) => warn_parse_failure::<T>(this.method, &raw, &e),
}
}
}
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Lagged(_)) => continue,
Err(TryRecvError::Closed) => return Poll::Ready(None),
}
}
let polled = {
let recv = this.rx.recv();
tokio::pin!(recv);
recv.poll(cx)
};
match polled {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(raw)) => {
if this.matches(&raw) {
match serde_json::from_value::<T>(raw.params.clone()) {
Ok(parsed) => return Poll::Ready(Some(parsed)),
Err(e) => warn_parse_failure::<T>(this.method, &raw, &e),
}
}
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Ready(Err(RecvError::Lagged(_))) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Ready(Err(RecvError::Closed)) => Poll::Ready(None),
}
}
}
fn warn_parse_failure<T>(method: &str, raw: &RawEvent, err: &serde_json::Error) {
let preview = raw.params.to_string();
let preview = if preview.len() > 200 {
&preview[..200]
} else {
preview.as_str()
};
tracing::warn!(
target: "thirtyfour::bidi",
method = %method,
error = %err,
wire_type = std::any::type_name::<T>(),
"BiDi event {method} did not deserialise as the requested typed event; skipping. \
Switch to subscribe_raw if you need access to events with this wire shape. \
Params (truncated): {preview}",
);
}
#[derive(Debug)]
pub struct RawEventStream {
rx: Receiver<RawEvent>,
}
impl RawEventStream {
pub(crate) fn new(rx: Receiver<RawEvent>) -> Self {
Self {
rx,
}
}
}
impl Stream for RawEventStream {
type Item = RawEvent;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
match this.rx.try_recv() {
Ok(raw) => return Poll::Ready(Some(raw)),
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Lagged(_)) => continue,
Err(TryRecvError::Closed) => return Poll::Ready(None),
}
}
let polled = {
let recv = this.rx.recv();
tokio::pin!(recv);
recv.poll(cx)
};
match polled {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(raw)) => Poll::Ready(Some(raw)),
Poll::Ready(Err(RecvError::Lagged(_))) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Ready(Err(RecvError::Closed)) => Poll::Ready(None),
}
}
}