Skip to main content

rns_crypto/
hkdf.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use crate::hmac::hmac_sha256;
5
6#[derive(Debug, PartialEq)]
7pub enum HkdfError {
8    InvalidLength,
9    EmptyInput,
10}
11
12impl fmt::Display for HkdfError {
13    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
14        match self {
15            HkdfError::InvalidLength => write!(f, "Invalid output key length"),
16            HkdfError::EmptyInput => write!(f, "Cannot derive key from empty input material"),
17        }
18    }
19}
20
21/// Custom HKDF implementation matching RNS/Cryptography/HKDF.py.
22/// WARNING: This is NOT RFC 5869. The counter wraps modulo 256.
23pub fn hkdf(
24    length: usize,
25    derive_from: &[u8],
26    salt: Option<&[u8]>,
27    context: Option<&[u8]>,
28) -> Result<Vec<u8>, HkdfError> {
29    let hash_len: usize = 32;
30
31    if length < 1 {
32        return Err(HkdfError::InvalidLength);
33    }
34
35    if derive_from.is_empty() {
36        return Err(HkdfError::EmptyInput);
37    }
38
39    let salt = match salt {
40        Some(s) if !s.is_empty() => s.to_vec(),
41        _ => alloc::vec![0u8; hash_len],
42    };
43
44    let context = context.unwrap_or(b"");
45
46    // Extract
47    let prk = hmac_sha256(&salt, derive_from);
48
49    // Expand
50    let mut block: Vec<u8> = Vec::new();
51    let mut derived = Vec::with_capacity(length);
52
53    let iterations = (length + hash_len - 1) / hash_len;
54    for i in 0..iterations {
55        let mut input = Vec::new();
56        input.extend_from_slice(&block);
57        input.extend_from_slice(context);
58        input.push(((i + 1) % 256) as u8);
59
60        block = hmac_sha256(&prk, &input).to_vec();
61        derived.extend_from_slice(&block);
62    }
63
64    derived.truncate(length);
65    Ok(derived)
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71
72    #[test]
73    fn test_hkdf_32bytes() {
74        let ikm = b"input key material";
75        let salt = b"salt value";
76        let result = hkdf(32, ikm, Some(salt), None).unwrap();
77        assert_eq!(result.len(), 32);
78    }
79
80    #[test]
81    fn test_hkdf_64bytes() {
82        let ikm = b"input key material";
83        let salt = b"salt value";
84        let result = hkdf(64, ikm, Some(salt), None).unwrap();
85        assert_eq!(result.len(), 64);
86    }
87
88    #[test]
89    fn test_hkdf_with_context() {
90        let ikm = b"input key material";
91        let salt = b"salt";
92        let ctx = b"context info";
93        let result = hkdf(32, ikm, Some(salt), Some(ctx)).unwrap();
94        assert_eq!(result.len(), 32);
95        // With context should differ from without
96        let result2 = hkdf(32, ikm, Some(salt), None).unwrap();
97        assert_ne!(result, result2);
98    }
99
100    #[test]
101    fn test_hkdf_none_salt() {
102        let ikm = b"input key material";
103        let result = hkdf(32, ikm, None, None).unwrap();
104        assert_eq!(result.len(), 32);
105    }
106
107    #[test]
108    fn test_hkdf_empty_salt() {
109        let ikm = b"input key material";
110        let result1 = hkdf(32, ikm, Some(b""), None).unwrap();
111        let result2 = hkdf(32, ikm, None, None).unwrap();
112        // Empty salt and None salt should produce same result
113        assert_eq!(result1, result2);
114    }
115
116    #[test]
117    fn test_hkdf_invalid_length() {
118        assert_eq!(hkdf(0, b"ikm", None, None), Err(HkdfError::InvalidLength));
119    }
120
121    #[test]
122    fn test_hkdf_empty_ikm() {
123        assert_eq!(hkdf(32, b"", None, None), Err(HkdfError::EmptyInput));
124    }
125}