use core::fmt;
use core::marker::PhantomData;
use crate::snowflake_id::{MixedId32, MixedId64, SnowflakeId32, SnowflakeId64};
pub trait IdInt: Copy + private::Sealed {
const BITS: u8;
fn from_raw(v: u64) -> Self;
fn to_raw(self) -> u64;
}
mod private {
pub trait Sealed {}
impl Sealed for u32 {}
impl Sealed for i32 {}
impl Sealed for u64 {}
impl Sealed for i64 {}
}
macro_rules! impl_id_int {
($($ty:ty => $bits:expr),* $(,)?) => {
$(
impl IdInt for $ty {
const BITS: u8 = $bits;
#[inline(always)]
fn from_raw(v: u64) -> Self {
v as Self
}
#[inline(always)]
fn to_raw(self) -> u64 {
self as u64
}
}
)*
};
}
impl_id_int!(
u32 => 32,
i32 => 32,
u64 => 64,
i64 => 64,
);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SequenceExhausted {
pub tick: u64,
pub max_sequence: u64,
}
impl fmt::Display for SequenceExhausted {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"sequence exhausted at tick {}: generated {} IDs in one tick",
self.tick,
self.max_sequence + 1
)
}
}
#[cfg(feature = "std")]
impl std::error::Error for SequenceExhausted {}
pub struct Snowflake<T: IdInt, const TS: u8, const WK: u8, const SQ: u8> {
worker_shifted: u64,
last_tick: u64,
sequence: u64,
_marker: PhantomData<T>,
}
impl<T: IdInt, const TS: u8, const WK: u8, const SQ: u8> Snowflake<T, TS, WK, SQ> {
const _VALIDATE: () = {
assert!(
TS as u16 + WK as u16 + SQ as u16 <= T::BITS as u16,
"layout exceeds integer bits: TS + WK + SQ > T::BITS"
);
assert!(TS > 0, "timestamp bits must be > 0");
assert!(SQ > 0, "sequence bits must be > 0");
};
const TS_SHIFT: u8 = WK + SQ;
const WK_SHIFT: u8 = SQ;
pub const SEQUENCE_MAX: u64 = (1u64 << SQ) - 1;
pub const WORKER_MAX: u64 = if WK == 0 { 0 } else { (1u64 << WK) - 1 };
pub const TIMESTAMP_MAX: u64 = (1u64 << TS) - 1;
pub fn new(worker: u64) -> Self {
let () = Self::_VALIDATE;
assert!(
worker <= Self::WORKER_MAX,
"worker {} exceeds max {}",
worker,
Self::WORKER_MAX
);
Self {
worker_shifted: worker << Self::WK_SHIFT,
last_tick: u64::MAX, sequence: 0,
_marker: PhantomData,
}
}
pub fn next(&mut self, tick: u64) -> Result<T, SequenceExhausted> {
if tick == self.last_tick {
self.sequence += 1;
if self.sequence > Self::SEQUENCE_MAX {
return Err(SequenceExhausted {
tick,
max_sequence: Self::SEQUENCE_MAX,
});
}
} else {
self.last_tick = tick;
self.sequence = 0;
}
Ok(T::from_raw(
(tick << Self::TS_SHIFT) | self.worker_shifted | self.sequence,
))
}
pub fn mixed(&mut self, tick: u64) -> Result<T, SequenceExhausted> {
let raw = self.next(tick)?;
Ok(T::from_raw(fibonacci_mix_64(raw.to_raw())))
}
pub fn unpack(id: T) -> (u64, u64, u64) {
let raw = id.to_raw();
(
raw >> Self::TS_SHIFT,
(raw >> Self::WK_SHIFT) & Self::WORKER_MAX,
raw & Self::SEQUENCE_MAX,
)
}
#[inline]
pub const fn worker(&self) -> u64 {
self.worker_shifted >> Self::WK_SHIFT
}
#[inline]
pub const fn sequence(&self) -> u64 {
self.sequence
}
#[inline]
pub const fn last_tick(&self) -> u64 {
self.last_tick
}
}
pub type Snowflake32<const TS: u8, const WK: u8, const SQ: u8> = Snowflake<u32, TS, WK, SQ>;
pub type Snowflake64<const TS: u8, const WK: u8, const SQ: u8> = Snowflake<u64, TS, WK, SQ>;
pub type SnowflakeSigned32<const TS: u8, const WK: u8, const SQ: u8> = Snowflake<i32, TS, WK, SQ>;
pub type SnowflakeSigned64<const TS: u8, const WK: u8, const SQ: u8> = Snowflake<i64, TS, WK, SQ>;
#[inline(always)]
const fn fibonacci_mix_64(x: u64) -> u64 {
x.wrapping_mul(0x9E37_79B9_7F4A_7C15)
}
impl<const TS: u8, const WK: u8, const SQ: u8> Snowflake<u64, TS, WK, SQ> {
#[inline]
pub fn next_id(&mut self, tick: u64) -> Result<SnowflakeId64<TS, WK, SQ>, SequenceExhausted> {
self.next(tick).map(SnowflakeId64::from_raw)
}
#[inline]
pub fn next_mixed(&mut self, tick: u64) -> Result<MixedId64<TS, WK, SQ>, SequenceExhausted> {
self.next_id(tick).map(|id| id.mixed())
}
}
impl<const TS: u8, const WK: u8, const SQ: u8> Snowflake<u32, TS, WK, SQ> {
#[inline]
pub fn next_id(&mut self, tick: u64) -> Result<SnowflakeId32<TS, WK, SQ>, SequenceExhausted> {
self.next(tick).map(SnowflakeId32::from_raw)
}
#[inline]
pub fn next_mixed(&mut self, tick: u64) -> Result<MixedId32<TS, WK, SQ>, SequenceExhausted> {
self.next_id(tick).map(|id| id.mixed())
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
type TestId = Snowflake64<42, 6, 16>;
#[test]
fn basic_generation() {
let mut id_gen = TestId::new(5);
let id = id_gen.next(0).unwrap();
let (ts, worker, seq) = TestId::unpack(id);
assert_eq!(ts, 0);
assert_eq!(worker, 5);
assert_eq!(seq, 0);
}
#[test]
fn sequence_increments_same_ts() {
let mut id_gen = TestId::new(5);
let id1 = id_gen.next(0).unwrap();
let id2 = id_gen.next(0).unwrap();
let id3 = id_gen.next(0).unwrap();
let (_, _, seq1) = TestId::unpack(id1);
let (_, _, seq2) = TestId::unpack(id2);
let (_, _, seq3) = TestId::unpack(id3);
assert_eq!(seq1, 0);
assert_eq!(seq2, 1);
assert_eq!(seq3, 2);
}
#[test]
fn sequence_resets_new_ts() {
let mut id_gen = TestId::new(5);
let _ = id_gen.next(0).unwrap();
let _ = id_gen.next(0).unwrap();
let id = id_gen.next(1).unwrap();
let (ts, _, seq) = TestId::unpack(id);
assert_eq!(ts, 1);
assert_eq!(seq, 0);
}
#[test]
fn worker_encoded_correctly() {
for worker in [0, 1, 31, 63] {
let mut id_gen = TestId::new(worker);
let id = id_gen.next(0).unwrap();
let (_, w, _) = TestId::unpack(id);
assert_eq!(w, worker);
}
}
#[test]
fn ids_are_unique() {
let mut id_gen = TestId::new(5);
let mut ids = Vec::new();
for i in 0..1000u64 {
let ts = i / 100;
ids.push(id_gen.next(ts).unwrap());
}
let mut sorted = ids.clone();
sorted.sort_unstable();
sorted.dedup();
assert_eq!(sorted.len(), ids.len());
}
#[test]
fn sequence_exhaustion() {
type TinySeq = Snowflake64<42, 6, 4>;
let mut id_gen = TinySeq::new(5);
for _ in 0..16 {
id_gen.next(0).unwrap();
}
let result = id_gen.next(0);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.max_sequence, 15);
}
#[test]
#[should_panic(expected = "worker 100 exceeds max 63")]
fn worker_overflow_panics() {
let _id_gen = TestId::new(100); }
#[test]
fn signed_output() {
type SignedId = SnowflakeSigned64<42, 6, 16>;
let mut id_gen = SignedId::new(5);
let id: i64 = id_gen.next(0).unwrap();
let (ts, worker, seq) = SignedId::unpack(id);
assert_eq!(ts, 0);
assert_eq!(worker, 5);
assert_eq!(seq, 0);
}
#[test]
fn small_layout_32bit() {
type SmallId = Snowflake32<20, 4, 8>;
let mut id_gen = SmallId::new(7);
let id: u32 = id_gen.next(0).unwrap();
let (ts, worker, seq) = SmallId::unpack(id);
assert_eq!(ts, 0);
assert_eq!(worker, 7);
assert_eq!(seq, 0);
}
#[test]
fn zero_worker_bits() {
type SingleWorker = Snowflake64<48, 0, 16>;
let mut id_gen = SingleWorker::new(0);
let id = id_gen.next(0).unwrap();
let (ts, worker, seq) = SingleWorker::unpack(id);
assert_eq!(ts, 0);
assert_eq!(worker, 0);
assert_eq!(seq, 0);
}
#[test]
fn non_time_timestamp() {
type BlockId = Snowflake64<32, 8, 24>;
let mut id_gen = BlockId::new(1);
let id1 = id_gen.next(1000).unwrap(); let id2 = id_gen.next(1000).unwrap(); let id3 = id_gen.next(1001).unwrap();
let (ts1, _, seq1) = BlockId::unpack(id1);
let (ts2, _, seq2) = BlockId::unpack(id2);
let (ts3, _, seq3) = BlockId::unpack(id3);
assert_eq!(ts1, 1000);
assert_eq!(seq1, 0);
assert_eq!(ts2, 1000);
assert_eq!(seq2, 1);
assert_eq!(ts3, 1001);
assert_eq!(seq3, 0);
}
}