use std::ops::RangeInclusive;
use crate::{
Message, Notification, ScannerError, ScannerMessage,
block_range_scanner::BlockScannerResult,
event_scanner::{filter::EventFilter, listener::EventListener},
types::TryStream,
};
use alloy::{
network::Network,
rpc::types::{Filter, Log},
};
use futures::StreamExt;
use robust_provider::{Error as RobustProviderError, RobustProvider};
use tokio::{
sync::{
broadcast::{self, Sender, error::RecvError},
mpsc,
},
task::JoinSet,
};
use tokio_stream::{Stream, wrappers::ReceiverStream};
pub trait BlockRangeHandler {
fn handle<S: Stream<Item = BlockScannerResult> + Unpin + Send>(
self,
stream: S,
) -> impl std::future::Future<Output = ()> + Send;
}
#[derive(Debug)]
pub struct StreamHandler<N: Network> {
provider: RobustProvider<N>,
listeners: Vec<EventListener>,
max_concurrent_fetches: usize,
broadcast_channel_capacity: usize,
}
impl<N: Network> StreamHandler<N> {
#[must_use]
pub fn new(
provider: RobustProvider<N>,
listeners: Vec<EventListener>,
max_concurrent_fetches: usize,
broadcast_channel_capacity: usize,
) -> Self {
Self { provider, listeners, max_concurrent_fetches, broadcast_channel_capacity }
}
fn spawn(self, range_tx: &Sender<BlockScannerResult>) -> JoinSet<()> {
let mut join_set = JoinSet::new();
for listener in self.listeners {
let max_concurrent_fetches = self.max_concurrent_fetches;
let provider = self.provider.clone();
let mut range_rx = range_tx.subscribe();
join_set.spawn(async move {
let (tx, rx) = mpsc::channel::<BlockScannerResult>(max_concurrent_fetches);
tokio::spawn(async move {
let mut stream = ReceiverStream::new(rx)
.map(async |message| match message {
Ok(ScannerMessage::Data(range)) => {
get_logs(range, &listener.filter, &provider)
.await
.map(Message::from)
.map_err(ScannerError::from)
}
Ok(ScannerMessage::Notification(notification)) => {
Ok(notification.into())
}
Err(e) => Err(e),
})
.buffered(max_concurrent_fetches);
while let Some(result) = stream.next().await {
if let Ok(ScannerMessage::Data(logs)) = result.as_ref() &&
logs.is_empty()
{
continue;
}
if listener.sender.try_stream(result).await.is_closed() {
return;
}
}
});
loop {
match range_rx.recv().await {
Ok(message) => {
tx.send(message)
.await
.expect("receiver dropped only if we exit this loop");
}
Err(RecvError::Closed) => {
trace!("Block range stream closed");
break;
}
Err(RecvError::Lagged(skipped)) => {
tx.send(Err(ScannerError::Lagged(skipped)))
.await
.expect("receiver dropped only if we exit this loop");
}
}
}
});
}
join_set
}
}
impl<N: Network> BlockRangeHandler for StreamHandler<N> {
async fn handle<S: Stream<Item = BlockScannerResult> + Unpin + Send>(self, stream: S) {
debug!(
listener_count = self.listeners.len(),
max_concurrent_fetches = self.max_concurrent_fetches,
broadcast_channel_capacity = self.broadcast_channel_capacity,
max_concurrent_fetches = self.max_concurrent_fetches,
"Starting block range handler that forwards logs as they are received"
);
let (range_tx, _) =
broadcast::channel::<BlockScannerResult>(self.broadcast_channel_capacity);
let consumers = self.spawn(&range_tx);
broadcast_stream(stream, range_tx, consumers).await;
}
}
#[derive(Debug)]
pub struct LatestEventsHandler<N: Network> {
provider: RobustProvider<N>,
listeners: Vec<EventListener>,
max_concurrent_fetches: usize,
count: usize,
broadcast_channel_capacity: usize,
}
impl<N: Network> LatestEventsHandler<N> {
#[must_use]
pub fn new(
provider: RobustProvider<N>,
listeners: Vec<EventListener>,
max_concurrent_fetches: usize,
count: usize,
broadcast_channel_capacity: usize,
) -> Self {
Self { provider, listeners, max_concurrent_fetches, count, broadcast_channel_capacity }
}
#[allow(clippy::too_many_lines)]
fn spawn(self, range_tx: &Sender<BlockScannerResult>) -> JoinSet<()> {
let mut join_set = JoinSet::new();
for listener in self.listeners {
let max_concurrent_fetches = self.max_concurrent_fetches;
let count = self.count;
let provider = self.provider.clone();
let mut range_rx = range_tx.subscribe();
join_set.spawn(async move {
let (tx, rx) = mpsc::channel::<BlockScannerResult>(max_concurrent_fetches);
tokio::spawn(async move {
let mut stream = ReceiverStream::new(rx)
.map(async |message| match message {
Ok(ScannerMessage::Data(range)) => {
get_logs(range, &listener.filter, &provider)
.await
.map(Message::from)
.map_err(ScannerError::from)
}
Ok(ScannerMessage::Notification(notification)) => {
Ok(notification.into())
}
Err(e) => Err(e),
})
.buffered(max_concurrent_fetches);
let mut collected = Vec::with_capacity(count);
let mut reorg_ancestor: Option<u64> = None;
while let Some(result) = stream.next().await {
match result {
Ok(ScannerMessage::Data(logs)) => {
if logs.is_empty() {
continue;
}
let last_log_block_num = logs
.last()
.expect("logs already confirmed not empty")
.block_number
.expect("pending blocks not supported");
if reorg_ancestor.is_some_and(|a| last_log_block_num <= a) {
trace!(
ancestor = reorg_ancestor,
"Reorg recovery complete, resuming normal log collection"
);
reorg_ancestor = None;
}
let should_prepend = reorg_ancestor.is_some();
if collect_logs(&mut collected, logs, count, should_prepend) {
break;
}
}
Ok(ScannerMessage::Notification(Notification::ReorgDetected {
common_ancestor,
})) => {
trace!(
common_ancestor = common_ancestor,
"Reorg detected, rescanning new canonical blocks"
);
reorg_ancestor = Some(common_ancestor);
collected =
discard_logs_from_orphaned_blocks(collected, common_ancestor);
}
Ok(ScannerMessage::Notification(notification)) => {
if listener.sender.try_stream(notification).await.is_closed() {
return;
}
}
Err(e) => {
if listener.sender.try_stream(e).await.is_closed() {
return;
}
}
}
}
if collected.is_empty() {
trace!("No logs found");
_ = listener.sender.try_stream(Notification::NoPastLogsFound).await;
return;
}
trace!(count = collected.len(), "Logs found");
collected.reverse();
_ = listener.sender.try_stream(collected).await;
});
loop {
match range_rx.recv().await {
Ok(message) => {
if tx.send(message).await.is_err() {
break;
}
}
Err(RecvError::Closed) => {
trace!("Block range stream closed");
break;
}
Err(RecvError::Lagged(skipped)) => {
tx.send(Err(ScannerError::Lagged(skipped)))
.await
.expect("receiver dropped only if we exit this loop");
}
}
}
});
}
join_set
}
}
impl<N: Network> BlockRangeHandler for LatestEventsHandler<N> {
async fn handle<S: Stream<Item = BlockScannerResult> + Unpin + Send>(self, stream: S) {
debug!(
listener_count = self.listeners.len(),
max_concurrent_fetches = self.max_concurrent_fetches,
broadcast_channel_capacity = self.broadcast_channel_capacity,
max_concurrent_fetches = self.max_concurrent_fetches,
count = self.count,
"Starting block range handler that collects logs before streaming them, as required by the latest events mode"
);
let (range_tx, _) =
broadcast::channel::<BlockScannerResult>(self.broadcast_channel_capacity);
let consumers = self.spawn(&range_tx);
broadcast_stream(stream, range_tx, consumers).await;
}
}
async fn broadcast_stream<S: Stream<Item = BlockScannerResult> + Unpin + Send>(
mut stream: S,
range_tx: Sender<Result<ScannerMessage<RangeInclusive<u64>>, ScannerError>>,
consumers: JoinSet<()>,
) {
while let Some(message) = stream.next().await {
if range_tx.send(message).is_err() {
debug!("All consumers dropped, stopping stream handler");
break;
}
}
debug!("Block range stream ended, waiting for consumers");
drop(range_tx);
consumers.join_all().await;
debug!("All event consumers finished");
}
fn discard_logs_from_orphaned_blocks(collected: Vec<Log>, common_ancestor: u64) -> Vec<Log> {
let before_count = collected.len();
let collected = collected
.into_iter()
.skip_while(|log| {
log.block_number.is_some_and(|n| n > common_ancestor)
})
.collect::<Vec<_>>();
let removed_count = before_count - collected.len();
if removed_count > 0 {
trace!(
removed_count = removed_count,
remaining_count = collected.len(),
"Invalidated logs from orphaned blocks"
);
}
collected
}
fn collect_logs<T>(collected: &mut Vec<T>, logs: Vec<T>, count: usize, prepend: bool) -> bool {
if prepend {
let new_logs = logs.into_iter().rev().take(count);
let keep = count.saturating_sub(new_logs.len());
collected.truncate(keep);
collected.splice(..0, new_logs);
} else {
let take = count.saturating_sub(collected.len());
if take == 0 {
return true;
}
collected.extend(logs.into_iter().rev().take(take));
}
collected.len() >= count
}
async fn get_logs<N: Network>(
range: RangeInclusive<u64>,
event_filter: &EventFilter,
provider: &RobustProvider<N>,
) -> Result<Vec<Log>, RobustProviderError> {
let log_filter = Filter::from(event_filter).from_block(*range.start()).to_block(*range.end());
trace!(from_block = *range.start(), to_block = *range.end(), "Fetching logs for block range");
match provider.get_logs(&log_filter).await {
Ok(logs) => {
if !logs.is_empty() {
debug!(
from_block = *range.start(),
to_block = *range.end(),
log_count = logs.len(),
"Found logs in block range"
);
}
Ok(logs)
}
Err(e) => {
error!(
from_block = *range.start(),
to_block = *range.end(),
"Failed to get logs for block range"
);
Err(e)
}
}
}
#[cfg(test)]
mod tests {
use alloy::{
network::Ethereum,
providers::{RootProvider, mock::Asserter},
rpc::client::RpcClient,
};
use robust_provider::RobustProviderBuilder;
use super::*;
#[test]
fn collect_logs_appends_in_reverse_order() {
let mut collected = vec![];
let new_logs = vec![10, 11, 12];
let done = collect_logs(&mut collected, new_logs, 5, false);
assert!(!done);
assert_eq!(collected, vec![12, 11, 10]);
}
#[test]
fn collect_logs_prepends_in_reverse_order() {
let mut collected = vec![];
let new_logs = vec![10, 11, 12];
let done = collect_logs(&mut collected, new_logs, 5, true);
assert!(!done);
assert_eq!(collected, vec![12, 11, 10]);
}
#[test]
fn collect_logs_stops_at_count() {
let mut collected = vec![15, 14];
let new_logs = vec![10, 11, 12, 13];
let done = collect_logs(&mut collected, new_logs, 5, false);
assert!(done);
assert_eq!(collected, vec![15, 14, 13, 12, 11]);
}
#[test]
fn collect_logs_prepends_during_reorg_recovery() {
let mut collected = vec![75, 70];
let new_logs = vec![85, 90];
let done = collect_logs(&mut collected, new_logs, 5, true);
assert!(!done);
assert_eq!(collected, vec![90, 85, 75, 70]);
}
#[test]
fn collect_logs_prioritizes_prepended_logs_when_truncating() {
let mut collected = vec![75, 70, 65, 60];
let new_logs = vec![85, 90, 95];
let done = collect_logs(&mut collected, new_logs, 5, true);
assert!(done);
assert_eq!(collected, vec![95, 90, 85, 75, 70]);
let mut collected = vec![75, 70, 65, 60];
let new_logs = vec![85, 90, 95, 100, 105];
let done = collect_logs(&mut collected, new_logs, 5, true);
assert!(done);
assert_eq!(collected, vec![105, 100, 95, 90, 85]);
}
#[test]
fn collect_logs_ignores_new_logs_for_appending_when_already_at_count() {
let mut collected = vec![100, 99, 98];
let new_logs = vec![90];
let done = collect_logs(&mut collected, new_logs, 3, false);
assert!(done);
assert_eq!(collected, vec![100, 99, 98]);
}
#[test]
fn collect_logs_prepend_respects_count_limit() {
let mut collected = vec![70];
let new_logs = vec![80, 85, 90, 95];
let done = collect_logs(&mut collected, new_logs, 3, true);
assert!(done);
assert_eq!(collected, vec![95, 90, 85]);
}
#[tokio::test]
async fn stream_handler_streams_lagged_error() -> anyhow::Result<()> {
let provider = RootProvider::<Ethereum>::new(RpcClient::mocked(Asserter::new()));
let provider = RobustProviderBuilder::fragile(provider).build().await?;
let (sender, mut receiver) = mpsc::channel(1);
let stream_handler = StreamHandler {
provider,
listeners: vec![EventListener { filter: EventFilter::new(), sender }],
max_concurrent_fetches: 1,
broadcast_channel_capacity: 1,
};
let (range_tx, _) = tokio::sync::broadcast::channel::<BlockScannerResult>(
stream_handler.broadcast_channel_capacity,
);
let _set = stream_handler.spawn(&range_tx);
range_tx.send(Ok(ScannerMessage::Data(0..=1)))?;
range_tx.send(Ok(ScannerMessage::Data(2..=3)))?;
assert!(matches!(receiver.recv().await.unwrap(), Err(ScannerError::Lagged(1))));
Ok(())
}
#[tokio::test]
async fn spawn_log_consumers_in_collection_mode_streams_lagged_error() -> anyhow::Result<()> {
let provider = RootProvider::<Ethereum>::new(RpcClient::mocked(Asserter::new()));
let provider = RobustProviderBuilder::fragile(provider).build().await?;
let (sender, mut receiver) = mpsc::channel(1);
let handler = LatestEventsHandler {
provider,
listeners: vec![EventListener { filter: EventFilter::new(), sender }],
max_concurrent_fetches: 1,
count: 5,
broadcast_channel_capacity: 1,
};
let (range_tx, _) = tokio::sync::broadcast::channel::<BlockScannerResult>(
handler.broadcast_channel_capacity,
);
let _set = handler.spawn(&range_tx);
range_tx.send(Ok(ScannerMessage::Data(2..=3)))?;
range_tx.send(Ok(ScannerMessage::Data(0..=1)))?;
assert!(matches!(receiver.recv().await.unwrap(), Err(ScannerError::Lagged(1))));
Ok(())
}
}