contact_tracing/
tkey.rs

1use std::fmt;
2
3use derive_more::{Display, Error};
4use rand::{thread_rng, RngCore};
5
6use crate::utils::Base64DebugFmtHelper;
7
8/// A compact representation of contact numbers.
9#[derive(Default, Copy, Clone, PartialEq, Eq, Ord, PartialOrd, Hash)]
10pub struct TracingKey {
11    bytes: [u8; 32],
12}
13
14impl fmt::Debug for TracingKey {
15    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
16        f.debug_tuple("TracingKey")
17            .field(&Base64DebugFmtHelper(self))
18            .finish()
19    }
20}
21
22impl TracingKey {
23    /// Returns a new unique tracing key.
24    pub fn unique() -> TracingKey {
25        let mut bytes = [0u8; 32];
26        let mut rng = thread_rng();
27        rng.fill_bytes(&mut bytes[..]);
28        TracingKey::from_bytes(&bytes[..]).unwrap()
29    }
30
31    /// loads a tracing key from raw bytes.
32    pub fn from_bytes(b: &[u8]) -> Result<TracingKey, InvalidTracingKey> {
33        if b.len() != 32 {
34            return Err(InvalidTracingKey);
35        }
36        let mut bytes = [0u8; 32];
37        bytes.copy_from_slice(b);
38        Ok(TracingKey { bytes })
39    }
40
41    /// Returns the bytes behind the tracing key.
42    pub fn as_bytes(&self) -> &[u8] {
43        &self.bytes
44    }
45}
46
47/// Raised if a tracing key is invalid.
48#[derive(Error, Display, Debug)]
49#[display(fmt = "invalid tracing key")]
50pub struct InvalidTracingKey;
51
52#[cfg(feature = "base64")]
53mod base64_impl {
54    use super::*;
55    use std::{fmt, str};
56
57    impl str::FromStr for TracingKey {
58        type Err = InvalidTracingKey;
59
60        fn from_str(value: &str) -> Result<TracingKey, InvalidTracingKey> {
61            let mut bytes = [0u8; 32];
62            if value.len() != 43 {
63                return Err(InvalidTracingKey);
64            }
65            base64_::decode_config_slice(value, base64_::URL_SAFE_NO_PAD, &mut bytes[..])
66                .map_err(|_| InvalidTracingKey)?;
67            Ok(TracingKey { bytes })
68        }
69    }
70
71    impl fmt::Display for TracingKey {
72        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
73            let mut buf = [0u8; 50];
74            let len = base64_::encode_config_slice(self.bytes, base64_::URL_SAFE_NO_PAD, &mut buf);
75            f.write_str(unsafe { std::str::from_utf8_unchecked(&buf[..len]) })
76        }
77    }
78}
79
80#[cfg(feature = "serde")]
81mod serde_impl {
82    use super::*;
83
84    use serde_::de::Deserializer;
85    use serde_::ser::Serializer;
86    use serde_::{Deserialize, Serialize};
87
88    impl Serialize for TracingKey {
89        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
90        where
91            S: Serializer,
92        {
93            if serializer.is_human_readable() {
94                serializer.serialize_str(&self.to_string())
95            } else {
96                serializer.serialize_bytes(self.as_bytes())
97            }
98        }
99    }
100
101    impl<'de> Deserialize<'de> for TracingKey {
102        fn deserialize<D>(deserializer: D) -> Result<TracingKey, D::Error>
103        where
104            D: Deserializer<'de>,
105        {
106            use serde_::de::Error;
107            if deserializer.is_human_readable() {
108                let s = String::deserialize(deserializer).map_err(D::Error::custom)?;
109                s.parse().map_err(D::Error::custom)
110            } else {
111                let buf = Vec::<u8>::deserialize(deserializer).map_err(D::Error::custom)?;
112                TracingKey::from_bytes(&buf).map_err(D::Error::custom)
113            }
114        }
115    }
116}