use alloc::boxed::Box;
use alloc::vec::Vec;
use super::common::{PcStatus, guard, out_write, slice};
use super::hash::id;
use crate::bignum::Uint;
use crate::der::{pem_decode, pem_encode};
use crate::hash::{Sha224, Sha256, Sha384, Sha512};
use crate::rng::OsRng;
use crate::rsa::{BoxedRsaPrivateKey, RsaPrivateKey};
use crate::x509::AnyPublicKey;
const PEM_LABEL: &str = "RSA PRIVATE KEY";
const E: u64 = 65537;
const ROUNDS: usize = 20;
pub struct PcRsaKey {
key: BoxedRsaPrivateKey,
der: Vec<u8>,
}
impl PcRsaKey {
fn from_pkcs1_der(der: Vec<u8>) -> Option<Self> {
let key = match BoxedRsaPrivateKey::from_pkcs1_der(&der) {
Ok(k) => k,
Err(_) => {
let mut der = der;
wipe_vec(&mut der);
return None;
}
};
Some(PcRsaKey { key, der })
}
}
impl Drop for PcRsaKey {
fn drop(&mut self) {
wipe_vec(&mut self.der);
}
}
pub(super) fn pc_rsa_inner_key(handle: &PcRsaKey) -> &BoxedRsaPrivateKey {
&handle.key
}
#[unsafe(no_mangle)]
pub extern "C" fn pc_rsa_generate(bits: u32) -> *mut PcRsaKey {
crate::ffi::common::guard_ptr(|| {
let der =
match bits {
2048 => RsaPrivateKey::<32>::generate(Uint::from_u64(E), &mut OsRng, ROUNDS)
.to_pkcs1_der(),
3072 => RsaPrivateKey::<48>::generate(Uint::from_u64(E), &mut OsRng, ROUNDS)
.to_pkcs1_der(),
4096 => RsaPrivateKey::<64>::generate(Uint::from_u64(E), &mut OsRng, ROUNDS)
.to_pkcs1_der(),
_ => return core::ptr::null_mut(),
};
match PcRsaKey::from_pkcs1_der(der) {
Some(k) => Box::into_raw(Box::new(k)),
None => core::ptr::null_mut(),
}
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_rsa_from_pem(pem: *const u8, len: usize) -> *mut PcRsaKey {
crate::ffi::common::guard_ptr(|| {
let Some(bytes) = (unsafe { slice(pem, len) }) else {
return core::ptr::null_mut();
};
let Ok(s) = core::str::from_utf8(bytes) else {
return core::ptr::null_mut();
};
let Ok(der) = pem_decode(s, PEM_LABEL) else {
return core::ptr::null_mut();
};
match PcRsaKey::from_pkcs1_der(der) {
Some(k) => Box::into_raw(Box::new(k)),
None => core::ptr::null_mut(),
}
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_rsa_private_to_pem(
key: *const PcRsaKey,
out: *mut u8,
out_len: *mut usize,
) -> PcStatus {
guard(|| {
if key.is_null() {
return PcStatus::NullPointer;
}
let mut pem = pem_encode(PEM_LABEL, &unsafe { &*key }.der).into_bytes();
let st = unsafe { out_write(&pem, out, out_len) };
wipe_vec(&mut pem);
st
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_rsa_public_to_pem(
key: *const PcRsaKey,
out: *mut u8,
out_len: *mut usize,
) -> PcStatus {
guard(|| {
if key.is_null() {
return PcStatus::NullPointer;
}
let pem = AnyPublicKey::Rsa(unsafe { &*key }.key.public_key()).to_spki_pem();
unsafe { out_write(pem.as_bytes(), out, out_len) }
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_rsa_sign_pkcs1(
key: *const PcRsaKey,
alg: i32,
msg: *const u8,
msg_len: usize,
out: *mut u8,
out_len: *mut usize,
) -> PcStatus {
guard(|| {
if key.is_null() {
return PcStatus::NullPointer;
}
let Some(m) = (unsafe { slice(msg, msg_len) }) else {
return PcStatus::NullPointer;
};
let k = &unsafe { &*key }.key;
let sig = match alg {
id::SHA224 => k.sign_pkcs1v15::<Sha224>(m),
id::SHA256 => k.sign_pkcs1v15::<Sha256>(m),
id::SHA384 => k.sign_pkcs1v15::<Sha384>(m),
id::SHA512 => k.sign_pkcs1v15::<Sha512>(m),
_ => return PcStatus::Unsupported,
};
match sig {
Ok(s) => unsafe { out_write(&s, out, out_len) },
Err(_) => PcStatus::Internal,
}
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_rsa_verify_pkcs1(
spki: *const u8,
spki_len: usize,
alg: i32,
msg: *const u8,
msg_len: usize,
sig: *const u8,
sig_len: usize,
) -> PcStatus {
guard(|| {
let (Some(spki), Some(m), Some(sig)) = (
unsafe { slice(spki, spki_len) },
unsafe { slice(msg, msg_len) },
unsafe { slice(sig, sig_len) },
) else {
return PcStatus::NullPointer;
};
let key = match AnyPublicKey::from_spki_der(spki) {
Ok(AnyPublicKey::Rsa(k)) => k,
Ok(_) => return PcStatus::Unsupported,
Err(_) => return PcStatus::BadEncoding,
};
let ok = match alg {
id::SHA224 => key.verify_pkcs1v15::<Sha224>(m, sig),
id::SHA256 => key.verify_pkcs1v15::<Sha256>(m, sig),
id::SHA384 => key.verify_pkcs1v15::<Sha384>(m, sig),
id::SHA512 => key.verify_pkcs1v15::<Sha512>(m, sig),
_ => return PcStatus::Unsupported,
};
if ok.is_ok() {
PcStatus::Ok
} else {
PcStatus::Verification
}
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_rsa_free(key: *mut PcRsaKey) {
if !key.is_null() {
drop(unsafe { Box::from_raw(key) });
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_rsa_sign_pss(
key: *const PcRsaKey,
alg: i32,
msg: *const u8,
msg_len: usize,
out: *mut u8,
out_len: *mut usize,
) -> PcStatus {
guard(|| {
if key.is_null() {
return PcStatus::NullPointer;
}
let Some(m) = (unsafe { slice(msg, msg_len) }) else {
return PcStatus::NullPointer;
};
let k = &unsafe { &*key }.key;
let mut rng = crate::rng::OsRng;
let sig = match alg {
id::SHA256 => k.sign_pss::<Sha256, _>(m, &mut rng),
id::SHA384 => k.sign_pss::<Sha384, _>(m, &mut rng),
id::SHA512 => k.sign_pss::<Sha512, _>(m, &mut rng),
_ => return PcStatus::Unsupported,
};
match sig {
Ok(s) => unsafe { out_write(&s, out, out_len) },
Err(_) => PcStatus::Internal,
}
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_rsa_verify_pss(
spki: *const u8,
spki_len: usize,
alg: i32,
msg: *const u8,
msg_len: usize,
sig: *const u8,
sig_len: usize,
) -> PcStatus {
guard(|| {
let (Some(spki), Some(m), Some(sig)) = (
unsafe { slice(spki, spki_len) },
unsafe { slice(msg, msg_len) },
unsafe { slice(sig, sig_len) },
) else {
return PcStatus::NullPointer;
};
let key = match AnyPublicKey::from_spki_der(spki) {
Ok(AnyPublicKey::Rsa(k)) => k,
Ok(_) => return PcStatus::Unsupported,
Err(_) => return PcStatus::BadEncoding,
};
let ok = match alg {
id::SHA256 => key.verify_pss::<Sha256>(m, sig),
id::SHA384 => key.verify_pss::<Sha384>(m, sig),
id::SHA512 => key.verify_pss::<Sha512>(m, sig),
_ => return PcStatus::Unsupported,
};
if ok.is_ok() {
PcStatus::Ok
} else {
PcStatus::Verification
}
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_rsa_encrypt_oaep(
spki: *const u8,
spki_len: usize,
hash: i32,
label: *const u8,
label_len: usize,
pt: *const u8,
pt_len: usize,
out: *mut u8,
out_len: *mut usize,
) -> PcStatus {
guard(|| {
let (Some(spki), Some(lbl), Some(pt)) = (
unsafe { slice(spki, spki_len) },
unsafe { slice(label, label_len) },
unsafe { slice(pt, pt_len) },
) else {
return PcStatus::NullPointer;
};
let key = match AnyPublicKey::from_spki_der(spki) {
Ok(AnyPublicKey::Rsa(k)) => k,
Ok(_) => return PcStatus::Unsupported,
Err(_) => return PcStatus::BadEncoding,
};
let mut rng = crate::rng::OsRng;
let ct = match hash {
id::SHA256 => key.encrypt_oaep::<Sha256, _>(pt, lbl, &mut rng),
id::SHA384 => key.encrypt_oaep::<Sha384, _>(pt, lbl, &mut rng),
id::SHA512 => key.encrypt_oaep::<Sha512, _>(pt, lbl, &mut rng),
_ => return PcStatus::Unsupported,
};
match ct {
Ok(c) => unsafe { out_write(&c, out, out_len) },
Err(_) => PcStatus::Internal,
}
})
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn pc_rsa_decrypt_oaep(
key: *const PcRsaKey,
hash: i32,
label: *const u8,
label_len: usize,
ct: *const u8,
ct_len: usize,
out: *mut u8,
out_len: *mut usize,
) -> PcStatus {
guard(|| {
if key.is_null() {
return PcStatus::NullPointer;
}
let (Some(lbl), Some(c)) = (unsafe { slice(label, label_len) }, unsafe {
slice(ct, ct_len)
}) else {
return PcStatus::NullPointer;
};
let k = &unsafe { &*key }.key;
let pt = match hash {
id::SHA256 => k.decrypt_oaep::<Sha256>(c, lbl),
id::SHA384 => k.decrypt_oaep::<Sha384>(c, lbl),
id::SHA512 => k.decrypt_oaep::<Sha512>(c, lbl),
_ => return PcStatus::Unsupported,
};
match pt {
Ok(mut p) => {
let st = unsafe { out_write(&p, out, out_len) };
wipe_vec(&mut p);
st
}
Err(_) => PcStatus::Verification,
}
})
}
fn wipe_vec(buf: &mut Vec<u8>) {
for b in buf.iter_mut() {
*b = 0;
}
let _ = core::hint::black_box(&buf);
}