burn_core/module/param/
id.rs

1use core::hash::{BuildHasher, Hasher};
2
3use alloc::string::String;
4use burn_common::id::IdGenerator;
5use data_encoding::BASE32_DNSSEC;
6
7// Hashbrown changed its default hasher in 0.15, but there are some issues
8// https://github.com/rust-lang/hashbrown/issues/577
9// Also, `param_serde_deserialize_legacy_uuid` doesn't pass with the default hasher.
10type DefaultHashBuilder = core::hash::BuildHasherDefault<ahash::AHasher>;
11
12/// Parameter ID.
13#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
14pub struct ParamId {
15    value: u64,
16}
17
18impl From<u64> for ParamId {
19    fn from(value: u64) -> Self {
20        Self { value }
21    }
22}
23
24impl Default for ParamId {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30impl ParamId {
31    /// Create a new parameter ID.
32    pub fn new() -> Self {
33        Self {
34            value: IdGenerator::generate(),
35        }
36    }
37
38    /// Gets the internal value of the id.
39    pub fn val(&self) -> u64 {
40        self.value
41    }
42
43    /// Convert the parameter ID into a string.
44    pub fn serialize(self) -> String {
45        BASE32_DNSSEC.encode(&self.value.to_le_bytes())
46    }
47
48    /// Deserialize a param id.
49    ///
50    /// Preserves compatibility with previous formats (6 bytes, 16-byte uuid).
51    pub fn deserialize(encoded: &str) -> ParamId {
52        let u64_id = match BASE32_DNSSEC.decode(encoded.as_bytes()) {
53            Ok(bytes) => {
54                let mut buffer = [0u8; 8];
55                buffer[..bytes.len()].copy_from_slice(&bytes);
56                u64::from_le_bytes(buffer)
57            }
58            Err(err) => match uuid::Uuid::try_parse(encoded) {
59                // Backward compatibility with uuid parameter identifiers
60                Ok(id) => {
61                    // Hash the 128-bit uuid to 64-bit
62                    // Though not *theoretically* unique, the probability of a collision should be extremely low
63                    let mut hasher = DefaultHashBuilder::default().build_hasher();
64                    // let mut hasher = DefaultHasher::new();
65                    hasher.write(id.as_bytes());
66                    hasher.finish()
67                }
68                Err(_) => panic!("Invalid id. {err}"),
69            },
70        };
71
72        ParamId::from(u64_id)
73    }
74}
75
76impl core::fmt::Display for ParamId {
77    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
78        f.write_str(&self.serialize())
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[test]
87    fn param_serde_deserialize() {
88        let val = ParamId::from(123456u64);
89        let deserialized = ParamId::deserialize(&val.serialize());
90        assert_eq!(val, deserialized);
91    }
92
93    #[test]
94    fn param_serde_deserialize_legacy() {
95        let legacy_val = [45u8; 6];
96        let param_id = ParamId::deserialize(&BASE32_DNSSEC.encode(&legacy_val));
97        assert_eq!(param_id.val().to_le_bytes()[0..6], legacy_val);
98        assert_eq!(param_id.val().to_le_bytes()[6..], [0, 0]);
99    }
100
101    #[test]
102    fn param_serde_deserialize_legacy_uuid() {
103        // Ensure support for legacy uuid deserialization and make sure it results in the same output
104        let legacy_id = "30b82c23-788d-4d63-a743-ada258d5f13c";
105        let param_id1 = ParamId::deserialize(legacy_id);
106        let param_id2 = ParamId::deserialize(legacy_id);
107        assert_eq!(param_id1, param_id2);
108    }
109
110    #[test]
111    #[should_panic = "Invalid id."]
112    fn param_serde_deserialize_invalid_id() {
113        let invalid_uuid = "30b82c23-788d-4d63-ada258d5f13c";
114        let _ = ParamId::deserialize(invalid_uuid);
115    }
116}