burn_core/module/param/
id.rs1use core::hash::{BuildHasher, Hasher};
2
3use alloc::string::String;
4use burn_common::id::IdGenerator;
5use data_encoding::BASE32_DNSSEC;
6
7type DefaultHashBuilder = core::hash::BuildHasherDefault<ahash::AHasher>;
11
12#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
14pub struct ParamId {
15 value: u64,
16}
17
18impl From<u64> for ParamId {
19 fn from(value: u64) -> Self {
20 Self { value }
21 }
22}
23
24impl Default for ParamId {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30impl ParamId {
31 pub fn new() -> Self {
33 Self {
34 value: IdGenerator::generate(),
35 }
36 }
37
38 pub fn val(&self) -> u64 {
40 self.value
41 }
42
43 pub fn serialize(self) -> String {
45 BASE32_DNSSEC.encode(&self.value.to_le_bytes())
46 }
47
48 pub fn deserialize(encoded: &str) -> ParamId {
52 let u64_id = match BASE32_DNSSEC.decode(encoded.as_bytes()) {
53 Ok(bytes) => {
54 let mut buffer = [0u8; 8];
55 buffer[..bytes.len()].copy_from_slice(&bytes);
56 u64::from_le_bytes(buffer)
57 }
58 Err(err) => match uuid::Uuid::try_parse(encoded) {
59 Ok(id) => {
61 let mut hasher = DefaultHashBuilder::default().build_hasher();
64 hasher.write(id.as_bytes());
66 hasher.finish()
67 }
68 Err(_) => panic!("Invalid id. {err}"),
69 },
70 };
71
72 ParamId::from(u64_id)
73 }
74}
75
76impl core::fmt::Display for ParamId {
77 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
78 f.write_str(&self.serialize())
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85
86 #[test]
87 fn param_serde_deserialize() {
88 let val = ParamId::from(123456u64);
89 let deserialized = ParamId::deserialize(&val.serialize());
90 assert_eq!(val, deserialized);
91 }
92
93 #[test]
94 fn param_serde_deserialize_legacy() {
95 let legacy_val = [45u8; 6];
96 let param_id = ParamId::deserialize(&BASE32_DNSSEC.encode(&legacy_val));
97 assert_eq!(param_id.val().to_le_bytes()[0..6], legacy_val);
98 assert_eq!(param_id.val().to_le_bytes()[6..], [0, 0]);
99 }
100
101 #[test]
102 fn param_serde_deserialize_legacy_uuid() {
103 let legacy_id = "30b82c23-788d-4d63-a743-ada258d5f13c";
105 let param_id1 = ParamId::deserialize(legacy_id);
106 let param_id2 = ParamId::deserialize(legacy_id);
107 assert_eq!(param_id1, param_id2);
108 }
109
110 #[test]
111 #[should_panic = "Invalid id."]
112 fn param_serde_deserialize_invalid_id() {
113 let invalid_uuid = "30b82c23-788d-4d63-ada258d5f13c";
114 let _ = ParamId::deserialize(invalid_uuid);
115 }
116}