use std::cell::RefCell;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
use crate::error::SnowflakeError;
use crate::generator::SnowflakeIdGenerator;
use crate::layout::BitLayout;
use std::time::{SystemTime, UNIX_EPOCH};
static INITIALIZED: AtomicBool = AtomicBool::new(false);
static MACHINE_ID: AtomicI64 = AtomicI64::new(0);
static NODE_COUNTER: AtomicI64 = AtomicI64::new(0);
static LAYOUT: OnceLock<BitLayout> = OnceLock::new();
static EPOCH: OnceLock<SystemTime> = OnceLock::new();
thread_local! {
static LOCAL_GEN: RefCell<Option<SnowflakeIdGenerator>> = const { RefCell::new(None) };
}
pub fn init(machine_id: i64, layout: BitLayout) -> Result<(), SnowflakeError> {
init_with_epoch(machine_id, layout, UNIX_EPOCH)
}
pub fn init_with_epoch(
machine_id: i64,
layout: BitLayout,
epoch: SystemTime,
) -> Result<(), SnowflakeError> {
if INITIALIZED.swap(true, Ordering::SeqCst) {
return Err(SnowflakeError::AlreadyInitialized);
}
MACHINE_ID.store(machine_id, Ordering::SeqCst);
LAYOUT.set(layout).ok();
EPOCH.set(epoch).ok();
Ok(())
}
pub fn is_initialized() -> bool {
INITIALIZED.load(Ordering::Relaxed)
}
pub fn next_id() -> Result<i64, SnowflakeError> {
ensure_init();
LOCAL_GEN.with(|cell| {
let mut opt = cell.borrow_mut();
if opt.is_none() {
*opt = Some(create_generator()?);
}
opt.as_mut().unwrap().generate()
})
}
pub fn real_time_next_id() -> Result<i64, SnowflakeError> {
ensure_init();
LOCAL_GEN.with(|cell| {
let mut opt = cell.borrow_mut();
if opt.is_none() {
*opt = Some(create_generator()?);
}
opt.as_mut().unwrap().real_time_generate()
})
}
fn create_generator() -> Result<SnowflakeIdGenerator, SnowflakeError> {
let m_id = MACHINE_ID.load(Ordering::Relaxed);
let layout = *LAYOUT.get().unwrap_or(&BitLayout::default());
let max = layout.max_node_id();
let n_id = loop {
let current = NODE_COUNTER.load(Ordering::Relaxed);
if current > max {
return Err(SnowflakeError::NodeIdExhausted { max });
}
match NODE_COUNTER.compare_exchange(current, current + 1, Ordering::SeqCst, Ordering::Relaxed) {
Ok(val) => break val,
Err(_) => continue,
}
};
let epoch = *EPOCH.get().unwrap_or(&UNIX_EPOCH);
SnowflakeIdGenerator::with_layout_and_epoch(m_id, n_id, layout, epoch)
}
#[inline]
fn ensure_init() {
if !is_initialized() {
panic!("Snowflake system must be initialized with `snowflake_gen::init(...)` before use.");
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_global_initialization() {
let layout = BitLayout::default();
let _ = init(1, layout);
assert!(is_initialized());
}
#[test]
fn test_reinit_returns_already_initialized() {
let _ = init(1, BitLayout::default());
let err = init(2, BitLayout::default()).unwrap_err();
assert!(matches!(err, SnowflakeError::AlreadyInitialized));
}
#[test]
fn test_next_id_concurrently() {
let _ = init(1, BitLayout::default());
let handles: Vec<_> = (0..10)
.map(|_| {
thread::spawn(|| {
let id = next_id().expect("should generate ID");
assert!(id > 0);
id
})
})
.collect();
let mut ids = Vec::new();
for h in handles {
ids.push(h.join().unwrap());
}
let mut sorted_ids = ids.clone();
sorted_ids.sort();
sorted_ids.dedup();
assert_eq!(
ids.len(),
sorted_ids.len(),
"IDs should be unique across threads"
);
}
#[test]
fn test_real_time_next_id() {
let _ = init(1, BitLayout::default());
let id = real_time_next_id().expect("should generate ID");
assert!(id > 0);
}
}