1#![forbid(unsafe_code)]
4
5use core::{fmt, iter::IntoIterator, mem, result::Result};
6
7use buggy::Bug;
8use generic_array::{ArrayLength, ConstArrayLength, GenericArray, IntoArrayLength};
9use subtle::{Choice, ConstantTimeEq};
10use typenum::{
11 type_operators::{IsGreaterOrEqual, IsLess},
12 Const, Unsigned, U32, U64, U65536,
13};
14
15use crate::{
16 keys::{RawSecretBytes, SecretKeyBytes},
17 zeroize::{is_zeroize_on_drop, ZeroizeOnDrop},
18};
19
20#[derive(Debug, Eq, PartialEq)]
22pub enum KdfError {
23 OutputTooLong,
26 Bug(Bug),
28}
29
30impl fmt::Display for KdfError {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 match self {
33 Self::OutputTooLong => write!(f, "requested KDF output too long"),
34 Self::Bug(err) => write!(f, "{err}"),
35 }
36 }
37}
38
39impl core::error::Error for KdfError {
40 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
41 match self {
42 Self::Bug(err) => Some(err),
43 _ => None,
44 }
45 }
46}
47
48impl From<Bug> for KdfError {
49 fn from(err: Bug) -> Self {
50 Self::Bug(err)
51 }
52}
53
54pub trait Kdf {
71 type MaxOutput: ArrayLength + IsGreaterOrEqual<U64> + 'static;
78 const MAX_OUTPUT: usize = Self::MaxOutput::USIZE;
85
86 type PrkSize: ArrayLength + IsGreaterOrEqual<U32> + IsLess<U65536> + 'static;
91 const PRK_SIZE: usize = Self::PrkSize::USIZE;
93
94 #[inline]
100 fn extract(ikm: &[u8], salt: &[u8]) -> Prk<Self::PrkSize> {
101 Self::extract_multi([ikm], salt)
102 }
103
104 fn extract_multi<'a, I>(ikm: I, salt: &[u8]) -> Prk<Self::PrkSize>
112 where
113 I: IntoIterator<Item = &'a [u8]>;
114
115 #[inline]
123 fn expand(out: &mut [u8], prk: &Prk<Self::PrkSize>, info: &[u8]) -> Result<(), KdfError> {
124 Self::expand_multi(out, prk, [info])
125 }
126
127 fn expand_multi<'a, I>(
138 out: &mut [u8],
139 prk: &Prk<Self::PrkSize>,
140 info: I,
141 ) -> Result<(), KdfError>
142 where
143 I: IntoIterator<Item = &'a [u8], IntoIter: Clone>;
144
145 #[inline]
157 fn extract_and_expand(
158 out: &mut [u8],
159 ikm: &[u8],
160 salt: &[u8],
161 info: &[u8],
162 ) -> Result<(), KdfError> {
163 Self::extract_and_expand_multi(out, [ikm], salt, [info])
164 }
165
166 fn extract_and_expand_multi<'a, Ikm, Info>(
178 out: &mut [u8],
179 ikm: Ikm,
180 salt: &[u8],
181 info: Info,
182 ) -> Result<(), KdfError>
183 where
184 Ikm: IntoIterator<Item = &'a [u8]>,
185 Info: IntoIterator<Item = &'a [u8], IntoIter: Clone>,
186 {
187 if out.len() > Self::MAX_OUTPUT {
188 Err(KdfError::OutputTooLong)
189 } else {
190 let prk = Self::extract_multi(ikm, salt);
191 Self::expand_multi(out, &prk, info)
192 }
193 }
194}
195
196#[derive(Clone, Default)]
198#[repr(transparent)]
199pub struct Prk<N: ArrayLength>(SecretKeyBytes<N>);
200
201impl<N: ArrayLength> Prk<N> {
202 #[inline]
204 pub const fn new(prk: SecretKeyBytes<N>) -> Self {
205 Self(prk)
206 }
207
208 #[inline]
210 #[allow(clippy::len_without_is_empty)]
211 pub const fn len(&self) -> usize {
212 self.0.len()
213 }
214
215 #[inline]
217 pub const fn as_bytes(&self) -> &[u8] {
218 self.0.as_bytes()
219 }
220
221 pub(crate) fn as_bytes_mut(&mut self) -> &mut [u8] {
223 self.0.as_bytes_mut()
224 }
225
226 #[inline]
228 pub fn into_bytes(mut self) -> SecretKeyBytes<N> {
229 mem::take(&mut self.0)
234 }
235}
236
237impl<N: ArrayLength> ConstantTimeEq for Prk<N> {
238 #[inline]
239 fn ct_eq(&self, other: &Self) -> Choice {
240 self.0.ct_eq(&other.0)
241 }
242}
243
244impl<N: ArrayLength> RawSecretBytes for Prk<N> {
245 #[inline]
246 fn raw_secret_bytes(&self) -> &[u8] {
247 self.as_bytes()
248 }
249}
250
251impl<N: ArrayLength> Expand for Prk<N>
252where
253 N: IsLess<U65536>,
254{
255 type Size = N;
256
257 fn expand_multi<'a, K, I>(prk: &Prk<K::PrkSize>, info: I) -> Result<Self, KdfError>
258 where
259 K: Kdf,
260 I: IntoIterator<Item = &'a [u8]>,
261 I::IntoIter: Clone,
262 {
263 Ok(Self(Expand::expand_multi::<K, I>(prk, info)?))
264 }
265}
266
267impl<N: ArrayLength> fmt::Debug for Prk<N> {
268 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269 f.debug_tuple("Prk").finish_non_exhaustive()
270 }
271}
272
273impl<N: ArrayLength> ZeroizeOnDrop for Prk<N> {}
274impl<N: ArrayLength> Drop for Prk<N> {
275 fn drop(&mut self) {
276 is_zeroize_on_drop(&self.0);
277 }
278}
279
280pub trait Expand: Sized {
282 type Size: ArrayLength + IsLess<U65536> + 'static;
284
285 fn expand<K: Kdf>(prk: &Prk<K::PrkSize>, info: &[u8]) -> Result<Self, KdfError> {
287 Self::expand_multi::<K, _>(prk, [info])
288 }
289
290 fn expand_multi<'a, K, I>(prk: &Prk<K::PrkSize>, info: I) -> Result<Self, KdfError>
292 where
293 K: Kdf,
294 I: IntoIterator<Item = &'a [u8], IntoIter: Clone>;
295}
296
297impl<N: ArrayLength> Expand for GenericArray<u8, N>
298where
299 N: IsLess<U65536>,
300{
301 type Size = N;
302
303 fn expand_multi<'a, K, I>(prk: &Prk<K::PrkSize>, info: I) -> Result<Self, KdfError>
304 where
305 K: Kdf,
306 I: IntoIterator<Item = &'a [u8]>,
307 I::IntoIter: Clone,
308 {
309 let mut out = GenericArray::default();
310 K::expand_multi(&mut out, prk, info)?;
311 Ok(out)
312 }
313}
314
315impl<const N: usize> Expand for [u8; N]
316where
317 Const<N>: IntoArrayLength,
318 ConstArrayLength<N>: IsLess<U65536>,
319{
320 type Size = ConstArrayLength<N>;
321
322 fn expand_multi<'a, K, I>(prk: &Prk<K::PrkSize>, info: I) -> Result<Self, KdfError>
323 where
324 K: Kdf,
325 I: IntoIterator<Item = &'a [u8]>,
326 I::IntoIter: Clone,
327 {
328 let mut out = [0u8; N];
329 K::expand_multi(&mut out, prk, info)?;
330 Ok(out)
331 }
332}