use super::error::JournalError;
use super::journal::Journal;
use super::types::{SequencerCommand, SequencerEvent, SequencerResult};
use crate::orderbook::{OrderBook, OrderBookError, OrderBookSnapshot};
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ReplayError {
#[error("journal is empty — nothing to replay")]
EmptyJournal,
#[error("invalid from_sequence {from_sequence}: journal last sequence is {last_sequence}")]
InvalidSequence {
from_sequence: u64,
last_sequence: u64,
},
#[error("sequence gap detected: expected {expected}, found {found}")]
SequenceGap {
expected: u64,
found: u64,
},
#[error("order book error during replay at sequence {sequence_num}: {source}")]
OrderBookError {
sequence_num: u64,
#[source]
source: OrderBookError,
},
#[error("snapshot mismatch: replayed state diverges from expected snapshot")]
SnapshotMismatch,
#[error("journal error during replay: {0}")]
JournalError(#[from] JournalError),
}
pub struct ReplayEngine<T> {
_phantom: PhantomData<T>,
}
impl<T> ReplayEngine<T>
where
T: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + Default + 'static,
{
pub fn replay_from(
journal: &impl Journal<T>,
from_sequence: u64,
symbol: &str,
) -> Result<(OrderBook<T>, u64), ReplayError> {
Self::replay_from_with_progress(journal, from_sequence, symbol, |_, _| {})
}
pub fn replay_from_with_progress(
journal: &impl Journal<T>,
from_sequence: u64,
symbol: &str,
progress: impl Fn(u64, u64),
) -> Result<(OrderBook<T>, u64), ReplayError> {
let last_seq = match journal.last_sequence() {
Some(seq) => seq,
None => return Err(ReplayError::EmptyJournal),
};
if from_sequence > last_seq {
return Err(ReplayError::InvalidSequence {
from_sequence,
last_sequence: last_seq,
});
}
let book = OrderBook::new(symbol);
let mut last_applied_seq = 0u64;
let mut count = 0u64;
let mut expected_seq = from_sequence;
let iter = journal.read_from(from_sequence)?;
for entry_result in iter {
let entry = entry_result?;
let event = &entry.event;
if event.sequence_num != expected_seq {
return Err(ReplayError::SequenceGap {
expected: expected_seq,
found: event.sequence_num,
});
}
Self::apply_event(&book, event)?;
last_applied_seq = event.sequence_num;
count = count.saturating_add(1);
expected_seq = expected_seq.saturating_add(1);
progress(count, last_applied_seq);
}
Ok((book, last_applied_seq))
}
pub fn verify(
journal: &impl Journal<T>,
expected_snapshot: &OrderBookSnapshot,
) -> Result<bool, ReplayError> {
let (book, _) = Self::replay_from(journal, 0, &expected_snapshot.symbol)?;
let actual = book.create_snapshot(usize::MAX);
Ok(snapshots_match(&actual, expected_snapshot))
}
fn apply_event(book: &OrderBook<T>, event: &SequencerEvent<T>) -> Result<(), ReplayError> {
if matches!(event.result, SequencerResult::Rejected { .. }) {
return Ok(());
}
match &event.command {
SequencerCommand::AddOrder(order) => {
book.add_order(order.clone())
.map_err(|e| ReplayError::OrderBookError {
sequence_num: event.sequence_num,
source: e,
})?;
}
SequencerCommand::CancelOrder(id) => {
book.cancel_order(*id)
.map_err(|e| ReplayError::OrderBookError {
sequence_num: event.sequence_num,
source: e,
})?;
}
SequencerCommand::UpdateOrder(update) => {
book.update_order(*update)
.map_err(|e| ReplayError::OrderBookError {
sequence_num: event.sequence_num,
source: e,
})?;
}
SequencerCommand::MarketOrder { id, quantity, side } => {
book.submit_market_order(*id, *quantity, *side)
.map_err(|e| ReplayError::OrderBookError {
sequence_num: event.sequence_num,
source: e,
})?;
}
SequencerCommand::CancelAll => {
let _ = book.cancel_all_orders();
}
SequencerCommand::CancelBySide { side } => {
let _ = book.cancel_orders_by_side(*side);
}
SequencerCommand::CancelByUser { user_id } => {
let _ = book.cancel_orders_by_user(*user_id);
}
SequencerCommand::CancelByPriceRange {
side,
min_price,
max_price,
} => {
let _ = book.cancel_orders_by_price_range(*side, *min_price, *max_price);
}
}
Ok(())
}
}
#[must_use]
pub fn snapshots_match(actual: &OrderBookSnapshot, expected: &OrderBookSnapshot) -> bool {
if actual.symbol != expected.symbol {
return false;
}
let mut actual_bids: Vec<_> = actual.bids.iter().collect();
let mut expected_bids: Vec<_> = expected.bids.iter().collect();
actual_bids.sort_by_key(|b| std::cmp::Reverse(b.price()));
expected_bids.sort_by_key(|b| std::cmp::Reverse(b.price()));
if actual_bids.len() != expected_bids.len() {
return false;
}
for (a, b) in actual_bids.iter().zip(expected_bids.iter()) {
if a.price() != b.price() || a.visible_quantity() != b.visible_quantity() {
return false;
}
}
let mut actual_asks: Vec<_> = actual.asks.iter().collect();
let mut expected_asks: Vec<_> = expected.asks.iter().collect();
actual_asks.sort_by_key(|l| l.price());
expected_asks.sort_by_key(|l| l.price());
if actual_asks.len() != expected_asks.len() {
return false;
}
for (a, b) in actual_asks.iter().zip(expected_asks.iter()) {
if a.price() != b.price() || a.visible_quantity() != b.visible_quantity() {
return false;
}
}
true
}