use std::sync::{
Arc,
Mutex,
};
use std::thread;
use std::time::{
Duration,
SystemTime,
UNIX_EPOCH,
};
use super::time_slice::TimeSlice;
use crate::{
IdError,
IdGenerator,
};
const DEFAULT_BITS_SEQUENCE: u8 = 8;
const DEFAULT_BITS_MACHINE: u8 = 16;
const DEFAULT_TIME_UNIT_NANOS: u128 = 10_000_000;
const MIN_TIME_UNIT_NANOS: u128 = 1_000_000;
const DEFAULT_START_MILLIS: u64 = 1_735_689_600_000;
pub struct SonyflakeGenerator {
bits_time: u8,
bits_sequence: u8,
bits_machine: u8,
time_unit: Duration,
start_time: SystemTime,
machine_id: u64,
clock: Arc<dyn Fn() -> SystemTime + Send + Sync>,
state: Mutex<TimeSlice>,
}
impl SonyflakeGenerator {
pub fn new(machine_id: u64) -> Result<Self, IdError> {
Self::with_epoch(
machine_id,
UNIX_EPOCH + Duration::from_millis(DEFAULT_START_MILLIS),
)
}
pub fn with_epoch(machine_id: u64, start_time: SystemTime) -> Result<Self, IdError> {
Self::with_options(
machine_id,
DEFAULT_BITS_SEQUENCE,
DEFAULT_BITS_MACHINE,
Duration::from_nanos(DEFAULT_TIME_UNIT_NANOS as u64),
start_time,
)
}
pub fn with_options(
machine_id: u64,
bits_sequence: u8,
bits_machine_id: u8,
time_unit: Duration,
start_time: SystemTime,
) -> Result<Self, IdError> {
Self::with_clock(
machine_id,
bits_sequence,
bits_machine_id,
time_unit,
start_time,
SystemTime::now,
)
}
pub fn with_clock<F>(
machine_id: u64,
bits_sequence: u8,
bits_machine_id: u8,
time_unit: Duration,
start_time: SystemTime,
clock: F,
) -> Result<Self, IdError>
where
F: Fn() -> SystemTime + Send + Sync + 'static,
{
let bits_sequence = Self::normalize_bits("sequence", bits_sequence, DEFAULT_BITS_SEQUENCE)?;
let bits_machine = Self::normalize_bits("machine", bits_machine_id, DEFAULT_BITS_MACHINE)?;
let bits_time = 63_u8
.checked_sub(bits_sequence)
.and_then(|value| value.checked_sub(bits_machine))
.ok_or(IdError::InvalidBitLength {
name: "time",
bits: 0,
reason: "63 - sequence bits - machine bits must be at least 32",
})?;
if bits_time < 32 {
return Err(IdError::InvalidBitLength {
name: "time",
bits: bits_time,
reason: "time bit length must be at least 32",
});
}
let nanos = time_unit.as_nanos();
if nanos < MIN_TIME_UNIT_NANOS {
return Err(IdError::InvalidTimeUnit {
nanos,
min_nanos: MIN_TIME_UNIT_NANOS,
});
}
if start_time > clock() {
return Err(IdError::StartTimeAhead);
}
let max_machine_id = (1_u64 << bits_machine) - 1;
if machine_id > max_machine_id {
return Err(IdError::MachineIdOutOfRange {
machine_id,
max: max_machine_id,
});
}
Ok(Self {
bits_time,
bits_sequence,
bits_machine,
time_unit,
start_time,
machine_id,
clock: Arc::new(clock),
state: Mutex::new(TimeSlice::with_sequence(0, (1_u64 << bits_sequence) - 1)),
})
}
fn normalize_bits(name: &'static str, bits: u8, default_bits: u8) -> Result<u8, IdError> {
let normalized = if bits == 0 { default_bits } else { bits };
if normalized >= 31 {
return Err(IdError::InvalidBitLength {
name,
bits: normalized,
reason: "bit length must be less than 31",
});
}
Ok(normalized)
}
pub const fn bits_time(&self) -> u8 {
self.bits_time
}
pub const fn bits_sequence(&self) -> u8 {
self.bits_sequence
}
pub const fn bits_machine(&self) -> u8 {
self.bits_machine
}
pub const fn max_elapsed_time(&self) -> u64 {
(1_u64 << self.bits_time) - 1
}
pub const fn max_sequence(&self) -> u64 {
(1_u64 << self.bits_sequence) - 1
}
pub const fn max_machine_id(&self) -> u64 {
(1_u64 << self.bits_machine) - 1
}
pub fn compose(
&self,
elapsed_time: u64,
sequence: u64,
machine_id: u64,
) -> Result<u64, IdError> {
if elapsed_time > self.max_elapsed_time() {
return Err(IdError::TimestampOverflow {
timestamp: elapsed_time,
max: self.max_elapsed_time(),
});
}
if sequence > self.max_sequence() {
return Err(IdError::SequenceOverflow {
sequence,
max: self.max_sequence(),
});
}
if machine_id > self.max_machine_id() {
return Err(IdError::MachineIdOutOfRange {
machine_id,
max: self.max_machine_id(),
});
}
Ok((elapsed_time << (self.bits_sequence + self.bits_machine))
| (sequence << self.bits_machine)
| machine_id)
}
pub fn extract_elapsed_time(&self, id: u64) -> u64 {
id >> (self.bits_sequence + self.bits_machine)
}
pub fn extract_sequence(&self, id: u64) -> u64 {
let mask = ((1_u64 << self.bits_sequence) - 1) << self.bits_machine;
(id & mask) >> self.bits_machine
}
pub fn extract_machine_id(&self, id: u64) -> u64 {
id & ((1_u64 << self.bits_machine) - 1)
}
fn elapsed_time_for(&self, time: SystemTime) -> Result<u64, IdError> {
let elapsed = time
.duration_since(self.start_time)
.map_err(|_| IdError::TimeBeforeEpoch)?;
let elapsed_units = elapsed.as_nanos() / self.time_unit.as_nanos();
if elapsed_units > u128::from(self.max_elapsed_time()) {
return Err(IdError::TimestampOverflow {
timestamp: u64::try_from(elapsed_units).unwrap_or(u64::MAX),
max: self.max_elapsed_time(),
});
}
Ok(elapsed_units as u64)
}
fn current_elapsed_time(&self) -> Result<u64, IdError> {
self.elapsed_time_for((self.clock)())
}
}
impl IdGenerator<u64> for SonyflakeGenerator {
type Error = IdError;
fn next_id(&self) -> Result<u64, Self::Error> {
let mut state = self
.state
.lock()
.expect("generator state mutex should not be poisoned");
let current = self.current_elapsed_time()?;
if state.timestamp < current {
state.timestamp = current;
state.sequence = 0;
} else {
state.sequence = (state.sequence + 1) & self.max_sequence();
if state.sequence == 0 {
state.timestamp += 1;
let overtime = state.timestamp.saturating_sub(current);
drop(state);
thread::sleep(Duration::from_nanos(
(u128::from(overtime) * self.time_unit.as_nanos()) as u64,
));
state = self
.state
.lock()
.expect("generator state mutex should not be poisoned");
}
}
self.compose(state.timestamp, state.sequence, self.machine_id)
}
}