use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
use transformable::{utils::*, Transformable};
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(transparent))]
#[repr(transparent)]
pub struct LamportTime(pub(crate) u64);
impl core::fmt::Display for LamportTime {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<u64> for LamportTime {
fn from(time: u64) -> Self {
Self(time)
}
}
impl From<LamportTime> for u64 {
fn from(time: LamportTime) -> Self {
time.0
}
}
impl LamportTime {
pub const ZERO: Self = Self(0);
#[inline]
pub const fn new(time: u64) -> Self {
Self(time)
}
#[inline]
pub const fn to_be_bytes(self) -> [u8; 8] {
self.0.to_be_bytes()
}
#[inline]
pub const fn to_le_bytes(self) -> [u8; 8] {
self.0.to_le_bytes()
}
#[inline]
pub const fn from_be_bytes(bytes: [u8; 8]) -> Self {
Self(u64::from_be_bytes(bytes))
}
#[inline]
pub const fn from_le_bytes(bytes: [u8; 8]) -> Self {
Self(u64::from_le_bytes(bytes))
}
}
impl core::ops::Add<Self> for LamportTime {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self::Output {
Self(self.0 + rhs.0)
}
}
impl core::ops::Sub<Self> for LamportTime {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self::Output {
Self(self.0 - rhs.0)
}
}
impl core::ops::Rem<Self> for LamportTime {
type Output = Self;
#[inline]
fn rem(self, rhs: Self) -> Self::Output {
Self(self.0 % rhs.0)
}
}
#[derive(thiserror::Error, Debug)]
pub enum LamportTimeTransformError {
#[error(transparent)]
Encode(#[from] InsufficientBuffer),
#[error(transparent)]
Decode(#[from] DecodeVarintError),
}
impl Transformable for LamportTime {
type Error = LamportTimeTransformError;
fn encode(&self, dst: &mut [u8]) -> Result<usize, Self::Error> {
encode_u64_varint(self.0, dst).map_err(Into::into)
}
fn encoded_len(&self) -> usize {
encoded_u64_varint_len(self.0)
}
fn decode(src: &[u8]) -> Result<(usize, Self), Self::Error>
where
Self: Sized,
{
decode_u64_varint(src)
.map(|(n, time)| (n, Self(time)))
.map_err(Into::into)
}
}
#[derive(Debug, Clone)]
pub struct LamportClock(Arc<AtomicU64>);
impl Default for LamportClock {
fn default() -> Self {
Self::new()
}
}
impl LamportClock {
#[inline]
pub fn new() -> Self {
Self(Arc::new(AtomicU64::new(0)))
}
#[inline]
pub fn time(&self) -> LamportTime {
LamportTime(self.0.load(Ordering::SeqCst))
}
#[inline]
pub fn increment(&self) -> LamportTime {
LamportTime(self.0.fetch_add(1, Ordering::SeqCst) + 1)
}
#[inline]
pub fn witness(&self, time: LamportTime) {
loop {
let current = self.0.load(Ordering::SeqCst);
if time.0 < current {
return;
}
match self
.0
.compare_exchange_weak(current, time.0 + 1, Ordering::SeqCst, Ordering::SeqCst)
{
Ok(_) => return,
Err(_) => continue,
}
}
}
}
#[cfg(test)]
impl LamportTime {
pub(crate) fn random() -> Self {
use rand::Rng;
Self(rand::thread_rng().gen_range(0..u64::MAX))
}
}
#[test]
fn test_lamport_clock() {
let l = LamportClock::new();
assert_eq!(l.time(), 0.into());
assert_eq!(l.increment(), 1.into());
assert_eq!(l.time(), 1.into());
l.witness(41.into());
assert_eq!(l.time(), 42.into());
l.witness(41.into());
assert_eq!(l.time(), 42.into());
l.witness(30.into());
assert_eq!(l.time(), 42.into());
}