pykrete_jsonwebkey/
byte_array.rs1use generic_array::{ArrayLength, GenericArray};
2use serde::{
3 de::{self, Deserializer},
4 Deserialize, Serialize,
5};
6use zeroize::{Zeroize, Zeroizing};
7
8#[derive(Clone, PartialEq, Eq, Serialize)]
10#[serde(transparent)]
11pub struct ByteArray<N: ArrayLength<u8>>(
12 #[serde(serialize_with = "crate::utils::serde_base64::serialize")] GenericArray<u8, N>,
13);
14
15impl<N: ArrayLength<u8>> std::fmt::Debug for ByteArray<N> {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 f.write_str(&crate::utils::base64_encode(&self.0))
18 }
19}
20
21impl<N: ArrayLength<u8>, T: Into<GenericArray<u8, N>>> From<T> for ByteArray<N> {
22 fn from(arr: T) -> Self {
23 Self(arr.into())
24 }
25}
26
27impl<N: ArrayLength<u8>> Drop for ByteArray<N> {
28 fn drop(&mut self) {
29 Zeroize::zeroize(self.0.as_mut_slice())
30 }
31}
32
33impl<N: ArrayLength<u8>> AsRef<[u8]> for ByteArray<N> {
34 fn as_ref(&self) -> &[u8] {
35 &self.0
36 }
37}
38
39impl<N: ArrayLength<u8>> std::ops::Deref for ByteArray<N> {
40 type Target = [u8];
41 fn deref(&self) -> &Self::Target {
42 &self.0
43 }
44}
45
46impl<N: ArrayLength<u8>> ByteArray<N> {
47 pub fn from_slice(bytes: impl AsRef<[u8]>) -> Self {
49 Self::try_from_slice(bytes).unwrap()
50 }
51
52 pub fn try_from_slice(bytes: impl AsRef<[u8]>) -> Result<Self, String> {
53 let bytes = bytes.as_ref();
54 if bytes.len() != N::USIZE {
55 Err(format!(
56 "expected {} bytes but got {}",
57 N::USIZE,
58 bytes.len()
59 ))
60 } else {
61 Ok(Self(bytes.iter().copied().collect()))
62 }
63 }
64}
65
66impl<'de, N: ArrayLength<u8>> Deserialize<'de> for ByteArray<N> {
67 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
68 let bytes = Zeroizing::new(crate::utils::serde_base64::deserialize(d)?);
69 Self::try_from_slice(&*bytes).map_err(|_| {
70 de::Error::invalid_length(
71 bytes.len(),
72 &format!("{} base64-encoded bytes", N::USIZE).as_str(),
73 )
74 })
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 use generic_array::typenum::*;
83
84 static BYTES: &[u8] = &[1, 2, 3, 4, 5, 6, 7];
85 static BASE64_JSON: &str = "\"AQIDBAUGBw\"";
86
87 fn get_de() -> serde_json::Deserializer<serde_json::de::StrRead<'static>> {
88 serde_json::Deserializer::from_str(BASE64_JSON)
89 }
90
91 #[test]
92 fn test_serde_byte_array_good() {
93 let arr = ByteArray::<U7>::try_from_slice(BYTES).unwrap();
94 let b64 = serde_json::to_string(&arr).unwrap();
95 assert_eq!(b64, BASE64_JSON);
96 let bytes: ByteArray<U7> = serde_json::from_str(&b64).unwrap();
97 assert_eq!(bytes.as_ref(), BYTES);
98 }
99
100 #[test]
101 fn test_serde_deserialize_byte_array_invalid() {
102 let mut de = serde_json::Deserializer::from_str("\"Z\"");
103 ByteArray::<U0>::deserialize(&mut de).unwrap_err();
104 }
105
106 #[test]
107 fn test_serde_base64_deserialize_array_long() {
108 ByteArray::<U6>::deserialize(&mut get_de()).unwrap_err();
109 }
110
111 #[test]
112 fn test_serde_base64_deserialize_array_short() {
113 ByteArray::<U8>::deserialize(&mut get_de()).unwrap_err();
114 }
115}