use crate::config::{Config, HashFor, RpcConfigFor};
use crate::error::BackendError;
use futures::{FutureExt, Stream, StreamExt, TryStreamExt};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use subxt_rpcs::methods::chain_head::{ChainHeadRpcMethods, FollowEvent};
pub struct FollowStream<Hash> {
stream_getter: FollowEventStreamGetter<Hash>,
stream: InnerStreamState<Hash>,
}
impl<Hash> std::fmt::Debug for FollowStream<Hash> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FollowStream")
.field("stream_getter", &"..")
.field("stream", &self.stream)
.finish()
}
}
pub type FollowEventStreamGetter<Hash> = Box<dyn FnMut() -> FollowEventStreamFut<Hash> + Send>;
pub type FollowEventStreamFut<Hash> = Pin<
Box<
dyn Future<Output = Result<(FollowEventStream<Hash>, String), BackendError>>
+ Send
+ 'static,
>,
>;
pub type FollowEventStream<Hash> =
Pin<Box<dyn Stream<Item = Result<FollowEvent<Hash>, BackendError>> + Send + 'static>>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FollowStreamMsg<Hash> {
Ready(String),
Event(FollowEvent<Hash>),
}
impl<Hash> FollowStreamMsg<Hash> {
pub fn into_event(self) -> Option<FollowEvent<Hash>> {
match self {
FollowStreamMsg::Ready(_) => None,
FollowStreamMsg::Event(e) => Some(e),
}
}
}
enum InnerStreamState<Hash> {
New,
Initializing(FollowEventStreamFut<Hash>),
Ready(Option<(FollowEventStream<Hash>, String)>),
ReceivingEvents(FollowEventStream<Hash>),
Stopped,
Finished,
}
impl<Hash> std::fmt::Debug for InnerStreamState<Hash> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::New => write!(f, "New"),
Self::Initializing(_) => write!(f, "Initializing(..)"),
Self::Ready(_) => write!(f, "Ready(..)"),
Self::ReceivingEvents(_) => write!(f, "ReceivingEvents(..)"),
Self::Stopped => write!(f, "Stopped"),
Self::Finished => write!(f, "Finished"),
}
}
}
impl<Hash> FollowStream<Hash> {
pub fn new(stream_getter: FollowEventStreamGetter<Hash>) -> Self {
Self {
stream_getter,
stream: InnerStreamState::New,
}
}
pub fn from_methods<T: Config>(
methods: ChainHeadRpcMethods<RpcConfigFor<T>>,
) -> FollowStream<HashFor<T>> {
FollowStream {
stream_getter: Box::new(move || {
let methods = methods.clone();
Box::pin(async move {
let stream = methods.chainhead_v1_follow(true).await?;
let Some(sub_id) = stream.subscription_id().map(ToOwned::to_owned) else {
return Err(BackendError::other(
"Subscription ID expected for chainHead_follow response, but not given",
));
};
let stream = stream.map_err(|e| e.into());
let stream: FollowEventStream<HashFor<T>> = Box::pin(stream);
Ok((stream, sub_id))
})
}),
stream: InnerStreamState::New,
}
}
}
impl<Hash> std::marker::Unpin for FollowStream<Hash> {}
impl<Hash> Stream for FollowStream<Hash> {
type Item = Result<FollowStreamMsg<Hash>, BackendError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
match &mut this.stream {
InnerStreamState::New => {
let fut = (this.stream_getter)();
this.stream = InnerStreamState::Initializing(fut);
continue;
}
InnerStreamState::Initializing(fut) => {
match fut.poll_unpin(cx) {
Poll::Pending => {
return Poll::Pending;
}
Poll::Ready(Ok(sub_with_id)) => {
this.stream = InnerStreamState::Ready(Some(sub_with_id));
continue;
}
Poll::Ready(Err(e)) => {
if e.is_disconnected_will_reconnect() {
this.stream = InnerStreamState::Stopped;
continue;
}
this.stream = InnerStreamState::Finished;
return Poll::Ready(Some(Err(e)));
}
}
}
InnerStreamState::Ready(stream) => {
let (sub, sub_id) = stream.take().expect("should always be Some");
this.stream = InnerStreamState::ReceivingEvents(sub);
return Poll::Ready(Some(Ok(FollowStreamMsg::Ready(sub_id))));
}
InnerStreamState::ReceivingEvents(stream) => {
match stream.poll_next_unpin(cx) {
Poll::Pending => {
return Poll::Pending;
}
Poll::Ready(None) => {
this.stream = InnerStreamState::Stopped;
continue;
}
Poll::Ready(Some(Ok(ev))) => {
if let FollowEvent::Stop = ev {
this.stream = InnerStreamState::Stopped;
continue;
}
return Poll::Ready(Some(Ok(FollowStreamMsg::Event(ev))));
}
Poll::Ready(Some(Err(e))) => {
if e.is_disconnected_will_reconnect() {
this.stream = InnerStreamState::Stopped;
continue;
}
this.stream = InnerStreamState::Finished;
return Poll::Ready(Some(Err(e)));
}
}
}
InnerStreamState::Stopped => {
this.stream = InnerStreamState::New;
return Poll::Ready(Some(Ok(FollowStreamMsg::Event(FollowEvent::Stop))));
}
InnerStreamState::Finished => {
return Poll::Ready(None);
}
}
}
}
}
#[cfg(test)]
pub(super) mod test_utils {
use super::*;
use crate::config::substrate::H256;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use subxt_rpcs::methods::chain_head::{BestBlockChanged, Finalized, Initialized, NewBlock};
pub fn test_stream_getter<Hash, F, I>(events: F) -> FollowEventStreamGetter<Hash>
where
Hash: Send + 'static,
F: Fn() -> I + Send + 'static,
I: IntoIterator<Item = Result<FollowEvent<Hash>, BackendError>>,
{
let start_idx = Arc::new(AtomicUsize::new(0));
Box::new(move || {
let start_idx = start_idx.clone();
let this_idx = start_idx.load(Ordering::Relaxed);
let events: Vec<_> = events().into_iter().skip(this_idx).collect();
Box::pin(async move {
let stream = futures::stream::iter(events).map(move |ev| {
start_idx.fetch_add(1, Ordering::Relaxed);
ev
});
let stream: FollowEventStream<Hash> = Box::pin(stream);
Ok((stream, format!("sub_id_{this_idx}")))
})
})
}
pub fn ev_initialized(n: u64) -> FollowEvent<H256> {
FollowEvent::Initialized(Initialized {
finalized_block_hashes: vec![H256::from_low_u64_le(n)],
finalized_block_runtime: None,
})
}
pub fn ev_new_block(parent_n: u64, n: u64) -> FollowEvent<H256> {
FollowEvent::NewBlock(NewBlock {
parent_block_hash: H256::from_low_u64_le(parent_n),
block_hash: H256::from_low_u64_le(n),
new_runtime: None,
})
}
pub fn ev_best_block(n: u64) -> FollowEvent<H256> {
FollowEvent::BestBlockChanged(BestBlockChanged {
best_block_hash: H256::from_low_u64_le(n),
})
}
pub fn ev_finalized(
finalized_ns: impl IntoIterator<Item = u64>,
pruned_ns: impl IntoIterator<Item = u64>,
) -> FollowEvent<H256> {
FollowEvent::Finalized(Finalized {
finalized_block_hashes: finalized_ns
.into_iter()
.map(H256::from_low_u64_le)
.collect(),
pruned_block_hashes: pruned_ns.into_iter().map(H256::from_low_u64_le).collect(),
})
}
}
#[cfg(test)]
pub mod test {
use super::*;
use test_utils::{ev_initialized, ev_new_block, test_stream_getter};
#[tokio::test]
async fn follow_stream_provides_messages_until_error() {
let stream_getter = test_stream_getter(|| {
[
Ok(ev_initialized(1)),
Ok(FollowEvent::Stop),
Ok(FollowEvent::Stop),
Ok(ev_new_block(1, 2)),
Err(BackendError::other("ended")),
Ok(ev_new_block(2, 3)),
]
});
let s = FollowStream::new(stream_getter);
let out: Vec<_> = s.filter_map(async |e| e.ok()).collect().await;
assert_eq!(
out,
vec![
FollowStreamMsg::Ready("sub_id_0".to_owned()),
FollowStreamMsg::Event(ev_initialized(1)),
FollowStreamMsg::Event(FollowEvent::Stop),
FollowStreamMsg::Ready("sub_id_2".to_owned()),
FollowStreamMsg::Event(FollowEvent::Stop),
FollowStreamMsg::Ready("sub_id_3".to_owned()),
FollowStreamMsg::Event(ev_new_block(1, 2)),
]
);
}
}