mls_rs_crypto_openssl/
ecdh.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use std::ops::Deref;
6
7use mls_rs_crypto_traits::{Curve, DhType, SamplingMethod};
8use thiserror::Error;
9
10use mls_rs_core::{
11    crypto::{CipherSuite, HpkePublicKey, HpkeSecretKey},
12    error::IntoAnyError,
13};
14
15use crate::ec::{
16    generate_keypair, private_key_bytes_to_public, private_key_ecdh, private_key_from_bytes,
17    pub_key_from_uncompressed, EcError, EcPublicKey,
18};
19
20#[derive(Debug, Error)]
21pub enum EcdhKemError {
22    #[error(transparent)]
23    OpensslError(#[from] openssl::error::ErrorStack),
24    #[error(transparent)]
25    EcError(#[from] EcError),
26    #[error("unsupported cipher suite")]
27    UnsupportedCipherSuite,
28}
29
30impl IntoAnyError for EcdhKemError {
31    fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
32        Ok(self.into())
33    }
34}
35
36#[derive(Clone, Debug, Eq, PartialEq)]
37pub struct Ecdh(Curve);
38
39impl Deref for Ecdh {
40    type Target = Curve;
41
42    fn deref(&self) -> &Self::Target {
43        &self.0
44    }
45}
46
47impl Ecdh {
48    pub fn new(cipher_suite: CipherSuite) -> Option<Self> {
49        Curve::from_ciphersuite(cipher_suite, false).map(Self)
50    }
51}
52
53#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
54#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
55#[cfg_attr(
56    all(not(target_arch = "wasm32"), mls_build_async),
57    maybe_async::must_be_async
58)]
59impl DhType for Ecdh {
60    type Error = EcdhKemError;
61
62    async fn dh(
63        &self,
64        secret_key: &HpkeSecretKey,
65        public_key: &HpkePublicKey,
66    ) -> Result<Vec<u8>, Self::Error> {
67        Ok(private_key_ecdh(
68            &private_key_from_bytes(secret_key, self.0, false)?,
69            &self.to_ec_public_key(public_key)?,
70        )?)
71    }
72
73    async fn to_public(&self, secret_key: &HpkeSecretKey) -> Result<HpkePublicKey, Self::Error> {
74        Ok(private_key_bytes_to_public(secret_key, self.0)?.into())
75    }
76
77    async fn generate(&self) -> Result<(HpkeSecretKey, HpkePublicKey), Self::Error> {
78        let key_pair = generate_keypair(self.0)?;
79        Ok((key_pair.secret.into(), key_pair.public.into()))
80    }
81
82    fn bitmask_for_rejection_sampling(&self) -> SamplingMethod {
83        self.hpke_sampling_method()
84    }
85
86    fn public_key_validate(&self, key: &HpkePublicKey) -> Result<(), Self::Error> {
87        self.to_ec_public_key(key).map(|_| ())
88    }
89
90    fn secret_key_size(&self) -> usize {
91        self.0.secret_key_size()
92    }
93
94    fn public_key_size(&self) -> usize {
95        self.0.secret_key_size()
96    }
97}
98
99impl Ecdh {
100    fn to_ec_public_key(&self, public_key: &HpkePublicKey) -> Result<EcPublicKey, EcdhKemError> {
101        Ok(pub_key_from_uncompressed(public_key, self.0)?)
102    }
103}
104
105#[cfg(all(test, not(mls_build_async)))]
106mod test {
107    use mls_rs_core::crypto::{CipherSuite, HpkePublicKey, HpkeSecretKey};
108    use mls_rs_crypto_traits::DhType;
109    use serde::Deserialize;
110
111    use crate::ecdh::Ecdh;
112
113    fn get_ecdhs() -> Vec<Ecdh> {
114        [
115            CipherSuite::P256_AES128,
116            CipherSuite::P384_AES256,
117            CipherSuite::P521_AES256,
118            CipherSuite::CURVE25519_AES128,
119            CipherSuite::CURVE448_AES256,
120        ]
121        .into_iter()
122        .map(|v| Ecdh::new(v).unwrap())
123        .collect()
124    }
125
126    #[derive(Deserialize)]
127    struct TestCase {
128        pub ciphersuite: u16,
129        #[serde(with = "hex::serde")]
130        pub alice_pub: Vec<u8>,
131        #[serde(with = "hex::serde")]
132        pub alice_pri: Vec<u8>,
133        #[serde(with = "hex::serde")]
134        pub bob_pub: Vec<u8>,
135        #[serde(with = "hex::serde")]
136        pub bob_pri: Vec<u8>,
137        #[serde(with = "hex::serde")]
138        pub shared_secret: Vec<u8>,
139    }
140
141    fn run_test_case(test_case: TestCase) {
142        println!(
143            "Running ECDH test for ciphersuite: {:?}",
144            test_case.ciphersuite
145        );
146
147        let ecdh = Ecdh::new(test_case.ciphersuite.into()).unwrap();
148
149        // Import the keys into their structures
150        let alice_pub: HpkePublicKey = test_case.alice_pub.into();
151        let bob_pub: HpkePublicKey = test_case.bob_pub.into();
152        let alice_pri: HpkeSecretKey = test_case.alice_pri.into();
153        let bob_pri: HpkeSecretKey = test_case.bob_pri.into();
154
155        assert_eq!(ecdh.to_public(&alice_pri).unwrap(), alice_pub);
156        assert_eq!(ecdh.to_public(&bob_pri).unwrap(), bob_pub);
157
158        assert_eq!(
159            ecdh.dh(&alice_pri, &bob_pub).unwrap(),
160            test_case.shared_secret
161        );
162
163        assert_eq!(
164            ecdh.dh(&bob_pri, &alice_pub).unwrap(),
165            test_case.shared_secret
166        );
167    }
168
169    #[test]
170    fn test_algo_test_cases() {
171        let test_case_file = include_str!("../test_data/test_ecdh.json");
172        let test_cases: Vec<TestCase> = serde_json::from_str(test_case_file).unwrap();
173
174        for case in test_cases {
175            run_test_case(case);
176        }
177    }
178
179    #[test]
180    fn test_mismatched_curve() {
181        for ecdh in get_ecdhs() {
182            println!("Testing mismatched curve error for {:?}", *ecdh);
183
184            let secret_key = ecdh.generate().unwrap().0;
185
186            for other_ecdh in get_ecdhs().into_iter().filter(|c| c != &ecdh) {
187                let other_public_key = other_ecdh.generate().unwrap().1;
188                assert!(ecdh.dh(&secret_key, &other_public_key).is_err());
189            }
190        }
191    }
192}