use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use async_broadcast::Receiver as ActiveReceiver;
use futures_core::stream;
use futures_util::stream::FusedStream;
use ordered_stream::{OrderedStream, PollResult};
use static_assertions::assert_impl_all;
use tracing::warn;
use crate::{
connection::ConnectionInner,
message::{Message, Sequence},
AsyncDrop, Connection, MatchRule, OwnedMatchRule, Result,
};
#[derive(Clone, Debug)]
#[must_use = "streams do nothing unless polled"]
pub struct MessageStream {
inner: Inner,
}
assert_impl_all!(MessageStream: Send, Sync, Unpin);
impl MessageStream {
pub async fn for_match_rule<R>(
rule: R,
conn: &Connection,
max_queued: Option<usize>,
) -> Result<Self>
where
R: TryInto<OwnedMatchRule>,
R::Error: Into<crate::Error>,
{
let rule = rule.try_into().map_err(Into::into)?;
let msg_receiver = conn.add_match(rule.clone(), max_queued).await?;
Ok(Self::for_subscription_channel(
msg_receiver,
Some(rule),
conn,
))
}
pub fn match_rule(&self) -> Option<MatchRule<'_>> {
self.inner.match_rule.as_deref().cloned()
}
pub fn max_queued(&self) -> usize {
self.inner.msg_receiver.capacity()
}
pub fn set_max_queued(&mut self, max_queued: usize) {
if max_queued <= self.max_queued() {
return;
}
self.inner.msg_receiver.set_capacity(max_queued);
}
pub(crate) fn for_subscription_channel(
msg_receiver: ActiveReceiver<Result<Message>>,
rule: Option<OwnedMatchRule>,
conn: &Connection,
) -> Self {
let conn_inner = conn.inner.clone();
Self {
inner: Inner {
conn_inner,
msg_receiver,
match_rule: rule,
},
}
}
}
impl stream::Stream for MessageStream {
type Item = Result<Message>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
Pin::new(&mut this.inner.msg_receiver).poll_next(cx)
}
}
impl OrderedStream for MessageStream {
type Data = Result<Message>;
type Ordering = Sequence;
fn poll_next_before(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
before: Option<&Self::Ordering>,
) -> Poll<PollResult<Self::Ordering, Self::Data>> {
let this = self.get_mut();
match stream::Stream::poll_next(Pin::new(this), cx) {
Poll::Pending if before.is_some() => {
Poll::Ready(PollResult::NoneBefore)
}
Poll::Pending => Poll::Pending,
Poll::Ready(Some(Ok(msg))) => Poll::Ready(PollResult::Item {
ordering: msg.recv_position(),
data: Ok(msg),
}),
Poll::Ready(Some(Err(e))) => Poll::Ready(PollResult::Item {
ordering: Sequence::LAST,
data: Err(e),
}),
Poll::Ready(None) => Poll::Ready(PollResult::Terminated),
}
}
}
impl FusedStream for MessageStream {
fn is_terminated(&self) -> bool {
self.inner.msg_receiver.is_terminated()
}
}
impl From<Connection> for MessageStream {
fn from(conn: Connection) -> Self {
let conn_inner = conn.inner;
let msg_receiver = conn_inner.msg_receiver.activate_cloned();
Self {
inner: Inner {
conn_inner,
msg_receiver,
match_rule: None,
},
}
}
}
impl From<&Connection> for MessageStream {
fn from(conn: &Connection) -> Self {
Self::from(conn.clone())
}
}
impl From<MessageStream> for Connection {
fn from(stream: MessageStream) -> Connection {
Connection::from(&stream)
}
}
impl From<&MessageStream> for Connection {
fn from(stream: &MessageStream) -> Connection {
Connection {
inner: stream.inner.conn_inner.clone(),
}
}
}
#[derive(Clone, Debug)]
struct Inner {
conn_inner: Arc<ConnectionInner>,
msg_receiver: ActiveReceiver<Result<Message>>,
match_rule: Option<OwnedMatchRule>,
}
impl Drop for Inner {
fn drop(&mut self) {
let conn = Connection {
inner: self.conn_inner.clone(),
};
if let Some(rule) = self.match_rule.take() {
conn.queue_remove_match(rule);
}
}
}
#[async_trait::async_trait]
impl AsyncDrop for MessageStream {
async fn async_drop(mut self) {
let conn = Connection {
inner: self.inner.conn_inner.clone(),
};
if let Some(rule) = self.inner.match_rule.take() {
if let Err(e) = conn.remove_match(rule).await {
warn!("Failed to remove match rule: {}", e);
}
}
}
}