use core::fmt;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
pub const DEFAULT_EPOCH_MS: u64 = 1_767_225_600_000;
pub const SEQUENCE_BITS: u32 = 12;
pub const WORKER_BITS: u32 = 10;
pub const TIMESTAMP_BITS: u32 = 41;
const SEQUENCE_MASK: u64 = (1 << SEQUENCE_BITS) - 1;
const WORKER_MASK: u64 = (1 << WORKER_BITS) - 1;
const TIMESTAMP_MASK: u64 = (1 << TIMESTAMP_BITS) - 1;
const WORKER_SHIFT: u32 = SEQUENCE_BITS;
const TIMESTAMP_SHIFT: u32 = SEQUENCE_BITS + WORKER_BITS;
const STATE_SEQ_BITS: u32 = 13;
const STATE_SEQ_MASK: u64 = (1 << STATE_SEQ_BITS) - 1;
const STATE_SEQ_EXHAUSTED: u64 = SEQUENCE_MASK + 1;
#[derive(Debug)]
pub struct Snowflake {
worker_id: u16,
epoch_ms: u64,
state: AtomicU64,
}
impl Snowflake {
pub const fn new(worker_id: u16) -> Self {
Self::with_epoch(worker_id, DEFAULT_EPOCH_MS)
}
pub const fn with_epoch(worker_id: u16, epoch_ms: u64) -> Self {
Self {
worker_id: (worker_id as u64 & WORKER_MASK) as u16,
epoch_ms,
state: AtomicU64::new(0),
}
}
pub const fn worker_id(&self) -> u16 {
self.worker_id
}
pub const fn epoch_ms(&self) -> u64 {
self.epoch_ms
}
pub fn try_next_id(&self) -> Result<u64, ClockSkew> {
loop {
let cur = self.state.load(Ordering::Acquire);
let last_ms = cur >> STATE_SEQ_BITS;
let next_seq = cur & STATE_SEQ_MASK;
let now = current_offset_ms(self.epoch_ms);
if now < last_ms {
return Err(ClockSkew {
last_ms,
now_ms: now,
});
}
let (use_ms, assigned, new_next_seq) = if now == last_ms {
if next_seq >= STATE_SEQ_EXHAUSTED {
sleep_until_after(self.epoch_ms, last_ms);
continue;
}
(last_ms, next_seq, next_seq + 1)
} else {
(now, 0u64, 1u64)
};
let new_state = (use_ms << STATE_SEQ_BITS) | new_next_seq;
if self
.state
.compare_exchange(cur, new_state, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
let id = (use_ms << TIMESTAMP_SHIFT)
| ((self.worker_id as u64) << WORKER_SHIFT)
| assigned;
return Ok(id);
}
}
}
pub fn next_id(&self) -> u64 {
match self.try_next_id() {
Ok(id) => id,
Err(e) => panic!("snowflake: clock moved backward ({e})"),
}
}
pub const fn parts(id: u64) -> (u64, u16, u16) {
let timestamp_offset = (id >> TIMESTAMP_SHIFT) & TIMESTAMP_MASK;
let worker = ((id >> WORKER_SHIFT) & WORKER_MASK) as u16;
let sequence = (id & SEQUENCE_MASK) as u16;
(timestamp_offset, worker, sequence)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ClockSkew {
pub last_ms: u64,
pub now_ms: u64,
}
impl fmt::Display for ClockSkew {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"clock moved backward: last issued at offset {} ms, now at offset {} ms",
self.last_ms, self.now_ms
)
}
}
impl std::error::Error for ClockSkew {}
fn current_offset_ms(epoch_ms: u64) -> u64 {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
now.saturating_sub(epoch_ms) & TIMESTAMP_MASK
}
fn sleep_until_after(epoch_ms: u64, last_ms: u64) {
loop {
if current_offset_ms(epoch_ms) > last_ms {
return;
}
std::thread::sleep(Duration::from_micros(100));
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::thread;
#[test]
fn next_id_produces_value() {
let gen = Snowflake::new(1);
assert!(gen.next_id() > 0);
}
#[test]
fn worker_id_clamped() {
let gen = Snowflake::new(0xffff);
assert_eq!(gen.worker_id(), 0x3ff);
let id = gen.next_id();
assert_eq!(Snowflake::parts(id).1, 0x3ff);
}
#[test]
fn worker_field_extracts() {
let gen = Snowflake::new(42);
let id = gen.next_id();
let (_, worker, _) = Snowflake::parts(id);
assert_eq!(worker, 42);
}
#[test]
fn monotonic_in_burst() {
let gen = Snowflake::new(1);
let mut prev = gen.next_id();
for _ in 0..10_000 {
let cur = gen.next_id();
assert!(cur > prev, "expected {cur} > {prev}");
prev = cur;
}
}
#[test]
fn all_unique_in_burst() {
let gen = Snowflake::new(1);
let mut set = HashSet::new();
for _ in 0..50_000 {
let id = gen.next_id();
assert!(set.insert(id));
}
}
#[test]
fn parts_round_trip() {
let gen = Snowflake::with_epoch(7, DEFAULT_EPOCH_MS);
let id = gen.next_id();
let (ts, worker, seq) = Snowflake::parts(id);
assert_eq!(worker, 7);
let reassembled = (ts << TIMESTAMP_SHIFT) | ((worker as u64) << WORKER_SHIFT) | seq as u64;
assert_eq!(reassembled, id);
}
#[test]
fn sequence_resets_each_ms() {
let gen = Snowflake::new(1);
let _ = gen.next_id();
thread::sleep(Duration::from_millis(3));
let id_after_sleep = gen.next_id();
let (_, _, seq) = Snowflake::parts(id_after_sleep);
assert_eq!(seq, 0, "first ID of a fresh ms must have sequence 0");
}
#[test]
fn sequence_exhaustion_blocks_until_next_ms() {
let gen = Snowflake::new(1);
let now = current_offset_ms(gen.epoch_ms);
let exhausted_state = (now << STATE_SEQ_BITS) | STATE_SEQ_EXHAUSTED;
gen.state.store(exhausted_state, Ordering::Release);
let start = SystemTime::now();
let id = gen.next_id();
let elapsed = SystemTime::now().duration_since(start).unwrap();
let (ts, _, seq) = Snowflake::parts(id);
assert!(ts > now, "new ID must be in a later millisecond");
assert_eq!(seq, 0);
assert!(
elapsed < Duration::from_millis(50),
"block should be roughly one ms, got {elapsed:?}"
);
}
#[test]
fn clock_skew_reported_via_result() {
let gen = Snowflake::new(1);
let future_ms = current_offset_ms(gen.epoch_ms) + 5_000;
gen.state
.store(future_ms << STATE_SEQ_BITS, Ordering::Release);
match gen.try_next_id() {
Err(ClockSkew { last_ms, now_ms }) => {
assert_eq!(last_ms, future_ms);
assert!(now_ms < last_ms);
}
Ok(id) => panic!("expected ClockSkew, got id {id}"),
}
}
#[test]
#[should_panic(expected = "clock moved backward")]
fn next_id_panics_on_clock_skew() {
let gen = Snowflake::new(1);
let future_ms = current_offset_ms(gen.epoch_ms) + 5_000;
gen.state
.store(future_ms << STATE_SEQ_BITS, Ordering::Release);
let _ = gen.next_id();
}
#[test]
fn multi_thread_all_unique() {
let gen = Arc::new(Snowflake::new(3));
let mut handles = Vec::new();
for _ in 0..8 {
let g = Arc::clone(&gen);
handles.push(thread::spawn(move || {
let mut local = Vec::with_capacity(2000);
for _ in 0..2000 {
local.push(g.next_id());
}
local
}));
}
let mut all = HashSet::new();
for h in handles {
for id in h.join().unwrap() {
assert!(all.insert(id), "duplicate id under thread contention");
}
}
assert_eq!(all.len(), 8 * 2000);
}
#[test]
fn custom_epoch_round_trip() {
let epoch = 1_700_000_000_000_u64;
let gen = Snowflake::with_epoch(9, epoch);
let id = gen.next_id();
let (ts_offset, worker, _) = Snowflake::parts(id);
assert_eq!(worker, 9);
assert_eq!(gen.epoch_ms(), epoch);
let wall = ts_offset + epoch;
assert!(wall > epoch);
}
#[test]
fn parts_extracts_each_field() {
let ts: u64 = 12_345;
let worker: u64 = 700;
let seq: u64 = 4000;
let id = (ts << TIMESTAMP_SHIFT) | (worker << WORKER_SHIFT) | seq;
let (got_ts, got_w, got_s) = Snowflake::parts(id);
assert_eq!(got_ts, ts);
assert_eq!(got_w as u64, worker);
assert_eq!(got_s as u64, seq);
}
}