use std::{
fmt::{self, Display},
hash::Hash,
};
use ahash::RandomState;
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum TrajId {
Int(u32),
Str(String),
}
impl TrajId {
pub fn stable_hash(&self) -> u64 {
RandomState::with_seeds(1, 2, 3, 4).hash_one(self)
}
}
impl Display for TrajId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TrajId::Int(n) => write!(f, "{}", n),
TrajId::Str(s) => write!(f, "{}", s),
}
}
}
impl From<&TrajId> for TrajId {
#[inline]
fn from(id: &TrajId) -> Self {
id.clone()
}
}
impl From<u32> for TrajId {
#[inline]
fn from(n: u32) -> Self {
TrajId::Int(n)
}
}
impl From<u16> for TrajId {
#[inline]
fn from(n: u16) -> Self {
TrajId::Int(n as u32)
}
}
impl From<u8> for TrajId {
#[inline]
fn from(n: u8) -> Self {
TrajId::Int(n as u32)
}
}
impl From<&u32> for TrajId {
#[inline]
fn from(n: &u32) -> Self {
TrajId::Int(*n)
}
}
impl From<&u16> for TrajId {
#[inline]
fn from(n: &u16) -> Self {
TrajId::Int(*n as u32)
}
}
impl From<&u8> for TrajId {
#[inline]
fn from(n: &u8) -> Self {
TrajId::Int(*n as u32)
}
}
impl From<String> for TrajId {
#[inline]
fn from(s: String) -> Self {
TrajId::Str(s)
}
}
impl From<&String> for TrajId {
#[inline]
fn from(s: &String) -> Self {
TrajId::Str(s.clone())
}
}
impl From<&str> for TrajId {
#[inline]
fn from(s: &str) -> Self {
TrajId::Str(s.to_string())
}
}
impl TryFrom<usize> for TrajId {
type Error = std::num::TryFromIntError;
#[inline]
fn try_from(n: usize) -> Result<Self, Self::Error> {
Ok(TrajId::Int(u32::try_from(n)?))
}
}
impl TryFrom<u64> for TrajId {
type Error = std::num::TryFromIntError;
#[inline]
fn try_from(n: u64) -> Result<Self, Self::Error> {
Ok(TrajId::Int(u32::try_from(n)?))
}
}
impl TryFrom<i64> for TrajId {
type Error = &'static str;
#[inline]
fn try_from(n: i64) -> Result<Self, Self::Error> {
if n < 0 {
return Err("negative value is not a valid TrajId::Int");
}
let n = u64::try_from(n).map_err(|_| "conversion failed")?;
let n = u32::try_from(n).map_err(|_| "value exceeds u32 range")?;
Ok(TrajId::Int(n))
}
}
impl std::str::FromStr for TrajId {
type Err = std::num::ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.parse::<u32>() {
Ok(n) => Ok(TrajId::Int(n)),
Err(e) => {
if s.chars().any(|c| !c.is_ascii_digit()) {
Ok(TrajId::Str(s.to_string()))
} else {
Err(e)
}
}
}
}
}
#[cfg(test)]
mod traj_id_tests {
use super::*;
#[test]
fn test_stable_hash_deterministic() {
let id_int = TrajId::Int(42);
let id_str = TrajId::Str("2023 AB1".to_string());
assert_eq!(id_int.stable_hash(), id_int.stable_hash());
assert_eq!(id_str.stable_hash(), id_str.stable_hash());
}
#[test]
fn test_stable_hash_distinct_inputs() {
let a = TrajId::Int(0);
let b = TrajId::Int(1);
let c = TrajId::Str("0".to_string());
assert_ne!(a.stable_hash(), b.stable_hash());
assert_ne!(a.stable_hash(), c.stable_hash());
}
#[test]
fn test_stable_hash_cross_run_stability() {
let expected_int = TrajId::Int(42).stable_hash();
let expected_str = TrajId::Str("2023 AB1".to_string()).stable_hash();
assert_eq!(14966747408011497582, expected_int);
assert_eq!(16188224256132921782, expected_str);
}
#[test]
fn test_stable_hash_xor_seed_nonzero() {
let base_seed: u64 = 0xdeadbeefcafebabe;
let ids = [
TrajId::Int(0),
TrajId::Int(1),
TrajId::Int(u32::MAX),
TrajId::Str(String::new()),
TrajId::Str("test".to_string()),
];
for id in &ids {
assert_ne!(base_seed ^ id.stable_hash(), 0);
}
}
#[test]
fn test_stable_hash_clone_equality() {
let id = TrajId::Str("C/2024 X1".to_string());
assert_eq!(id.stable_hash(), id.clone().stable_hash());
}
}