mls_rs_crypto_rustcrypto/
kdf.rs1use core::fmt::Debug;
6
7use hkdf::SimpleHkdf;
8use mls_rs_core::{crypto::CipherSuite, error::IntoAnyError};
9use mls_rs_crypto_traits::{KdfId, KdfType};
10use sha2::{Sha256, Sha384, Sha512};
11
12use alloc::vec;
13use alloc::vec::Vec;
14
15#[derive(Debug)]
16#[cfg_attr(feature = "std", derive(thiserror::Error))]
17pub enum KdfError {
18 #[cfg_attr(feature = "std", error("invalid prk length"))]
19 InvalidPrkLength,
20 #[cfg_attr(feature = "std", error("invalid length"))]
21 InvalidLength,
22 #[cfg_attr(
23 feature = "std",
24 error("the provided length of the key {0} is shorter than the minimum length {1}")
25 )]
26 TooShortKey(usize, usize),
27 #[cfg_attr(feature = "std", error("unsupported cipher suite"))]
28 UnsupportedCipherSuite,
29}
30
31impl From<hkdf::InvalidPrkLength> for KdfError {
32 fn from(_value: hkdf::InvalidPrkLength) -> Self {
33 KdfError::InvalidPrkLength
34 }
35}
36
37impl From<hkdf::InvalidLength> for KdfError {
38 fn from(_value: hkdf::InvalidLength) -> Self {
39 KdfError::InvalidLength
40 }
41}
42
43impl IntoAnyError for KdfError {
44 #[cfg(feature = "std")]
45 fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
46 Ok(self.into())
47 }
48}
49
50#[derive(Clone, Copy, Debug, Eq, PartialEq)]
51pub struct Kdf(KdfId);
52
53impl Kdf {
54 pub fn new(cipher_suite: CipherSuite) -> Option<Self> {
55 KdfId::new(cipher_suite).map(Self)
56 }
57}
58
59#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
60#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
61#[cfg_attr(
62 all(not(target_arch = "wasm32"), mls_build_async),
63 maybe_async::must_be_async
64)]
65impl KdfType for Kdf {
66 type Error = KdfError;
67
68 async fn expand(&self, prk: &[u8], info: &[u8], len: usize) -> Result<Vec<u8>, KdfError> {
69 if prk.len() < self.extract_size() {
70 return Err(KdfError::TooShortKey(prk.len(), self.extract_size()));
71 }
72
73 let mut buf = vec![0u8; len];
74
75 match self.0 {
76 KdfId::HkdfSha256 => Ok(SimpleHkdf::<Sha256>::from_prk(prk)?.expand(info, &mut buf)?),
77 KdfId::HkdfSha384 => Ok(SimpleHkdf::<Sha384>::from_prk(prk)?.expand(info, &mut buf)?),
78 KdfId::HkdfSha512 => Ok(SimpleHkdf::<Sha512>::from_prk(prk)?.expand(info, &mut buf)?),
79 _ => Err(KdfError::UnsupportedCipherSuite),
80 }?;
81
82 Ok(buf)
83 }
84
85 async fn extract(&self, salt: &[u8], ikm: &[u8]) -> Result<Vec<u8>, KdfError> {
86 if ikm.is_empty() {
87 return Err(KdfError::TooShortKey(0, 1));
88 }
89
90 let salt = if salt.is_empty() { None } else { Some(salt) };
91
92 match self.0 {
93 KdfId::HkdfSha256 => Ok(SimpleHkdf::<Sha256>::extract(salt, ikm).0.to_vec()),
94 KdfId::HkdfSha384 => Ok(SimpleHkdf::<Sha384>::extract(salt, ikm).0.to_vec()),
95 KdfId::HkdfSha512 => Ok(SimpleHkdf::<Sha512>::extract(salt, ikm).0.to_vec()),
96 _ => Err(KdfError::UnsupportedCipherSuite),
97 }
98 }
99
100 fn extract_size(&self) -> usize {
101 self.0.extract_size()
102 }
103
104 fn kdf_id(&self) -> u16 {
105 self.0 as u16
106 }
107}
108
109#[cfg(all(test, not(mls_build_async)))]
110mod test {
111 use assert_matches::assert_matches;
112 use mls_rs_core::crypto::CipherSuite;
113 use mls_rs_crypto_traits::KdfType;
114
115 use crate::kdf::{Kdf, KdfError};
116
117 use alloc::vec;
118
119 #[test]
120 fn no_key() {
121 let kdf = Kdf::new(CipherSuite::CURVE25519_AES128).unwrap();
122 assert!(kdf.extract(b"key", &[]).is_err());
123 }
124
125 #[test]
126 fn no_salt() {
127 let kdf = Kdf::new(CipherSuite::CURVE25519_AES128).unwrap();
128 assert!(kdf.extract(&[], b"key").is_ok());
129 }
130
131 #[test]
132 fn no_info() {
133 let kdf = Kdf::new(CipherSuite::CURVE25519_AES128).unwrap();
134 let key = vec![0u8; kdf.extract_size()];
135 assert!(kdf.expand(&key, &[], 42).is_ok());
136 }
137
138 #[test]
139 fn test_short_key() {
140 let kdf = Kdf::new(CipherSuite::CURVE25519_AES128).unwrap();
141 let key = vec![0u8; kdf.extract_size() - 1];
142
143 assert_matches!(kdf.expand(&key, &[], 42), Err(KdfError::TooShortKey(_, _)));
144 }
145}