use serde::{Deserialize, Serialize};
use std::{
cmp::Ordering,
time::{SystemTime, UNIX_EPOCH},
};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub struct HlcTimestamp {
pub physical_ms: u64,
pub logical: u32,
pub node_id: u32,
}
impl HlcTimestamp {
pub fn new(physical_ms: u64, logical: u32, node_id: u32) -> Self {
Self {
physical_ms,
logical,
node_id,
}
}
pub fn zero() -> Self {
Self {
physical_ms: 0,
logical: 0,
node_id: 0,
}
}
pub fn to_u128(&self) -> u128 {
(u128::from(self.physical_ms) << 64)
| (u128::from(self.logical) << 32)
| u128::from(self.node_id)
}
pub fn from_u128(val: u128) -> Self {
Self {
physical_ms: (val >> 64) as u64,
logical: ((val >> 32) & 0xFFFF_FFFF) as u32,
node_id: (val & 0xFFFF_FFFF) as u32,
}
}
}
impl Ord for HlcTimestamp {
fn cmp(&self, other: &Self) -> Ordering {
self.physical_ms
.cmp(&other.physical_ms)
.then(self.logical.cmp(&other.logical))
.then(self.node_id.cmp(&other.node_id))
}
}
impl PartialOrd for HlcTimestamp {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl std::fmt::Display for HlcTimestamp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}:{}", self.physical_ms, self.logical, self.node_id)
}
}
pub struct HybridLogicalClock {
node_id: u32,
state: std::sync::Mutex<(u64, u32)>,
max_drift_ms: u64,
}
impl HybridLogicalClock {
pub fn new(node_id: u32) -> Self {
Self {
node_id,
state: std::sync::Mutex::new((0, 0)),
max_drift_ms: 60_000, }
}
pub fn with_max_drift(node_id: u32, max_drift_ms: u64) -> Self {
Self {
node_id,
state: std::sync::Mutex::new((0, 0)),
max_drift_ms,
}
}
pub fn node_id(&self) -> u32 {
self.node_id
}
fn wall_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
pub fn now(&self) -> HlcTimestamp {
let wall = Self::wall_ms();
let mut guard = self.state.lock().unwrap();
let (last_pt, last_l) = *guard;
let (new_pt, new_l) = if wall > last_pt {
(wall, 0)
} else {
(last_pt, last_l + 1)
};
*guard = (new_pt, new_l);
HlcTimestamp::new(new_pt, new_l, self.node_id)
}
pub fn receive(&self, remote: &HlcTimestamp) -> Result<HlcTimestamp, HlcDriftError> {
let wall = Self::wall_ms();
if remote.physical_ms > wall + self.max_drift_ms {
return Err(HlcDriftError {
remote_ms: remote.physical_ms,
local_wall_ms: wall,
max_drift_ms: self.max_drift_ms,
});
}
let mut guard = self.state.lock().unwrap();
let (last_pt, last_l) = *guard;
let new_pt = wall.max(last_pt).max(remote.physical_ms);
let new_l = if new_pt == last_pt && new_pt == remote.physical_ms {
last_l.max(remote.logical) + 1
} else if new_pt == last_pt {
last_l + 1
} else if new_pt == remote.physical_ms {
remote.logical + 1
} else {
0
};
*guard = (new_pt, new_l);
Ok(HlcTimestamp::new(new_pt, new_l, self.node_id))
}
pub fn current(&self) -> HlcTimestamp {
let guard = self.state.lock().unwrap();
HlcTimestamp::new(guard.0, guard.1, self.node_id)
}
}
#[derive(Debug, Clone)]
pub struct HlcDriftError {
pub remote_ms: u64,
pub local_wall_ms: u64,
pub max_drift_ms: u64,
}
impl std::fmt::Display for HlcDriftError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"HLC drift violation: remote={}, local_wall={}, max_drift={}",
self.remote_ms, self.local_wall_ms, self.max_drift_ms,
)
}
}
impl std::error::Error for HlcDriftError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hlc_now_monotonic() {
let hlc = HybridLogicalClock::new(1);
let t1 = hlc.now();
let t2 = hlc.now();
let t3 = hlc.now();
assert!(t1 < t2);
assert!(t2 < t3);
}
#[test]
fn test_hlc_now_uses_wall_clock() {
let hlc = HybridLogicalClock::new(1);
let before = HybridLogicalClock::wall_ms();
let ts = hlc.now();
let after = HybridLogicalClock::wall_ms();
assert!(ts.physical_ms >= before);
assert!(ts.physical_ms <= after + 1); }
#[test]
fn test_hlc_receive_advances_past_remote() {
let hlc = HybridLogicalClock::new(1);
let _ = hlc.now();
let remote = HlcTimestamp::new(HybridLogicalClock::wall_ms() + 100, 5, 2);
let ts = hlc.receive(&remote).unwrap();
assert!(ts > remote);
}
#[test]
fn test_hlc_receive_drift_rejected() {
let hlc = HybridLogicalClock::with_max_drift(1, 1000); let remote = HlcTimestamp::new(HybridLogicalClock::wall_ms() + 5000, 0, 2);
assert!(hlc.receive(&remote).is_err());
}
#[test]
fn test_hlc_receive_same_physical_time() {
let hlc = HybridLogicalClock::new(1);
let wall = HybridLogicalClock::wall_ms();
let remote1 = HlcTimestamp::new(wall, 5, 2);
let ts1 = hlc.receive(&remote1).unwrap();
assert!(ts1.logical > 5);
let remote2 = HlcTimestamp::new(wall, 10, 3);
let ts2 = hlc.receive(&remote2).unwrap();
assert!(ts2 > ts1);
}
#[test]
fn test_hlc_timestamp_ordering() {
let a = HlcTimestamp::new(100, 0, 1);
let b = HlcTimestamp::new(100, 1, 1);
let c = HlcTimestamp::new(101, 0, 1);
let d = HlcTimestamp::new(100, 0, 2);
assert!(a < b); assert!(b < c); assert!(a < d); }
#[test]
fn test_hlc_timestamp_u128_roundtrip() {
let ts = HlcTimestamp::new(1_700_000_000_000, 42, 7);
let encoded = ts.to_u128();
let decoded = HlcTimestamp::from_u128(encoded);
assert_eq!(ts, decoded);
}
#[test]
fn test_hlc_timestamp_display() {
let ts = HlcTimestamp::new(1000, 5, 3);
assert_eq!(ts.to_string(), "1000:5:3");
}
#[test]
fn test_hlc_zero() {
let z = HlcTimestamp::zero();
assert_eq!(z.physical_ms, 0);
assert_eq!(z.logical, 0);
assert_eq!(z.node_id, 0);
}
#[test]
fn test_hlc_concurrent_access() {
use std::sync::Arc;
let hlc = Arc::new(HybridLogicalClock::new(1));
let mut handles = vec![];
for _ in 0..10 {
let hlc = Arc::clone(&hlc);
handles.push(std::thread::spawn(move || {
let mut timestamps = Vec::new();
for _ in 0..100 {
timestamps.push(hlc.now());
}
timestamps
}));
}
let mut all: Vec<HlcTimestamp> = vec![];
for h in handles {
all.extend(h.join().unwrap());
}
let mut sorted = all.clone();
sorted.sort();
sorted.dedup();
assert_eq!(
sorted.len(),
all.len(),
"HLC must produce unique timestamps under contention"
);
}
}