pykrete_jsonwebkey/
byte_array.rs

1use generic_array::{ArrayLength, GenericArray};
2use serde::{
3    de::{self, Deserializer},
4    Deserialize, Serialize,
5};
6use zeroize::{Zeroize, Zeroizing};
7
8/// A zeroizing-on-drop container for a `[u8; N]` that deserializes from base64.
9#[derive(Clone, PartialEq, Eq, Serialize)]
10#[serde(transparent)]
11pub struct ByteArray<N: ArrayLength<u8>>(
12    #[serde(serialize_with = "crate::utils::serde_base64::serialize")] GenericArray<u8, N>,
13);
14
15impl<N: ArrayLength<u8>> std::fmt::Debug for ByteArray<N> {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        f.write_str(&crate::utils::base64_encode(&self.0))
18    }
19}
20
21impl<N: ArrayLength<u8>, T: Into<GenericArray<u8, N>>> From<T> for ByteArray<N> {
22    fn from(arr: T) -> Self {
23        Self(arr.into())
24    }
25}
26
27impl<N: ArrayLength<u8>> Drop for ByteArray<N> {
28    fn drop(&mut self) {
29        Zeroize::zeroize(self.0.as_mut_slice())
30    }
31}
32
33impl<N: ArrayLength<u8>> AsRef<[u8]> for ByteArray<N> {
34    fn as_ref(&self) -> &[u8] {
35        &self.0
36    }
37}
38
39impl<N: ArrayLength<u8>> std::ops::Deref for ByteArray<N> {
40    type Target = [u8];
41    fn deref(&self) -> &Self::Target {
42        &self.0
43    }
44}
45
46impl<N: ArrayLength<u8>> ByteArray<N> {
47    /// An unwrapping version of `try_from_slice`.
48    pub fn from_slice(bytes: impl AsRef<[u8]>) -> Self {
49        Self::try_from_slice(bytes).unwrap()
50    }
51
52    pub fn try_from_slice(bytes: impl AsRef<[u8]>) -> Result<Self, String> {
53        let bytes = bytes.as_ref();
54        if bytes.len() != N::USIZE {
55            Err(format!(
56                "expected {} bytes but got {}",
57                N::USIZE,
58                bytes.len()
59            ))
60        } else {
61            Ok(Self(bytes.iter().copied().collect()))
62        }
63    }
64}
65
66impl<'de, N: ArrayLength<u8>> Deserialize<'de> for ByteArray<N> {
67    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
68        let bytes = Zeroizing::new(crate::utils::serde_base64::deserialize(d)?);
69        Self::try_from_slice(&*bytes).map_err(|_| {
70            de::Error::invalid_length(
71                bytes.len(),
72                &format!("{} base64-encoded bytes", N::USIZE).as_str(),
73            )
74        })
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    use generic_array::typenum::*;
83
84    static BYTES: &[u8] = &[1, 2, 3, 4, 5, 6, 7];
85    static BASE64_JSON: &str = "\"AQIDBAUGBw\"";
86
87    fn get_de() -> serde_json::Deserializer<serde_json::de::StrRead<'static>> {
88        serde_json::Deserializer::from_str(BASE64_JSON)
89    }
90
91    #[test]
92    fn test_serde_byte_array_good() {
93        let arr = ByteArray::<U7>::try_from_slice(BYTES).unwrap();
94        let b64 = serde_json::to_string(&arr).unwrap();
95        assert_eq!(b64, BASE64_JSON);
96        let bytes: ByteArray<U7> = serde_json::from_str(&b64).unwrap();
97        assert_eq!(bytes.as_ref(), BYTES);
98    }
99
100    #[test]
101    fn test_serde_deserialize_byte_array_invalid() {
102        let mut de = serde_json::Deserializer::from_str("\"Z\"");
103        ByteArray::<U0>::deserialize(&mut de).unwrap_err();
104    }
105
106    #[test]
107    fn test_serde_base64_deserialize_array_long() {
108        ByteArray::<U6>::deserialize(&mut get_de()).unwrap_err();
109    }
110
111    #[test]
112    fn test_serde_base64_deserialize_array_short() {
113        ByteArray::<U8>::deserialize(&mut get_de()).unwrap_err();
114    }
115}