use crate::ecdh::rsa_encrypt;
use crate::errors::{BottleError, Result};
use rand::{CryptoRng, RngCore};
use rsa::RsaPublicKey;
use zeroize::Zeroize;
pub fn mem_clr(data: &mut [u8]) {
data.zeroize();
}
pub fn encrypt_short_buffer<R: RngCore + CryptoRng>(
rng: &mut R,
plaintext: &[u8],
public_key: &[u8],
) -> Result<Vec<u8>> {
if !public_key.is_empty() && public_key[0] == 0x30 {
if let Ok(rsa_pub_key) = parse_rsa_public_key_from_pkix(public_key) {
return rsa_encrypt(rng, plaintext, &rsa_pub_key);
}
}
Err(BottleError::UnsupportedAlgorithm)
}
fn parse_rsa_public_key_from_pkix(der_bytes: &[u8]) -> Result<RsaPublicKey> {
use const_oid::db::rfc5912;
use der::asn1::AnyRef;
use der::asn1::BitString;
use der::Decode;
use spki::SubjectPublicKeyInfo;
let spki: SubjectPublicKeyInfo<AnyRef, BitString> =
SubjectPublicKeyInfo::from_der(der_bytes).map_err(|_| BottleError::InvalidKeyType)?;
if spki.algorithm.oid != rfc5912::RSA_ENCRYPTION {
return Err(BottleError::InvalidKeyType);
}
let rsa_key_bytes = spki.subject_public_key.raw_bytes();
parse_rsa_public_key_pkcs1(rsa_key_bytes)
}
fn parse_rsa_public_key_pkcs1(der_bytes: &[u8]) -> Result<RsaPublicKey> {
use der::Decode;
use rsa::BigUint;
if der_bytes.is_empty() || der_bytes[0] != 0x30 {
return Err(BottleError::InvalidKeyType);
}
let mut pos = 1;
if pos >= der_bytes.len() {
return Err(BottleError::InvalidKeyType);
}
let seq_len = if (der_bytes[pos] & 0x80) == 0 {
let len = der_bytes[pos] as usize;
pos += 1;
len
} else {
let len_bytes = (der_bytes[pos] & 0x7f) as usize;
if len_bytes == 0 || len_bytes > 4 || pos + len_bytes >= der_bytes.len() {
return Err(BottleError::InvalidKeyType);
}
pos += 1;
let mut len = 0usize;
for i in 0..len_bytes {
len = (len << 8) | (der_bytes[pos + i] as usize);
}
pos += len_bytes;
len
};
if pos + seq_len > der_bytes.len() {
return Err(BottleError::InvalidKeyType);
}
let seq_content = &der_bytes[pos..pos + seq_len];
let n_uint = der::asn1::Uint::from_der(seq_content).map_err(|_| BottleError::InvalidKeyType)?;
let n_len = if seq_content.is_empty() || seq_content[0] != 0x02 {
return Err(BottleError::InvalidKeyType);
} else {
let mut n_pos = 1;
if n_pos >= seq_content.len() {
return Err(BottleError::InvalidKeyType);
}
let n_val_len = if (seq_content[n_pos] & 0x80) == 0 {
let len = seq_content[n_pos] as usize;
n_pos += 1;
len
} else {
let len_bytes = (seq_content[n_pos] & 0x7f) as usize;
if len_bytes == 0 || len_bytes > 4 || n_pos + len_bytes >= seq_content.len() {
return Err(BottleError::InvalidKeyType);
}
n_pos += 1;
let mut len = 0usize;
for i in 0..len_bytes {
len = (len << 8) | (seq_content[n_pos + i] as usize);
}
n_pos += len_bytes;
len
};
n_pos + n_val_len
};
if n_len >= seq_content.len() {
return Err(BottleError::InvalidKeyType);
}
let e_uint = der::asn1::Uint::from_der(&seq_content[n_len..])
.map_err(|_| BottleError::InvalidKeyType)?;
let n = BigUint::from_bytes_be(n_uint.as_bytes());
let e = BigUint::from_bytes_be(e_uint.as_bytes());
RsaPublicKey::new(n, e).map_err(|_| BottleError::InvalidKeyType)
}
pub fn decrypt_short_buffer(_ciphertext: &[u8], _private_key: &[u8]) -> Result<Vec<u8>> {
Err(BottleError::UnsupportedAlgorithm)
}