1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#![no_std]
extern crate digest;
extern crate hmac;
#[cfg(feature = "std")] extern crate std;
use digest::{BlockInput, FixedOutput, Input, Reset};
use digest::generic_array::{self, ArrayLength, GenericArray};
use hmac::{Hmac, Mac};
use core::fmt;
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct InvalidLength;
#[derive(Clone)]
pub struct Hkdf<D>
where D: Input + BlockInput + FixedOutput + Reset + Default + Clone,
D::OutputSize: ArrayLength<u8>,
{
pub prk: GenericArray<u8, D::OutputSize>,
}
impl<D> Hkdf<D>
where D: Input + BlockInput + FixedOutput + Reset + Default + Clone,
D::BlockSize: ArrayLength<u8> + Clone,
D::OutputSize: ArrayLength<u8>,
{
pub fn extract(salt: Option<&[u8]>, ikm: &[u8]) -> Hkdf<D> {
let mut hmac = match salt {
Some(s) => Hmac::<D>::new_varkey(s).expect("HMAC can take a key of any size"),
None => Hmac::<D>::new(&Default::default()),
};
hmac.input(ikm);
Hkdf {
prk: hmac.result().code(),
}
}
pub fn expand(&self, info: &[u8], okm: &mut [u8]) -> Result<(), InvalidLength> {
use generic_array::typenum::Unsigned;
let mut prev: Option<GenericArray<u8, <D as digest::FixedOutput>::OutputSize>> = None;
let hmac_output_bytes = D::OutputSize::to_usize();
if okm.len() > hmac_output_bytes * 255 {
return Err(InvalidLength);
}
let mut hmac = Hmac::<D>::new_varkey(&self.prk).unwrap();
for (blocknum, okm_block) in okm.chunks_mut(hmac_output_bytes).enumerate() {
let block_len = okm_block.len();
if let Some(ref prev) = prev { hmac.input(prev) };
hmac.input(info);
hmac.input(&[blocknum as u8 + 1]);
let output = hmac.result_reset().code();
okm_block.copy_from_slice(&output[..block_len]);
prev = Some(output);
}
Ok(())
}
}
impl fmt::Display for InvalidLength {
fn fmt(&self, f: & mut fmt::Formatter) -> Result<(), fmt::Error> {
f.write_str("invalid number of blocks, too large output")
}
}
#[cfg(feature = "std")]
impl ::std::error::Error for InvalidLength {}