use super::follow_stream_unpin::{BlockRef, FollowStreamMsg, FollowStreamUnpin};
use crate::config::Hash;
use crate::error::{BackendError, RpcError};
use futures::stream::{Stream, StreamExt};
use std::collections::{HashMap, HashSet, VecDeque};
use std::ops::DerefMut;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use subxt_rpcs::methods::chain_head::{FollowEvent, Initialized, RuntimeEvent};
#[derive(Debug)]
pub struct FollowStreamDriver<H: Hash> {
inner: FollowStreamUnpin<H>,
shared: Shared<H>,
}
impl<H: Hash> FollowStreamDriver<H> {
pub fn new(follow_unpin: FollowStreamUnpin<H>) -> Self {
Self {
inner: follow_unpin,
shared: Shared::default(),
}
}
pub fn handle(&self) -> FollowStreamDriverHandle<H> {
FollowStreamDriverHandle {
shared: self.shared.clone(),
}
}
}
impl<H: Hash> Stream for FollowStreamDriver<H> {
type Item = Result<(), BackendError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.inner.poll_next_unpin(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => {
self.shared.done();
Poll::Ready(None)
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Ok(item))) => {
self.shared.push_item(item);
Poll::Ready(Some(Ok(())))
}
}
}
}
#[derive(Debug, Clone)]
pub struct FollowStreamDriverHandle<H: Hash> {
shared: Shared<H>,
}
impl<H: Hash> FollowStreamDriverHandle<H> {
pub fn subscribe(&self) -> FollowStreamDriverSubscription<H> {
self.shared.subscribe()
}
}
#[derive(Debug)]
pub struct FollowStreamDriverSubscription<H: Hash> {
id: usize,
done: bool,
shared: Shared<H>,
local_items: VecDeque<FollowStreamMsg<BlockRef<H>>>,
}
impl<H: Hash> Stream for FollowStreamDriverSubscription<H> {
type Item = FollowStreamMsg<BlockRef<H>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.done {
return Poll::Ready(None);
}
loop {
if let Some(item) = self.local_items.pop_front() {
return Poll::Ready(Some(item));
}
let items = self.shared.take_items_and_save_waker(self.id, cx.waker());
let Some(items) = items else {
self.done = true;
return Poll::Ready(None);
};
if items.is_empty() {
return Poll::Pending;
} else {
self.local_items = items;
}
}
}
}
impl<H: Hash> FollowStreamDriverSubscription<H> {
pub async fn subscription_id(self) -> Option<String> {
let ready_event = self
.skip_while(|ev| std::future::ready(!matches!(ev, FollowStreamMsg::Ready(_))))
.next()
.await?;
match ready_event {
FollowStreamMsg::Ready(sub_id) => Some(sub_id),
_ => None,
}
}
pub fn events(self) -> impl Stream<Item = FollowEvent<BlockRef<H>>> + Send + Sync {
self.filter_map(|ev| std::future::ready(ev.into_event()))
}
}
impl<H: Hash> Clone for FollowStreamDriverSubscription<H> {
fn clone(&self) -> Self {
self.shared.subscribe()
}
}
impl<H: Hash> Drop for FollowStreamDriverSubscription<H> {
fn drop(&mut self) {
self.shared.remove_sub(self.id);
}
}
#[derive(Debug, Clone)]
struct Shared<H: Hash>(Arc<Mutex<SharedState<H>>>);
#[derive(Debug)]
struct SharedState<H: Hash> {
done: bool,
next_id: usize,
subscribers: HashMap<usize, SubscriberDetails<H>>,
block_events_for_new_subscriptions: VecDeque<FollowEvent<BlockRef<H>>>,
current_subscription_id: Option<String>,
current_init_message: Option<Initialized<BlockRef<H>>>,
seen_runtime_events: HashMap<H, RuntimeEvent>,
}
impl<H: Hash> Default for Shared<H> {
fn default() -> Self {
Shared(Arc::new(Mutex::new(SharedState {
next_id: 1,
done: false,
subscribers: HashMap::new(),
current_init_message: None,
current_subscription_id: None,
seen_runtime_events: HashMap::new(),
block_events_for_new_subscriptions: VecDeque::new(),
})))
}
}
impl<H: Hash> Shared<H> {
pub fn done(&self) {
let mut shared = self.0.lock().unwrap();
shared.done = true;
for details in shared.subscribers.values_mut() {
if let Some(waker) = details.waker.take() {
waker.wake();
}
}
}
pub fn remove_sub(&self, sub_id: usize) {
let mut shared = self.0.lock().unwrap();
shared.subscribers.remove(&sub_id);
}
pub fn take_items_and_save_waker(
&self,
sub_id: usize,
waker: &Waker,
) -> Option<VecDeque<FollowStreamMsg<BlockRef<H>>>> {
let mut shared = self.0.lock().unwrap();
let is_done = shared.done;
let details = shared.subscribers.get_mut(&sub_id)?;
if details.items.is_empty() && is_done {
return None;
}
let items = std::mem::take(&mut details.items);
if !is_done {
details.waker = Some(waker.clone());
}
Some(items)
}
pub fn push_item(&self, item: FollowStreamMsg<BlockRef<H>>) {
let mut shared = self.0.lock().unwrap();
let shared = shared.deref_mut();
for details in shared.subscribers.values_mut() {
details.items.push_back(item.clone());
if let Some(waker) = details.waker.take() {
waker.wake();
}
}
match item {
FollowStreamMsg::Ready(sub_id) => {
shared.current_subscription_id = Some(sub_id);
}
FollowStreamMsg::Event(FollowEvent::Initialized(ev)) => {
shared.current_init_message = Some(ev.clone());
shared.block_events_for_new_subscriptions.clear();
}
FollowStreamMsg::Event(FollowEvent::Finalized(finalized_ev)) => {
if let Some(init_message) = &mut shared.current_init_message {
let newest_runtime = finalized_ev
.finalized_block_hashes
.iter()
.rev()
.filter_map(|h| shared.seen_runtime_events.get(&h.hash()).cloned())
.next();
shared.seen_runtime_events.clear();
init_message
.finalized_block_hashes
.clone_from(&finalized_ev.finalized_block_hashes);
if let Some(runtime_ev) = newest_runtime {
init_message.finalized_block_runtime = Some(runtime_ev);
}
}
let to_remove: HashSet<H> = finalized_ev
.finalized_block_hashes
.iter()
.chain(finalized_ev.pruned_block_hashes.iter())
.map(|h| h.hash())
.collect();
shared
.block_events_for_new_subscriptions
.retain(|ev| match ev {
FollowEvent::NewBlock(new_block_ev) => {
!to_remove.contains(&new_block_ev.block_hash.hash())
}
FollowEvent::BestBlockChanged(best_block_ev) => {
!to_remove.contains(&best_block_ev.best_block_hash.hash())
}
_ => true,
});
}
FollowStreamMsg::Event(FollowEvent::NewBlock(new_block_ev)) => {
if let Some(runtime_event) = &new_block_ev.new_runtime {
shared
.seen_runtime_events
.insert(new_block_ev.block_hash.hash(), runtime_event.clone());
}
shared
.block_events_for_new_subscriptions
.push_back(FollowEvent::NewBlock(new_block_ev));
}
FollowStreamMsg::Event(ev @ FollowEvent::BestBlockChanged(_)) => {
shared.block_events_for_new_subscriptions.push_back(ev);
}
FollowStreamMsg::Event(FollowEvent::Stop) => {
shared.block_events_for_new_subscriptions.clear();
shared.current_subscription_id = None;
shared.current_init_message = None;
}
_ => {
}
}
}
pub fn subscribe(&self) -> FollowStreamDriverSubscription<H> {
let mut shared = self.0.lock().unwrap();
let id = shared.next_id;
shared.next_id += 1;
shared.subscribers.insert(
id,
SubscriberDetails {
items: VecDeque::new(),
waker: None,
},
);
let mut local_items = VecDeque::new();
if let Some(sub_id) = &shared.current_subscription_id {
local_items.push_back(FollowStreamMsg::Ready(sub_id.clone()));
}
if let Some(init_msg) = &shared.current_init_message {
local_items.push_back(FollowStreamMsg::Event(FollowEvent::Initialized(
init_msg.clone(),
)));
}
for ev in &shared.block_events_for_new_subscriptions {
local_items.push_back(FollowStreamMsg::Event(ev.clone()));
}
drop(shared);
FollowStreamDriverSubscription {
id,
done: false,
shared: self.clone(),
local_items,
}
}
}
#[derive(Debug)]
struct SubscriberDetails<H: Hash> {
items: VecDeque<FollowStreamMsg<BlockRef<H>>>,
waker: Option<Waker>,
}
#[derive(Debug)]
pub struct FollowStreamFinalizedHeads<H: Hash, F> {
stream: FollowStreamDriverSubscription<H>,
sub_id: Option<String>,
last_seen_block: Option<BlockRef<H>>,
f: F,
is_done: bool,
}
impl<H: Hash, F> Unpin for FollowStreamFinalizedHeads<H, F> {}
impl<H, F> FollowStreamFinalizedHeads<H, F>
where
H: Hash,
F: Fn(FollowEvent<BlockRef<H>>) -> Vec<BlockRef<H>>,
{
pub fn new(stream: FollowStreamDriverSubscription<H>, f: F) -> Self {
Self {
stream,
sub_id: None,
last_seen_block: None,
f,
is_done: false,
}
}
}
impl<H, F> Stream for FollowStreamFinalizedHeads<H, F>
where
H: Hash,
F: Fn(FollowEvent<BlockRef<H>>) -> Vec<BlockRef<H>>,
{
type Item = Result<(String, Vec<BlockRef<H>>), BackendError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.is_done {
return Poll::Ready(None);
}
loop {
let Some(ev) = futures::ready!(self.stream.poll_next_unpin(cx)) else {
self.is_done = true;
return Poll::Ready(None);
};
let block_refs = match ev {
FollowStreamMsg::Ready(sub_id) => {
self.sub_id = Some(sub_id);
continue;
}
FollowStreamMsg::Event(FollowEvent::Finalized(finalized)) => {
self.last_seen_block = finalized.finalized_block_hashes.last().cloned();
(self.f)(FollowEvent::Finalized(finalized))
}
FollowStreamMsg::Event(FollowEvent::Initialized(mut init)) => {
let prev = self.last_seen_block.take();
self.last_seen_block = init.finalized_block_hashes.last().cloned();
if let Some(p) = prev {
let Some(pos) = init
.finalized_block_hashes
.iter()
.position(|b| b.hash() == p.hash())
else {
return Poll::Ready(Some(Err(RpcError::ClientError(
subxt_rpcs::Error::DisconnectedWillReconnect(
"Missed at least one block when the connection was lost"
.to_owned(),
),
)
.into())));
};
init.finalized_block_hashes.drain(0..=pos);
}
(self.f)(FollowEvent::Initialized(init))
}
FollowStreamMsg::Event(ev) => (self.f)(ev),
};
if block_refs.is_empty() {
continue;
}
let sub_id = self
.sub_id
.clone()
.expect("Ready is always emitted before any other event");
return Poll::Ready(Some(Ok((sub_id, block_refs))));
}
}
}
#[cfg(test)]
mod test_utils {
use super::super::follow_stream_unpin::test_utils::test_unpin_stream_getter;
use super::*;
pub fn test_follow_stream_driver_getter<H, F, I>(
events: F,
max_life: usize,
) -> FollowStreamDriver<H>
where
H: Hash + 'static,
F: Fn() -> I + Send + 'static,
I: IntoIterator<Item = Result<FollowEvent<H>, BackendError>>,
{
let (stream, _) = test_unpin_stream_getter(events, max_life);
FollowStreamDriver::new(stream)
}
}
#[cfg(test)]
mod test {
use futures::TryStreamExt;
use primitive_types::H256;
use super::super::follow_stream::test_utils::{
ev_best_block, ev_finalized, ev_initialized, ev_new_block,
};
use super::super::follow_stream_unpin::test_utils::{
ev_best_block_ref, ev_finalized_ref, ev_initialized_ref, ev_new_block_ref,
};
use super::test_utils::test_follow_stream_driver_getter;
use super::*;
#[test]
fn follow_stream_driver_is_sendable() {
fn assert_send<T: Send + 'static>(_: T) {}
let stream_getter = test_follow_stream_driver_getter(|| [Ok(ev_initialized(1))], 10);
assert_send(stream_getter);
}
#[tokio::test]
async fn subscribers_all_receive_events_and_finish_gracefully_on_error() {
let mut driver = test_follow_stream_driver_getter(
|| {
[
Ok(ev_initialized(0)),
Ok(ev_new_block(0, 1)),
Ok(ev_best_block(1)),
Ok(ev_finalized([1], [])),
Err(BackendError::other("ended")),
]
},
10,
);
let handle = driver.handle();
let a = handle.subscribe();
let b = handle.subscribe();
let c = handle.subscribe();
tokio::spawn(async move { while driver.next().await.is_some() {} });
let a_vec: Vec<_> = a.collect().await;
let b_vec: Vec<_> = b.collect().await;
let c_vec: Vec<_> = c.collect().await;
let expected = vec![
FollowStreamMsg::Ready("sub_id_0".into()),
FollowStreamMsg::Event(ev_initialized_ref(0)),
FollowStreamMsg::Event(ev_new_block_ref(0, 1)),
FollowStreamMsg::Event(ev_best_block_ref(1)),
FollowStreamMsg::Event(ev_finalized_ref([1])),
];
assert_eq!(a_vec, expected);
assert_eq!(b_vec, expected);
assert_eq!(c_vec, expected);
}
#[tokio::test]
async fn subscribers_receive_block_events_from_last_finalised() {
let mut driver = test_follow_stream_driver_getter(
|| {
[
Ok(ev_initialized(0)),
Ok(ev_new_block(0, 1)),
Ok(ev_best_block(1)),
Ok(ev_finalized([1], [])),
Ok(ev_new_block(1, 2)),
Ok(ev_new_block(2, 3)),
Err(BackendError::other("ended")),
]
},
10,
);
let _r = driver.next().await.unwrap();
let _i0 = driver.next().await.unwrap();
let _n1 = driver.next().await.unwrap();
let _b1 = driver.next().await.unwrap();
let evs: Vec<_> = driver.handle().subscribe().take(4).collect().await;
let expected = vec![
FollowStreamMsg::Ready("sub_id_0".into()),
FollowStreamMsg::Event(ev_initialized_ref(0)),
FollowStreamMsg::Event(ev_new_block_ref(0, 1)),
FollowStreamMsg::Event(ev_best_block_ref(1)),
];
assert_eq!(evs, expected);
let _f1 = driver.next().await.unwrap();
let _n2 = driver.next().await.unwrap();
let _n3 = driver.next().await.unwrap();
let evs: Vec<_> = driver.handle().subscribe().take(4).collect().await;
let expected = vec![
FollowStreamMsg::Ready("sub_id_0".into()),
FollowStreamMsg::Event(ev_initialized_ref(1)),
FollowStreamMsg::Event(ev_new_block_ref(1, 2)),
FollowStreamMsg::Event(ev_new_block_ref(2, 3)),
];
assert_eq!(evs, expected);
}
#[tokio::test]
async fn subscribers_receive_new_blocks_before_subscribing() {
let mut driver = test_follow_stream_driver_getter(
|| {
[
Ok(ev_initialized(0)),
Ok(ev_new_block(0, 1)),
Ok(ev_best_block(1)),
Ok(ev_new_block(1, 2)),
Ok(ev_new_block(2, 3)),
Ok(ev_finalized([1], [])),
Err(BackendError::other("ended")),
]
},
10,
);
let _r = driver.next().await.unwrap();
let _i0 = driver.next().await.unwrap();
let _n1 = driver.next().await.unwrap();
let _b1 = driver.next().await.unwrap();
let _n2 = driver.next().await.unwrap();
let _n3 = driver.next().await.unwrap();
let _f1 = driver.next().await.unwrap();
let evs: Vec<_> = driver.handle().subscribe().take(4).collect().await;
let expected = vec![
FollowStreamMsg::Ready("sub_id_0".into()),
FollowStreamMsg::Event(ev_initialized_ref(1)),
FollowStreamMsg::Event(ev_new_block_ref(1, 2)),
FollowStreamMsg::Event(ev_new_block_ref(2, 3)),
];
assert_eq!(evs, expected);
}
#[tokio::test]
async fn subscribe_finalized_blocks_restart_works() {
let mut driver = test_follow_stream_driver_getter(
|| {
[
Ok(ev_initialized(0)),
Ok(ev_new_block(0, 1)),
Ok(ev_best_block(1)),
Ok(ev_finalized([1], [])),
Ok(FollowEvent::Stop),
Ok(ev_initialized(1)),
Ok(ev_finalized([2], [])),
Err(BackendError::other("ended")),
]
},
10,
);
let handle = driver.handle();
tokio::spawn(async move { while driver.next().await.is_some() {} });
let f = |ev| match ev {
FollowEvent::Finalized(ev) => ev.finalized_block_hashes,
FollowEvent::Initialized(ev) => ev.finalized_block_hashes,
_ => vec![],
};
let stream = FollowStreamFinalizedHeads::new(handle.subscribe(), f);
let evs: Vec<_> = stream.try_collect().await.unwrap();
let expected = vec![
(
"sub_id_0".to_string(),
vec![BlockRef::new(H256::from_low_u64_le(0))],
),
(
"sub_id_0".to_string(),
vec![BlockRef::new(H256::from_low_u64_le(1))],
),
(
"sub_id_5".to_string(),
vec![BlockRef::new(H256::from_low_u64_le(2))],
),
];
assert_eq!(evs, expected);
}
#[tokio::test]
async fn subscribe_finalized_blocks_restart_with_missed_blocks() {
let mut driver = test_follow_stream_driver_getter(
|| {
[
Ok(ev_initialized(0)),
Ok(FollowEvent::Stop),
Ok(ev_initialized(13)),
Ok(ev_finalized([14], [])),
Err(BackendError::other("ended")),
]
},
10,
);
let handle = driver.handle();
tokio::spawn(async move { while driver.next().await.is_some() {} });
let f = |ev| match ev {
FollowEvent::Finalized(ev) => ev.finalized_block_hashes,
FollowEvent::Initialized(ev) => ev.finalized_block_hashes,
_ => vec![],
};
let evs: Vec<_> = FollowStreamFinalizedHeads::new(handle.subscribe(), f)
.collect()
.await;
assert_eq!(
evs[0].as_ref().unwrap(),
&(
"sub_id_0".to_string(),
vec![BlockRef::new(H256::from_low_u64_le(0))]
)
);
assert!(
matches!(&evs[1], Err(BackendError::Rpc(RpcError::ClientError(subxt_rpcs::Error::DisconnectedWillReconnect(e)))) if e.contains("Missed at least one block when the connection was lost"))
);
assert_eq!(
evs[2].as_ref().unwrap(),
&(
"sub_id_2".to_string(),
vec![BlockRef::new(H256::from_low_u64_le(14))]
)
);
}
}