1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use std::{fmt::Debug, ops::Deref};

use mls_rs_core::{crypto::CipherSuite, error::IntoAnyError};
use mls_rs_crypto_traits::{KdfId, KdfType};
use openssl::{
    md::{Md, MdRef},
    pkey::Id,
    pkey_ctx::{HkdfMode, PkeyCtx},
};
use thiserror::Error;

#[derive(Debug, Error)]
pub enum KdfError {
    #[error(transparent)]
    OpensslError(#[from] openssl::error::ErrorStack),
    #[error("the provided length of the key {0} is shorter than the minimum length {1}")]
    TooShortKey(usize, usize),
    #[error("unsupported cipher suite")]
    UnsupportedCipherSuite,
}

impl IntoAnyError for KdfError {
    fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
        Ok(self.into())
    }
}

#[derive(Clone)]
pub struct Kdf {
    message_digest: &'static MdRef,
    kdf_id: KdfId,
}

impl Debug for Kdf {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "Kdf with kdf_id {:?}", self.kdf_id)
    }
}

impl Deref for Kdf {
    type Target = MdRef;

    fn deref(&self) -> &Self::Target {
        self.message_digest
    }
}

impl Kdf {
    pub fn new(cipher_suite: CipherSuite) -> Option<Self> {
        let kdf_id = KdfId::new(cipher_suite)?;

        let message_digest = match kdf_id {
            KdfId::HkdfSha256 => Some(Md::sha256()),
            KdfId::HkdfSha384 => Some(Md::sha384()),
            KdfId::HkdfSha512 => Some(Md::sha512()),
            _ => None,
        }?;

        Some(Self {
            message_digest,
            kdf_id,
        })
    }
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
#[cfg_attr(
    all(not(target_arch = "wasm32"), mls_build_async),
    maybe_async::must_be_async
)]
impl KdfType for Kdf {
    type Error = KdfError;

    /// # Warning
    /// The length of info can *not* exceed 1024 bytes when using the OpenSSL Engine due to underlying
    /// restrictions in OpenSSL. This function will throw an [OpensslError](KdfError::OpensslError)
    /// in the event info is > 1024 bytes.
    async fn expand(&self, prk: &[u8], info: &[u8], len: usize) -> Result<Vec<u8>, KdfError> {
        if prk.len() < self.extract_size() {
            return Err(KdfError::TooShortKey(prk.len(), self.extract_size()));
        }

        let mut ctx = self.create_hkdf_ctx(HkdfMode::EXPAND_ONLY, prk)?;
        ctx.add_hkdf_info(info)?;

        let mut buf = vec![0u8; len];
        ctx.derive(Some(&mut buf))?;
        Ok(buf)
    }

    async fn extract(&self, salt: &[u8], ikm: &[u8]) -> Result<Vec<u8>, KdfError> {
        if ikm.is_empty() {
            return Err(KdfError::TooShortKey(0, 1));
        }

        let mut ctx = self.create_hkdf_ctx(HkdfMode::EXTRACT_ONLY, ikm)?;
        ctx.set_hkdf_salt(salt)?;

        let mut buf = vec![0u8; self.size()];
        ctx.derive(Some(&mut buf))?;
        Ok(buf)
    }

    fn extract_size(&self) -> usize {
        self.size()
    }

    fn kdf_id(&self) -> u16 {
        self.kdf_id as u16
    }
}

impl Kdf {
    fn create_hkdf_ctx(&self, mode: HkdfMode, key: &[u8]) -> Result<PkeyCtx<()>, KdfError> {
        let mut ctx = PkeyCtx::new_id(Id::HKDF)?;
        ctx.derive_init()?;
        ctx.set_hkdf_mode(mode)?;
        ctx.set_hkdf_md(self)?;
        ctx.set_hkdf_key(key)?;

        Ok(ctx)
    }
}

#[cfg(all(test, not(mls_build_async)))]
mod test {
    use assert_matches::assert_matches;
    use mls_rs_core::crypto::CipherSuite;
    use mls_rs_crypto_traits::KdfType;

    use crate::kdf::{Kdf, KdfError};

    #[test]
    fn no_key() {
        let kdf = Kdf::new(CipherSuite::CURVE25519_AES128).unwrap();
        assert!(kdf.extract(b"key", &[]).is_err());
    }

    #[test]
    fn no_salt() {
        let kdf = Kdf::new(CipherSuite::CURVE25519_AES128).unwrap();
        assert!(kdf.extract(&[], b"key").is_ok());
    }

    #[test]
    fn no_info() {
        let kdf = Kdf::new(CipherSuite::CURVE25519_AES128).unwrap();
        let key = vec![0u8; kdf.extract_size()];
        assert!(kdf.expand(&key, &[], 42).is_ok());
    }

    #[test]
    fn test_short_key() {
        let kdf = Kdf::new(CipherSuite::CURVE25519_AES128).unwrap();
        let key = vec![0u8; kdf.extract_size() - 1];

        assert_matches!(kdf.expand(&key, &[], 42), Err(KdfError::TooShortKey(_, _)));
    }
}