use crate::{p2p::MultiaddrWithPeerId, RepoEvent};
use cid::Cid;
use core::fmt::Debug;
use core::hash::Hash;
use core::pin::Pin;
use futures::channel::mpsc::Sender;
use futures::future::Future;
use libp2p::kad::QueryId;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::fmt;
use std::mem;
use std::sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Mutex,
};
use std::task::{Context, Poll, Waker};
static GLOBAL_REQ_COUNT: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub enum RequestKind {
Connect(MultiaddrWithPeerId),
GetBlock(Cid),
KadQuery(QueryId),
#[cfg(test)]
Num(u32),
}
impl From<MultiaddrWithPeerId> for RequestKind {
fn from(addr: MultiaddrWithPeerId) -> Self {
Self::Connect(addr)
}
}
impl From<Cid> for RequestKind {
fn from(cid: Cid) -> Self {
Self::GetBlock(cid)
}
}
impl From<QueryId> for RequestKind {
fn from(id: QueryId) -> Self {
Self::KadQuery(id)
}
}
impl fmt::Display for RequestKind {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Connect(tgt) => write!(fmt, "Connect to {:?}", tgt),
Self::GetBlock(cid) => write!(fmt, "Obtain block {}", cid),
Self::KadQuery(id) => write!(fmt, "Kad request {:?}", id),
#[cfg(test)]
Self::Num(n) => write!(fmt, "A test request for {}", n),
}
}
}
type SubscriptionId = u64;
pub type Subscriptions<T, E> = HashMap<RequestKind, HashMap<SubscriptionId, Subscription<T, E>>>;
pub struct SubscriptionRegistry<T: Debug + Clone + PartialEq, E: Debug + Clone> {
pub(crate) subscriptions: Arc<Mutex<Subscriptions<T, E>>>,
shutting_down: AtomicBool,
}
impl<T: Debug + Clone + PartialEq, E: Debug + Clone> fmt::Debug for SubscriptionRegistry<T, E> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(
fmt,
"{}(subscriptions: {:?})",
std::any::type_name::<Self>(),
self.subscriptions
)
}
}
impl<T: Debug + Clone + PartialEq, E: Debug + Clone> SubscriptionRegistry<T, E> {
pub fn create_subscription(
&self,
kind: RequestKind,
cancel_notifier: Option<Sender<RepoEvent>>,
) -> SubscriptionFuture<T, E> {
let id = GLOBAL_REQ_COUNT.fetch_add(1, Ordering::SeqCst);
debug!("Creating subscription {} to {}", id, kind);
let mut subscription = Subscription::new(cancel_notifier);
if self.shutting_down.load(Ordering::SeqCst) {
subscription.cancel(id, kind.clone(), true);
}
self.subscriptions
.lock()
.unwrap()
.entry(kind.clone())
.or_default()
.insert(id, subscription);
SubscriptionFuture {
id,
kind,
subscriptions: Arc::clone(&self.subscriptions),
cleanup_complete: false,
}
}
pub fn finish_subscription(&self, req_kind: RequestKind, result: Result<T, E>) {
let mut subscriptions = self.subscriptions.lock().unwrap();
let related_subs = subscriptions.get_mut(&req_kind);
if let Some(related_subs) = related_subs {
debug!("Finishing the subscription to {}", req_kind);
let mut awoken = 0;
for sub in related_subs.values_mut() {
if let Subscription::Pending { .. } = sub {
sub.wake(result.clone());
awoken += 1;
}
}
debug_assert!(
awoken != 0,
"no subscriptions to be awoken! subs: {:?}; req_kind: {:?}",
subscriptions,
req_kind
);
trace!("Woke {} related subscription(s)", awoken);
}
}
pub fn shutdown(&self) {
if self.shutting_down.swap(true, Ordering::SeqCst) {
return;
}
trace!("Shutting down {:?}", self);
let mut cancelled = 0;
let mut subscriptions = mem::take(&mut *self.subscriptions.lock().unwrap());
for (kind, subs) in subscriptions.iter_mut() {
for (id, sub) in subs.iter_mut() {
sub.cancel(*id, kind.clone(), true);
cancelled += 1;
}
}
if cancelled > 0 {
trace!("Cancelled {} subscriptions", cancelled);
}
}
}
impl<T: Debug + Clone + PartialEq, E: Debug + Clone> Default for SubscriptionRegistry<T, E> {
fn default() -> Self {
Self {
subscriptions: Default::default(),
shutting_down: Default::default(),
}
}
}
impl<T: Debug + Clone + PartialEq, E: Debug + Clone> Drop for SubscriptionRegistry<T, E> {
fn drop(&mut self) {
self.shutdown();
}
}
#[derive(Debug, PartialEq)]
pub enum SubscriptionErr<E: Debug + PartialEq> {
Cancelled,
Failed(E),
}
impl<E: Debug + PartialEq> fmt::Display for SubscriptionErr<E> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{:?}", self)
}
}
impl<E: Debug + PartialEq> std::error::Error for SubscriptionErr<E> {}
pub enum Subscription<T, E> {
Ready(Result<T, E>),
Pending {
waker: Option<Waker>,
cancel_notifier: Option<Sender<RepoEvent>>,
},
Cancelled,
}
impl<T: Clone, E: Clone> Clone for Subscription<T, E> {
fn clone(&self) -> Self {
match self {
Self::Ready(res) => Self::Ready(res.clone()),
Self::Pending {
waker,
cancel_notifier,
} => Self::Pending {
waker: waker.clone(),
cancel_notifier: cancel_notifier.clone(),
},
Self::Cancelled => Self::Cancelled,
}
}
}
impl<T, E> fmt::Debug for Subscription<T, E> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
use Subscription::*;
match self {
Ready(_) => write!(fmt, "Ready"),
Pending {
waker: Some(_),
cancel_notifier: Some(_),
} => write!(fmt, "Pending {{ waker: Some, cancel_notifier: Some }}"),
Pending {
waker: None,
cancel_notifier: Some(_),
} => write!(fmt, "Pending {{ waker: None, cancel_notifier: Some }}"),
Pending {
waker: Some(_),
cancel_notifier: None,
} => write!(fmt, "Pending {{ waker: Some, cancel_notifier: None }}"),
Pending {
waker: None,
cancel_notifier: None,
} => write!(fmt, "Pendnig {{ waker: None, cancel_notifier: None }}"),
Cancelled => write!(fmt, "Cancelled"),
}
}
}
impl<T, E> Subscription<T, E> {
fn new(cancel_notifier: Option<Sender<RepoEvent>>) -> Self {
Self::Pending {
waker: Default::default(),
cancel_notifier,
}
}
fn wake(&mut self, result: Result<T, E>) {
let former_self = mem::replace(self, Subscription::Ready(result));
if let Subscription::Pending { waker, .. } = former_self {
if let Some(waker) = waker {
waker.wake();
}
}
}
fn cancel(&mut self, id: SubscriptionId, kind: RequestKind, is_last: bool) {
trace!("Cancelling subscription {} to {}", id, kind);
let former_self = mem::replace(self, Subscription::Cancelled);
if let Subscription::Pending {
waker,
cancel_notifier,
} = former_self
{
if is_last {
if let Some(mut sender) = cancel_notifier {
trace!("Last related subscription cancelled, sending a cancel notification");
let _ = sender.try_send(RepoEvent::try_from(kind).unwrap());
}
}
if let Some(waker) = waker {
waker.wake();
}
}
}
}
pub struct SubscriptionFuture<T: Debug + PartialEq, E: Debug> {
id: u64,
kind: RequestKind,
subscriptions: Arc<Mutex<Subscriptions<T, E>>>,
cleanup_complete: bool,
}
impl<T: Debug + PartialEq, E: Debug + PartialEq> Future for SubscriptionFuture<T, E> {
type Output = Result<T, SubscriptionErr<E>>;
fn poll(mut self: Pin<&mut Self>, context: &mut Context) -> Poll<Self::Output> {
use std::collections::hash_map::Entry::*;
let mut subscriptions = self.subscriptions.lock().unwrap();
if let Some(related_subs) = subscriptions.get_mut(&self.kind) {
let (became_empty, ret) = match related_subs.entry(self.id) {
Vacant(_) => return Poll::Ready(Err(SubscriptionErr::Cancelled)),
Occupied(mut oe) => {
let unwrapped = match oe.get_mut() {
Subscription::Pending { ref mut waker, .. } => {
*waker = Some(context.waker().clone());
return Poll::Pending;
}
Subscription::Cancelled => {
oe.remove();
Err(SubscriptionErr::Cancelled)
}
_ => match oe.remove() {
Subscription::Ready(result) => result.map_err(SubscriptionErr::Failed),
_ => unreachable!("already matched"),
},
};
(related_subs.is_empty(), unwrapped)
}
};
if became_empty {
subscriptions.remove(&self.kind);
}
drop(subscriptions);
self.cleanup_complete = became_empty;
Poll::Ready(ret)
} else {
Poll::Ready(Err(SubscriptionErr::Cancelled))
}
}
}
impl<T: Debug + PartialEq, E: Debug> Drop for SubscriptionFuture<T, E> {
fn drop(&mut self) {
trace!("Dropping subscription future {} to {}", self.id, self.kind);
if self.cleanup_complete {
return;
}
let (sub, is_last) = {
let mut subscriptions = self.subscriptions.lock().unwrap();
if let Some(subs) = subscriptions.get_mut(&self.kind) {
let sub = subs.remove(&self.id);
let is_last = subs.is_empty();
if is_last {
subscriptions.remove(&self.kind);
}
(sub, is_last)
} else {
(None, false)
}
};
if let Some(sub) = sub {
if let mut sub @ Subscription::Pending { .. } = sub {
sub.cancel(self.id, self.kind.clone(), is_last);
}
}
}
}
impl<T: Debug + PartialEq, E: Debug> fmt::Debug for SubscriptionFuture<T, E> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(
fmt,
"SubscriptionFuture<Output = Result<{}, {}>>",
std::any::type_name::<T>(),
std::any::type_name::<E>(),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
impl From<u32> for RequestKind {
fn from(n: u32) -> Self {
Self::Num(n)
}
}
#[tokio::test(max_threads = 1)]
async fn subscription_basics() {
let registry = SubscriptionRegistry::<u32, ()>::default();
let s1 = registry.create_subscription(0.into(), None);
let s2 = registry.create_subscription(0.into(), None);
let s3 = registry.create_subscription(0.into(), None);
registry.finish_subscription(0.into(), Ok(10));
assert_eq!(s1.await.unwrap(), 10);
assert_eq!(s2.await.unwrap(), 10);
assert_eq!(s3.await.unwrap(), 10);
}
#[tokio::test(max_threads = 1)]
async fn subscription_cancelled_on_dropping_registry() {
let registry = SubscriptionRegistry::<u32, ()>::default();
let s1 = registry.create_subscription(0.into(), None);
drop(registry);
assert_eq!(s1.await, Err(SubscriptionErr::Cancelled));
}
#[tokio::test(max_threads = 1)]
async fn subscription_cancelled_on_shutdown() {
let registry = SubscriptionRegistry::<u32, ()>::default();
let s1 = registry.create_subscription(0.into(), None);
registry.shutdown();
assert_eq!(s1.await, Err(SubscriptionErr::Cancelled));
}
#[tokio::test(max_threads = 1)]
async fn new_subscriptions_cancelled_after_shutdown() {
let registry = SubscriptionRegistry::<u32, ()>::default();
registry.shutdown();
let s1 = registry.create_subscription(0.into(), None);
assert_eq!(s1.await, Err(SubscriptionErr::Cancelled));
}
#[tokio::test(max_threads = 1)]
async fn dropping_subscription_future_after_registering() {
use std::time::Duration;
use tokio::time::timeout;
let registry = SubscriptionRegistry::<u32, ()>::default();
let s1 = timeout(
Duration::from_millis(1),
registry.create_subscription(0.into(), None),
);
let s2 = registry.create_subscription(0.into(), None);
s1.await.unwrap_err();
registry.finish_subscription(0.into(), Ok(0));
assert_eq!(s2.await.unwrap(), 0);
}
#[tokio::test(max_threads = 1)]
#[ignore]
async fn subscription_stress_test() {
use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};
use std::time::Duration;
use tokio::{task, time::delay_for};
tracing_subscriber::fmt::init();
const KIND_COUNT: u32 = 100;
const KIND_SUB_COUNT: u32 = 10;
const CREATE_WAIT_TIME: u64 = 50;
const FINISH_WAIT_TIME: u64 = 100;
const CANCEL_WAIT_TIME: u64 = 200;
let reg = Arc::new(SubscriptionRegistry::<u32, ()>::default());
let subs = Arc::new(Mutex::new(Vec::with_capacity(1024)));
let mut rng = StdRng::from_entropy();
let reg_clone = Arc::clone(®);
let subs_clone = Arc::clone(&subs);
let mut rng_clone = rng.clone();
let create_task = task::spawn(async move {
let (mut kind, mut count);
loop {
delay_for(Duration::from_millis(CREATE_WAIT_TIME)).await;
kind = rng_clone.gen_range(0, KIND_COUNT);
count = rng_clone.gen_range(0, KIND_SUB_COUNT);
if count > 0 {
let mut subs = subs_clone.lock().unwrap();
for _ in 0..count {
subs.push(task::spawn(
reg_clone.create_subscription(kind.into(), None),
));
}
}
}
});
let reg_clone = Arc::clone(®);
let mut rng_clone = rng.clone();
let finish_task = task::spawn(async move {
let (mut kinds, mut count);
loop {
delay_for(Duration::from_millis(FINISH_WAIT_TIME)).await;
kinds = reg_clone
.subscriptions
.lock()
.unwrap()
.keys()
.cloned()
.collect::<Vec<_>>();
count = rng_clone.gen_range(0, kinds.len());
for kind in kinds.choose_multiple(&mut rng_clone, count) {
reg.finish_subscription(kind.to_owned(), Ok(0));
}
}
});
let cancel_task = task::spawn(async move {
let (mut count, mut idx);
loop {
delay_for(Duration::from_millis(CANCEL_WAIT_TIME)).await;
let subs_unlocked = &mut *subs.lock().unwrap();
count = rng.gen_range(0, subs_unlocked.len());
for _ in 0..count {
idx = rng.gen_range(0, subs_unlocked.len());
subs_unlocked.remove(idx);
}
}
});
let _ = futures::join!(create_task, finish_task, cancel_task);
}
}