use crate::{Epoch, Error, Layout, Result, Snowflake};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::cmp;
use std::hash::{Hash, Hasher};
use std::time::SystemTime;
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[repr(transparent)]
pub struct SnowflakeComparator {
timestamp: u64,
}
impl SnowflakeComparator {
pub fn from_system_time(time: SystemTime) -> Result<Self> {
let timestamp = time
.duration_since(SystemTime::UNIX_EPOCH)
.map_err(|_| Error::InvalidEpoch)?
.as_millis();
if timestamp > u64::MAX as u128 {
return Err(Error::FatalSnowflakeExhaustion);
}
Ok(Self {
timestamp: timestamp as u64,
})
}
pub fn from_timestamp<E>(timestamp: u64) -> Result<Self>
where
E: Epoch,
{
Ok(Self {
timestamp: Self::convert_epoch_timestamp::<E>(timestamp)?,
})
}
pub fn from_raw_timestamp(timestamp: u64) -> Self {
Self { timestamp }
}
fn convert_epoch_timestamp<E>(timestamp: u64) -> Result<u64>
where
E: Epoch,
{
E::millis_since_unix()
.checked_add(timestamp)
.ok_or(Error::FatalSnowflakeExhaustion)
}
}
impl<L, E> TryFrom<Snowflake<L, E>> for SnowflakeComparator
where
L: Layout,
E: Epoch,
{
type Error = Error;
fn try_from(value: Snowflake<L, E>) -> std::result::Result<Self, Self::Error> {
SnowflakeComparator::from_timestamp::<E>(L::timestamp(value.inner))
}
}
impl TryFrom<SystemTime> for SnowflakeComparator {
type Error = Error;
fn try_from(value: SystemTime) -> std::result::Result<Self, Self::Error> {
Self::from_system_time(value)
}
}
impl PartialEq for SnowflakeComparator {
fn eq(&self, other: &Self) -> bool {
self.timestamp == other.timestamp
}
}
impl<L, E> PartialEq<Snowflake<L, E>> for SnowflakeComparator
where
L: Layout,
E: Epoch,
{
fn eq(&self, other: &Snowflake<L, E>) -> bool {
let other = match Self::convert_epoch_timestamp::<E>(other.timestamp_raw()) {
Ok(other) => other,
Err(_) => {
return false;
}
};
self.timestamp == other
}
}
impl<L, E> PartialEq<SnowflakeComparator> for Snowflake<L, E>
where
L: Layout,
E: Epoch,
{
fn eq(&self, other: &SnowflakeComparator) -> bool {
let timestamp = match SnowflakeComparator::convert_epoch_timestamp::<E>(self.timestamp_raw()) {
Ok(timestamp) => timestamp,
Err(_) => {
return false;
}
};
timestamp == other.timestamp
}
}
impl Eq for SnowflakeComparator {}
impl Hash for SnowflakeComparator {
fn hash<H: Hasher>(&self, state: &mut H) {
self.timestamp.hash(state);
}
}
impl PartialOrd for SnowflakeComparator {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<L, E> PartialOrd<Snowflake<L, E>> for SnowflakeComparator
where
L: Layout,
E: Epoch,
{
fn partial_cmp(&self, other: &Snowflake<L, E>) -> Option<cmp::Ordering> {
let other = match Self::convert_epoch_timestamp::<E>(other.timestamp_raw()) {
Ok(other) => other,
Err(_) => return Some(cmp::Ordering::Less),
};
Some(self.timestamp.cmp(&other))
}
}
impl<L, E> PartialOrd<SnowflakeComparator> for Snowflake<L, E>
where
L: Layout,
E: Epoch,
{
fn partial_cmp(&self, other: &SnowflakeComparator) -> Option<cmp::Ordering> {
let timestamp = match SnowflakeComparator::convert_epoch_timestamp::<E>(self.timestamp_raw()) {
Ok(timestamp) => timestamp,
Err(_) => return Some(cmp::Ordering::Greater),
};
Some(timestamp.cmp(&other.timestamp))
}
}
impl Ord for SnowflakeComparator {
fn cmp(&self, other: &Self) -> cmp::Ordering {
self.timestamp.cmp(&other.timestamp)
}
}
#[cfg(test)]
mod tests {
use crate::{ClassicLayout, Epoch, Error, Layout, MachineId, Snowflake, SnowflakeComparator};
use std::time::{Duration, SystemTime};
struct SimpleEpoch;
impl Epoch for SimpleEpoch {
fn millis_since_unix() -> u64 {
0
}
}
const SECOND_SECOND: u64 = 1000;
#[test]
fn from_system_time() {
let comparator =
SnowflakeComparator::from_system_time(SystemTime::UNIX_EPOCH + Duration::from_millis(SECOND_SECOND))
.unwrap();
verify_comparator(comparator, SECOND_SECOND);
assert_eq!(
Error::InvalidEpoch,
SnowflakeComparator::from_system_time(SystemTime::UNIX_EPOCH - Duration::from_millis(1)).unwrap_err(),
"snowflake comparator didn't detect a \"negative\" timestamp"
);
assert_eq!(
Error::InvalidEpoch,
SnowflakeComparator::try_from(SystemTime::UNIX_EPOCH - Duration::from_millis(1)).unwrap_err(),
"snowflake comparator didn't detect a \"negative\" timestamp"
);
assert_eq!(
Error::FatalSnowflakeExhaustion,
SnowflakeComparator::from_system_time(
SystemTime::UNIX_EPOCH + Duration::from_millis(u64::MAX) + Duration::from_millis(1)
)
.unwrap_err(),
"snowflake comparator accepted timestamp that exceeds its data type"
);
assert_eq!(
Error::FatalSnowflakeExhaustion,
SnowflakeComparator::try_from(
SystemTime::UNIX_EPOCH + Duration::from_millis(u64::MAX) + Duration::from_millis(1)
)
.unwrap_err(),
"snowflake comparator accepted timestamp that exceeds its data type"
);
}
#[test]
fn from_timestamp() {
let comparator = SnowflakeComparator::from_timestamp::<SimpleEpoch>(SECOND_SECOND).unwrap();
verify_comparator(comparator, SECOND_SECOND);
struct OtherEpoch;
impl Epoch for OtherEpoch {
fn millis_since_unix() -> u64 {
SECOND_SECOND
}
}
let comparator = SnowflakeComparator::from_timestamp::<OtherEpoch>(SECOND_SECOND).unwrap();
verify_comparator(comparator, SECOND_SECOND * 2);
assert_eq!(
Error::FatalSnowflakeExhaustion,
SnowflakeComparator::from_timestamp::<OtherEpoch>(u64::MAX).unwrap_err()
);
}
#[test]
fn from_raw_timestamp() {
let comparator = SnowflakeComparator::from_raw_timestamp(SECOND_SECOND);
verify_comparator(comparator, SECOND_SECOND);
}
#[allow(clippy::nonminimal_bool, clippy::eq_op)]
fn verify_comparator(comparator: SnowflakeComparator, timestamp: u64) {
assert!(timestamp > u64::MIN && timestamp < u64::MAX);
let (less, equal, greater) = (
SnowflakeComparator::from_raw_timestamp(timestamp - 1),
SnowflakeComparator::from_raw_timestamp(timestamp),
SnowflakeComparator::from_raw_timestamp(timestamp + 1),
);
crate::snowflake_tests::validate_partial_ord(comparator, less, equal, greater);
crate::snowflake_tests::validate_ord(comparator, less, equal, greater);
#[derive(Debug)]
struct SimpleParams;
impl Layout for SimpleParams {
fn construct_snowflake(timestamp: u64, sequence_number: u64) -> u64 {
assert!(!Self::exceeds_timestamp(timestamp) && !Self::exceeds_sequence_number(sequence_number));
timestamp << 32 | sequence_number
}
fn timestamp(input: u64) -> u64 {
input >> 32
}
fn exceeds_timestamp(input: u64) -> bool {
input > u32::MAX as u64
}
fn sequence_number(input: u64) -> u64 {
input & u32::MAX as u64
}
fn exceeds_sequence_number(input: u64) -> bool {
input > u32::MAX as u64
}
fn is_valid_snowflake(_input: u64) -> bool {
true
}
}
impl Epoch for SimpleParams {
fn millis_since_unix() -> u64 {
0
}
}
type SimpleSnowflake = Snowflake<SimpleParams, SimpleParams>;
let (less, equal, greater) = (
SimpleSnowflake::from_raw((timestamp - 1) << 32 | 3).unwrap(),
SimpleSnowflake::from_raw(timestamp << 32 | 2).unwrap(),
SimpleSnowflake::from_raw((timestamp + 1) << 32 | 1).unwrap(),
);
crate::snowflake_tests::validate_partial_ord(comparator, less, equal, greater);
let snowflake = SimpleSnowflake::from_raw(timestamp << 32).unwrap();
let (less, equal, greater) = (
SnowflakeComparator::from_raw_timestamp(timestamp - 1),
SnowflakeComparator::from_raw_timestamp(timestamp),
SnowflakeComparator::from_raw_timestamp(timestamp + 1),
);
crate::snowflake_tests::validate_partial_ord(snowflake, less, equal, greater);
}
#[test]
fn extreme_epoch() {
struct ExtremeParams;
impl Epoch for ExtremeParams {
fn millis_since_unix() -> u64 {
u64::MAX
}
}
impl MachineId for ExtremeParams {
fn machine_id() -> u64 {
0
}
}
type ExtremeSnowflake = Snowflake<ClassicLayout<ExtremeParams>, ExtremeParams>;
let extreme = ExtremeSnowflake::from_raw((u64::MAX << 23) >> 1).unwrap();
let (small_comparator, large_comparator) = (
SnowflakeComparator::from_raw_timestamp(0),
SnowflakeComparator::from_raw_timestamp(u64::MAX),
);
assert!(small_comparator < extreme);
assert!(large_comparator < extreme);
assert!(extreme > small_comparator);
assert!(extreme > large_comparator);
assert!(extreme != small_comparator && extreme != large_comparator);
assert!(small_comparator != extreme && large_comparator != extreme);
let extreme = ExtremeSnowflake::from_raw(1 << 22).unwrap();
assert!(small_comparator < extreme);
assert!(large_comparator < extreme);
assert!(extreme > small_comparator);
assert!(extreme > large_comparator);
assert!(extreme != small_comparator && extreme != large_comparator);
assert!(small_comparator != extreme && large_comparator != extreme);
}
}