use std::fmt;
use std::str::FromStr;
use std::sync::{Mutex, OnceLock};
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use crate::crypto::ids::IdError;
const DEFAULT_EPOCH_MS: u64 = 1_577_836_800_000;
const WORKER_BITS: u8 = 10;
const SEQUENCE_BITS: u8 = 12;
const MAX_WORKER: u16 = (1 << WORKER_BITS) - 1;
const MAX_SEQUENCE: u16 = (1 << SEQUENCE_BITS) - 1;
#[derive(Debug, Clone)]
pub struct SnowflakeConfig {
pub worker_id: u16,
pub epoch_ms: u64,
}
impl Default for SnowflakeConfig {
fn default() -> Self {
let worker_id: u16 = std::env::var("SNOWFLAKE_WORKER_ID")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(0);
Self {
worker_id: worker_id & MAX_WORKER,
epoch_ms: DEFAULT_EPOCH_MS,
}
}
}
struct GeneratorState {
last_ts: i64,
sequence: u16,
}
static STATE: OnceLock<Mutex<GeneratorState>> = OnceLock::new();
fn state() -> &'static Mutex<GeneratorState> {
STATE.get_or_init(|| {
Mutex::new(GeneratorState {
last_ts: -1,
sequence: 0,
})
})
}
fn current_ts(epoch_ms: u64) -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time went backwards")
.as_millis() as i64
- epoch_ms as i64
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Snowflake(i64);
impl Snowflake {
pub fn generate(config: &SnowflakeConfig) -> Self {
let mut guard = state().lock().unwrap();
let mut ts = current_ts(config.epoch_ms);
if ts == guard.last_ts {
guard.sequence = (guard.sequence + 1) & MAX_SEQUENCE;
if guard.sequence == 0 {
while ts <= guard.last_ts {
ts = current_ts(config.epoch_ms);
}
}
} else {
guard.sequence = 0;
}
guard.last_ts = ts;
let seq = guard.sequence as i64;
drop(guard);
let id = (ts << (WORKER_BITS + SEQUENCE_BITS) as i64)
| ((config.worker_id as i64 & MAX_WORKER as i64) << SEQUENCE_BITS as i64)
| seq;
Self(id)
}
pub fn new() -> Self {
static CFG: OnceLock<SnowflakeConfig> = OnceLock::new();
Self::generate(CFG.get_or_init(SnowflakeConfig::default))
}
pub fn value(&self) -> i64 {
self.0
}
pub fn timestamp_ms(&self, config: &SnowflakeConfig) -> u64 {
let ts = self.0 >> (WORKER_BITS + SEQUENCE_BITS) as i64;
(ts as u64) + config.epoch_ms
}
pub fn worker_id(&self) -> u16 {
((self.0 >> SEQUENCE_BITS as i64) & MAX_WORKER as i64) as u16
}
pub fn sequence(&self) -> u16 {
(self.0 & MAX_SEQUENCE as i64) as u16
}
}
impl Default for Snowflake {
fn default() -> Self {
Self::new()
}
}
impl fmt::Display for Snowflake {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl FromStr for Snowflake {
type Err = IdError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let n: i64 = s
.parse()
.map_err(|_| IdError::InvalidFormat("snowflake", "expected integer"))?;
if n < 0 {
return Err(IdError::InvalidFormat("snowflake", "must be non-negative"));
}
Ok(Self(n))
}
}
impl From<i64> for Snowflake {
fn from(n: i64) -> Self {
Self(n)
}
}
impl From<Snowflake> for i64 {
fn from(s: Snowflake) -> i64 {
s.0
}
}
#[cfg(feature = "crypto-sqlx")]
mod sqlx_impl {
use super::Snowflake;
use sqlx::{
encode::IsNull,
error::BoxDynError,
postgres::{PgArgumentBuffer, PgTypeInfo, PgValueRef},
};
impl sqlx::Type<sqlx::Postgres> for Snowflake {
fn type_info() -> PgTypeInfo {
<i64 as sqlx::Type<sqlx::Postgres>>::type_info()
}
}
impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Snowflake {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
<i64 as sqlx::Encode<'q, sqlx::Postgres>>::encode_by_ref(&self.0, buf)
}
}
impl<'r> sqlx::Decode<'r, sqlx::Postgres> for Snowflake {
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
let n = <i64 as sqlx::Decode<sqlx::Postgres>>::decode(value)?;
Ok(Self(n))
}
}
}