Skip to main content

burn_core/module/param/
id.rs

1use core::hash::{BuildHasher, Hasher};
2
3use alloc::string::String;
4use burn_std::id::IdGenerator;
5use data_encoding::BASE32_DNSSEC;
6
7// Hashbrown changed its default hasher in 0.15, but there are some issues
8// https://github.com/rust-lang/hashbrown/issues/577
9// Also, `param_serde_deserialize_legacy_uuid` doesn't pass with the default hasher.
10type DefaultHashBuilder = core::hash::BuildHasherDefault<ahash::AHasher>;
11
12/// Unique ID for a parameter of a module.
13#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
14pub struct ParamId {
15    value: u64,
16}
17
18impl From<u64> for ParamId {
19    fn from(value: u64) -> Self {
20        Self { value }
21    }
22}
23
24impl Default for ParamId {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30impl ParamId {
31    /// Create a new parameter ID.
32    pub fn new() -> Self {
33        Self {
34            value: IdGenerator::generate(),
35        }
36    }
37
38    /// Gets the internal value of the id.
39    pub fn val(&self) -> u64 {
40        self.value
41    }
42
43    /// Convert the parameter ID into a string.
44    pub fn serialize(self) -> String {
45        BASE32_DNSSEC.encode(&self.value.to_le_bytes())
46    }
47
48    /// Deserialize a param id.
49    ///
50    /// Preserves compatibility with previous formats (6 bytes, 16-byte uuid).
51    ///
52    /// # Panics
53    /// On invalid id format
54    pub fn deserialize(encoded: &str) -> ParamId {
55        Self::try_deserialize(encoded).expect("Invalid id.")
56    }
57
58    /// Deserialize a param id.
59    ///
60    /// Preserves compatibility with previous formats (6 bytes, 16-byte uuid).
61    ///
62    /// # Returns
63    /// An `Option<ParamId>`
64    pub fn try_deserialize(encoded: &str) -> Option<ParamId> {
65        let u64_id: Option<u64> = match BASE32_DNSSEC.decode(encoded.as_bytes()) {
66            Ok(bytes) => {
67                let mut buffer = [0u8; 8];
68                buffer[..bytes.len()].copy_from_slice(&bytes);
69                Some(u64::from_le_bytes(buffer))
70            }
71            Err(_) => match uuid::Uuid::try_parse(encoded) {
72                // Backward compatibility with uuid parameter identifiers
73                Ok(id) => {
74                    // Hash the 128-bit uuid to 64-bit
75                    // Though not *theoretically* unique, the probability of a collision should be extremely low
76                    let mut hasher = DefaultHashBuilder::default().build_hasher();
77                    // let mut hasher = DefaultHasher::new();
78                    hasher.write(id.as_bytes());
79                    Some(hasher.finish())
80                }
81                Err(_) => None,
82            },
83        };
84        u64_id.map(Self::from)
85    }
86}
87
88impl core::fmt::Display for ParamId {
89    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
90        f.write_str(&self.serialize())
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn param_serde_try_deserialize() {
100        let val = ParamId::from(123456u64);
101        let deserialized = ParamId::try_deserialize(&val.serialize()).unwrap();
102        assert_eq!(val, deserialized);
103
104        assert_eq!(ParamId::try_deserialize("invalid_id"), None);
105    }
106
107    #[test]
108    fn param_serde_deserialize() {
109        let val = ParamId::from(123456u64);
110        let deserialized = ParamId::deserialize(&val.serialize());
111        assert_eq!(val, deserialized);
112    }
113
114    #[test]
115    fn param_serde_deserialize_legacy() {
116        let legacy_val = [45u8; 6];
117        let param_id = ParamId::deserialize(&BASE32_DNSSEC.encode(&legacy_val));
118        assert_eq!(param_id.val().to_le_bytes()[0..6], legacy_val);
119        assert_eq!(param_id.val().to_le_bytes()[6..], [0, 0]);
120    }
121
122    #[test]
123    fn param_serde_deserialize_legacy_uuid() {
124        // Ensure support for legacy uuid deserialization and make sure it results in the same output
125        let legacy_id = "30b82c23-788d-4d63-a743-ada258d5f13c";
126        let param_id1 = ParamId::deserialize(legacy_id);
127        let param_id2 = ParamId::deserialize(legacy_id);
128        assert_eq!(param_id1, param_id2);
129    }
130
131    #[test]
132    #[should_panic = "Invalid id."]
133    fn param_serde_deserialize_invalid_id() {
134        let invalid_uuid = "30b82c23-788d-4d63-ada258d5f13c";
135        let _ = ParamId::deserialize(invalid_uuid);
136    }
137}