askar_crypto/buffer/
array.rs1use core::{
2 array::TryFromSliceError,
3 fmt::{self, Debug, Formatter},
4 hash,
5 marker::{PhantomData, PhantomPinned},
6 ops::Deref,
7};
8
9use crate::generic_array::{ArrayLength, GenericArray};
10use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
11use subtle::{Choice, ConstantTimeEq};
12use zeroize::Zeroize;
13
14use super::HexRepr;
15use crate::{
16 error::Error,
17 kdf::{FromKeyDerivation, KeyDerivation},
18 random::KeyMaterial,
19};
20
21#[derive(Clone)]
23#[repr(transparent)]
24pub struct ArrayKey<L: ArrayLength<u8>>(
25 GenericArray<u8, L>,
26 PhantomPinned,
28);
29
30impl<L: ArrayLength<u8>> ArrayKey<L> {
31 pub const SIZE: usize = L::USIZE;
33
34 #[inline]
36 pub fn generate(mut rng: impl KeyMaterial) -> Self {
37 Self::new_with(|buf| rng.read_okm(buf))
38 }
39
40 pub fn new_with(f: impl FnOnce(&mut [u8])) -> Self {
42 let mut slf = Self::default();
43 f(slf.0.as_mut());
44 slf
45 }
46
47 pub fn try_new_with<E>(f: impl FnOnce(&mut [u8]) -> Result<(), E>) -> Result<Self, E> {
49 let mut slf = Self::default();
50 f(slf.0.as_mut())?;
51 Ok(slf)
52 }
53
54 pub fn temp<R>(f: impl FnOnce(&mut GenericArray<u8, L>) -> R) -> R {
56 let mut slf = Self::default();
57 f(&mut slf.0)
58 }
59
60 #[inline]
62 pub fn extract(self) -> GenericArray<u8, L> {
63 self.0.clone()
64 }
65
66 #[inline]
70 pub fn from_slice(data: &[u8]) -> Self {
71 Self::from(GenericArray::from_slice(data))
72 }
73
74 #[inline]
76 pub fn len() -> usize {
77 Self::SIZE
78 }
79
80 #[cfg(feature = "getrandom")]
82 #[inline]
83 pub fn random() -> Self {
84 Self::generate(crate::random::default_rng())
85 }
86
87 pub fn as_hex(&self) -> HexRepr<&[u8]> {
89 HexRepr(self.0.as_ref())
90 }
91}
92
93impl<L: ArrayLength<u8>> AsRef<GenericArray<u8, L>> for ArrayKey<L> {
94 #[inline(always)]
95 fn as_ref(&self) -> &GenericArray<u8, L> {
96 &self.0
97 }
98}
99
100impl<L: ArrayLength<u8>> Deref for ArrayKey<L> {
101 type Target = [u8];
102
103 #[inline(always)]
104 fn deref(&self) -> &[u8] {
105 self.0.as_ref()
106 }
107}
108
109impl<L: ArrayLength<u8>> Default for ArrayKey<L> {
110 #[inline(always)]
111 fn default() -> Self {
112 Self(GenericArray::default(), PhantomPinned)
113 }
114}
115
116impl<L: ArrayLength<u8>> From<&GenericArray<u8, L>> for ArrayKey<L> {
117 #[inline(always)]
118 fn from(key: &GenericArray<u8, L>) -> Self {
119 Self(key.clone(), PhantomPinned)
120 }
121}
122
123impl<L: ArrayLength<u8>> From<GenericArray<u8, L>> for ArrayKey<L> {
124 #[inline(always)]
125 fn from(key: GenericArray<u8, L>) -> Self {
126 Self(key, PhantomPinned)
127 }
128}
129
130impl<'a, L: ArrayLength<u8>, const N: usize> TryFrom<&'a ArrayKey<L>> for &'a [u8; N] {
131 type Error = TryFromSliceError;
132
133 #[inline(always)]
134 fn try_from(key: &ArrayKey<L>) -> Result<&[u8; N], TryFromSliceError> {
135 key.0.as_slice().try_into()
136 }
137}
138
139impl<L: ArrayLength<u8>> Debug for ArrayKey<L> {
140 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
141 if cfg!(test) {
142 f.debug_tuple("ArrayKey").field(&self.0).finish()
143 } else {
144 f.debug_tuple("ArrayKey").field(&"<secret>").finish()
145 }
146 }
147}
148
149impl<L: ArrayLength<u8>> ConstantTimeEq for ArrayKey<L> {
150 fn ct_eq(&self, other: &Self) -> Choice {
151 ConstantTimeEq::ct_eq(self.0.as_ref(), other.0.as_ref())
152 }
153}
154
155impl<L: ArrayLength<u8>> PartialEq for ArrayKey<L> {
156 #[inline]
157 fn eq(&self, other: &Self) -> bool {
158 self.ct_eq(other).into()
159 }
160}
161impl<L: ArrayLength<u8>> Eq for ArrayKey<L> {}
162
163impl<L: ArrayLength<u8>> hash::Hash for ArrayKey<L> {
164 fn hash<H: hash::Hasher>(&self, state: &mut H) {
165 self.0.hash(state);
166 }
167}
168
169impl<L: ArrayLength<u8>> Serialize for ArrayKey<L> {
170 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
171 where
172 S: Serializer,
173 {
174 serializer.serialize_bytes(self.as_ref())
175 }
176}
177
178impl<'de, L: ArrayLength<u8>> Deserialize<'de> for ArrayKey<L> {
179 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
180 where
181 D: Deserializer<'de>,
182 {
183 deserializer.deserialize_bytes(KeyVisitor { _pd: PhantomData })
184 }
185}
186
187impl<L: ArrayLength<u8>> Zeroize for ArrayKey<L> {
188 fn zeroize(&mut self) {
189 self.0.zeroize();
190 }
191}
192
193impl<L: ArrayLength<u8>> Drop for ArrayKey<L> {
194 fn drop(&mut self) {
195 self.zeroize();
196 }
197}
198
199struct KeyVisitor<L: ArrayLength<u8>> {
200 _pd: PhantomData<L>,
201}
202
203impl<L: ArrayLength<u8>> de::Visitor<'_> for KeyVisitor<L> {
204 type Value = ArrayKey<L>;
205
206 fn expecting(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
207 formatter.write_str("ArrayKey")
208 }
209
210 fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
211 where
212 E: de::Error,
213 {
214 if value.len() != L::USIZE {
215 return Err(E::invalid_length(value.len(), &self));
216 }
217 Ok(ArrayKey::from_slice(value))
218 }
219}
220
221impl<L: ArrayLength<u8>> FromKeyDerivation for ArrayKey<L> {
222 fn from_key_derivation<D: KeyDerivation>(mut derive: D) -> Result<Self, Error>
223 where
224 Self: Sized,
225 {
226 Self::try_new_with(|buf| derive.derive_key_bytes(buf))
227 }
228}