rustls_mbedcrypto_provider/
sign.rs

1use alloc::string::String;
2use alloc::vec;
3use alloc::{boxed::Box, sync::Arc, vec::Vec};
4use core::fmt::{self, Debug};
5use mbedtls::pk::{EcGroupId, ECDSA_MAX_LEN};
6use mbedtls::rng::RngCallback;
7use rustls::{pki_types, SignatureScheme};
8use std::sync::Mutex;
9use utils::error::mbedtls_err_into_rustls_err;
10use utils::hash::{buffer_for_hash_type, rustls_signature_scheme_to_mbedtls_hash_type};
11use utils::pk::{get_signature_schema_from_offered, pk_type_to_signature_algo, rustls_signature_scheme_to_mbedtls_pk_options};
12
13struct MbedTlsSigner<T: RngCallback> {
14    pk: Arc<Mutex<mbedtls::pk::Pk>>,
15    signature_scheme: SignatureScheme,
16    rng_provider: fn() -> Option<T>,
17}
18
19impl<T: RngCallback> Debug for MbedTlsSigner<T> {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        f.debug_tuple("MbedTlsSigner")
22            .field(&"Arc<Mutex<mbedtls::pk::Pk>>")
23            .field(&self.signature_scheme)
24            .finish()
25    }
26}
27
28impl<T: RngCallback> rustls::sign::Signer for MbedTlsSigner<T> {
29    fn sign(&self, message: &[u8]) -> Result<Vec<u8>, rustls::Error> {
30        let hash_type = rustls_signature_scheme_to_mbedtls_hash_type(self.signature_scheme);
31        let mut hash = buffer_for_hash_type(hash_type).ok_or_else(|| rustls::Error::General("unexpected hash type".into()))?;
32        let hash_size = mbedtls::hash::Md::hash(hash_type, message, &mut hash).map_err(mbedtls_err_into_rustls_err)?;
33
34        let mut pk = self
35            .pk
36            .lock()
37            .expect("poisoned PK lock!");
38        if let Some(opts) = rustls_signature_scheme_to_mbedtls_pk_options(self.signature_scheme) {
39            pk.set_options(opts);
40        }
41
42        fn sig_len_for_pk(pk: &mbedtls::pk::Pk) -> usize {
43            match pk.pk_type() {
44                mbedtls::pk::Type::Eckey | mbedtls::pk::Type::EckeyDh | mbedtls::pk::Type::Ecdsa => ECDSA_MAX_LEN,
45                _ => pk.len() / 8,
46            }
47        }
48        let mut sig = vec![0; sig_len_for_pk(&pk)];
49        let sig_len = pk
50            .sign(
51                hash_type,
52                &hash[..hash_size],
53                &mut sig,
54                &mut (self.rng_provider)().ok_or(rustls::Error::FailedToGetRandomBytes)?,
55            )
56            .map_err(mbedtls_err_into_rustls_err)?;
57        sig.truncate(sig_len);
58        Ok(sig)
59    }
60
61    fn scheme(&self) -> SignatureScheme {
62        self.signature_scheme
63    }
64}
65
66struct MbedTlsPkSigningKey {
67    pk: Arc<Mutex<mbedtls::pk::Pk>>,
68    pk_type: mbedtls::pk::Type,
69    signature_algorithm: rustls::SignatureAlgorithm,
70    ec_signature_scheme: Option<SignatureScheme>,
71    rsa_scheme_prefer_order_list: &'static [SignatureScheme],
72}
73
74impl Debug for MbedTlsPkSigningKey {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        f.debug_struct("MbedTlsPkSigningKeyWrapper")
77            .field("pk", &"Arc<Mutex<mbedtls::pk::Pk>>")
78            .field("pk_type", &self.pk_type)
79            .field("signature_algorithm", &self.signature_algorithm)
80            .field("ec_signature_scheme", &self.ec_signature_scheme)
81            .finish()
82    }
83}
84
85/// A type implements [`SigningKey`] by using [`mbedtls`].
86///
87/// [`SigningKey`]: rustls::sign::SigningKey
88pub struct MbedTlsPkSigningKeyWrapper<T: RngCallback> {
89    signing_key: MbedTlsPkSigningKey,
90    rng_provider: fn() -> Option<T>,
91}
92
93impl<T: RngCallback> Debug for MbedTlsPkSigningKeyWrapper<T> {
94    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        f.debug_struct("MbedTlsPkSigningKeyWrapper")
96            .field("signing_key", &self.signing_key)
97            .finish()
98    }
99}
100
101impl<T: RngCallback> MbedTlsPkSigningKeyWrapper<T> {
102    /// Make a new [`MbedTlsPkSigningKeyWrapper`] from a DER encoding.
103    pub fn new(der: &pki_types::PrivateKeyDer<'_>, rng_provider: fn() -> Option<T>) -> Result<Self, rustls::Error> {
104        let pk = mbedtls::pk::Pk::from_private_key(der.secret_der(), None)
105            .map_err(|err| rustls::Error::Other(rustls::OtherError(Arc::new(err))))?;
106        Self::from_pk(pk, rng_provider)
107    }
108
109    /// Make a new [`MbedTlsPkSigningKeyWrapper`] from a [`mbedtls::pk::Pk`].
110    pub fn from_pk(pk: mbedtls::pk::Pk, rng_provider: fn() -> Option<T>) -> Result<Self, rustls::Error> {
111        let pk_type = pk.pk_type();
112        let signature_algorithm = pk_type_to_signature_algo(pk_type).ok_or(rustls::Error::General(String::from(
113            "MbedTlsPkSigningKeyWrapper: invalid pk type",
114        )))?;
115        let ec_signature_scheme = if signature_algorithm == rustls::SignatureAlgorithm::ECDSA {
116            Some(
117                match pk
118                    .curve()
119                    .map_err(|err| rustls::Error::Other(rustls::OtherError(Arc::new(err))))?
120                {
121                    EcGroupId::SecP256R1 => SignatureScheme::ECDSA_NISTP256_SHA256,
122                    EcGroupId::SecP384R1 => SignatureScheme::ECDSA_NISTP384_SHA384,
123                    EcGroupId::SecP521R1 => SignatureScheme::ECDSA_NISTP521_SHA512,
124                    _ => {
125                        return Err(rustls::Error::General(String::from(
126                            "MbedTlsPkSigningKeyWrapper: unsupported ec curve",
127                        )))
128                    }
129                },
130            )
131        } else {
132            None
133        };
134        Ok(Self {
135            signing_key: MbedTlsPkSigningKey {
136                pk: Arc::new(Mutex::new(pk)),
137                pk_type,
138                signature_algorithm,
139                ec_signature_scheme,
140                rsa_scheme_prefer_order_list: DEFAULT_RSA_SIGNATURE_SCHEME_PREFER_LIST,
141            },
142            rng_provider,
143        })
144    }
145
146    /// Change the rsa signature scheme prefer list
147    pub fn set_rsa_signature_scheme_prefer_list(&mut self, prefer_order_list: &'static [SignatureScheme]) {
148        self.signing_key
149            .rsa_scheme_prefer_order_list = prefer_order_list
150    }
151}
152
153/// An ordered list of RSA [`SignatureScheme`] used for choosing scheme in [`rustls::sign::SigningKey`]
154pub const DEFAULT_RSA_SIGNATURE_SCHEME_PREFER_LIST: &[SignatureScheme] = &[
155    SignatureScheme::RSA_PSS_SHA512,
156    SignatureScheme::RSA_PSS_SHA384,
157    SignatureScheme::RSA_PSS_SHA256,
158    SignatureScheme::RSA_PKCS1_SHA512,
159    SignatureScheme::RSA_PKCS1_SHA384,
160    SignatureScheme::RSA_PKCS1_SHA256,
161];
162
163impl<T: RngCallback + 'static> rustls::sign::SigningKey for MbedTlsPkSigningKeyWrapper<T> {
164    fn choose_scheme(&self, offered: &[SignatureScheme]) -> Option<Box<dyn rustls::sign::Signer>> {
165        let scheme = get_signature_schema_from_offered(
166            self.signing_key.pk_type,
167            offered,
168            self.signing_key.ec_signature_scheme,
169            self.signing_key
170                .rsa_scheme_prefer_order_list,
171        )?;
172        let signer = MbedTlsSigner {
173            pk: Arc::clone(&self.signing_key.pk),
174            signature_scheme: scheme,
175            rng_provider: self.rng_provider,
176        };
177        Some(Box::new(signer))
178    }
179
180    fn algorithm(&self) -> rustls::SignatureAlgorithm {
181        self.signing_key.signature_algorithm
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use rustls::{sign::SigningKey, SignatureAlgorithm};
189
190    #[test]
191    fn test_signing_key() {
192        let ec_key_pem = include_str!("../../test-ca/ecdsa/end.key");
193        let der: pki_types::PrivateKeyDer<'static> =
194            rustls_pemfile::pkcs8_private_keys(&mut std::io::BufReader::new(ec_key_pem.as_bytes()))
195                .next()
196                .unwrap()
197                .unwrap()
198                .into();
199        let key = MbedTlsPkSigningKeyWrapper::new(&der, crate::rng::rng_new).unwrap();
200        assert_eq!(format!("{key:?}"), "MbedTlsPkSigningKeyWrapper { signing_key: MbedTlsPkSigningKeyWrapper { pk: \"Arc<Mutex<mbedtls::pk::Pk>>\", pk_type: Eckey, signature_algorithm: ECDSA, ec_signature_scheme: Some(ECDSA_NISTP256_SHA256) } }");
201        assert!(key
202            .choose_scheme(&[SignatureScheme::RSA_PKCS1_SHA1])
203            .is_none());
204        let res = key.choose_scheme(&[SignatureScheme::ECDSA_NISTP256_SHA256]);
205        assert!(res.is_some());
206        assert_eq!(
207            format!("{res:?}"),
208            "Some(MbedTlsSigner(\"Arc<Mutex<mbedtls::pk::Pk>>\", ECDSA_NISTP256_SHA256))",
209        );
210    }
211
212    #[test]
213    fn test_pk_type_to_signature_algo() {
214        assert_eq!(
215            pk_type_to_signature_algo(mbedtls::pk::Type::Rsa),
216            Some(SignatureAlgorithm::RSA)
217        );
218        assert_eq!(
219            pk_type_to_signature_algo(mbedtls::pk::Type::Ecdsa),
220            Some(SignatureAlgorithm::ECDSA)
221        );
222        assert_eq!(
223            pk_type_to_signature_algo(mbedtls::pk::Type::RsassaPss),
224            Some(SignatureAlgorithm::RSA)
225        );
226        assert_eq!(
227            pk_type_to_signature_algo(mbedtls::pk::Type::RsaAlt),
228            Some(SignatureAlgorithm::RSA)
229        );
230        assert_eq!(
231            pk_type_to_signature_algo(mbedtls::pk::Type::Eckey),
232            Some(SignatureAlgorithm::ECDSA)
233        );
234        assert_eq!(
235            pk_type_to_signature_algo(mbedtls::pk::Type::EckeyDh),
236            Some(SignatureAlgorithm::ECDSA)
237        );
238        assert_eq!(pk_type_to_signature_algo(mbedtls::pk::Type::Custom), None);
239        assert_eq!(pk_type_to_signature_algo(mbedtls::pk::Type::None), None);
240    }
241}