use core::pin::Pin;
use futures::{Stream, StreamExt};
use parking_lot::Mutex;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::{Notify, watch};
use tokio_stream::wrappers::WatchStream;
use tsoracle_consensus::{ConsensusDriver, ConsensusError, LeaderState};
use tsoracle_core::Epoch;
#[derive(Clone)]
pub struct InMemoryDriver {
state: Arc<Mutex<u64>>,
tx: watch::Sender<LeaderState>,
rx: watch::Receiver<LeaderState>,
}
impl Default for InMemoryDriver {
fn default() -> Self {
let (tx, rx) = watch::channel(LeaderState::Unknown);
InMemoryDriver {
state: Arc::new(Mutex::new(0)),
tx,
rx,
}
}
}
impl InMemoryDriver {
pub fn new() -> Self {
Self::default()
}
pub fn become_leader(&self, epoch: Epoch) {
let _ = self.tx.send(LeaderState::Leader { epoch });
}
pub fn become_follower(&self, hint: Option<String>) {
let _ = self.tx.send(LeaderState::Follower {
leader_endpoint: hint,
});
}
pub fn current_high_water(&self) -> u64 {
*self.state.lock()
}
}
#[async_trait::async_trait]
impl ConsensusDriver for InMemoryDriver {
fn leadership_events(&self) -> Pin<Box<dyn Stream<Item = LeaderState> + Send>> {
Box::pin(WatchStream::new(self.rx.clone()).boxed())
}
async fn load_high_water(&self) -> Result<u64, ConsensusError> {
Ok(*self.state.lock())
}
async fn persist_high_water(
&self,
at_least: u64,
_epoch: Epoch,
) -> Result<u64, ConsensusError> {
let mut high_water = self.state.lock();
if at_least > *high_water {
*high_water = at_least;
}
Ok(*high_water)
}
}
#[derive(Clone)]
pub struct StallableDriver {
inner: InMemoryDriver,
stall_threshold: Arc<AtomicU64>,
threshold_wake: Arc<Notify>,
persist_calls: Arc<AtomicU64>,
}
impl Default for StallableDriver {
fn default() -> Self {
StallableDriver {
inner: InMemoryDriver::new(),
stall_threshold: Arc::new(AtomicU64::new(u64::MAX)),
threshold_wake: Arc::new(Notify::new()),
persist_calls: Arc::new(AtomicU64::new(0)),
}
}
}
impl StallableDriver {
pub fn new() -> Self {
Self::default()
}
pub fn become_leader(&self, epoch: Epoch) {
self.inner.become_leader(epoch);
}
pub fn stall_from(&self, threshold: u64) {
self.stall_threshold.store(threshold, Ordering::SeqCst);
}
pub fn release(&self) {
self.stall_threshold.store(u64::MAX, Ordering::SeqCst);
self.threshold_wake.notify_waiters();
}
pub fn persist_call_count(&self) -> u64 {
self.persist_calls.load(Ordering::SeqCst)
}
}
#[async_trait::async_trait]
impl ConsensusDriver for StallableDriver {
fn leadership_events(&self) -> Pin<Box<dyn Stream<Item = LeaderState> + Send>> {
self.inner.leadership_events()
}
async fn load_high_water(&self) -> Result<u64, ConsensusError> {
self.inner.load_high_water().await
}
async fn persist_high_water(&self, at_least: u64, epoch: Epoch) -> Result<u64, ConsensusError> {
let call_idx = self.persist_calls.fetch_add(1, Ordering::SeqCst);
let mut was_stalled = false;
loop {
let notified = self.threshold_wake.notified();
tokio::pin!(notified);
notified.as_mut().enable();
if call_idx < self.stall_threshold.load(Ordering::SeqCst) {
break;
}
was_stalled = true;
notified.as_mut().await;
}
let effective_at_least = if was_stalled {
let now_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(at_least);
core::cmp::max(at_least, now_ms.saturating_add(60_000))
} else {
at_least
};
self.inner
.persist_high_water(effective_at_least, epoch)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn persist_is_monotonic() {
let driver = InMemoryDriver::new();
assert_eq!(driver.persist_high_water(100, Epoch(1)).await.unwrap(), 100);
assert_eq!(driver.persist_high_water(50, Epoch(1)).await.unwrap(), 100); assert_eq!(driver.persist_high_water(200, Epoch(1)).await.unwrap(), 200);
assert_eq!(driver.load_high_water().await.unwrap(), 200);
}
#[tokio::test]
async fn stallable_driver_holds_persist_until_released() {
use std::time::Duration;
use tokio::time::timeout;
let driver = StallableDriver::new();
driver.stall_from(1);
let first = driver.persist_high_water(50, Epoch(1)).await.unwrap();
assert_eq!(first, 50);
let stalled = driver.persist_high_water(100, Epoch(1));
assert!(
timeout(Duration::from_millis(50), stalled).await.is_err(),
"stalled persist must not complete before release"
);
driver.release();
let released = timeout(
Duration::from_millis(500),
driver.persist_high_water(200, Epoch(1)),
)
.await
.expect("released persist must complete")
.unwrap();
assert_eq!(released, 200);
assert!(driver.persist_call_count() >= 3);
}
#[tokio::test]
async fn leadership_events_observe_transitions() {
let driver = InMemoryDriver::new();
let mut events = driver.leadership_events();
driver.become_leader(Epoch(1));
loop {
match events.next().await.unwrap() {
LeaderState::Leader { epoch } => {
assert_eq!(epoch, Epoch(1));
break;
}
_ => continue,
}
}
}
}