use std::sync::Mutex;
use super::{Hlc, SkewError, Timestamp, wall_ms};
pub trait HlcStorage {
type Error;
type Sink: ?Sized;
fn load(&self) -> Result<Option<u64>, Self::Error>;
fn save(&self, sink: &mut Self::Sink, raw: u64) -> Result<(), Self::Error>;
}
#[derive(Debug)]
pub enum HlcError<E> {
Skew(SkewError),
Storage(E),
}
impl<E: std::fmt::Display> std::fmt::Display for HlcError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HlcError::Skew(e) => write!(f, "{e}"),
HlcError::Storage(e) => write!(f, "hlc storage: {e}"),
}
}
}
impl<E: std::error::Error + 'static> std::error::Error for HlcError<E> {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
HlcError::Skew(e) => Some(e),
HlcError::Storage(e) => Some(e),
}
}
}
pub struct HlcService<S: HlcStorage> {
state: Mutex<Hlc>,
storage: S,
}
impl<S: HlcStorage> HlcService<S> {
pub fn open(storage: S) -> Result<Self, S::Error> {
let seed = storage.load()?.unwrap_or(0);
Ok(Self {
state: Mutex::new(Hlc::new(seed)),
storage,
})
}
fn lock(&self) -> std::sync::MutexGuard<'_, Hlc> {
self.state.lock().unwrap_or_else(|poisoned| poisoned.into_inner())
}
pub fn now(&self, sink: &mut S::Sink) -> Result<Timestamp, S::Error> {
let mut hlc = self.lock();
let ts = hlc.tick(wall_ms());
self.storage.save(sink, hlc.state().raw())?;
Ok(ts)
}
pub fn observe(
&self,
received: Timestamp,
local_wall_ms: u64,
sink: &mut S::Sink,
) -> Result<(), HlcError<S::Error>> {
let mut hlc = self.lock();
let before = hlc.state();
hlc.observe(received, local_wall_ms)
.map_err(HlcError::Skew)?;
let after = hlc.state();
if after != before {
self.storage.save(sink, after.raw()).map_err(HlcError::Storage)?;
}
Ok(())
}
pub fn state(&self) -> Timestamp {
self.lock().state()
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::{MAX_SKEW_MS, wall_ms};
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
#[derive(Default)]
struct MemStorage {
value: AtomicU64,
saves: AtomicUsize,
present: std::sync::atomic::AtomicBool,
}
impl HlcStorage for &MemStorage {
type Error = std::convert::Infallible;
type Sink = ();
fn load(&self) -> Result<Option<u64>, Self::Error> {
Ok(self
.present
.load(Ordering::SeqCst)
.then(|| self.value.load(Ordering::SeqCst)))
}
fn save(&self, _sink: &mut (), raw: u64) -> Result<(), Self::Error> {
self.value.store(raw, Ordering::SeqCst);
self.present.store(true, Ordering::SeqCst);
self.saves.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
#[test]
fn now_persists_every_tick() {
let mem = MemStorage::default();
let svc = HlcService::open(&mem).unwrap();
let t1 = svc.now(&mut ()).unwrap();
let t2 = svc.now(&mut ()).unwrap();
assert!(t2 > t1);
assert_eq!(mem.value.load(Ordering::SeqCst), t2.raw());
assert_eq!(mem.saves.load(Ordering::SeqCst), 2);
}
#[test]
fn reopen_resumes_past_persisted_state() {
let mem = MemStorage::default();
let last = {
let svc = HlcService::open(&mem).unwrap();
svc.now(&mut ()).unwrap()
};
let svc2 = HlcService::open(&mem).unwrap();
let t = svc2.now(&mut ()).unwrap();
assert!(t > last);
}
#[test]
fn observe_persists_only_on_advance() {
let mem = MemStorage::default();
let svc = HlcService::open(&mem).unwrap();
let local = wall_ms();
let ahead = Timestamp::from_parts(local + MAX_SKEW_MS, 7);
svc.observe(ahead, local, &mut ()).unwrap();
let saves_after_advance = mem.saves.load(Ordering::SeqCst);
assert_eq!(saves_after_advance, 1, "advancing observe persists");
svc.observe(Timestamp::from_parts(local, 0), local, &mut ()).unwrap();
assert_eq!(mem.saves.load(Ordering::SeqCst), saves_after_advance);
assert_eq!(svc.state(), ahead);
}
#[test]
fn observe_beyond_skew_errors_and_persists_nothing() {
let mem = MemStorage::default();
let svc = HlcService::open(&mem).unwrap();
let local = wall_ms();
let too_far = Timestamp::from_parts(local + MAX_SKEW_MS + 1, 0);
let err = svc.observe(too_far, local, &mut ()).unwrap_err();
assert!(matches!(err, HlcError::Skew(_)));
assert_eq!(svc.state(), Timestamp::from_raw(0));
assert_eq!(mem.saves.load(Ordering::SeqCst), 0);
}
struct FailingStorage;
impl HlcStorage for FailingStorage {
type Error = &'static str;
type Sink = ();
fn load(&self) -> Result<Option<u64>, Self::Error> {
Ok(None)
}
fn save(&self, _sink: &mut (), _raw: u64) -> Result<(), Self::Error> {
Err("save failed")
}
}
#[test]
fn now_propagates_save_failure() {
let svc = HlcService::open(FailingStorage).unwrap();
assert!(svc.now(&mut ()).is_err());
}
#[test]
fn advancing_observe_surfaces_save_failure_as_storage_error() {
let svc = HlcService::open(FailingStorage).unwrap();
let local = wall_ms();
let ahead = Timestamp::from_parts(local + MAX_SKEW_MS, 1);
assert!(matches!(
svc.observe(ahead, local, &mut ()),
Err(HlcError::Storage(_))
));
}
#[test]
fn lock_recovers_from_a_poisoned_clock() {
use std::panic::{catch_unwind, AssertUnwindSafe};
let mem = MemStorage::default();
let svc = HlcService::open(&mem).unwrap();
let t1 = svc.now(&mut ()).unwrap();
let _ = catch_unwind(AssertUnwindSafe(|| {
let _guard = svc.lock();
panic!("poison the clock");
}));
let t2 = svc.now(&mut ()).unwrap();
assert!(t2 > t1);
}
#[test]
fn concurrent_ticks_are_all_distinct() {
let mem = MemStorage::default();
let svc = HlcService::open(&mem).unwrap();
let mut all: Vec<Timestamp> = std::thread::scope(|s| {
let handles: Vec<_> = (0..8)
.map(|_| s.spawn(|| (0..100).map(|_| svc.now(&mut ()).unwrap()).collect::<Vec<_>>()))
.collect();
handles
.into_iter()
.flat_map(|h| h.join().unwrap())
.collect()
});
let total = all.len();
all.sort();
all.dedup();
assert_eq!(all.len(), total, "every concurrent tick must be unique");
}
}