use crate::builder::Builder;
use crate::error::*;
use base64::Engine;
use base64::engine::general_purpose;
use chrono::prelude::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
pub struct SharedSnowflake {
pub(crate) state: AtomicU64,
pub(crate) start_time: i64,
pub(crate) data_center_id: u16,
pub(crate) machine_id: u16,
pub(crate) bit_len_time: u8,
pub(crate) bit_len_sequence: u8,
pub(crate) bit_len_data_center_id: u8,
pub(crate) bit_len_machine_id: u8,
}
pub struct Snowflake(pub(crate) Arc<SharedSnowflake>);
impl Snowflake {
pub fn new() -> Result<Self, Error> {
Builder::new().finalize()
}
pub fn builder<'a>() -> Builder<'a> {
Builder::new()
}
pub(crate) fn new_inner(shared: Arc<SharedSnowflake>) -> Self {
Self(shared)
}
pub fn next_id(&self) -> Result<u64, Error> {
let sequence_mask = (1u64 << self.0.bit_len_sequence) - 1;
let time_shift = self.0.bit_len_sequence;
let time_max = (1u64 << self.0.bit_len_time) - 1;
loop {
let current_state = self.0.state.load(Ordering::Relaxed);
let last_time = current_state >> time_shift;
let elapsed_time = current_elapsed_time(self.0.start_time) as u64;
let (next_time, next_sequence) = if elapsed_time == last_time {
let sequence = (current_state & sequence_mask) + 1;
if sequence > sequence_mask {
til_next_millis(self.0.start_time + last_time as i64);
continue; }
(last_time, sequence)
} else {
(elapsed_time, 0)
};
if next_time > time_max {
return Err(Error::OverTimeLimit);
}
let new_state = (next_time << time_shift) | next_sequence;
if self
.0
.state
.compare_exchange_weak(
current_state,
new_state,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
let id = (next_time
<< (self.0.bit_len_data_center_id
+ self.0.bit_len_machine_id
+ self.0.bit_len_sequence))
| ((self.0.data_center_id as u64)
<< (self.0.bit_len_machine_id + self.0.bit_len_sequence))
| ((self.0.machine_id as u64) << self.0.bit_len_sequence)
| next_sequence;
return Ok(id);
}
}
}
pub fn decompose(&self, id: u64) -> DecomposedSnowflake {
DecomposedSnowflake::decompose(
id,
self.0.bit_len_time,
self.0.bit_len_sequence,
self.0.bit_len_data_center_id,
self.0.bit_len_machine_id,
)
}
}
impl Clone for Snowflake {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
const SNOWFLAKE_TIME_UNIT: i64 = 1_000_000;
pub(crate) fn to_snowflake_time(time: DateTime<Utc>) -> i64 {
time.timestamp_nanos_opt().unwrap_or(0) / SNOWFLAKE_TIME_UNIT
}
fn current_elapsed_time(start_time: i64) -> i64 {
to_snowflake_time(Utc::now()) - start_time
}
fn til_next_millis(last_timestamp: i64) {
let mut now = to_snowflake_time(Utc::now());
while now <= last_timestamp {
now = to_snowflake_time(Utc::now());
}
}
pub struct DecomposedSnowflake {
pub id: u64,
pub time: u64,
pub sequence: u64,
pub data_center_id: u64,
pub machine_id: u64,
}
impl DecomposedSnowflake {
pub fn decompose(
id: u64,
bit_len_time: u8,
bit_len_sequence: u8,
bit_len_data_center_id: u8,
bit_len_machine_id: u8,
) -> Self {
let total_bits = bit_len_time as u32
+ bit_len_sequence as u32
+ bit_len_data_center_id as u32
+ bit_len_machine_id as u32;
assert_eq!(total_bits, 63, "Total bit length must be 63");
let sequence_shift = 0;
let machine_id_shift = sequence_shift + bit_len_sequence;
let data_center_id_shift = machine_id_shift + bit_len_machine_id;
let time_shift = data_center_id_shift + bit_len_data_center_id;
let sequence_mask = (1u64 << bit_len_sequence) - 1;
let machine_id_mask = (1u64 << bit_len_machine_id) - 1;
let data_center_id_mask = (1u64 << bit_len_data_center_id) - 1;
Self {
id,
time: id >> time_shift,
data_center_id: (id >> data_center_id_shift) & data_center_id_mask,
machine_id: (id >> machine_id_shift) & machine_id_mask,
sequence: (id >> sequence_shift) & sequence_mask,
}
}
pub fn nanos_time(&self) -> i64 {
(self.time as i64) * SNOWFLAKE_TIME_UNIT
}
pub fn int64(&self) -> i64 {
self.id as i64
}
pub fn string(&self) -> String {
self.id.to_string()
}
pub fn base2(&self) -> String {
format!("{:b}", self.id)
}
pub fn base32(&self) -> String {
const ENCODE_BASE32_MAP: &str = "ybndrfg8ejkmcpqxot1uwisza345h769";
let mut id = self.id;
if id < 32 {
return ENCODE_BASE32_MAP
.chars()
.nth(id as usize)
.unwrap()
.to_string();
}
let mut b = Vec::new();
while id >= 32 {
b.push(ENCODE_BASE32_MAP.chars().nth((id % 32) as usize).unwrap());
id /= 32;
}
b.push(ENCODE_BASE32_MAP.chars().nth(id as usize).unwrap());
b.reverse();
b.into_iter().collect()
}
pub fn base36(&self) -> String {
format!("{:x}", self.id)
}
pub fn base58(&self) -> String {
const ENCODE_BASE58_MAP: &str =
"123456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ";
let mut id = self.id;
if id < 58 {
return ENCODE_BASE58_MAP
.chars()
.nth(id as usize)
.unwrap()
.to_string();
}
let mut b = Vec::new();
while id >= 58 {
b.push(ENCODE_BASE58_MAP.chars().nth((id % 58) as usize).unwrap());
id /= 58;
}
b.push(ENCODE_BASE58_MAP.chars().nth(id as usize).unwrap());
b.reverse();
b.into_iter().collect()
}
pub fn base64(&self) -> String {
general_purpose::STANDARD.encode(self.id.to_be_bytes())
}
pub fn bytes(&self) -> Vec<u8> {
self.id.to_string().into_bytes()
}
pub fn int_bytes(&self) -> [u8; 8] {
self.id.to_be_bytes()
}
pub fn time(&self) -> i64 {
self.time as i64
}
}