use std::{
cmp::Ordering,
hint::spin_loop,
time::{Instant, SystemTime, UNIX_EPOCH},
};
const MIN_BITS: u64 = 1;
const TIMEOUT_MILLIS: u128 = 1000;
#[cfg(feature = "float-safe")]
const WORKER_ID_BITS: u64 = 4;
#[cfg(not(feature = "float-safe"))]
const WORKER_ID_BITS: u64 = 10;
#[cfg(not(feature = "float-safe"))]
const EPOCH_MILLIS: u64 = 1704038400000; #[cfg(feature = "float-safe")]
const EPOCH_SECS: u64 = 1704038400;
#[cfg(feature = "float-safe")]
const TIMESTAMP_BITS: u64 = 32;
#[cfg(not(feature = "float-safe"))]
const TIMESTAMP_BITS: u64 = 41;
#[cfg(feature = "float-safe")]
const SAFE_UNUSED_BITS: u64 = 11;
#[cfg(not(feature = "float-safe"))]
const SIGN_BITS: u64 = 1;
#[cfg(not(feature = "float-safe"))]
const MAX_ADJUSTABLE_BITS: u64 = 64 - SIGN_BITS - TIMESTAMP_BITS;
#[cfg(feature = "float-safe")]
const MAX_ADJUSTABLE_BITS: u64 = 64 - SAFE_UNUSED_BITS - TIMESTAMP_BITS;
#[derive(Debug)]
pub struct Snowflake {
epoch: u64, last_timestamp: u64, worker_id: u64, sequence: u64, timeout_millis: Option<u128>,
max_sequence: u64, timestamp_shift: u64, worker_id_shift: u64, }
#[derive(thiserror::Error, Debug, Clone, PartialEq)]
pub enum SnowflakeError {
#[error("argument error: {0}")]
ArgumentError(String),
#[error("clock move backwards")]
ClockMoveBackwards,
#[error("wait for next period timeout")]
WaitForNextPeriodTimeout,
#[error("epoch must be greater than the current time")]
InvalidEpoch,
#[error("failed to convert timestamp to milliseconds")]
FailedConvertToMillis,
}
impl Snowflake {
pub fn new(worker_id: u64) -> Result<Self, SnowflakeError> { Self::builder().with_worker_id(worker_id).build() }
pub fn builder() -> SnowflakeBuilder {
SnowflakeBuilder {
worker_id: 0,
worker_id_bits: Some(WORKER_ID_BITS),
timeout_millis: Some(TIMEOUT_MILLIS),
#[cfg(feature = "float-safe")]
epoch: Some(EPOCH_SECS),
#[cfg(not(feature = "float-safe"))]
epoch: Some(EPOCH_MILLIS),
}
}
fn with_config(
worker_id: u64,
worker_id_bits: Option<u64>,
timeout_millis: Option<u128>,
epoch: Option<u64>,
) -> Result<Self, SnowflakeError> {
let worker_id_bits = worker_id_bits.unwrap_or(WORKER_ID_BITS);
if !(MIN_BITS .. MAX_ADJUSTABLE_BITS).contains(&worker_id_bits) {
return Err(SnowflakeError::ArgumentError(
format!(
"invalid worker id bits(={worker_id_bits}), expected worker id bits ∈ [{MIN_BITS},{MAX_ADJUSTABLE_BITS})"
))
);
}
let sequence_bits = MAX_ADJUSTABLE_BITS - worker_id_bits;
let max_worker_id = (1 << worker_id_bits) - 1;
let max_sequence = (1 << sequence_bits) - 1;
let worker_id_shift = sequence_bits;
let timestamp_shift = worker_id_bits + sequence_bits;
if worker_id > max_worker_id {
return Err(SnowflakeError::ArgumentError(format!(
"invalid worker id(={worker_id}), expected worker id ∈ [0,{max_worker_id}]",
)));
}
#[cfg(feature = "float-safe")]
let epoch = epoch.unwrap_or(EPOCH_SECS);
#[cfg(not(feature = "float-safe"))]
let epoch = epoch.unwrap_or(EPOCH_MILLIS);
#[cfg(feature = "float-safe")]
if epoch >= Self::timestamp()? {
return Err(SnowflakeError::InvalidEpoch);
}
#[cfg(not(feature = "float-safe"))]
if epoch >= Self::timestamp_millis()? {
return Err(SnowflakeError::InvalidEpoch);
}
Ok(Self {
epoch,
last_timestamp: 0,
worker_id,
sequence: 0,
timeout_millis,
max_sequence,
timestamp_shift,
worker_id_shift,
})
}
pub fn generate(&mut self) -> Result<u64, SnowflakeError> {
#[cfg(feature = "float-safe")]
let mut now = self.current_timestamp_since_epoch()?;
#[cfg(not(feature = "float-safe"))]
let mut now = self.current_timestamp_millis_since_epoch()?;
match now.cmp(&self.last_timestamp) {
Ordering::Less => {
let possible_sequence = (self.sequence + 1) & self.max_sequence;
if possible_sequence > 0 {
self.sequence = possible_sequence;
return Ok((self.last_timestamp << self.timestamp_shift)
| (self.worker_id << self.worker_id_shift)
| (self.sequence));
}
return Err(SnowflakeError::ClockMoveBackwards);
}
Ordering::Equal => {
self.sequence = (self.sequence + 1) & self.max_sequence;
if self.sequence == 0 {
let timeout_start = Instant::now();
while now <= self.last_timestamp {
if let Some(timeout_millis) = self.timeout_millis {
if Instant::now().duration_since(timeout_start).as_millis() > timeout_millis {
return Err(SnowflakeError::WaitForNextPeriodTimeout);
}
}
#[cfg(feature = "float-safe")]
if let Ok(latest_timestamp) = self.current_timestamp_since_epoch() {
now = latest_timestamp;
}
#[cfg(not(feature = "float-safe"))]
if let Ok(latest_timestamp_millis) = self.current_timestamp_millis_since_epoch() {
now = latest_timestamp_millis;
}
spin_loop();
}
}
}
Ordering::Greater => {
self.sequence = 0;
}
}
self.last_timestamp = now;
Ok((now << self.timestamp_shift) | (self.worker_id << self.worker_id_shift) | (self.sequence))
}
#[cfg(feature = "float-safe")]
fn timestamp() -> Result<u64, SnowflakeError> {
Ok(SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| SnowflakeError::ClockMoveBackwards)?
.as_secs())
}
#[cfg(not(feature = "float-safe"))]
fn timestamp_millis() -> Result<u64, SnowflakeError> {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| SnowflakeError::ClockMoveBackwards)?
.as_millis()
.try_into()
.map_err(|_| SnowflakeError::FailedConvertToMillis)
}
#[cfg(feature = "float-safe")]
fn current_timestamp_since_epoch(&self) -> Result<u64, SnowflakeError> {
let now = Self::timestamp()?;
match now.cmp(&self.epoch) {
Ordering::Less => Err(SnowflakeError::ClockMoveBackwards),
_ => Ok(now - self.epoch),
}
}
#[cfg(not(feature = "float-safe"))]
fn current_timestamp_millis_since_epoch(&self) -> Result<u64, SnowflakeError> {
let now = Self::timestamp_millis()?;
match now.cmp(&self.epoch) {
Ordering::Less => Err(SnowflakeError::ClockMoveBackwards),
_ => Ok(now - self.epoch),
}
}
}
pub struct SnowflakeBuilder {
worker_id: u64,
worker_id_bits: Option<u64>,
timeout_millis: Option<u128>,
epoch: Option<u64>,
}
impl SnowflakeBuilder {
pub fn with_worker_id(mut self, worker_id: u64) -> Self {
self.worker_id = worker_id;
self
}
pub fn with_worker_id_bits(mut self, worker_id_bits: u64) -> Self {
self.worker_id_bits = Some(worker_id_bits);
self
}
pub fn with_timeout_millis(mut self, timeout_millis: u128) -> Self {
self.timeout_millis = Some(timeout_millis);
self
}
pub fn with_epoch(mut self, epoch: u64) -> Self {
self.epoch = Some(epoch);
self
}
pub fn build(self) -> Result<Snowflake, SnowflakeError> {
Snowflake::with_config(self.worker_id, self.worker_id_bits, self.timeout_millis, self.epoch)
}
}