burn_core/module/param/
id.rsuse core::hash::{BuildHasher, Hasher};
use alloc::string::String;
use burn_common::id::IdGenerator;
use data_encoding::BASE32_DNSSEC;
type DefaultHashBuilder = core::hash::BuildHasherDefault<ahash::AHasher>;
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
pub struct ParamId {
value: u64,
}
impl From<u64> for ParamId {
fn from(value: u64) -> Self {
Self { value }
}
}
impl Default for ParamId {
fn default() -> Self {
Self::new()
}
}
impl ParamId {
pub fn new() -> Self {
Self {
value: IdGenerator::generate(),
}
}
pub fn val(&self) -> u64 {
self.value
}
pub fn serialize(self) -> String {
BASE32_DNSSEC.encode(&self.value.to_le_bytes())
}
pub fn deserialize(encoded: &str) -> ParamId {
let u64_id = match BASE32_DNSSEC.decode(encoded.as_bytes()) {
Ok(bytes) => {
let mut buffer = [0u8; 8];
buffer[..bytes.len()].copy_from_slice(&bytes);
u64::from_le_bytes(buffer)
}
Err(err) => match uuid::Uuid::try_parse(encoded) {
Ok(id) => {
let mut hasher = DefaultHashBuilder::default().build_hasher();
hasher.write(id.as_bytes());
hasher.finish()
}
Err(_) => panic!("Invalid id. {err}"),
},
};
ParamId::from(u64_id)
}
}
impl core::fmt::Display for ParamId {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(&self.serialize())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn param_serde_deserialize() {
let val = ParamId::from(123456u64);
let deserialized = ParamId::deserialize(&val.serialize());
assert_eq!(val, deserialized);
}
#[test]
fn param_serde_deserialize_legacy() {
let legacy_val = [45u8; 6];
let param_id = ParamId::deserialize(&BASE32_DNSSEC.encode(&legacy_val));
assert_eq!(param_id.val().to_le_bytes()[0..6], legacy_val);
assert_eq!(param_id.val().to_le_bytes()[6..], [0, 0]);
}
#[test]
fn param_serde_deserialize_legacy_uuid() {
let legacy_id = "30b82c23-788d-4d63-a743-ada258d5f13c";
let param_id1 = ParamId::deserialize(legacy_id);
let param_id2 = ParamId::deserialize(legacy_id);
assert_eq!(param_id1, param_id2);
}
#[test]
#[should_panic = "Invalid id."]
fn param_serde_deserialize_invalid_id() {
let invalid_uuid = "30b82c23-788d-4d63-ada258d5f13c";
let _ = ParamId::deserialize(invalid_uuid);
}
}