mod mock_clock;
pub use mock_clock::MockClock;
use core::pin::Pin;
use futures::{Stream, StreamExt};
use parking_lot::Mutex;
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::{Notify, watch};
use tokio_stream::wrappers::WatchStream;
use tsoracle_consensus::{ConsensusDriver, ConsensusError, LeaderState};
use tsoracle_core::{Epoch, PeerEndpoint};
#[derive(Clone)]
pub struct InMemoryDriver {
state: Arc<Mutex<u64>>,
tx: watch::Sender<LeaderState>,
rx: watch::Receiver<LeaderState>,
}
impl Default for InMemoryDriver {
fn default() -> Self {
let (tx, rx) = watch::channel(LeaderState::Unknown);
InMemoryDriver {
state: Arc::new(Mutex::new(0)),
tx,
rx,
}
}
}
impl InMemoryDriver {
pub fn new() -> Self {
Self::default()
}
pub fn become_leader(&self, epoch: Epoch) {
let _ = self.tx.send(LeaderState::Leader { epoch });
}
pub fn become_follower(&self, hint: Option<PeerEndpoint>) {
let _ = self.tx.send(LeaderState::Follower {
leader_endpoint: hint,
leader_epoch: None,
});
}
pub fn become_follower_with_epoch(&self, hint: Option<PeerEndpoint>, epoch: Option<Epoch>) {
let _ = self.tx.send(LeaderState::Follower {
leader_endpoint: hint,
leader_epoch: epoch,
});
}
pub fn current_high_water(&self) -> u64 {
*self.state.lock()
}
}
#[async_trait::async_trait]
impl ConsensusDriver for InMemoryDriver {
fn leadership_events(&self) -> Pin<Box<dyn Stream<Item = LeaderState> + Send>> {
Box::pin(WatchStream::new(self.rx.clone()).boxed())
}
async fn load_high_water(&self) -> Result<u64, ConsensusError> {
Ok(*self.state.lock())
}
async fn persist_high_water(
&self,
at_least: u64,
_epoch: Epoch,
) -> Result<u64, ConsensusError> {
let mut high_water = self.state.lock();
if at_least > *high_water {
*high_water = at_least;
}
Ok(*high_water)
}
}
#[derive(Clone)]
pub struct StallableDriver {
inner: InMemoryDriver,
stall_threshold: Arc<AtomicU64>,
threshold_wake: Arc<Notify>,
persist_calls: Arc<AtomicU64>,
}
impl Default for StallableDriver {
fn default() -> Self {
StallableDriver {
inner: InMemoryDriver::new(),
stall_threshold: Arc::new(AtomicU64::new(u64::MAX)),
threshold_wake: Arc::new(Notify::new()),
persist_calls: Arc::new(AtomicU64::new(0)),
}
}
}
impl StallableDriver {
pub fn new() -> Self {
Self::default()
}
pub fn become_leader(&self, epoch: Epoch) {
self.inner.become_leader(epoch);
}
pub fn stall_from(&self, threshold: u64) {
self.stall_threshold.store(threshold, Ordering::SeqCst);
}
pub fn release(&self) {
self.stall_threshold.store(u64::MAX, Ordering::SeqCst);
self.threshold_wake.notify_waiters();
}
pub fn persist_call_count(&self) -> u64 {
self.persist_calls.load(Ordering::SeqCst)
}
}
#[async_trait::async_trait]
impl ConsensusDriver for StallableDriver {
fn leadership_events(&self) -> Pin<Box<dyn Stream<Item = LeaderState> + Send>> {
self.inner.leadership_events()
}
async fn load_high_water(&self) -> Result<u64, ConsensusError> {
self.inner.load_high_water().await
}
async fn persist_high_water(&self, at_least: u64, epoch: Epoch) -> Result<u64, ConsensusError> {
let call_idx = self.persist_calls.fetch_add(1, Ordering::SeqCst);
let mut was_stalled = false;
loop {
let notified = self.threshold_wake.notified();
tokio::pin!(notified);
notified.as_mut().enable();
if call_idx < self.stall_threshold.load(Ordering::SeqCst) {
break;
}
was_stalled = true;
notified.as_mut().await;
}
let effective_at_least = if was_stalled {
let now_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(at_least);
core::cmp::max(at_least, now_ms.saturating_add(60_000))
} else {
at_least
};
self.inner
.persist_high_water(effective_at_least, epoch)
.await
}
}
#[derive(Clone, Copy, Debug)]
pub enum FaultKind {
NotLeader,
Transient,
Permanent,
}
impl FaultKind {
fn into_error(self) -> ConsensusError {
match self {
FaultKind::NotLeader => ConsensusError::NotLeader { observed: None },
FaultKind::Transient => ConsensusError::TransientDriver(Box::new(
std::io::Error::other("injected transient fault"),
)),
FaultKind::Permanent => ConsensusError::PermanentDriver(Box::new(
std::io::Error::other("injected permanent fault"),
)),
}
}
}
#[derive(Clone)]
pub struct FaultyDriver {
inner: InMemoryDriver,
persist_faults: Arc<Mutex<VecDeque<FaultKind>>>,
persist_calls: Arc<AtomicU64>,
}
impl Default for FaultyDriver {
fn default() -> Self {
FaultyDriver {
inner: InMemoryDriver::new(),
persist_faults: Arc::new(Mutex::new(VecDeque::new())),
persist_calls: Arc::new(AtomicU64::new(0)),
}
}
}
impl FaultyDriver {
pub fn new() -> Self {
Self::default()
}
pub fn become_leader(&self, epoch: Epoch) {
self.inner.become_leader(epoch);
}
pub fn become_follower(&self, hint: Option<PeerEndpoint>) {
self.inner.become_follower(hint);
}
pub fn current_high_water(&self) -> u64 {
self.inner.current_high_water()
}
pub fn persist_call_count(&self) -> u64 {
self.persist_calls.load(Ordering::SeqCst)
}
pub fn fail_next_persists(&self, count: usize, kind: FaultKind) {
let mut queue = self.persist_faults.lock();
for _ in 0..count {
queue.push_back(kind);
}
}
}
#[async_trait::async_trait]
impl ConsensusDriver for FaultyDriver {
fn leadership_events(&self) -> Pin<Box<dyn Stream<Item = LeaderState> + Send>> {
self.inner.leadership_events()
}
async fn load_high_water(&self) -> Result<u64, ConsensusError> {
self.inner.load_high_water().await
}
async fn persist_high_water(&self, at_least: u64, epoch: Epoch) -> Result<u64, ConsensusError> {
self.persist_calls.fetch_add(1, Ordering::SeqCst);
if let Some(kind) = self.persist_faults.lock().pop_front() {
return Err(kind.into_error());
}
self.inner.persist_high_water(at_least, epoch).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn persist_is_monotonic() {
let driver = InMemoryDriver::new();
assert_eq!(driver.persist_high_water(100, Epoch(1)).await.unwrap(), 100);
assert_eq!(driver.persist_high_water(50, Epoch(1)).await.unwrap(), 100); assert_eq!(driver.persist_high_water(200, Epoch(1)).await.unwrap(), 200);
assert_eq!(driver.load_high_water().await.unwrap(), 200);
}
#[tokio::test]
async fn stallable_driver_holds_persist_until_released() {
use std::time::Duration;
use tokio::time::timeout;
let driver = StallableDriver::new();
driver.stall_from(1);
let first = driver.persist_high_water(50, Epoch(1)).await.unwrap();
assert_eq!(first, 50);
let stalled = driver.persist_high_water(100, Epoch(1));
assert!(
timeout(Duration::from_millis(50), stalled).await.is_err(),
"stalled persist must not complete before release"
);
driver.release();
let released = timeout(
Duration::from_millis(500),
driver.persist_high_water(200, Epoch(1)),
)
.await
.expect("released persist must complete")
.unwrap();
assert_eq!(released, 200);
assert!(driver.persist_call_count() >= 3);
}
#[tokio::test]
async fn leadership_events_observe_transitions() {
let driver = InMemoryDriver::new();
let mut events = driver.leadership_events();
driver.become_leader(Epoch(1));
loop {
match events.next().await.unwrap() {
LeaderState::Leader { epoch } => {
assert_eq!(epoch, Epoch(1));
break;
}
_ => continue,
}
}
}
}