use std::collections::{HashMap, VecDeque};
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::time::Duration;
use parking_lot::RwLock;
use tokio::sync::mpsc;
use tokio::time::{sleep, Instant};
use super::subscriber::{ActiveSubscription, SubscriptionRequest};
use super::{Error, Event, Pubsub, Spec};
use crate::task::spawn;
const STREAM_CONNECTION_BACKOFF: Duration = Duration::from_millis(2_000);
const STREAM_CONNECTION_MAX_BACKOFF: Duration = Duration::from_millis(30_000);
const INTERNAL_POLL_SIZE: usize = 1_000;
const POLL_SLEEP: Duration = Duration::from_millis(2_000);
struct UniqueSubscription<S>
where
S: Spec,
{
name: S::SubscriptionId,
total_subscribers: usize,
}
type UniqueSubscriptions<S> = RwLock<HashMap<<S as Spec>::Topic, UniqueSubscription<S>>>;
type ActiveSubscriptions<S> =
RwLock<HashMap<Arc<<S as Spec>::SubscriptionId>, Vec<<S as Spec>::Topic>>>;
type CacheEvent<S> = HashMap<<<S as Spec>::Event as Event>::Topic, <S as Spec>::Event>;
#[allow(missing_debug_implementations)]
pub struct Consumer<T>
where
T: Transport + 'static,
{
transport: T,
inner_pubsub: Arc<Pubsub<T::Spec>>,
remote_subscriptions: UniqueSubscriptions<T::Spec>,
subscriptions: ActiveSubscriptions<T::Spec>,
stream_ctrl: RwLock<Option<mpsc::Sender<StreamCtrl<T::Spec>>>>,
still_running: AtomicBool,
prefer_polling: bool,
cached_events: Arc<RwLock<CacheEvent<T::Spec>>>,
}
#[allow(missing_debug_implementations)]
pub struct RemoteActiveConsumer<T>
where
T: Transport + 'static,
{
inner: ActiveSubscription<T::Spec>,
previous_messages: VecDeque<<T::Spec as Spec>::Event>,
consumer: Arc<Consumer<T>>,
}
impl<T> RemoteActiveConsumer<T>
where
T: Transport + 'static,
{
pub async fn recv(&mut self) -> Option<<T::Spec as Spec>::Event> {
if let Some(event) = self.previous_messages.pop_front() {
Some(event)
} else {
self.inner.recv().await
}
}
pub fn try_recv(&mut self) -> Option<<T::Spec as Spec>::Event> {
if let Some(event) = self.previous_messages.pop_front() {
Some(event)
} else {
self.inner.try_recv()
}
}
pub fn name(&self) -> &<T::Spec as Spec>::SubscriptionId {
self.inner.name()
}
}
impl<T> Drop for RemoteActiveConsumer<T>
where
T: Transport + 'static,
{
fn drop(&mut self) {
let _ = self.consumer.unsubscribe(self.name().clone());
}
}
#[allow(missing_debug_implementations)]
pub struct InternalRelay<S>
where
S: Spec + 'static,
{
inner: Arc<Pubsub<S>>,
cached_events: Arc<RwLock<CacheEvent<S>>>,
}
impl<S> InternalRelay<S>
where
S: Spec + 'static,
{
pub fn send<X>(&self, event: X)
where
X: Into<S::Event>,
{
let event = event.into();
let mut cached_events = self.cached_events.write();
for topic in event.get_topics() {
cached_events.insert(topic, event.clone());
}
self.inner.publish(event);
}
}
impl<T> Consumer<T>
where
T: Transport + 'static,
{
pub fn new(
transport: T,
prefer_polling: bool,
context: <T::Spec as Spec>::Context,
) -> Arc<Self> {
let this = Arc::new(Self {
transport,
prefer_polling,
inner_pubsub: Arc::new(Pubsub::new(T::Spec::new_instance(context))),
subscriptions: Default::default(),
remote_subscriptions: Default::default(),
stream_ctrl: RwLock::new(None),
cached_events: Default::default(),
still_running: true.into(),
});
spawn(Self::stream(this.clone()));
this
}
async fn stream(instance: Arc<Self>) {
let mut stream_supported = true;
let mut poll_supported = true;
let mut backoff = STREAM_CONNECTION_BACKOFF;
let mut retry_at = None;
loop {
if (!stream_supported && !poll_supported)
|| !instance
.still_running
.load(std::sync::atomic::Ordering::Relaxed)
{
break;
}
if instance.remote_subscriptions.read().is_empty() {
sleep(Duration::from_millis(100)).await;
continue;
}
if stream_supported
&& !instance.prefer_polling
&& retry_at
.map(|retry_at| retry_at < Instant::now())
.unwrap_or(true)
{
let (sender, receiver) = mpsc::channel(INTERNAL_POLL_SIZE);
{
*instance.stream_ctrl.write() = Some(sender);
}
let current_subscriptions = {
instance
.remote_subscriptions
.read()
.iter()
.map(|(key, name)| (name.name.clone(), key.clone()))
.collect::<Vec<_>>()
};
if let Err(err) = instance
.transport
.stream(
receiver,
current_subscriptions,
InternalRelay {
inner: instance.inner_pubsub.clone(),
cached_events: instance.cached_events.clone(),
},
)
.await
{
retry_at = Some(Instant::now() + backoff);
backoff =
(backoff + STREAM_CONNECTION_BACKOFF).min(STREAM_CONNECTION_MAX_BACKOFF);
if matches!(err, Error::NotSupported) {
stream_supported = false;
}
tracing::error!("Long connection failed with error {:?}", err);
} else {
backoff = STREAM_CONNECTION_BACKOFF;
}
let _ = instance.stream_ctrl.write().take();
}
if poll_supported {
let current_subscriptions = {
instance
.remote_subscriptions
.read()
.iter()
.map(|(key, name)| (name.name.clone(), key.clone()))
.collect::<Vec<_>>()
};
if let Err(err) = instance
.transport
.poll(
current_subscriptions,
InternalRelay {
inner: instance.inner_pubsub.clone(),
cached_events: instance.cached_events.clone(),
},
)
.await
{
if matches!(err, Error::NotSupported) {
poll_supported = false;
}
tracing::error!("Polling failed with error {:?}", err);
}
sleep(POLL_SLEEP).await;
}
}
}
fn unsubscribe(
self: &Arc<Self>,
subscription_name: <T::Spec as Spec>::SubscriptionId,
) -> Result<(), Error> {
let topics = self
.subscriptions
.write()
.remove(&subscription_name)
.ok_or(Error::NoSubscription)?;
let mut remote_subscriptions = self.remote_subscriptions.write();
for topic in topics {
let mut remote_subscription =
if let Some(remote_subscription) = remote_subscriptions.remove(&topic) {
remote_subscription
} else {
continue;
};
remote_subscription.total_subscribers =
remote_subscription.total_subscribers.saturating_sub(1);
if remote_subscription.total_subscribers == 0 {
let mut cached_events = self.cached_events.write();
cached_events.remove(&topic);
self.message_to_stream(StreamCtrl::Unsubscribe(remote_subscription.name.clone()))?;
} else {
remote_subscriptions.insert(topic, remote_subscription);
}
}
if remote_subscriptions.is_empty() {
self.message_to_stream(StreamCtrl::Stop)?;
}
Ok(())
}
#[inline(always)]
fn message_to_stream(&self, message: StreamCtrl<T::Spec>) -> Result<(), Error> {
let to_stream = self.stream_ctrl.read();
if let Some(to_stream) = to_stream.as_ref() {
Ok(to_stream.try_send(message)?)
} else {
Ok(())
}
}
pub fn subscribe<I>(self: &Arc<Self>, request: I) -> Result<RemoteActiveConsumer<T>, Error>
where
I: SubscriptionRequest<
Topic = <T::Spec as Spec>::Topic,
SubscriptionId = <T::Spec as Spec>::SubscriptionId,
>,
{
let subscription_name = request.subscription_name();
let topics = request.try_get_topics()?;
let mut remote_subscriptions = self.remote_subscriptions.write();
let mut subscriptions = self.subscriptions.write();
if subscriptions.get(&subscription_name).is_some() {
return Err(Error::NoSubscription);
}
let mut previous_messages = Vec::new();
let cached_events = self.cached_events.read();
for topic in topics.iter() {
if let Some(subscription) = remote_subscriptions.get_mut(topic) {
subscription.total_subscribers += 1;
if let Some(v) = cached_events.get(topic).cloned() {
previous_messages.push(v);
}
} else {
let internal_sub_name = self.transport.new_name();
remote_subscriptions.insert(
topic.clone(),
UniqueSubscription {
total_subscribers: 1,
name: internal_sub_name.clone(),
},
);
self.message_to_stream(StreamCtrl::Subscribe((internal_sub_name, topic.clone())))?;
}
}
subscriptions.insert(subscription_name, topics);
drop(subscriptions);
Ok(RemoteActiveConsumer {
inner: self.inner_pubsub.subscribe(request)?,
previous_messages: previous_messages.into(),
consumer: self.clone(),
})
}
}
impl<T> Drop for Consumer<T>
where
T: Transport + 'static,
{
fn drop(&mut self) {
self.still_running
.store(false, std::sync::atomic::Ordering::Release);
if let Some(to_stream) = self.stream_ctrl.read().as_ref() {
let _ = to_stream.try_send(StreamCtrl::Stop).inspect_err(|err| {
tracing::error!("Failed to send message LongPoll::Stop due to {err:?}")
});
}
}
}
pub type SubscribeMessage<S> = (<S as Spec>::SubscriptionId, <S as Spec>::Topic);
#[allow(missing_debug_implementations)]
pub enum StreamCtrl<S>
where
S: Spec + 'static,
{
Subscribe(SubscribeMessage<S>),
Unsubscribe(S::SubscriptionId),
Stop,
}
impl<S> Clone for StreamCtrl<S>
where
S: Spec + 'static,
{
fn clone(&self) -> Self {
match self {
Self::Subscribe(s) => Self::Subscribe(s.clone()),
Self::Unsubscribe(u) => Self::Unsubscribe(u.clone()),
Self::Stop => Self::Stop,
}
}
}
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
pub trait Transport: Send + Sync {
type Spec: Spec;
fn new_name(&self) -> <Self::Spec as Spec>::SubscriptionId;
async fn stream(
&self,
subscribe_changes: mpsc::Receiver<StreamCtrl<Self::Spec>>,
topics: Vec<SubscribeMessage<Self::Spec>>,
reply_to: InternalRelay<Self::Spec>,
) -> Result<(), Error>;
async fn poll(
&self,
topics: Vec<SubscribeMessage<Self::Spec>>,
reply_to: InternalRelay<Self::Spec>,
) -> Result<(), Error>;
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use tokio::time::{timeout, Duration};
use super::{
InternalRelay, RemoteActiveConsumer, StreamCtrl, SubscribeMessage, Transport,
INTERNAL_POLL_SIZE,
};
use crate::pub_sub::remote_consumer::Consumer;
use crate::pub_sub::test::{CustomPubSub, IndexTest, Message};
use crate::pub_sub::{Error, Spec, SubscriptionRequest};
#[derive(Clone, Debug)]
enum SubscriptionReq {
Foo(String, u64),
Bar(String, u64),
}
impl SubscriptionRequest for SubscriptionReq {
type Topic = IndexTest;
type SubscriptionId = String;
fn try_get_topics(&self) -> Result<Vec<Self::Topic>, Error> {
Ok(vec![match self {
SubscriptionReq::Foo(_, n) => IndexTest::Foo(*n),
SubscriptionReq::Bar(_, n) => IndexTest::Bar(*n),
}])
}
fn subscription_name(&self) -> Arc<Self::SubscriptionId> {
Arc::new(match self {
SubscriptionReq::Foo(n, _) => n.to_string(),
SubscriptionReq::Bar(n, _) => n.to_string(),
})
}
}
struct TestTransport {
name_ctr: AtomicUsize,
observe_ctrl_tx: mpsc::Sender<StreamCtrl<CustomPubSub>>,
support_long: bool,
support_poll: bool,
rx: Mutex<mpsc::Receiver<Message>>,
}
impl TestTransport {
fn new(
support_long: bool,
support_poll: bool,
) -> (
Self,
mpsc::Sender<Message>,
mpsc::Receiver<StreamCtrl<CustomPubSub>>,
) {
let (events_tx, rx) = mpsc::channel::<Message>(INTERNAL_POLL_SIZE);
let (observe_ctrl_tx, observe_ctrl_rx) =
mpsc::channel::<StreamCtrl<_>>(INTERNAL_POLL_SIZE);
let t = TestTransport {
name_ctr: AtomicUsize::new(1),
rx: Mutex::new(rx),
observe_ctrl_tx,
support_long,
support_poll,
};
(t, events_tx, observe_ctrl_rx)
}
}
#[async_trait::async_trait]
impl Transport for TestTransport {
type Spec = CustomPubSub;
fn new_name(&self) -> <Self::Spec as Spec>::SubscriptionId {
format!("sub-{}", self.name_ctr.fetch_add(1, Ordering::Relaxed))
}
async fn stream(
&self,
mut subscribe_changes: mpsc::Receiver<StreamCtrl<Self::Spec>>,
topics: Vec<SubscribeMessage<Self::Spec>>,
reply_to: InternalRelay<Self::Spec>,
) -> Result<(), Error> {
if !self.support_long {
return Err(Error::NotSupported);
}
let mut rx = self.rx.lock().await;
let observe = self.observe_ctrl_tx.clone();
for topic in topics {
observe.try_send(StreamCtrl::Subscribe(topic)).unwrap();
}
loop {
tokio::select! {
Some(ctrl) = subscribe_changes.recv() => {
observe.try_send(ctrl.clone()).unwrap();
if matches!(ctrl, StreamCtrl::Stop) {
break;
}
}
Some(msg) = rx.recv() => {
reply_to.send(msg);
}
}
}
Ok(())
}
async fn poll(
&self,
_topics: Vec<SubscribeMessage<Self::Spec>>,
reply_to: InternalRelay<Self::Spec>,
) -> Result<(), Error> {
if !self.support_poll {
return Err(Error::NotSupported);
}
let mut rx = self.rx.lock().await;
for _ in 0..32 {
match rx.try_recv() {
Ok(msg) => reply_to.send(msg),
Err(mpsc::error::TryRecvError::Empty) => continue,
Err(mpsc::error::TryRecvError::Disconnected) => break,
}
}
Ok(())
}
}
async fn recv_next<T: Transport>(
sub: &mut RemoteActiveConsumer<T>,
dur_ms: u64,
) -> Option<<T::Spec as Spec>::Event> {
timeout(Duration::from_millis(dur_ms), sub.recv())
.await
.ok()
.flatten()
}
async fn expect_ctrl(
rx: &mut mpsc::Receiver<StreamCtrl<CustomPubSub>>,
dur_ms: u64,
pred: impl Fn(&StreamCtrl<CustomPubSub>) -> bool,
) -> StreamCtrl<CustomPubSub> {
timeout(Duration::from_millis(dur_ms), async {
loop {
if let Some(msg) = rx.recv().await {
if pred(&msg) {
break msg;
}
}
}
})
.await
.expect("timed out waiting for control message")
}
#[tokio::test]
async fn stream_delivery_and_unsubscribe_on_drop() {
let (transport, events_tx, mut ctrl_rx) = TestTransport::new(true, true);
let consumer = Consumer::new(transport, false, ());
let mut sub = consumer
.subscribe(SubscriptionReq::Foo("t".to_owned(), 7))
.expect("subscribe ok");
let ctrl = expect_ctrl(
&mut ctrl_rx,
1000,
|m| matches!(m, StreamCtrl::Subscribe((_, idx)) if *idx == IndexTest::Foo(7)),
)
.await;
match ctrl {
StreamCtrl::Subscribe((name, idx)) => {
assert_ne!(name, "t".to_owned());
assert_eq!(idx, IndexTest::Foo(7));
}
_ => unreachable!(),
}
events_tx.send(Message { foo: 7, bar: 1 }).await.unwrap();
let got = recv_next::<TestTransport>(&mut sub, 1000)
.await
.expect("got event");
assert_eq!(got, Message { foo: 7, bar: 1 });
drop(sub);
let _ctrl = expect_ctrl(&mut ctrl_rx, 1000, |m| {
matches!(m, StreamCtrl::Unsubscribe(_))
})
.await;
drop(consumer);
let _ = expect_ctrl(&mut ctrl_rx, 1000, |m| matches!(m, StreamCtrl::Stop)).await;
}
#[tokio::test]
async fn test_cache_and_invalation() {
let (transport, events_tx, mut ctrl_rx) = TestTransport::new(true, true);
let consumer = Consumer::new(transport, false, ());
let mut sub_1 = consumer
.subscribe(SubscriptionReq::Foo("t".to_owned(), 7))
.expect("subscribe ok");
let ctrl = expect_ctrl(
&mut ctrl_rx,
1000,
|m| matches!(m, StreamCtrl::Subscribe((_, idx)) if *idx == IndexTest::Foo(7)),
)
.await;
match ctrl {
StreamCtrl::Subscribe((name, idx)) => {
assert_ne!(name, "t1".to_owned());
assert_eq!(idx, IndexTest::Foo(7));
}
_ => unreachable!(),
}
events_tx.send(Message { foo: 7, bar: 1 }).await.unwrap();
let got = recv_next::<TestTransport>(&mut sub_1, 1000)
.await
.expect("got event");
assert_eq!(got, Message { foo: 7, bar: 1 });
let mut sub_2 = consumer
.subscribe(SubscriptionReq::Foo("t2".to_owned(), 7))
.expect("subscribe ok");
let got = recv_next::<TestTransport>(&mut sub_2, 1000)
.await
.expect("got event");
assert_eq!(got, Message { foo: 7, bar: 1 });
drop(sub_1);
let mut sub_3 = consumer
.subscribe(SubscriptionReq::Foo("t3".to_owned(), 7))
.expect("subscribe ok");
let got = recv_next::<TestTransport>(&mut sub_3, 1000)
.await
.expect("got event");
assert_eq!(got, Message { foo: 7, bar: 1 });
events_tx.send(Message { foo: 7, bar: 2 }).await.unwrap();
let got = recv_next::<TestTransport>(&mut sub_2, 1000)
.await
.expect("got event");
assert_eq!(got, Message { foo: 7, bar: 2 });
let got = recv_next::<TestTransport>(&mut sub_3, 1000)
.await
.expect("got event");
assert_eq!(got, Message { foo: 7, bar: 2 });
drop(sub_2);
drop(sub_3);
let _ctrl = expect_ctrl(&mut ctrl_rx, 1000, |m| {
matches!(m, StreamCtrl::Unsubscribe(_))
})
.await;
let mut sub_4 = consumer
.subscribe(SubscriptionReq::Foo("t4".to_owned(), 7))
.expect("subscribe ok");
assert!(
recv_next::<TestTransport>(&mut sub_4, 1000).await.is_none(),
"Should have not receive any update"
);
drop(sub_4);
let _ = expect_ctrl(&mut ctrl_rx, 2000, |m| matches!(m, StreamCtrl::Stop)).await;
}
#[tokio::test]
async fn falls_back_to_poll_when_stream_not_supported() {
let (transport, events_tx, _) = TestTransport::new(false, true);
let consumer = Consumer::new(transport, true, ());
let mut sub = consumer
.subscribe(SubscriptionReq::Bar("t".to_owned(), 5))
.expect("subscribe ok");
events_tx.send(Message { foo: 9, bar: 5 }).await.unwrap();
let got = recv_next::<TestTransport>(&mut sub, 1500)
.await
.expect("event relayed via polling");
assert_eq!(got, Message { foo: 9, bar: 5 });
}
#[tokio::test]
async fn multiple_subscribers_share_single_remote_subscription() {
let (transport, events_tx, mut ctrl_rx) = TestTransport::new(true, true);
let consumer = Consumer::new(transport, false, ());
let mut a = consumer
.subscribe(SubscriptionReq::Foo("t".to_owned(), 1))
.expect("subscribe A");
let _ = expect_ctrl(
&mut ctrl_rx,
1000,
|m| matches!(m, StreamCtrl::Subscribe((_, idx)) if *idx == IndexTest::Foo(1)),
)
.await;
let mut b = consumer
.subscribe(SubscriptionReq::Foo("b".to_owned(), 1))
.expect("subscribe B");
if let Ok(Some(StreamCtrl::Subscribe((_, idx)))) =
timeout(Duration::from_millis(400), ctrl_rx.recv()).await
{
assert_ne!(idx, IndexTest::Foo(1), "should not resubscribe same topic");
}
events_tx.send(Message { foo: 1, bar: 42 }).await.unwrap();
let got_a = recv_next::<TestTransport>(&mut a, 1000)
.await
.expect("A got");
let got_b = recv_next::<TestTransport>(&mut b, 1000)
.await
.expect("B got");
assert_eq!(got_a, Message { foo: 1, bar: 42 });
assert_eq!(got_b, Message { foo: 1, bar: 42 });
drop(b);
if let Ok(Some(StreamCtrl::Unsubscribe(_))) =
timeout(Duration::from_millis(400), ctrl_rx.recv()).await
{
panic!("Should NOT unsubscribe while another local subscriber exists");
}
drop(a);
let _ = expect_ctrl(&mut ctrl_rx, 1000, |m| {
matches!(m, StreamCtrl::Unsubscribe(_))
})
.await;
let _ = expect_ctrl(&mut ctrl_rx, 1000, |m| matches!(m, StreamCtrl::Stop)).await;
}
}