kona_derive/stages/batch/
batch_provider.rsuse super::NextBatchProvider;
use crate::{
errors::PipelineError,
stages::{BatchQueue, BatchValidator},
traits::{AttributesProvider, L2ChainProvider, OriginAdvancer, OriginProvider, SignalReceiver},
types::{PipelineResult, Signal},
};
use alloc::{boxed::Box, sync::Arc};
use async_trait::async_trait;
use core::fmt::Debug;
use op_alloy_genesis::RollupConfig;
use op_alloy_protocol::{BlockInfo, L2BlockInfo, SingleBatch};
#[derive(Debug)]
pub struct BatchProvider<P, F>
where
P: NextBatchProvider + OriginAdvancer + OriginProvider + SignalReceiver + Debug,
F: L2ChainProvider + Clone + Debug,
{
cfg: Arc<RollupConfig>,
provider: F,
prev: Option<P>,
batch_queue: Option<BatchQueue<P, F>>,
batch_validator: Option<BatchValidator<P>>,
}
impl<P, F> BatchProvider<P, F>
where
P: NextBatchProvider + OriginAdvancer + OriginProvider + SignalReceiver + Debug,
F: L2ChainProvider + Clone + Debug,
{
pub const fn new(cfg: Arc<RollupConfig>, prev: P, provider: F) -> Self {
Self { cfg, provider, prev: Some(prev), batch_queue: None, batch_validator: None }
}
pub(crate) fn attempt_update(&mut self) -> PipelineResult<()> {
let origin = self.origin().ok_or(PipelineError::MissingOrigin.crit())?;
if let Some(prev) = self.prev.take() {
if self.cfg.is_holocene_active(origin.timestamp) {
self.batch_validator = Some(BatchValidator::new(self.cfg.clone(), prev));
} else {
self.batch_queue =
Some(BatchQueue::new(self.cfg.clone(), prev, self.provider.clone()));
}
} else if self.batch_queue.is_some() && self.cfg.is_holocene_active(origin.timestamp) {
let batch_queue = self.batch_queue.take().expect("Must have batch queue");
let mut bv = BatchValidator::new(self.cfg.clone(), batch_queue.prev);
bv.l1_blocks = batch_queue.l1_blocks;
self.batch_validator = Some(bv);
} else if self.batch_validator.is_some() && !self.cfg.is_holocene_active(origin.timestamp) {
let batch_validator = self.batch_validator.take().expect("Must have batch validator");
let mut bq =
BatchQueue::new(self.cfg.clone(), batch_validator.prev, self.provider.clone());
bq.l1_blocks = batch_validator.l1_blocks;
self.batch_queue = Some(bq);
}
Ok(())
}
}
#[async_trait]
impl<P, F> OriginAdvancer for BatchProvider<P, F>
where
P: NextBatchProvider + OriginAdvancer + OriginProvider + SignalReceiver + Send + Debug,
F: L2ChainProvider + Clone + Send + Debug,
{
async fn advance_origin(&mut self) -> PipelineResult<()> {
self.attempt_update()?;
if let Some(batch_validator) = self.batch_validator.as_mut() {
batch_validator.advance_origin().await
} else if let Some(batch_queue) = self.batch_queue.as_mut() {
batch_queue.advance_origin().await
} else {
Err(PipelineError::NotEnoughData.temp())
}
}
}
impl<P, F> OriginProvider for BatchProvider<P, F>
where
P: NextBatchProvider + OriginAdvancer + OriginProvider + SignalReceiver + Debug,
F: L2ChainProvider + Clone + Debug,
{
fn origin(&self) -> Option<BlockInfo> {
self.batch_validator.as_ref().map_or_else(
|| {
self.batch_queue.as_ref().map_or_else(
|| self.prev.as_ref().and_then(|prev| prev.origin()),
|batch_queue| batch_queue.origin(),
)
},
|batch_validator| batch_validator.origin(),
)
}
}
#[async_trait]
impl<P, F> SignalReceiver for BatchProvider<P, F>
where
P: NextBatchProvider + OriginAdvancer + OriginProvider + SignalReceiver + Send + Debug,
F: L2ChainProvider + Clone + Send + Debug,
{
async fn signal(&mut self, signal: Signal) -> PipelineResult<()> {
self.attempt_update()?;
if let Some(batch_validator) = self.batch_validator.as_mut() {
batch_validator.signal(signal).await
} else if let Some(batch_queue) = self.batch_queue.as_mut() {
batch_queue.signal(signal).await
} else {
Err(PipelineError::NotEnoughData.temp())
}
}
}
#[async_trait]
impl<P, F> AttributesProvider for BatchProvider<P, F>
where
P: NextBatchProvider + OriginAdvancer + OriginProvider + SignalReceiver + Debug + Send,
F: L2ChainProvider + Clone + Send + Debug,
{
fn is_last_in_span(&self) -> bool {
self.batch_validator.as_ref().map_or_else(
|| self.batch_queue.as_ref().map_or(false, |batch_queue| batch_queue.is_last_in_span()),
|batch_validator| batch_validator.is_last_in_span(),
)
}
async fn next_batch(&mut self, parent: L2BlockInfo) -> PipelineResult<SingleBatch> {
self.attempt_update()?;
if let Some(batch_validator) = self.batch_validator.as_mut() {
batch_validator.next_batch(parent).await
} else if let Some(batch_queue) = self.batch_queue.as_mut() {
batch_queue.next_batch(parent).await
} else {
Err(PipelineError::NotEnoughData.temp())
}
}
}
#[cfg(test)]
mod test {
use super::BatchProvider;
use crate::{
test_utils::{TestL2ChainProvider, TestNextBatchProvider},
traits::{OriginProvider, SignalReceiver},
types::ResetSignal,
};
use alloc::{sync::Arc, vec};
use op_alloy_genesis::RollupConfig;
use op_alloy_protocol::BlockInfo;
#[test]
fn test_batch_provider_validator_active() {
let provider = TestNextBatchProvider::new(vec![]);
let l2_provider = TestL2ChainProvider::default();
let cfg = Arc::new(RollupConfig { holocene_time: Some(0), ..Default::default() });
let mut batch_provider = BatchProvider::new(cfg, provider, l2_provider);
assert!(batch_provider.attempt_update().is_ok());
assert!(batch_provider.prev.is_none());
assert!(batch_provider.batch_queue.is_none());
assert!(batch_provider.batch_validator.is_some());
}
#[test]
fn test_batch_provider_batch_queue_active() {
let provider = TestNextBatchProvider::new(vec![]);
let l2_provider = TestL2ChainProvider::default();
let cfg = Arc::new(RollupConfig::default());
let mut batch_provider = BatchProvider::new(cfg, provider, l2_provider);
assert!(batch_provider.attempt_update().is_ok());
assert!(batch_provider.prev.is_none());
assert!(batch_provider.batch_queue.is_some());
assert!(batch_provider.batch_validator.is_none());
}
#[test]
fn test_batch_provider_transition_stage() {
let provider = TestNextBatchProvider::new(vec![]);
let l2_provider = TestL2ChainProvider::default();
let cfg = Arc::new(RollupConfig { holocene_time: Some(2), ..Default::default() });
let mut batch_provider = BatchProvider::new(cfg, provider, l2_provider);
batch_provider.attempt_update().unwrap();
let Some(ref mut stage) = batch_provider.batch_queue else {
panic!("Expected BatchQueue");
};
stage.prev.origin = Some(BlockInfo { number: 1, timestamp: 2, ..Default::default() });
batch_provider.attempt_update().unwrap();
assert!(batch_provider.batch_queue.is_none());
assert!(batch_provider.batch_validator.is_some());
assert_eq!(batch_provider.origin().unwrap().number, 1);
}
#[test]
fn test_batch_provider_transition_stage_backwards() {
let provider = TestNextBatchProvider::new(vec![]);
let l2_provider = TestL2ChainProvider::default();
let cfg = Arc::new(RollupConfig { holocene_time: Some(2), ..Default::default() });
let mut batch_provider = BatchProvider::new(cfg, provider, l2_provider);
batch_provider.attempt_update().unwrap();
let Some(ref mut stage) = batch_provider.batch_queue else {
panic!("Expected BatchQueue");
};
stage.prev.origin = Some(BlockInfo { number: 1, timestamp: 2, ..Default::default() });
batch_provider.attempt_update().unwrap();
assert!(batch_provider.batch_queue.is_none());
assert!(batch_provider.batch_validator.is_some());
let Some(ref mut stage) = batch_provider.batch_validator else {
panic!("Expected BatchValidator");
};
stage.prev.origin = Some(BlockInfo::default());
batch_provider.attempt_update().unwrap();
assert!(batch_provider.batch_queue.is_some());
assert!(batch_provider.batch_validator.is_none());
}
#[tokio::test]
async fn test_batch_provider_reset_bq() {
let provider = TestNextBatchProvider::new(vec![]);
let l2_provider = TestL2ChainProvider::default();
let cfg = Arc::new(RollupConfig::default());
let mut batch_provider = BatchProvider::new(cfg, provider, l2_provider);
batch_provider.signal(ResetSignal::default().signal()).await.unwrap();
let Some(bq) = batch_provider.batch_queue else {
panic!("Expected BatchQueue");
};
assert!(bq.l1_blocks.len() == 1);
}
#[tokio::test]
async fn test_batch_provider_reset_validator() {
let provider = TestNextBatchProvider::new(vec![]);
let l2_provider = TestL2ChainProvider::default();
let cfg = Arc::new(RollupConfig { holocene_time: Some(0), ..Default::default() });
let mut batch_provider = BatchProvider::new(cfg, provider, l2_provider);
batch_provider.signal(ResetSignal::default().signal()).await.unwrap();
let Some(bv) = batch_provider.batch_validator else {
panic!("Expected BatchValidator");
};
assert!(bv.l1_blocks.len() == 1);
}
}