use crate::types::MachineId;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::sync::atomic::{AtomicI32, AtomicI64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, thiserror::Error)]
pub enum SnowflakeError {
#[error(
"system clock jumped backward by {drift_ms}ms (>{max_drift_ms}ms max) — check NTP configuration"
)]
ClockDriftExceeded { drift_ms: i64, max_drift_ms: i64 },
}
const CUSTOM_EPOCH_MS: i64 = 1_735_689_600_000;
const MAX_CLOCK_DRIFT_MS: i64 = 5_000;
const MACHINE_ID_BITS: u32 = 10;
const SEQUENCE_BITS: u32 = 12;
const MACHINE_ID_SHIFT: u32 = SEQUENCE_BITS;
const TIMESTAMP_SHIFT: u32 = MACHINE_ID_BITS + SEQUENCE_BITS;
const SEQUENCE_MASK: i64 = (1 << SEQUENCE_BITS) - 1;
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
pub struct Snowflake(pub i64);
impl fmt::Display for Snowflake {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Copy)]
pub struct SnowflakeParts {
pub timestamp: i64,
pub machine_id: MachineId,
pub sequence: i32,
}
impl Snowflake {
pub fn parts(&self) -> SnowflakeParts {
SnowflakeParts {
timestamp: (self.0 >> TIMESTAMP_SHIFT) + CUSTOM_EPOCH_MS,
machine_id: MachineId::new_unchecked(
((self.0 >> MACHINE_ID_SHIFT) & ((1 << MACHINE_ID_BITS) - 1)) as i32,
),
sequence: (self.0 & SEQUENCE_MASK) as i32,
}
}
}
pub struct SnowflakeGenerator {
machine_id: AtomicI32,
ts_seq: AtomicI64,
}
fn pack_ts_seq(timestamp: i64, sequence: i64) -> i64 {
(timestamp << SEQUENCE_BITS) | sequence
}
fn unpack_timestamp(ts_seq: i64) -> i64 {
ts_seq >> SEQUENCE_BITS
}
fn unpack_sequence(ts_seq: i64) -> i64 {
ts_seq & SEQUENCE_MASK
}
impl SnowflakeGenerator {
pub fn new() -> Self {
Self {
machine_id: AtomicI32::new(0),
ts_seq: AtomicI64::new(pack_ts_seq(-1, 0)),
}
}
pub fn set_machine_id(&self, id: MachineId) {
assert!(
id.value() >= 0 && id.value() <= crate::types::MAX_MACHINE_ID,
"machine ID {} is out of range (valid: 0..=1023)",
id.value()
);
self.machine_id.store(id.value(), Ordering::Release);
}
fn current_timestamp() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock before Unix epoch")
.as_millis() as i64
}
pub async fn next_async(&self) -> Result<Snowflake, SnowflakeError> {
loop {
let timestamp = Self::current_timestamp();
let current = self.ts_seq.load(Ordering::Acquire);
let last_ts = unpack_timestamp(current);
if timestamp < last_ts {
let drift_ms = last_ts - timestamp;
if drift_ms > MAX_CLOCK_DRIFT_MS {
return Err(SnowflakeError::ClockDriftExceeded {
drift_ms,
max_drift_ms: MAX_CLOCK_DRIFT_MS,
});
}
if drift_ms > 100 {
tracing::warn!(
drift_ms,
"snowflake: system clock jumped backward, waiting for clock to catch up"
);
}
tokio::task::yield_now().await;
continue;
}
let (new_val, seq) = if timestamp == last_ts {
let seq = unpack_sequence(current) + 1;
if seq > SEQUENCE_MASK {
tokio::task::yield_now().await;
continue;
}
(pack_ts_seq(timestamp, seq), seq)
} else {
(pack_ts_seq(timestamp, 0), 0)
};
let machine_id = self.machine_id.load(Ordering::Acquire);
if self
.ts_seq
.compare_exchange(current, new_val, Ordering::AcqRel, Ordering::Relaxed)
.is_err()
{
continue;
}
let id = ((timestamp - CUSTOM_EPOCH_MS) << TIMESTAMP_SHIFT)
| ((machine_id as i64) << MACHINE_ID_SHIFT)
| seq;
return Ok(Snowflake(id));
}
}
pub fn next(&self) -> Result<Snowflake, SnowflakeError> {
loop {
let timestamp = Self::current_timestamp();
let current = self.ts_seq.load(Ordering::Acquire);
let last_ts = unpack_timestamp(current);
if timestamp < last_ts {
let drift_ms = last_ts - timestamp;
if drift_ms > MAX_CLOCK_DRIFT_MS {
return Err(SnowflakeError::ClockDriftExceeded {
drift_ms,
max_drift_ms: MAX_CLOCK_DRIFT_MS,
});
}
if drift_ms > 100 {
tracing::warn!(
drift_ms,
"snowflake: system clock jumped backward, waiting for clock to catch up"
);
}
std::thread::yield_now();
continue;
}
let (new_val, seq) = if timestamp == last_ts {
let seq = unpack_sequence(current) + 1;
if seq > SEQUENCE_MASK {
std::thread::yield_now();
continue;
}
(pack_ts_seq(timestamp, seq), seq)
} else {
(pack_ts_seq(timestamp, 0), 0)
};
let machine_id = self.machine_id.load(Ordering::Acquire);
if self
.ts_seq
.compare_exchange(current, new_val, Ordering::AcqRel, Ordering::Relaxed)
.is_err()
{
continue; }
let id = ((timestamp - CUSTOM_EPOCH_MS) << TIMESTAMP_SHIFT)
| ((machine_id as i64) << MACHINE_ID_SHIFT)
| seq;
return Ok(Snowflake(id));
}
}
}
impl Default for SnowflakeGenerator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
#[test]
fn uniqueness() {
let gen = SnowflakeGenerator::new();
let ids: Vec<Snowflake> = (0..10_000).map(|_| gen.next().unwrap()).collect();
let unique: HashSet<_> = ids.iter().collect();
assert_eq!(unique.len(), ids.len(), "all IDs must be unique");
}
#[test]
fn monotonicity() {
let gen = SnowflakeGenerator::new();
let mut prev = gen.next().unwrap();
for _ in 0..1_000 {
let next = gen.next().unwrap();
assert!(next > prev, "IDs must be strictly increasing");
prev = next;
}
}
#[test]
fn parts_round_trip() {
let gen = SnowflakeGenerator::new();
gen.set_machine_id(MachineId::new_unchecked(42));
let id = gen.next().unwrap();
let parts = id.parts();
assert_eq!(parts.machine_id, MachineId::new_unchecked(42));
assert!(parts.timestamp > CUSTOM_EPOCH_MS);
}
#[test]
fn machine_id_update() {
let gen = SnowflakeGenerator::new();
let id1 = gen.next().unwrap();
assert_eq!(id1.parts().machine_id, MachineId::new_unchecked(0));
gen.set_machine_id(MachineId::new_unchecked(7));
std::thread::sleep(std::time::Duration::from_millis(2));
let id2 = gen.next().unwrap();
assert_eq!(id2.parts().machine_id, MachineId::new_unchecked(7));
}
#[test]
fn serde_round_trip() {
let sf = Snowflake(123456789);
let bytes = rmp_serde::to_vec(&sf).unwrap();
let decoded: Snowflake = rmp_serde::from_slice(&bytes).unwrap();
assert_eq!(sf, decoded);
let json = serde_json::to_string(&sf).unwrap();
let decoded: Snowflake = serde_json::from_str(&json).unwrap();
assert_eq!(sf, decoded);
}
#[test]
#[should_panic(expected = "out of range")]
fn set_machine_id_rejects_overflow() {
let gen = SnowflakeGenerator::new();
gen.set_machine_id(MachineId::new_unchecked(1024));
}
#[test]
#[should_panic(expected = "out of range")]
fn set_machine_id_rejects_negative() {
let gen = SnowflakeGenerator::new();
gen.set_machine_id(MachineId::new_unchecked(-1));
}
#[test]
fn set_machine_id_accepts_max_valid() {
let gen = SnowflakeGenerator::new();
gen.set_machine_id(MachineId::new_unchecked(1023));
std::thread::sleep(std::time::Duration::from_millis(2));
let id = gen.next().unwrap();
assert_eq!(id.parts().machine_id, MachineId::new_unchecked(1023));
}
#[test]
fn concurrent_uniqueness() {
use std::sync::Arc;
let gen = Arc::new(SnowflakeGenerator::new());
let mut handles = vec![];
for _ in 0..4 {
let g = gen.clone();
handles.push(std::thread::spawn(move || {
(0..2_500).map(|_| g.next().unwrap()).collect::<Vec<_>>()
}));
}
let mut all_ids = HashSet::new();
for h in handles {
for id in h.join().unwrap() {
assert!(all_ids.insert(id), "duplicate ID found in concurrent test");
}
}
assert_eq!(all_ids.len(), 10_000);
}
#[tokio::test]
async fn next_async_uniqueness() {
let gen = SnowflakeGenerator::new();
let mut ids = HashSet::new();
for _ in 0..1_000 {
let id = gen.next_async().await.unwrap();
assert!(ids.insert(id), "duplicate ID from next_async");
}
assert_eq!(ids.len(), 1_000);
}
#[tokio::test]
async fn next_async_monotonicity() {
let gen = SnowflakeGenerator::new();
let mut prev = gen.next_async().await.unwrap();
for _ in 0..100 {
let next = gen.next_async().await.unwrap();
assert!(next > prev, "next_async IDs must be strictly increasing");
prev = next;
}
}
}