rustls_mbedcrypto_provider/
sign.rs1use alloc::string::String;
2use alloc::vec;
3use alloc::{boxed::Box, sync::Arc, vec::Vec};
4use core::fmt::{self, Debug};
5use mbedtls::pk::{EcGroupId, ECDSA_MAX_LEN};
6use mbedtls::rng::RngCallback;
7use rustls::{pki_types, SignatureScheme};
8use std::sync::Mutex;
9use utils::error::mbedtls_err_into_rustls_err;
10use utils::hash::{buffer_for_hash_type, rustls_signature_scheme_to_mbedtls_hash_type};
11use utils::pk::{get_signature_schema_from_offered, pk_type_to_signature_algo, rustls_signature_scheme_to_mbedtls_pk_options};
12
13struct MbedTlsSigner<T: RngCallback> {
14 pk: Arc<Mutex<mbedtls::pk::Pk>>,
15 signature_scheme: SignatureScheme,
16 rng_provider: fn() -> Option<T>,
17}
18
19impl<T: RngCallback> Debug for MbedTlsSigner<T> {
20 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21 f.debug_tuple("MbedTlsSigner")
22 .field(&"Arc<Mutex<mbedtls::pk::Pk>>")
23 .field(&self.signature_scheme)
24 .finish()
25 }
26}
27
28impl<T: RngCallback> rustls::sign::Signer for MbedTlsSigner<T> {
29 fn sign(&self, message: &[u8]) -> Result<Vec<u8>, rustls::Error> {
30 let hash_type = rustls_signature_scheme_to_mbedtls_hash_type(self.signature_scheme);
31 let mut hash = buffer_for_hash_type(hash_type).ok_or_else(|| rustls::Error::General("unexpected hash type".into()))?;
32 let hash_size = mbedtls::hash::Md::hash(hash_type, message, &mut hash).map_err(mbedtls_err_into_rustls_err)?;
33
34 let mut pk = self
35 .pk
36 .lock()
37 .expect("poisoned PK lock!");
38 if let Some(opts) = rustls_signature_scheme_to_mbedtls_pk_options(self.signature_scheme) {
39 pk.set_options(opts);
40 }
41
42 fn sig_len_for_pk(pk: &mbedtls::pk::Pk) -> usize {
43 match pk.pk_type() {
44 mbedtls::pk::Type::Eckey | mbedtls::pk::Type::EckeyDh | mbedtls::pk::Type::Ecdsa => ECDSA_MAX_LEN,
45 _ => pk.len() / 8,
46 }
47 }
48 let mut sig = vec![0; sig_len_for_pk(&pk)];
49 let sig_len = pk
50 .sign(
51 hash_type,
52 &hash[..hash_size],
53 &mut sig,
54 &mut (self.rng_provider)().ok_or(rustls::Error::FailedToGetRandomBytes)?,
55 )
56 .map_err(mbedtls_err_into_rustls_err)?;
57 sig.truncate(sig_len);
58 Ok(sig)
59 }
60
61 fn scheme(&self) -> SignatureScheme {
62 self.signature_scheme
63 }
64}
65
66struct MbedTlsPkSigningKey {
67 pk: Arc<Mutex<mbedtls::pk::Pk>>,
68 pk_type: mbedtls::pk::Type,
69 signature_algorithm: rustls::SignatureAlgorithm,
70 ec_signature_scheme: Option<SignatureScheme>,
71 rsa_scheme_prefer_order_list: &'static [SignatureScheme],
72}
73
74impl Debug for MbedTlsPkSigningKey {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 f.debug_struct("MbedTlsPkSigningKeyWrapper")
77 .field("pk", &"Arc<Mutex<mbedtls::pk::Pk>>")
78 .field("pk_type", &self.pk_type)
79 .field("signature_algorithm", &self.signature_algorithm)
80 .field("ec_signature_scheme", &self.ec_signature_scheme)
81 .finish()
82 }
83}
84
85pub struct MbedTlsPkSigningKeyWrapper<T: RngCallback> {
89 signing_key: MbedTlsPkSigningKey,
90 rng_provider: fn() -> Option<T>,
91}
92
93impl<T: RngCallback> Debug for MbedTlsPkSigningKeyWrapper<T> {
94 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95 f.debug_struct("MbedTlsPkSigningKeyWrapper")
96 .field("signing_key", &self.signing_key)
97 .finish()
98 }
99}
100
101impl<T: RngCallback> MbedTlsPkSigningKeyWrapper<T> {
102 pub fn new(der: &pki_types::PrivateKeyDer<'_>, rng_provider: fn() -> Option<T>) -> Result<Self, rustls::Error> {
104 let pk = mbedtls::pk::Pk::from_private_key(der.secret_der(), None)
105 .map_err(|err| rustls::Error::Other(rustls::OtherError(Arc::new(err))))?;
106 Self::from_pk(pk, rng_provider)
107 }
108
109 pub fn from_pk(pk: mbedtls::pk::Pk, rng_provider: fn() -> Option<T>) -> Result<Self, rustls::Error> {
111 let pk_type = pk.pk_type();
112 let signature_algorithm = pk_type_to_signature_algo(pk_type).ok_or(rustls::Error::General(String::from(
113 "MbedTlsPkSigningKeyWrapper: invalid pk type",
114 )))?;
115 let ec_signature_scheme = if signature_algorithm == rustls::SignatureAlgorithm::ECDSA {
116 Some(
117 match pk
118 .curve()
119 .map_err(|err| rustls::Error::Other(rustls::OtherError(Arc::new(err))))?
120 {
121 EcGroupId::SecP256R1 => SignatureScheme::ECDSA_NISTP256_SHA256,
122 EcGroupId::SecP384R1 => SignatureScheme::ECDSA_NISTP384_SHA384,
123 EcGroupId::SecP521R1 => SignatureScheme::ECDSA_NISTP521_SHA512,
124 _ => {
125 return Err(rustls::Error::General(String::from(
126 "MbedTlsPkSigningKeyWrapper: unsupported ec curve",
127 )))
128 }
129 },
130 )
131 } else {
132 None
133 };
134 Ok(Self {
135 signing_key: MbedTlsPkSigningKey {
136 pk: Arc::new(Mutex::new(pk)),
137 pk_type,
138 signature_algorithm,
139 ec_signature_scheme,
140 rsa_scheme_prefer_order_list: DEFAULT_RSA_SIGNATURE_SCHEME_PREFER_LIST,
141 },
142 rng_provider,
143 })
144 }
145
146 pub fn set_rsa_signature_scheme_prefer_list(&mut self, prefer_order_list: &'static [SignatureScheme]) {
148 self.signing_key
149 .rsa_scheme_prefer_order_list = prefer_order_list
150 }
151}
152
153pub const DEFAULT_RSA_SIGNATURE_SCHEME_PREFER_LIST: &[SignatureScheme] = &[
155 SignatureScheme::RSA_PSS_SHA512,
156 SignatureScheme::RSA_PSS_SHA384,
157 SignatureScheme::RSA_PSS_SHA256,
158 SignatureScheme::RSA_PKCS1_SHA512,
159 SignatureScheme::RSA_PKCS1_SHA384,
160 SignatureScheme::RSA_PKCS1_SHA256,
161];
162
163impl<T: RngCallback + 'static> rustls::sign::SigningKey for MbedTlsPkSigningKeyWrapper<T> {
164 fn choose_scheme(&self, offered: &[SignatureScheme]) -> Option<Box<dyn rustls::sign::Signer>> {
165 let scheme = get_signature_schema_from_offered(
166 self.signing_key.pk_type,
167 offered,
168 self.signing_key.ec_signature_scheme,
169 self.signing_key
170 .rsa_scheme_prefer_order_list,
171 )?;
172 let signer = MbedTlsSigner {
173 pk: Arc::clone(&self.signing_key.pk),
174 signature_scheme: scheme,
175 rng_provider: self.rng_provider,
176 };
177 Some(Box::new(signer))
178 }
179
180 fn algorithm(&self) -> rustls::SignatureAlgorithm {
181 self.signing_key.signature_algorithm
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use rustls::{sign::SigningKey, SignatureAlgorithm};
189
190 #[test]
191 fn test_signing_key() {
192 let ec_key_pem = include_str!("../../test-ca/ecdsa/end.key");
193 let der: pki_types::PrivateKeyDer<'static> =
194 rustls_pemfile::pkcs8_private_keys(&mut std::io::BufReader::new(ec_key_pem.as_bytes()))
195 .next()
196 .unwrap()
197 .unwrap()
198 .into();
199 let key = MbedTlsPkSigningKeyWrapper::new(&der, crate::rng::rng_new).unwrap();
200 assert_eq!(format!("{key:?}"), "MbedTlsPkSigningKeyWrapper { signing_key: MbedTlsPkSigningKeyWrapper { pk: \"Arc<Mutex<mbedtls::pk::Pk>>\", pk_type: Eckey, signature_algorithm: ECDSA, ec_signature_scheme: Some(ECDSA_NISTP256_SHA256) } }");
201 assert!(key
202 .choose_scheme(&[SignatureScheme::RSA_PKCS1_SHA1])
203 .is_none());
204 let res = key.choose_scheme(&[SignatureScheme::ECDSA_NISTP256_SHA256]);
205 assert!(res.is_some());
206 assert_eq!(
207 format!("{res:?}"),
208 "Some(MbedTlsSigner(\"Arc<Mutex<mbedtls::pk::Pk>>\", ECDSA_NISTP256_SHA256))",
209 );
210 }
211
212 #[test]
213 fn test_pk_type_to_signature_algo() {
214 assert_eq!(
215 pk_type_to_signature_algo(mbedtls::pk::Type::Rsa),
216 Some(SignatureAlgorithm::RSA)
217 );
218 assert_eq!(
219 pk_type_to_signature_algo(mbedtls::pk::Type::Ecdsa),
220 Some(SignatureAlgorithm::ECDSA)
221 );
222 assert_eq!(
223 pk_type_to_signature_algo(mbedtls::pk::Type::RsassaPss),
224 Some(SignatureAlgorithm::RSA)
225 );
226 assert_eq!(
227 pk_type_to_signature_algo(mbedtls::pk::Type::RsaAlt),
228 Some(SignatureAlgorithm::RSA)
229 );
230 assert_eq!(
231 pk_type_to_signature_algo(mbedtls::pk::Type::Eckey),
232 Some(SignatureAlgorithm::ECDSA)
233 );
234 assert_eq!(
235 pk_type_to_signature_algo(mbedtls::pk::Type::EckeyDh),
236 Some(SignatureAlgorithm::ECDSA)
237 );
238 assert_eq!(pk_type_to_signature_algo(mbedtls::pk::Type::Custom), None);
239 assert_eq!(pk_type_to_signature_algo(mbedtls::pk::Type::None), None);
240 }
241}