use super::{StreamOf, StreamOfResults};
use crate::error::BackendError;
use futures::{future::BoxFuture, FutureExt, Stream, StreamExt};
use std::{future::Future, pin::Pin, task::Poll};
type ResubscribeGetter<T> = Box<dyn FnMut() -> ResubscribeFuture<T> + Send>;
type ResubscribeFuture<T> =
Pin<Box<dyn Future<Output = Result<StreamOfResults<T>, BackendError>> + Send>>;
pub(crate) enum PendingOrStream<T> {
Pending(BoxFuture<'static, Result<StreamOfResults<T>, BackendError>>),
Stream(StreamOfResults<T>),
}
impl<T> std::fmt::Debug for PendingOrStream<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PendingOrStream::Pending(_) => write!(f, "Pending"),
PendingOrStream::Stream(_) => write!(f, "Stream"),
}
}
}
struct RetrySubscription<T> {
resubscribe: ResubscribeGetter<T>,
state: Option<PendingOrStream<T>>,
}
impl<T> std::marker::Unpin for RetrySubscription<T> {}
impl<T> Stream for RetrySubscription<T> {
type Item = Result<T, BackendError>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
loop {
let Some(mut this) = self.state.take() else {
return Poll::Ready(None);
};
match this {
PendingOrStream::Stream(ref mut s) => match s.poll_next_unpin(cx) {
Poll::Ready(Some(Err(err))) => {
if err.is_disconnected_will_reconnect() {
self.state = Some(PendingOrStream::Pending((self.resubscribe)()));
}
return Poll::Ready(Some(Err(err)));
},
Poll::Ready(None) => return Poll::Ready(None),
Poll::Ready(Some(Ok(val))) => {
self.state = Some(this);
return Poll::Ready(Some(Ok(val)));
},
Poll::Pending => {
self.state = Some(this);
return Poll::Pending;
},
},
PendingOrStream::Pending(mut fut) => match fut.poll_unpin(cx) {
Poll::Ready(Ok(stream)) => {
self.state = Some(PendingOrStream::Stream(stream));
continue;
},
Poll::Ready(Err(err)) => {
if err.is_disconnected_will_reconnect() {
self.state = Some(PendingOrStream::Pending((self.resubscribe)()));
}
return Poll::Ready(Some(Err(err)));
},
Poll::Pending => {
self.state = Some(PendingOrStream::Pending(fut));
return Poll::Pending;
},
},
};
}
}
}
pub async fn retry<T, F, R>(mut retry_future: F) -> Result<R, BackendError>
where
F: FnMut() -> T,
T: Future<Output = Result<R, BackendError>>,
{
const REJECTED_MAX_RETRIES: usize = 10;
let mut rejected_retries = 0;
loop {
match retry_future().await {
Ok(v) => return Ok(v),
Err(e) => {
if e.is_disconnected_will_reconnect() {
continue;
}
if e.is_rpc_limit_reached() && rejected_retries < REJECTED_MAX_RETRIES {
rejected_retries += 1;
continue;
}
return Err(e);
},
}
}
}
pub async fn retry_stream<F, R>(sub_stream: F) -> Result<StreamOfResults<R>, BackendError>
where
F: FnMut() -> ResubscribeFuture<R> + Send + 'static + Clone,
R: Send + 'static,
{
let stream = retry(sub_stream.clone()).await?;
let resubscribe = Box::new(move || {
let sub_stream = sub_stream.clone();
async move { retry(sub_stream).await }.boxed()
});
Ok(StreamOf::new(Box::pin(RetrySubscription {
state: Some(PendingOrStream::Stream(stream)),
resubscribe,
})))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::StreamOf;
fn disconnect_err() -> BackendError {
BackendError::Rpc(
pezkuwi_subxt_rpcs::Error::DisconnectedWillReconnect(String::new()).into(),
)
}
fn custom_err() -> BackendError {
BackendError::Other(String::new())
}
#[tokio::test]
async fn retry_stream_works() {
let retry_stream = retry_stream(|| {
async {
Ok(StreamOf::new(Box::pin(futures::stream::iter([
Ok(1),
Ok(2),
Ok(3),
Err(disconnect_err()),
]))))
}
.boxed()
})
.await
.unwrap();
let result = retry_stream.take(5).collect::<Vec<Result<usize, BackendError>>>().await;
assert!(matches!(result[0], Ok(r) if r == 1));
assert!(matches!(result[1], Ok(r) if r == 2));
assert!(matches!(result[2], Ok(r) if r == 3));
assert!(matches!(result[3], Err(ref e) if e.is_disconnected_will_reconnect()));
assert!(matches!(result[4], Ok(r) if r == 1));
}
#[tokio::test]
async fn retry_sub_works() {
let stream = futures::stream::iter([Ok(1), Err(disconnect_err())]);
let resubscribe = Box::new(move || {
async move { Ok(StreamOf::new(Box::pin(futures::stream::iter([Ok(2)])))) }.boxed()
});
let retry_stream = RetrySubscription {
state: Some(PendingOrStream::Stream(StreamOf::new(Box::pin(stream)))),
resubscribe,
};
let result: Vec<_> = retry_stream.collect().await;
assert!(matches!(result[0], Ok(r) if r == 1));
assert!(matches!(result[1], Err(ref e) if e.is_disconnected_will_reconnect()));
assert!(matches!(result[2], Ok(r) if r == 2));
}
#[tokio::test]
async fn retry_sub_err_terminates_stream() {
let stream = futures::stream::iter([Ok(1)]);
let resubscribe = Box::new(|| async move { Err(custom_err()) }.boxed());
let retry_stream = RetrySubscription {
state: Some(PendingOrStream::Stream(StreamOf::new(Box::pin(stream)))),
resubscribe,
};
assert_eq!(retry_stream.count().await, 1);
}
#[tokio::test]
async fn retry_sub_resubscribe_err() {
let stream = futures::stream::iter([Ok(1), Err(disconnect_err())]);
let resubscribe = Box::new(|| async move { Err(custom_err()) }.boxed());
let retry_stream = RetrySubscription {
state: Some(PendingOrStream::Stream(StreamOf::new(Box::pin(stream)))),
resubscribe,
};
let result: Vec<_> = retry_stream.collect().await;
assert!(matches!(result[0], Ok(r) if r == 1));
assert!(matches!(result[1], Err(ref e) if e.is_disconnected_will_reconnect()));
assert!(matches!(result[2], Err(ref e) if matches!(e, BackendError::Other(_))));
}
}