use std::sync::Mutex;
use std::time::{SystemTime, UNIX_EPOCH};
pub const NOETL_EPOCH_MS: u64 = 1_704_067_200_000;
const MACHINE_ID_BITS: u8 = 10;
const SEQUENCE_BITS: u8 = 12;
pub const MAX_MACHINE_ID: u16 = (1 << MACHINE_ID_BITS) - 1;
const SEQUENCE_MASK: u16 = (1 << SEQUENCE_BITS) - 1;
const MACHINE_ID_SHIFT: u8 = SEQUENCE_BITS;
const TIMESTAMP_SHIFT: u8 = SEQUENCE_BITS + MACHINE_ID_BITS;
#[derive(Debug)]
struct State {
last_timestamp: u64,
sequence: u16,
}
#[derive(Debug)]
pub struct SnowflakeGenerator {
machine_id: u16,
state: Mutex<State>,
}
impl SnowflakeGenerator {
pub fn new(machine_id: u16) -> Result<Self, SnowflakeError> {
if machine_id > MAX_MACHINE_ID {
return Err(SnowflakeError::MachineIdOutOfRange { machine_id });
}
Ok(Self {
machine_id,
state: Mutex::new(State {
last_timestamp: 0,
sequence: 0,
}),
})
}
pub fn generate(&self) -> Result<i64, SnowflakeError> {
let mut state = self
.state
.lock()
.map_err(|_| SnowflakeError::StateLockPoisoned)?;
let mut now = current_noetl_ms()?;
if now < state.last_timestamp {
now = state.last_timestamp;
}
if now == state.last_timestamp {
state.sequence = (state.sequence + 1) & SEQUENCE_MASK;
if state.sequence == 0 {
now = wait_until_next_ms(state.last_timestamp)?;
state.last_timestamp = now;
}
} else {
state.last_timestamp = now;
state.sequence = 0;
}
let id = ((now as i64) << TIMESTAMP_SHIFT)
| ((self.machine_id as i64) << MACHINE_ID_SHIFT)
| (state.sequence as i64);
Ok(id)
}
pub fn machine_id(&self) -> u16 {
self.machine_id
}
}
pub fn derive_machine_id(seed: &str) -> u16 {
let mut hash: u64 = 0xcbf29ce484222325;
for b in seed.bytes() {
hash ^= b as u64;
hash = hash.wrapping_mul(0x100000001b3);
}
(hash & MAX_MACHINE_ID as u64) as u16
}
#[derive(Debug, thiserror::Error)]
pub enum SnowflakeError {
#[error("machine_id {machine_id} exceeds 10-bit max {MAX_MACHINE_ID}")]
MachineIdOutOfRange { machine_id: u16 },
#[error("system clock is before NoETL epoch (2024-01-01); fix NTP")]
ClockBeforeEpoch,
#[error("snowflake generator state mutex was poisoned")]
StateLockPoisoned,
}
fn current_noetl_ms() -> Result<u64, SnowflakeError> {
let unix_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| SnowflakeError::ClockBeforeEpoch)?
.as_millis() as u64;
if unix_ms < NOETL_EPOCH_MS {
return Err(SnowflakeError::ClockBeforeEpoch);
}
Ok(unix_ms - NOETL_EPOCH_MS)
}
fn wait_until_next_ms(last: u64) -> Result<u64, SnowflakeError> {
loop {
let now = current_noetl_ms()?;
if now > last {
return Ok(now);
}
std::hint::spin_loop();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use std::sync::Arc;
#[test]
fn rejects_machine_id_above_10_bits() {
let err = SnowflakeGenerator::new(1024).unwrap_err();
match err {
SnowflakeError::MachineIdOutOfRange { machine_id } => {
assert_eq!(machine_id, 1024);
}
other => panic!("expected MachineIdOutOfRange, got {other:?}"),
}
}
#[test]
fn accepts_machine_id_at_max() {
SnowflakeGenerator::new(MAX_MACHINE_ID).expect("max machine id is valid");
}
#[test]
fn generated_id_is_non_negative_i64() {
let gen = SnowflakeGenerator::new(1).unwrap();
for _ in 0..100 {
let id = gen.generate().unwrap();
assert!(id >= 0, "id {id} negative — sign bit leaked");
}
}
#[test]
fn ids_are_monotonic_within_a_single_thread() {
let gen = SnowflakeGenerator::new(7).unwrap();
let mut prev = gen.generate().unwrap();
for _ in 0..10_000 {
let id = gen.generate().unwrap();
assert!(
id > prev,
"ids not monotonic: prev={prev} current={id}"
);
prev = id;
}
}
#[test]
fn machine_id_is_preserved_in_generated_ids() {
let gen = SnowflakeGenerator::new(42).unwrap();
let id = gen.generate().unwrap();
let extracted = ((id >> MACHINE_ID_SHIFT) as u16) & MAX_MACHINE_ID;
assert_eq!(extracted, 42);
}
#[test]
fn sequence_rolls_over_after_4096_within_one_ms() {
let gen = SnowflakeGenerator::new(3).unwrap();
let mut seen = HashSet::with_capacity(10_000);
for _ in 0..10_000 {
let id = gen.generate().unwrap();
assert!(seen.insert(id), "id {id} repeated — sequence overflow not handled");
}
}
#[test]
fn concurrent_generators_produce_unique_ids() {
let gen = Arc::new(SnowflakeGenerator::new(5).unwrap());
let mut handles = Vec::new();
for _ in 0..8 {
let g = gen.clone();
handles.push(std::thread::spawn(move || {
let mut local = Vec::with_capacity(1000);
for _ in 0..1000 {
local.push(g.generate().unwrap());
}
local
}));
}
let mut all = HashSet::new();
for h in handles {
for id in h.join().unwrap() {
assert!(all.insert(id), "duplicate id {id} from concurrent generators");
}
}
assert_eq!(all.len(), 8000);
}
#[test]
fn derive_machine_id_is_stable_for_same_input() {
let a = derive_machine_id("noetl-server-pod-0");
let b = derive_machine_id("noetl-server-pod-0");
assert_eq!(a, b);
}
#[test]
fn derive_machine_id_differs_for_different_inputs() {
let a = derive_machine_id("pod-0");
let b = derive_machine_id("pod-1");
assert_ne!(a, b);
}
#[test]
fn derive_machine_id_stays_within_10_bits() {
for s in &["", "a", "noetl-server", "very-long-hostname-with-pid-12345"] {
assert!(derive_machine_id(s) <= MAX_MACHINE_ID);
}
}
}