burn_core/module/param/
id.rs1use core::hash::{BuildHasher, Hasher};
2
3use alloc::string::String;
4use burn_std::id::IdGenerator;
5use data_encoding::BASE32_DNSSEC;
6
7type DefaultHashBuilder = core::hash::BuildHasherDefault<ahash::AHasher>;
11
12#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
14pub struct ParamId {
15 value: u64,
16}
17
18impl From<u64> for ParamId {
19 fn from(value: u64) -> Self {
20 Self { value }
21 }
22}
23
24impl Default for ParamId {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30impl ParamId {
31 pub fn new() -> Self {
33 Self {
34 value: IdGenerator::generate(),
35 }
36 }
37
38 pub fn val(&self) -> u64 {
40 self.value
41 }
42
43 pub fn serialize(self) -> String {
45 BASE32_DNSSEC.encode(&self.value.to_le_bytes())
46 }
47
48 pub fn deserialize(encoded: &str) -> ParamId {
55 Self::try_deserialize(encoded).expect("Invalid id.")
56 }
57
58 pub fn try_deserialize(encoded: &str) -> Option<ParamId> {
65 let u64_id: Option<u64> = match BASE32_DNSSEC.decode(encoded.as_bytes()) {
66 Ok(bytes) => {
67 let mut buffer = [0u8; 8];
68 buffer[..bytes.len()].copy_from_slice(&bytes);
69 Some(u64::from_le_bytes(buffer))
70 }
71 Err(_) => match uuid::Uuid::try_parse(encoded) {
72 Ok(id) => {
74 let mut hasher = DefaultHashBuilder::default().build_hasher();
77 hasher.write(id.as_bytes());
79 Some(hasher.finish())
80 }
81 Err(_) => None,
82 },
83 };
84 u64_id.map(Self::from)
85 }
86}
87
88impl core::fmt::Display for ParamId {
89 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
90 f.write_str(&self.serialize())
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[test]
99 fn param_serde_try_deserialize() {
100 let val = ParamId::from(123456u64);
101 let deserialized = ParamId::try_deserialize(&val.serialize()).unwrap();
102 assert_eq!(val, deserialized);
103
104 assert_eq!(ParamId::try_deserialize("invalid_id"), None);
105 }
106
107 #[test]
108 fn param_serde_deserialize() {
109 let val = ParamId::from(123456u64);
110 let deserialized = ParamId::deserialize(&val.serialize());
111 assert_eq!(val, deserialized);
112 }
113
114 #[test]
115 fn param_serde_deserialize_legacy() {
116 let legacy_val = [45u8; 6];
117 let param_id = ParamId::deserialize(&BASE32_DNSSEC.encode(&legacy_val));
118 assert_eq!(param_id.val().to_le_bytes()[0..6], legacy_val);
119 assert_eq!(param_id.val().to_le_bytes()[6..], [0, 0]);
120 }
121
122 #[test]
123 fn param_serde_deserialize_legacy_uuid() {
124 let legacy_id = "30b82c23-788d-4d63-a743-ada258d5f13c";
126 let param_id1 = ParamId::deserialize(legacy_id);
127 let param_id2 = ParamId::deserialize(legacy_id);
128 assert_eq!(param_id1, param_id2);
129 }
130
131 #[test]
132 #[should_panic = "Invalid id."]
133 fn param_serde_deserialize_invalid_id() {
134 let invalid_uuid = "30b82c23-788d-4d63-ada258d5f13c";
135 let _ = ParamId::deserialize(invalid_uuid);
136 }
137}