1use std::{
7 self,
8 fmt::{self, Debug, Formatter},
9 result::Result,
10};
11
12use openssl::{hash, pkey, rsa, sign};
13
14use crate::types::status_code::StatusCode;
15
16#[derive(Copy, Clone, Debug, PartialEq)]
17pub enum RsaPadding {
18 Pkcs1,
19 OaepSha1,
20 OaepSha256,
21 Pkcs1Pss,
22}
23
24impl Into<rsa::Padding> for RsaPadding {
25 fn into(self) -> rsa::Padding {
26 match self {
27 RsaPadding::Pkcs1 => rsa::Padding::PKCS1,
28 RsaPadding::OaepSha1 => rsa::Padding::PKCS1_OAEP,
29 RsaPadding::Pkcs1Pss => rsa::Padding::PKCS1_PSS,
30 RsaPadding::OaepSha256 => rsa::Padding::PKCS1_OAEP,
32 }
33 }
34}
35
36#[derive(Debug)]
37pub struct PKeyError;
38
39impl fmt::Display for PKeyError {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 write!(f, "PKeyError")
42 }
43}
44
45impl std::error::Error for PKeyError {}
46
47pub struct PKey<T> {
50 pub(crate) value: pkey::PKey<T>,
51}
52
53pub type PublicKey = PKey<pkey::Public>;
55pub type PrivateKey = PKey<pkey::Private>;
57
58impl<T> Debug for PKey<T> {
59 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
60 write!(f, "[pkey]")
63 }
64}
65
66pub trait KeySize {
67 fn bit_length(&self) -> usize;
68
69 fn size(&self) -> usize {
70 self.bit_length() / 8
71 }
72
73 fn calculate_cipher_text_size(&self, data_size: usize, padding: RsaPadding) -> usize {
74 let plain_text_block_size = self.plain_text_block_size(padding);
75 let block_count = if data_size % plain_text_block_size == 0 {
76 data_size / plain_text_block_size
77 } else {
78 (data_size / plain_text_block_size) + 1
79 };
80 block_count * self.cipher_text_block_size()
81 }
82
83 fn plain_text_block_size(&self, padding: RsaPadding) -> usize {
84 match padding {
88 RsaPadding::Pkcs1 => self.size() - 11,
89 RsaPadding::OaepSha1 => self.size() - 42,
90 RsaPadding::OaepSha256 => self.size() - 66,
91 _ => panic!("Unsupported padding"),
92 }
93 }
94
95 fn cipher_text_block_size(&self) -> usize {
96 self.size()
97 }
98}
99
100impl KeySize for PrivateKey {
101 fn bit_length(&self) -> usize {
103 self.value.bits() as usize
104 }
105}
106
107impl PrivateKey {
108 pub fn new(bit_length: u32) -> PrivateKey {
109 PKey {
110 value: {
111 let rsa = rsa::Rsa::generate(bit_length).unwrap();
112 pkey::PKey::from_rsa(rsa).unwrap()
113 },
114 }
115 }
116
117 pub fn wrap_private_key(pkey: pkey::PKey<pkey::Private>) -> PrivateKey {
118 PrivateKey { value: pkey }
119 }
120
121 pub fn from_pem(pem: &[u8]) -> Result<PrivateKey, PKeyError> {
122 pkey::PKey::private_key_from_pem(pem)
123 .map(|value| PKey { value })
124 .map_err(|_| {
125 error!("Cannot produce a private key from the data supplied");
126 PKeyError
127 })
128 }
129
130 pub fn private_key_to_pem(&self) -> Result<Vec<u8>, PKeyError> {
131 self.value.private_key_to_pem_pkcs8().map_err(|_| {
132 error!("Cannot turn private key to PEM");
133 PKeyError
134 })
135 }
136
137 fn sign(
139 &self,
140 message_digest: hash::MessageDigest,
141 data: &[u8],
142 signature: &mut [u8],
143 padding: RsaPadding,
144 ) -> Result<usize, StatusCode> {
145 trace!("RSA signing");
146 if let Ok(mut signer) = sign::Signer::new(message_digest, &self.value) {
147 let _ = signer.set_rsa_padding(padding.into());
148 let _ = signer.set_rsa_pss_saltlen(sign::RsaPssSaltlen::DIGEST_LENGTH);
149 if signer.update(data).is_ok() {
150 return signer
151 .sign_to_vec()
152 .map(|result| {
153 trace!(
154 "Signature result, len {} = {:?}, copying to signature len {}",
155 result.len(),
156 result,
157 signature.len()
158 );
159 signature.copy_from_slice(&result);
160 result.len()
161 })
162 .map_err(|err| {
163 debug!("Cannot sign data - error = {:?}", err);
164 StatusCode::BadUnexpectedError
165 });
166 }
167 }
168 Err(StatusCode::BadUnexpectedError)
169 }
170
171 pub fn sign_sha1(&self, data: &[u8], signature: &mut [u8]) -> Result<usize, StatusCode> {
173 self.sign(
174 hash::MessageDigest::sha1(),
175 data,
176 signature,
177 RsaPadding::Pkcs1,
178 )
179 }
180
181 pub fn sign_sha256(&self, data: &[u8], signature: &mut [u8]) -> Result<usize, StatusCode> {
183 self.sign(
184 hash::MessageDigest::sha256(),
185 data,
186 signature,
187 RsaPadding::Pkcs1,
188 )
189 }
190
191 pub fn sign_sha256_pss(&self, data: &[u8], signature: &mut [u8]) -> Result<usize, StatusCode> {
193 self.sign(
194 hash::MessageDigest::sha256(),
195 data,
196 signature,
197 RsaPadding::Pkcs1Pss,
198 )
199 }
200
201 pub fn private_decrypt(
204 &self,
205 src: &[u8],
206 dst: &mut [u8],
207 padding: RsaPadding,
208 ) -> Result<usize, PKeyError> {
209 let cipher_text_block_size = self.cipher_text_block_size();
211 let rsa = self.value.rsa().unwrap();
212 let is_oaep_sha256 = padding == RsaPadding::OaepSha256;
213 let rsa_padding: rsa::Padding = padding.into();
214
215 let mut src_idx = 0;
217 let mut dst_idx = 0;
218
219 let src_len = src.len();
220 while src_idx < src_len {
221 dst_idx += {
223 let src = &src[src_idx..(src_idx + cipher_text_block_size)];
224 let dst = &mut dst[dst_idx..(dst_idx + cipher_text_block_size)];
225
226 if is_oaep_sha256 {
227 oaep_sha256::decrypt(&rsa, src, dst)
228 } else {
229 rsa.private_decrypt(src, dst, rsa_padding)
230 }.map_err(|err| {
231 error!("Decryption failed for key size {}, src idx {}, dst idx {}, padding {:?}, error - {:?}", cipher_text_block_size, src_idx, dst_idx, padding, err);
232 PKeyError
233 })?
234 };
235 src_idx += cipher_text_block_size;
236 }
237 Ok(dst_idx)
238 }
239}
240
241impl KeySize for PublicKey {
242 fn bit_length(&self) -> usize {
244 self.value.bits() as usize
245 }
246}
247
248impl PublicKey {
249 pub fn wrap_public_key(pkey: pkey::PKey<pkey::Public>) -> PublicKey {
250 PublicKey { value: pkey }
251 }
252
253 fn verify(
255 &self,
256 message_digest: hash::MessageDigest,
257 data: &[u8],
258 signature: &[u8],
259 padding: RsaPadding,
260 ) -> Result<bool, StatusCode> {
261 trace!(
262 "RSA verifying, against signature {:?}, len {}",
263 signature,
264 signature.len()
265 );
266 if let Ok(mut verifier) = sign::Verifier::new(message_digest, &self.value) {
267 let _ = verifier.set_rsa_padding(padding.into());
268 let _ = verifier.set_rsa_pss_saltlen(sign::RsaPssSaltlen::DIGEST_LENGTH);
269 if verifier.update(data).is_ok() {
270 return verifier
271 .verify(signature)
272 .map(|result| {
273 trace!("Key verified = {:?}", result);
274 result
275 })
276 .map_err(|err| {
277 debug!("Cannot verify key - error = {:?}", err);
278 StatusCode::BadUnexpectedError
279 });
280 }
281 }
282 Err(StatusCode::BadUnexpectedError)
283 }
284
285 pub fn verify_sha1(&self, data: &[u8], signature: &[u8]) -> Result<bool, StatusCode> {
287 self.verify(
288 hash::MessageDigest::sha1(),
289 data,
290 signature,
291 RsaPadding::Pkcs1,
292 )
293 }
294
295 pub fn verify_sha256(&self, data: &[u8], signature: &[u8]) -> Result<bool, StatusCode> {
297 self.verify(
298 hash::MessageDigest::sha256(),
299 data,
300 signature,
301 RsaPadding::Pkcs1,
302 )
303 }
304
305 pub fn verify_sha256_pss(&self, data: &[u8], signature: &[u8]) -> Result<bool, StatusCode> {
307 self.verify(
308 hash::MessageDigest::sha256(),
309 data,
310 signature,
311 RsaPadding::Pkcs1Pss,
312 )
313 }
314
315 pub fn public_encrypt(
318 &self,
319 src: &[u8],
320 dst: &mut [u8],
321 padding: RsaPadding,
322 ) -> Result<usize, PKeyError> {
323 let cipher_text_block_size = self.cipher_text_block_size();
324 let plain_text_block_size = self.plain_text_block_size(padding);
325
326 let rsa = self.value.rsa().unwrap();
330 let is_oaep_sha256 = padding == RsaPadding::OaepSha256;
331 let padding: rsa::Padding = padding.into();
332
333 let mut src_idx = 0;
335 let mut dst_idx = 0;
336
337 let src_len = src.len();
338 while src_idx < src_len {
339 let bytes_to_encrypt = if src_len < plain_text_block_size {
340 src_len
341 } else if (src_len - src_idx) < plain_text_block_size {
342 src_len - src_idx
343 } else {
344 plain_text_block_size
345 };
346
347 dst_idx += {
349 let src = &src[src_idx..(src_idx + bytes_to_encrypt)];
350 let dst = &mut dst[dst_idx..(dst_idx + cipher_text_block_size)];
351
352 if is_oaep_sha256 {
353 oaep_sha256::encrypt(&rsa, src, dst)
354 } else {
355 rsa.public_encrypt(src, dst, padding)
356 }.map_err(|err| {
357 error!("Encryption failed for bytes_to_encrypt {}, src len {}, src_idx {}, dst len {}, dst_idx {}, cipher_text_block_size {}, plain_text_block_size {}, error - {:?}",
358 bytes_to_encrypt, src.len(), src_idx, dst.len(), dst_idx, cipher_text_block_size, plain_text_block_size, err);
359 PKeyError
360 })?
361 };
362
363 src_idx += bytes_to_encrypt;
365 }
366
367 Ok(dst_idx)
368 }
369}
370
371mod oaep_sha256 {
375 use foreign_types::ForeignType;
376 use libc::*;
377 use openssl::{
378 error,
379 pkey::{Private, Public},
380 rsa::{self, Rsa},
381 };
382 use openssl_sys::{
383 ERR_get_error, EVP_PKEY_CTX_ctrl, EVP_PKEY_CTX_free, EVP_PKEY_CTX_new,
384 EVP_PKEY_CTX_set_rsa_mgf1_md, EVP_PKEY_CTX_set_rsa_padding, EVP_PKEY_decrypt,
385 EVP_PKEY_decrypt_init, EVP_PKEY_encrypt, EVP_PKEY_encrypt_init, EVP_PKEY_free,
386 EVP_PKEY_new, EVP_PKEY_set1_RSA, EVP_sha256, EVP_MD, EVP_PKEY_ALG_CTRL, EVP_PKEY_CTX,
387 EVP_PKEY_OP_TYPE_CRYPT, EVP_PKEY_RSA,
388 };
389 use std::ptr;
390
391 unsafe fn set_evp_ctrl_oaep_sha256(ctx: *mut EVP_PKEY_CTX) {
393 EVP_PKEY_CTX_set_rsa_padding(ctx, rsa::Padding::PKCS1_OAEP.as_raw());
394 let md = EVP_sha256() as *mut EVP_MD;
395 EVP_PKEY_CTX_set_rsa_mgf1_md(ctx, md);
396 const EVP_PKEY_CTRL_RSA_OAEP_MD: c_int = EVP_PKEY_ALG_CTRL + 9;
398 EVP_PKEY_CTX_ctrl(
399 ctx,
400 EVP_PKEY_RSA,
401 EVP_PKEY_OP_TYPE_CRYPT,
402 EVP_PKEY_CTRL_RSA_OAEP_MD,
403 0,
404 md as *mut c_void,
405 );
406 }
407
408 pub fn decrypt(
410 pkey: &Rsa<Private>,
411 from: &[u8],
412 to: &mut [u8],
413 ) -> Result<usize, error::ErrorStack> {
414 let result;
415 unsafe {
416 let priv_key = EVP_PKEY_new();
417 if !priv_key.is_null() {
418 EVP_PKEY_set1_RSA(priv_key, pkey.as_ptr());
419 let ctx = EVP_PKEY_CTX_new(priv_key, ptr::null_mut());
420 EVP_PKEY_free(priv_key);
421
422 if !ctx.is_null() {
423 let _ret = EVP_PKEY_decrypt_init(ctx);
424 set_evp_ctrl_oaep_sha256(ctx);
425
426 let mut out_len: size_t = to.len();
427 let ret = EVP_PKEY_decrypt(
428 ctx,
429 to.as_mut_ptr(),
430 &mut out_len,
431 from.as_ptr(),
432 from.len(),
433 );
434 if ret > 0 && out_len > 0 {
435 result = Ok(out_len as usize);
436 } else {
437 trace!(
438 "oaep_sha256::decrypt EVP_PKEY_decrypt, ret = {}, out_len = {}",
439 ret,
440 out_len
441 );
442 result = Err(error::ErrorStack::get());
443 }
444 EVP_PKEY_CTX_free(ctx);
445 } else {
446 trace!("oaep_sha256::decrypt EVP_PKEY_CTX_new");
447 result = Err(error::ErrorStack::get());
448 }
449 } else {
450 trace!(
451 "oaep_sha256::decrypt EVP_PKEY_new failed, err {}",
452 ERR_get_error()
453 );
454 result = Err(error::ErrorStack::get());
455 }
456 }
457
458 result
459 }
460
461 pub fn encrypt(
463 pkey: &Rsa<Public>,
464 from: &[u8],
465 to: &mut [u8],
466 ) -> Result<usize, error::ErrorStack> {
467 let result;
468 unsafe {
469 let pub_key = EVP_PKEY_new();
470 if !pub_key.is_null() {
471 EVP_PKEY_set1_RSA(pub_key, pkey.as_ptr());
472 let ctx = EVP_PKEY_CTX_new(pub_key, ptr::null_mut());
473 EVP_PKEY_free(pub_key);
474
475 if !ctx.is_null() {
476 let _ret = EVP_PKEY_encrypt_init(ctx);
477 set_evp_ctrl_oaep_sha256(ctx);
478
479 let mut out_len: size_t = to.len();
480 let ret = EVP_PKEY_encrypt(
481 ctx,
482 to.as_mut_ptr(),
483 &mut out_len,
484 from.as_ptr(),
485 from.len(),
486 );
487 if ret > 0 && out_len > 0 {
488 result = Ok(out_len as usize);
489 } else {
490 trace!(
491 "oaep_sha256::encrypt EVP_PKEY_encrypt, ret = {}, out_len = {}",
492 ret,
493 out_len
494 );
495 result = Err(error::ErrorStack::get());
496 }
497 EVP_PKEY_CTX_free(ctx);
498 } else {
499 trace!("oaep_sha256::encrypt EVP_PKEY_CTX_new");
500 result = Err(error::ErrorStack::get());
501 }
502 } else {
503 trace!(
504 "oaep_sha256::encrypt EVP_PKEY_new failed, err {}",
505 ERR_get_error()
506 );
507 result = Err(error::ErrorStack::get());
508 }
509 }
510 result
511 }
512}