use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use futures::stream::Stream;
use log::{debug, warn};
use tokio::sync::mpsc;
use super::common::{filter_notice, process_decode_result, DecoderContext, ProcessingResult, RoutedItem, SubscriptionItem};
use super::StreamDecoder;
use crate::messages::ResponseMessage;
use crate::transport::{AsyncInternalSubscription, AsyncMessageBus};
use crate::Error;
type CancelFn = Box<dyn Fn(i32, Option<i32>, Option<&DecoderContext>) -> Result<Vec<u8>, Error> + Send + Sync>;
type DecoderFn<T> = Arc<dyn Fn(&DecoderContext, &mut ResponseMessage) -> Result<T, Error> + Send + Sync>;
#[must_use = "Subscription must be polled (via .next().await or .filter_data()) to receive data; dropping it cancels the request"]
pub struct Subscription<T> {
inner: SubscriptionInner<T>,
request_id: Option<i32>,
order_id: Option<i32>,
context: DecoderContext,
cancelled: Arc<AtomicBool>,
stream_ended: AtomicBool,
message_bus: Option<Arc<dyn AsyncMessageBus>>,
cancel_fn: Option<Arc<CancelFn>>,
}
enum SubscriptionInner<T> {
WithDecoder {
subscription: AsyncInternalSubscription,
decoder: DecoderFn<T>,
},
PreDecoded { receiver: mpsc::UnboundedReceiver<Result<T, Error>> },
}
impl<T> Clone for SubscriptionInner<T> {
fn clone(&self) -> Self {
match self {
SubscriptionInner::WithDecoder { subscription, decoder } => SubscriptionInner::WithDecoder {
subscription: subscription.clone(),
decoder: decoder.clone(),
},
SubscriptionInner::PreDecoded { .. } => {
panic!("Cannot clone pre-decoded subscriptions");
}
}
}
}
impl<T> Clone for Subscription<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
request_id: self.request_id,
order_id: self.order_id,
context: self.context.clone(),
cancelled: self.cancelled.clone(),
stream_ended: AtomicBool::new(false),
message_bus: self.message_bus.clone(),
cancel_fn: self.cancel_fn.clone(),
}
}
}
impl<T> Subscription<T> {
pub(crate) fn with_decoder<D>(
internal: AsyncInternalSubscription,
message_bus: Arc<dyn AsyncMessageBus>,
decoder: D,
request_id: Option<i32>,
order_id: Option<i32>,
context: DecoderContext,
) -> Self
where
D: Fn(&DecoderContext, &mut ResponseMessage) -> Result<T, Error> + Send + Sync + 'static,
{
Self {
inner: SubscriptionInner::WithDecoder {
subscription: internal,
decoder: Arc::new(decoder),
},
request_id,
order_id,
context,
cancelled: Arc::new(AtomicBool::new(false)),
stream_ended: AtomicBool::new(false),
message_bus: Some(message_bus),
cancel_fn: None,
}
}
pub(crate) fn new_from_internal<D>(
internal: AsyncInternalSubscription,
message_bus: Arc<dyn AsyncMessageBus>,
request_id: Option<i32>,
order_id: Option<i32>,
context: DecoderContext,
) -> Self
where
D: StreamDecoder<T> + 'static,
T: 'static,
{
let mut sub = Self::with_decoder(internal, message_bus, D::decode, request_id, order_id, context);
sub.cancel_fn = Some(Arc::new(Box::new(D::cancel_message)));
sub
}
pub(crate) fn new_from_internal_simple<D>(
internal: AsyncInternalSubscription,
message_bus: Arc<dyn AsyncMessageBus>,
context: DecoderContext,
) -> Self
where
D: StreamDecoder<T> + 'static,
T: 'static,
{
Self::new_from_internal::<D>(internal, message_bus, None, None, context)
}
pub fn new(receiver: mpsc::UnboundedReceiver<Result<T, Error>>) -> Self {
Self {
inner: SubscriptionInner::PreDecoded { receiver },
request_id: None,
order_id: None,
context: DecoderContext::default(),
cancelled: Arc::new(AtomicBool::new(false)),
stream_ended: AtomicBool::new(false),
message_bus: None,
cancel_fn: None,
}
}
pub fn request_id(&self) -> Option<i32> {
self.request_id
}
}
impl<T: Send + 'static> Stream for Subscription<T> {
type Item = Result<SubscriptionItem<T>, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.stream_ended.load(Ordering::Relaxed) {
return Poll::Ready(None);
}
let Subscription {
inner,
context,
stream_ended,
..
} = this;
loop {
match inner {
SubscriptionInner::WithDecoder { subscription, decoder } => {
let routed = match Pin::new(&mut subscription.stream).poll_next(cx) {
Poll::Ready(Some(Ok(item))) => item,
Poll::Ready(Some(Err(_lagged))) => continue, Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
};
match routed {
RoutedItem::Response(mut message) => {
let result = decoder(context, &mut message);
match process_decode_result(result) {
ProcessingResult::Success(val) => return Poll::Ready(Some(Ok(SubscriptionItem::Data(val)))),
ProcessingResult::EndOfStream => {
stream_ended.store(true, Ordering::Relaxed);
return Poll::Ready(None);
}
ProcessingResult::Skip => {
log::trace!("skipping unexpected message on shared channel");
continue;
}
ProcessingResult::Error(err) => {
stream_ended.store(true, Ordering::Relaxed);
return Poll::Ready(Some(Err(err)));
}
}
}
RoutedItem::Notice(notice) => return Poll::Ready(Some(Ok(SubscriptionItem::Notice(notice)))),
RoutedItem::Error(Error::EndOfStream) => {
stream_ended.store(true, Ordering::Relaxed);
return Poll::Ready(None);
}
RoutedItem::Error(e) => {
stream_ended.store(true, Ordering::Relaxed);
return Poll::Ready(Some(Err(e)));
}
}
}
SubscriptionInner::PreDecoded { receiver } => {
return match receiver.poll_recv(cx) {
Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(Ok(SubscriptionItem::Data(t)))),
Poll::Ready(Some(Err(e))) => {
stream_ended.store(true, Ordering::Relaxed);
Poll::Ready(Some(Err(e)))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
};
}
}
}
}
}
impl<T> Subscription<T> {
pub async fn cancel(&self) {
if self.cancelled.load(Ordering::Relaxed) {
return;
}
self.cancelled.store(true, Ordering::Relaxed);
if let (Some(message_bus), Some(cancel_fn)) = (&self.message_bus, &self.cancel_fn) {
let id = self.request_id.or(self.order_id);
if let Ok(message) = cancel_fn(self.context.server_version, id, Some(&self.context)) {
if let Err(e) = message_bus.send_message(message).await {
warn!("error sending cancel message: {e}")
}
}
}
}
}
impl<T> Drop for Subscription<T> {
fn drop(&mut self) {
debug!("dropping async subscription");
if self.cancelled.load(Ordering::Relaxed) {
return;
}
self.cancelled.store(true, Ordering::Relaxed);
if let (Some(message_bus), Some(cancel_fn)) = (&self.message_bus, &self.cancel_fn) {
let message_bus = message_bus.clone();
let id = self.request_id.or(self.order_id);
let context = self.context.clone();
if let Ok(message) = cancel_fn(context.server_version, id, Some(&context)) {
tokio::spawn(async move {
if let Err(e) = message_bus.send_message(message).await {
warn!("error sending cancel message in drop: {e}");
}
});
}
}
}
}
#[must_use = "streams are lazy and do nothing unless polled"]
pub struct FilterDataStream<S> {
inner: S,
}
impl<S, T> Stream for FilterDataStream<S>
where
S: Stream<Item = Result<SubscriptionItem<T>, Error>> + Unpin,
{
type Item = Result<T, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(item)) => {
if let Some(out) = filter_notice(item) {
return Poll::Ready(Some(out));
}
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
}
}
}
}
pub trait SubscriptionItemStreamExt: Stream + Sized {
fn filter_data<T>(self) -> FilterDataStream<Self>
where
Self: Stream<Item = Result<SubscriptionItem<T>, Error>>,
{
FilterDataStream { inner: self }
}
}
impl<S: Stream + Sized> SubscriptionItemStreamExt for S {}
#[cfg(all(test, feature = "async"))]
#[path = "async_tests.rs"]
mod tests;