use std::sync::{
Arc,
Mutex,
};
use std::thread;
use std::time::{
Duration,
SystemTime,
UNIX_EPOCH,
};
use super::constants::DEFAULT_QUBIT_EPOCH_MILLIS;
use super::time_slice::TimeSlice;
use crate::{
IdError,
IdGenerator,
};
const TIMESTAMP_BITS: u8 = 41;
const NODE_BITS: u8 = 10;
const SEQUENCE_BITS: u8 = 12;
const MAX_NODE_ID: u64 = (1_u64 << NODE_BITS) - 1;
pub struct SnowflakeGenerator {
node_id: u64,
epoch: SystemTime,
clock: Arc<dyn Fn() -> SystemTime + Send + Sync>,
state: Mutex<TimeSlice>,
}
impl SnowflakeGenerator {
pub fn new(node_id: u64) -> Result<Self, IdError> {
Self::with_epoch(
node_id,
UNIX_EPOCH + Duration::from_millis(DEFAULT_QUBIT_EPOCH_MILLIS),
)
}
pub fn with_epoch(node_id: u64, epoch: SystemTime) -> Result<Self, IdError> {
Self::with_clock(node_id, epoch, SystemTime::now)
}
pub fn with_clock<F>(node_id: u64, epoch: SystemTime, clock: F) -> Result<Self, IdError>
where
F: Fn() -> SystemTime + Send + Sync + 'static,
{
if node_id > MAX_NODE_ID {
return Err(IdError::NodeOutOfRange {
node_id,
max: MAX_NODE_ID,
});
}
Ok(Self {
node_id,
epoch,
clock: Arc::new(clock),
state: Mutex::new(TimeSlice::new(0)),
})
}
pub const fn node_id(&self) -> u64 {
self.node_id
}
pub const fn epoch(&self) -> SystemTime {
self.epoch
}
pub const fn max_timestamp(&self) -> u64 {
(1_u64 << TIMESTAMP_BITS) - 1
}
pub const fn max_sequence(&self) -> u64 {
(1_u64 << SEQUENCE_BITS) - 1
}
pub fn compose(&self, timestamp: u64, sequence: u64) -> Result<u64, IdError> {
if timestamp > self.max_timestamp() {
return Err(IdError::TimestampOverflow {
timestamp,
max: self.max_timestamp(),
});
}
if sequence > self.max_sequence() {
return Err(IdError::SequenceOverflow {
sequence,
max: self.max_sequence(),
});
}
Ok((timestamp << (NODE_BITS + SEQUENCE_BITS)) | (self.node_id << SEQUENCE_BITS) | sequence)
}
pub const fn extract_timestamp(&self, id: u64) -> u64 {
id >> (NODE_BITS + SEQUENCE_BITS)
}
pub const fn extract_node_id(&self, id: u64) -> u64 {
(id >> SEQUENCE_BITS) & MAX_NODE_ID
}
pub const fn extract_sequence(&self, id: u64) -> u64 {
id & ((1_u64 << SEQUENCE_BITS) - 1)
}
fn timestamp_for(&self, time: SystemTime) -> Result<u64, IdError> {
let elapsed = time
.duration_since(self.epoch)
.map_err(|_| IdError::TimeBeforeEpoch)?;
let timestamp = elapsed.as_millis();
if timestamp > u128::from(self.max_timestamp()) {
return Err(IdError::TimestampOverflow {
timestamp: u64::try_from(timestamp).unwrap_or(u64::MAX),
max: self.max_timestamp(),
});
}
Ok(timestamp as u64)
}
fn current_timestamp(&self) -> Result<u64, IdError> {
self.timestamp_for((self.clock)())
}
fn wait_for_next_timestamp(&self, last_timestamp: u64) -> Result<u64, IdError> {
let mut timestamp = self.current_timestamp()?;
while timestamp <= last_timestamp {
thread::sleep(Duration::from_millis(1));
timestamp = self.current_timestamp()?;
}
Ok(timestamp)
}
}
impl IdGenerator<u64> for SnowflakeGenerator {
type Error = IdError;
fn next_id(&self) -> Result<u64, Self::Error> {
let mut state = self
.state
.lock()
.expect("generator state mutex should not be poisoned");
let mut timestamp = self.current_timestamp()?;
if state.timestamp > timestamp {
return Err(IdError::ClockMovedBackwards {
last_timestamp: state.timestamp,
current_timestamp: timestamp,
skew_millis: state.timestamp - timestamp,
max_skew_millis: 0,
});
}
let sequence = if timestamp == state.timestamp {
let next_sequence = state.sequence + 1;
if next_sequence > self.max_sequence() {
drop(state);
timestamp = self.wait_for_next_timestamp(timestamp)?;
let mut state = self
.state
.lock()
.expect("generator state mutex should not be poisoned");
state.timestamp = timestamp;
state.sequence = 0;
return self.compose(timestamp, 0);
}
next_sequence
} else {
0
};
state.timestamp = timestamp;
state.sequence = sequence;
self.compose(timestamp, sequence)
}
}