mls_rs_crypto_rustcrypto/
kdf.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use core::fmt::Debug;
6
7use hkdf::SimpleHkdf;
8use mls_rs_core::{crypto::CipherSuite, error::IntoAnyError};
9use mls_rs_crypto_traits::{KdfId, KdfType};
10use sha2::{Sha256, Sha384, Sha512};
11
12use alloc::vec;
13use alloc::vec::Vec;
14
15#[derive(Debug)]
16#[cfg_attr(feature = "std", derive(thiserror::Error))]
17pub enum KdfError {
18    #[cfg_attr(feature = "std", error("invalid prk length"))]
19    InvalidPrkLength,
20    #[cfg_attr(feature = "std", error("invalid length"))]
21    InvalidLength,
22    #[cfg_attr(
23        feature = "std",
24        error("the provided length of the key {0} is shorter than the minimum length {1}")
25    )]
26    TooShortKey(usize, usize),
27    #[cfg_attr(feature = "std", error("unsupported cipher suite"))]
28    UnsupportedCipherSuite,
29}
30
31impl From<hkdf::InvalidPrkLength> for KdfError {
32    fn from(_value: hkdf::InvalidPrkLength) -> Self {
33        KdfError::InvalidPrkLength
34    }
35}
36
37impl From<hkdf::InvalidLength> for KdfError {
38    fn from(_value: hkdf::InvalidLength) -> Self {
39        KdfError::InvalidLength
40    }
41}
42
43impl IntoAnyError for KdfError {
44    #[cfg(feature = "std")]
45    fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
46        Ok(self.into())
47    }
48}
49
50#[derive(Clone, Copy, Debug, Eq, PartialEq)]
51pub struct Kdf(KdfId);
52
53impl Kdf {
54    pub fn new(cipher_suite: CipherSuite) -> Option<Self> {
55        KdfId::new(cipher_suite).map(Self)
56    }
57}
58
59#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
60#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
61#[cfg_attr(
62    all(not(target_arch = "wasm32"), mls_build_async),
63    maybe_async::must_be_async
64)]
65impl KdfType for Kdf {
66    type Error = KdfError;
67
68    async fn expand(&self, prk: &[u8], info: &[u8], len: usize) -> Result<Vec<u8>, KdfError> {
69        if prk.len() < self.extract_size() {
70            return Err(KdfError::TooShortKey(prk.len(), self.extract_size()));
71        }
72
73        let mut buf = vec![0u8; len];
74
75        match self.0 {
76            KdfId::HkdfSha256 => Ok(SimpleHkdf::<Sha256>::from_prk(prk)?.expand(info, &mut buf)?),
77            KdfId::HkdfSha384 => Ok(SimpleHkdf::<Sha384>::from_prk(prk)?.expand(info, &mut buf)?),
78            KdfId::HkdfSha512 => Ok(SimpleHkdf::<Sha512>::from_prk(prk)?.expand(info, &mut buf)?),
79            _ => Err(KdfError::UnsupportedCipherSuite),
80        }?;
81
82        Ok(buf)
83    }
84
85    async fn extract(&self, salt: &[u8], ikm: &[u8]) -> Result<Vec<u8>, KdfError> {
86        if ikm.is_empty() {
87            return Err(KdfError::TooShortKey(0, 1));
88        }
89
90        let salt = if salt.is_empty() { None } else { Some(salt) };
91
92        match self.0 {
93            KdfId::HkdfSha256 => Ok(SimpleHkdf::<Sha256>::extract(salt, ikm).0.to_vec()),
94            KdfId::HkdfSha384 => Ok(SimpleHkdf::<Sha384>::extract(salt, ikm).0.to_vec()),
95            KdfId::HkdfSha512 => Ok(SimpleHkdf::<Sha512>::extract(salt, ikm).0.to_vec()),
96            _ => Err(KdfError::UnsupportedCipherSuite),
97        }
98    }
99
100    fn extract_size(&self) -> usize {
101        self.0.extract_size()
102    }
103
104    fn kdf_id(&self) -> u16 {
105        self.0 as u16
106    }
107}
108
109#[cfg(all(test, not(mls_build_async)))]
110mod test {
111    use assert_matches::assert_matches;
112    use mls_rs_core::crypto::CipherSuite;
113    use mls_rs_crypto_traits::KdfType;
114
115    use crate::kdf::{Kdf, KdfError};
116
117    use alloc::vec;
118
119    #[test]
120    fn no_key() {
121        let kdf = Kdf::new(CipherSuite::CURVE25519_AES128).unwrap();
122        assert!(kdf.extract(b"key", &[]).is_err());
123    }
124
125    #[test]
126    fn no_salt() {
127        let kdf = Kdf::new(CipherSuite::CURVE25519_AES128).unwrap();
128        assert!(kdf.extract(&[], b"key").is_ok());
129    }
130
131    #[test]
132    fn no_info() {
133        let kdf = Kdf::new(CipherSuite::CURVE25519_AES128).unwrap();
134        let key = vec![0u8; kdf.extract_size()];
135        assert!(kdf.expand(&key, &[], 42).is_ok());
136    }
137
138    #[test]
139    fn test_short_key() {
140        let kdf = Kdf::new(CipherSuite::CURVE25519_AES128).unwrap();
141        let key = vec![0u8; kdf.extract_size() - 1];
142
143        assert_matches!(kdf.expand(&key, &[], 42), Err(KdfError::TooShortKey(_, _)));
144    }
145}