1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use crate::algorithm::{Algorithm, HashAlgorithm, HkdfAlgorithm};
use crate::errors::Error;
use ring::hkdf::{Prk, Salt};

#[macro_export]
macro_rules! none {
    () => {
        b""
    };
}

pub struct HkdfWrapper {
    salt: Salt,
    algo: HkdfAlgorithm,
}

impl HkdfWrapper {
    pub fn new(salt: &[u8], algo: HkdfAlgorithm) -> Self {
        Self {
            salt: Salt::new(algo.to_hkdf(), salt),
            algo,
        }
    }

    pub fn extract(&self, ikm: &[u8]) -> Prk {
        self.salt.extract(ikm)
    }

    pub fn get_okm_len(&self) -> usize {
        self.algo.output_length()
    }

    pub fn expand(&self, ikm: &[u8], info: &[u8]) -> Result<Vec<u8>, Error> {
        let algo = &self.algo;
        let prk = self.extract(ikm);

        let mut okm = vec![0u8; self.get_okm_len()];
        let okm_slice = &mut okm[..];
        prk.expand(&[info], Algorithm::from_hkdf(algo).to_hmac())
            .map_err(|_| Error::HkdfExpandError)?
            .fill(okm_slice)
            .map_err(|_| Error::HkdfFillError)?;
        Ok(okm)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::algorithm::HkdfAlgorithm;
    use ring::rand::SecureRandom;

    fn bytes_to_hex(bytes: &[u8]) -> String {
        bytes.iter().map(|b| format!("{:02x}", b)).collect()
    }

    fn get_random_bytes(len: usize) -> Vec<u8> {
        let rng = ring::rand::SystemRandom::new();
        let mut bytes = vec![0u8; len];
        rng.fill(&mut bytes)
            .expect("Failed to generate random bytes");
        bytes
    }

    #[test]
    fn test_empty_key_hkdf_expand() {
        let salt = none!();
        let ikm = b"";
        let info = none!();
        let hkdf = HkdfWrapper::new(salt, HkdfAlgorithm::SHA1);
        let okm = hkdf.expand(ikm, info).unwrap();

        println!("sha1 okm: {}", bytes_to_hex(&okm));
        assert_eq!(okm.len(), hkdf.get_okm_len());
    }

    #[test]
    fn test_hdkf_expand_with_salt() {
        let salt = get_random_bytes(32);
        let ikm = b"";
        let info = none!();
        let hkdf = HkdfWrapper::new(&salt, HkdfAlgorithm::SHA256);
        let okm = hkdf.expand(ikm, info).unwrap();

        println!("sha256 okm: {}", bytes_to_hex(&okm));
        assert_eq!(okm.len(), hkdf.get_okm_len());
    }

    #[test]
    fn test_hdkf_expand_with_ikm() {
        let salt = none!();
        let ikm = b"kjhjason";
        let info = none!();
        let hkdf = HkdfWrapper::new(salt, HkdfAlgorithm::SHA384);
        let okm = hkdf.expand(ikm.as_ref(), info).unwrap();

        println!("sha384 okm: {}", bytes_to_hex(&okm));
        assert_eq!(okm.len(), hkdf.get_okm_len());
    }

    #[test]
    fn test_hdkf_expand_with_info() {
        let salt = none!();
        let ikm = b"";
        let info = b"kjhjason";
        let hkdf = HkdfWrapper::new(salt, HkdfAlgorithm::SHA512);
        let okm = hkdf.expand(ikm, info).unwrap();

        println!("sha512 okm: {}", bytes_to_hex(&okm));
        assert_eq!(okm.len(), hkdf.get_okm_len());
    }

    #[test]
    fn test_hdkf_expand_with_all() {
        let salt = b"kjhjason.com";
        let ikm = b"jason";
        let info = b"kjhjason";
        let hkdf = HkdfWrapper::new(salt, HkdfAlgorithm::SHA256);
        let okm = hkdf.expand(ikm, info).unwrap();

        println!("sha256 okm: {}", bytes_to_hex(&okm));
        assert_eq!(okm.len(), hkdf.get_okm_len());
    }
}