mls_rs_crypto_rustcrypto/
ecdh.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use core::ops::Deref;
6
7use alloc::vec::Vec;
8
9use mls_rs_crypto_traits::{Curve, DhType, SamplingMethod};
10
11use mls_rs_core::{
12    crypto::{CipherSuite, HpkePublicKey, HpkeSecretKey},
13    error::IntoAnyError,
14};
15
16use crate::ec::{
17    generate_keypair, private_key_bytes_to_public, private_key_ecdh, private_key_from_bytes,
18    pub_key_from_uncompressed, EcError, EcPublicKey,
19};
20
21#[derive(Debug)]
22#[cfg_attr(feature = "std", derive(thiserror::Error))]
23pub enum EcdhKemError {
24    #[cfg_attr(feature = "std", error(transparent))]
25    EcError(EcError),
26    #[cfg_attr(feature = "std", error("unsupported cipher suite"))]
27    UnsupportedCipherSuite,
28}
29
30impl From<EcError> for EcdhKemError {
31    fn from(e: EcError) -> Self {
32        EcdhKemError::EcError(e)
33    }
34}
35
36impl IntoAnyError for EcdhKemError {
37    #[cfg(feature = "std")]
38    fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
39        Ok(self.into())
40    }
41}
42
43#[derive(Clone, Debug, Eq, PartialEq)]
44pub struct Ecdh(Curve);
45
46impl Deref for Ecdh {
47    type Target = Curve;
48
49    fn deref(&self) -> &Self::Target {
50        &self.0
51    }
52}
53
54impl Ecdh {
55    pub fn new(cipher_suite: CipherSuite) -> Option<Self> {
56        Curve::from_ciphersuite(cipher_suite, false).map(Self)
57    }
58}
59
60#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
61#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
62#[cfg_attr(
63    all(not(target_arch = "wasm32"), mls_build_async),
64    maybe_async::must_be_async
65)]
66impl DhType for Ecdh {
67    type Error = EcdhKemError;
68
69    async fn dh(
70        &self,
71        secret_key: &HpkeSecretKey,
72        public_key: &HpkePublicKey,
73    ) -> Result<Vec<u8>, Self::Error> {
74        Ok(private_key_ecdh(
75            &private_key_from_bytes(secret_key, self.0)?,
76            &self.to_ec_public_key(public_key)?,
77        )?)
78    }
79
80    async fn to_public(&self, secret_key: &HpkeSecretKey) -> Result<HpkePublicKey, Self::Error> {
81        Ok(private_key_bytes_to_public(secret_key, self.0)?.into())
82    }
83
84    async fn generate(&self) -> Result<(HpkeSecretKey, HpkePublicKey), Self::Error> {
85        let key_pair = generate_keypair(self.0)?;
86        Ok((key_pair.secret.into(), key_pair.public.into()))
87    }
88
89    fn bitmask_for_rejection_sampling(&self) -> SamplingMethod {
90        self.hpke_sampling_method()
91    }
92
93    fn public_key_validate(&self, key: &HpkePublicKey) -> Result<(), Self::Error> {
94        self.to_ec_public_key(key).map(|_| ())
95    }
96
97    fn secret_key_size(&self) -> usize {
98        self.0.secret_key_size()
99    }
100
101    fn public_key_size(&self) -> usize {
102        self.0.public_key_size()
103    }
104}
105
106impl Ecdh {
107    fn to_ec_public_key(&self, public_key: &HpkePublicKey) -> Result<EcPublicKey, EcdhKemError> {
108        Ok(pub_key_from_uncompressed(public_key, self.0)?)
109    }
110}
111
112#[cfg(all(test, not(mls_build_async)))]
113mod test {
114    use mls_rs_core::crypto::{CipherSuite, HpkePublicKey, HpkeSecretKey};
115    use mls_rs_crypto_traits::DhType;
116    use serde::Deserialize;
117
118    use alloc::vec::Vec;
119
120    use crate::ecdh::Ecdh;
121
122    fn get_ecdhs() -> Vec<Ecdh> {
123        [CipherSuite::P256_AES128, CipherSuite::CURVE25519_AES128]
124            .into_iter()
125            .map(|c| Ecdh::new(c).unwrap())
126            .collect()
127    }
128
129    #[derive(Deserialize)]
130    struct TestCase {
131        pub ciphersuite: u16,
132        #[serde(with = "hex::serde")]
133        pub alice_pub: Vec<u8>,
134        #[serde(with = "hex::serde")]
135        pub alice_pri: Vec<u8>,
136        #[serde(with = "hex::serde")]
137        pub bob_pub: Vec<u8>,
138        #[serde(with = "hex::serde")]
139        pub bob_pri: Vec<u8>,
140        #[serde(with = "hex::serde")]
141        pub shared_secret: Vec<u8>,
142    }
143
144    fn run_test_case(test_case: TestCase) {
145        let ecdh = Ecdh::new(test_case.ciphersuite.into()).unwrap();
146
147        // Import the keys into their structures
148        let alice_pub: HpkePublicKey = test_case.alice_pub.into();
149        let bob_pub: HpkePublicKey = test_case.bob_pub.into();
150        let alice_pri: HpkeSecretKey = test_case.alice_pri.into();
151        let bob_pri: HpkeSecretKey = test_case.bob_pri.into();
152
153        assert_eq!(ecdh.to_public(&alice_pri).unwrap(), alice_pub);
154        assert_eq!(ecdh.to_public(&bob_pri).unwrap(), bob_pub);
155
156        assert_eq!(
157            ecdh.dh(&alice_pri, &bob_pub).unwrap(),
158            test_case.shared_secret
159        );
160
161        assert_eq!(
162            ecdh.dh(&bob_pri, &alice_pub).unwrap(),
163            test_case.shared_secret
164        );
165    }
166
167    #[test]
168    fn test_algo_test_cases() {
169        let test_case_file = include_str!("../test_data/test_ecdh.json");
170        let test_cases: Vec<TestCase> = serde_json::from_str(test_case_file).unwrap();
171
172        for case in test_cases {
173            run_test_case(case);
174        }
175    }
176
177    #[test]
178    fn test_mismatched_curve() {
179        for ecdh in get_ecdhs() {
180            let secret_key = ecdh.generate().unwrap().0;
181
182            for other_ecdh in get_ecdhs().into_iter().filter(|c| c != &ecdh) {
183                let other_public_key = other_ecdh.generate().unwrap().1;
184                assert!(ecdh.dh(&secret_key, &other_public_key).is_err());
185            }
186        }
187    }
188}