askar_crypto/buffer/
array.rs

1use core::{
2    array::TryFromSliceError,
3    fmt::{self, Debug, Formatter},
4    hash,
5    marker::{PhantomData, PhantomPinned},
6    ops::Deref,
7};
8
9use crate::generic_array::{ArrayLength, GenericArray};
10use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
11use subtle::{Choice, ConstantTimeEq};
12use zeroize::Zeroize;
13
14use super::HexRepr;
15use crate::{
16    error::Error,
17    kdf::{FromKeyDerivation, KeyDerivation},
18    random::KeyMaterial,
19};
20
21/// A secure representation for fixed-length keys
22#[derive(Clone)]
23#[repr(transparent)]
24pub struct ArrayKey<L: ArrayLength<u8>>(
25    GenericArray<u8, L>,
26    // ensure that the type does not implement Unpin
27    PhantomPinned,
28);
29
30impl<L: ArrayLength<u8>> ArrayKey<L> {
31    /// The array length in bytes
32    pub const SIZE: usize = L::USIZE;
33
34    /// Create a new buffer from a random data source
35    #[inline]
36    pub fn generate(mut rng: impl KeyMaterial) -> Self {
37        Self::new_with(|buf| rng.read_okm(buf))
38    }
39
40    /// Create a new buffer using an initializer for the data
41    pub fn new_with(f: impl FnOnce(&mut [u8])) -> Self {
42        let mut slf = Self::default();
43        f(slf.0.as_mut());
44        slf
45    }
46
47    /// Create a new buffer using a fallible initializer for the data
48    pub fn try_new_with<E>(f: impl FnOnce(&mut [u8]) -> Result<(), E>) -> Result<Self, E> {
49        let mut slf = Self::default();
50        f(slf.0.as_mut())?;
51        Ok(slf)
52    }
53
54    /// Temporarily allocate and use a key
55    pub fn temp<R>(f: impl FnOnce(&mut GenericArray<u8, L>) -> R) -> R {
56        let mut slf = Self::default();
57        f(&mut slf.0)
58    }
59
60    /// Convert this array to a non-zeroing GenericArray instance
61    #[inline]
62    pub fn extract(self) -> GenericArray<u8, L> {
63        self.0.clone()
64    }
65
66    /// Create a new array instance from a slice of bytes.
67    /// Like <&GenericArray>::from_slice, panics if the length of the slice
68    /// is incorrect.
69    #[inline]
70    pub fn from_slice(data: &[u8]) -> Self {
71        Self::from(GenericArray::from_slice(data))
72    }
73
74    /// Get the length of the array
75    #[inline]
76    pub fn len() -> usize {
77        Self::SIZE
78    }
79
80    /// Create a new array of random bytes
81    #[cfg(feature = "getrandom")]
82    #[inline]
83    pub fn random() -> Self {
84        Self::generate(crate::random::default_rng())
85    }
86
87    /// Get a hex formatter for the key data
88    pub fn as_hex(&self) -> HexRepr<&[u8]> {
89        HexRepr(self.0.as_ref())
90    }
91}
92
93impl<L: ArrayLength<u8>> AsRef<GenericArray<u8, L>> for ArrayKey<L> {
94    #[inline(always)]
95    fn as_ref(&self) -> &GenericArray<u8, L> {
96        &self.0
97    }
98}
99
100impl<L: ArrayLength<u8>> Deref for ArrayKey<L> {
101    type Target = [u8];
102
103    #[inline(always)]
104    fn deref(&self) -> &[u8] {
105        self.0.as_ref()
106    }
107}
108
109impl<L: ArrayLength<u8>> Default for ArrayKey<L> {
110    #[inline(always)]
111    fn default() -> Self {
112        Self(GenericArray::default(), PhantomPinned)
113    }
114}
115
116impl<L: ArrayLength<u8>> From<&GenericArray<u8, L>> for ArrayKey<L> {
117    #[inline(always)]
118    fn from(key: &GenericArray<u8, L>) -> Self {
119        Self(key.clone(), PhantomPinned)
120    }
121}
122
123impl<L: ArrayLength<u8>> From<GenericArray<u8, L>> for ArrayKey<L> {
124    #[inline(always)]
125    fn from(key: GenericArray<u8, L>) -> Self {
126        Self(key, PhantomPinned)
127    }
128}
129
130impl<'a, L: ArrayLength<u8>, const N: usize> TryFrom<&'a ArrayKey<L>> for &'a [u8; N] {
131    type Error = TryFromSliceError;
132
133    #[inline(always)]
134    fn try_from(key: &ArrayKey<L>) -> Result<&[u8; N], TryFromSliceError> {
135        key.0.as_slice().try_into()
136    }
137}
138
139impl<L: ArrayLength<u8>> Debug for ArrayKey<L> {
140    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
141        if cfg!(test) {
142            f.debug_tuple("ArrayKey").field(&self.0).finish()
143        } else {
144            f.debug_tuple("ArrayKey").field(&"<secret>").finish()
145        }
146    }
147}
148
149impl<L: ArrayLength<u8>> ConstantTimeEq for ArrayKey<L> {
150    fn ct_eq(&self, other: &Self) -> Choice {
151        ConstantTimeEq::ct_eq(self.0.as_ref(), other.0.as_ref())
152    }
153}
154
155impl<L: ArrayLength<u8>> PartialEq for ArrayKey<L> {
156    #[inline]
157    fn eq(&self, other: &Self) -> bool {
158        self.ct_eq(other).into()
159    }
160}
161impl<L: ArrayLength<u8>> Eq for ArrayKey<L> {}
162
163impl<L: ArrayLength<u8>> hash::Hash for ArrayKey<L> {
164    fn hash<H: hash::Hasher>(&self, state: &mut H) {
165        self.0.hash(state);
166    }
167}
168
169impl<L: ArrayLength<u8>> Serialize for ArrayKey<L> {
170    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
171    where
172        S: Serializer,
173    {
174        serializer.serialize_bytes(self.as_ref())
175    }
176}
177
178impl<'de, L: ArrayLength<u8>> Deserialize<'de> for ArrayKey<L> {
179    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
180    where
181        D: Deserializer<'de>,
182    {
183        deserializer.deserialize_bytes(KeyVisitor { _pd: PhantomData })
184    }
185}
186
187impl<L: ArrayLength<u8>> Zeroize for ArrayKey<L> {
188    fn zeroize(&mut self) {
189        self.0.zeroize();
190    }
191}
192
193impl<L: ArrayLength<u8>> Drop for ArrayKey<L> {
194    fn drop(&mut self) {
195        self.zeroize();
196    }
197}
198
199struct KeyVisitor<L: ArrayLength<u8>> {
200    _pd: PhantomData<L>,
201}
202
203impl<L: ArrayLength<u8>> de::Visitor<'_> for KeyVisitor<L> {
204    type Value = ArrayKey<L>;
205
206    fn expecting(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
207        formatter.write_str("ArrayKey")
208    }
209
210    fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
211    where
212        E: de::Error,
213    {
214        if value.len() != L::USIZE {
215            return Err(E::invalid_length(value.len(), &self));
216        }
217        Ok(ArrayKey::from_slice(value))
218    }
219}
220
221impl<L: ArrayLength<u8>> FromKeyDerivation for ArrayKey<L> {
222    fn from_key_derivation<D: KeyDerivation>(mut derive: D) -> Result<Self, Error>
223    where
224        Self: Sized,
225    {
226        Self::try_new_with(|buf| derive.derive_key_bytes(buf))
227    }
228}