use std::ffi::{c_char, c_uchar};
use std::ptr;
use foreign_types::{ForeignType, ForeignTypeRef};
use openssl::{
error::ErrorStack,
pkey::{PKey, PKeyRef, Public},
pkey_ctx::{PkeyCtx, PkeyCtxRef},
};
use openssl_sys::{EVP_PKEY, EVP_PKEY_CTX, EVP_PKEY_new, OSSL_LIB_CTX, OSSL_PARAM, c_int};
use super::{cvt, cvt_p};
pub(crate) trait PkeyCtxRefKemExt {
fn encapsulate_init(&self) -> Result<(), ErrorStack>;
fn encapsulate_to_vec(&mut self) -> Result<(Vec<u8>, Vec<u8>), ErrorStack>;
fn decapsulate_init(&self) -> Result<(), ErrorStack>;
fn decapsulate_to_vec(&self, enc: &[u8]) -> Result<Vec<u8>, ErrorStack>;
}
pub(crate) trait PkeyCtxExt: Sized {
fn new_from_name(name: &'static [u8]) -> Result<Self, ErrorStack>;
}
pub(crate) trait PkeyExt: Sized {
fn from_encoded_public_key(
encoded_public_key: &[u8],
algorithm_name: &'static [u8],
) -> Result<Self, ErrorStack>;
}
pub(crate) trait PKeyRefExt {
fn get_octet_string_param(&self, key_name: &[u8]) -> Result<Vec<u8>, ErrorStack>;
}
impl<T> PkeyCtxRefKemExt for PkeyCtxRef<T> {
fn encapsulate_init(&self) -> Result<(), ErrorStack> {
unsafe {
cvt(EVP_PKEY_encapsulate_init(self.as_ptr(), ptr::null()))?;
}
Ok(())
}
fn encapsulate_to_vec(&mut self) -> Result<(Vec<u8>, Vec<u8>), ErrorStack> {
let mut out_len = 0;
let mut secret_len = 0;
unsafe {
cvt(EVP_PKEY_encapsulate(
self.as_ptr(),
ptr::null_mut(),
&mut out_len,
ptr::null_mut(),
&mut secret_len,
))?;
}
let mut out = vec![0; out_len];
let mut secret = vec![0; secret_len];
unsafe {
cvt(EVP_PKEY_encapsulate(
self.as_ptr(),
out.as_mut_ptr().cast(),
&mut out_len,
secret.as_mut_ptr().cast(),
&mut secret_len,
))?;
}
Ok((out, secret))
}
fn decapsulate_init(&self) -> Result<(), ErrorStack> {
unsafe {
cvt(EVP_PKEY_decapsulate_init(self.as_ptr(), ptr::null()))?;
}
Ok(())
}
fn decapsulate_to_vec(&self, enc: &[u8]) -> Result<Vec<u8>, ErrorStack> {
let mut unwrapped_len = 0;
unsafe {
cvt(EVP_PKEY_decapsulate(
self.as_ptr(),
ptr::null_mut(),
&mut unwrapped_len,
enc.as_ptr().cast(),
enc.len(),
))?;
}
let mut unwrapped = vec![0; unwrapped_len];
unsafe {
cvt(EVP_PKEY_decapsulate(
self.as_ptr(),
unwrapped.as_mut_ptr().cast(),
&mut unwrapped_len,
enc.as_ptr().cast(),
enc.len(),
))?;
}
Ok(unwrapped)
}
}
impl<T> PkeyCtxExt for PkeyCtx<T> {
fn new_from_name(name: &'static [u8]) -> Result<Self, ErrorStack> {
openssl_sys::init();
unsafe {
let ptr = cvt_p(EVP_PKEY_CTX_new_from_name(
ptr::null_mut(),
name.as_ptr().cast(),
ptr::null(),
))?;
Ok(PkeyCtx::from_ptr(ptr))
}
}
}
impl PkeyExt for PKey<Public> {
fn from_encoded_public_key(
encoded_public_key: &[u8],
algorithm_name: &'static [u8],
) -> Result<Self, ErrorStack> {
let ctx = PkeyCtx::<()>::new_from_name(algorithm_name)?;
unsafe {
let mut evp = cvt_p(EVP_PKEY_new())?;
cvt(EVP_PKEY_paramgen_init(ctx.as_ptr()))?;
cvt(EVP_PKEY_paramgen(ctx.as_ptr(), &mut evp))?;
cvt(EVP_PKEY_set1_encoded_public_key(
evp,
encoded_public_key.as_ptr(),
encoded_public_key.len(),
))?;
Ok(PKey::from_ptr(evp))
}
}
}
impl<T> PKeyRefExt for PKeyRef<T> {
fn get_octet_string_param(&self, key_name: &[u8]) -> Result<Vec<u8>, ErrorStack> {
let mut out_len = 0;
unsafe {
cvt(EVP_PKEY_get_octet_string_param(
self.as_ptr(),
key_name.as_ptr().cast(),
ptr::null_mut(),
0,
&mut out_len,
))
.unwrap();
}
let mut out = vec![0; out_len];
unsafe {
cvt(EVP_PKEY_get_octet_string_param(
self.as_ptr(),
key_name.as_ptr().cast(),
out.as_mut_ptr(),
out_len,
&mut out_len,
))?;
}
Ok(out)
}
}
unsafe extern "C" {
pub fn EVP_PKEY_encapsulate_init(ctx: *mut EVP_PKEY_CTX, params: *const OSSL_PARAM) -> c_int;
}
unsafe extern "C" {
pub fn EVP_PKEY_encapsulate(
ctx: *mut EVP_PKEY_CTX,
wrappedkey: *mut c_uchar,
wrappedkeylen: *mut usize,
genkey: *mut c_uchar,
genkeylen: *mut usize,
) -> c_int;
}
unsafe extern "C" {
pub unsafe fn EVP_PKEY_decapsulate_init(
ctx: *mut EVP_PKEY_CTX,
params: *const OSSL_PARAM,
) -> c_int;
}
unsafe extern "C" {
pub unsafe fn EVP_PKEY_decapsulate(
ctx: *mut EVP_PKEY_CTX,
unwrapped: *mut c_uchar,
unwrappedlen: *mut usize,
wrapped: *const c_uchar,
wrappedlen: usize,
) -> c_int;
}
unsafe extern "C" {
pub unsafe fn EVP_PKEY_CTX_new_from_name(
libctx: *mut OSSL_LIB_CTX,
name: *const c_char,
propquery: *const c_char,
) -> *mut EVP_PKEY_CTX;
}
unsafe extern "C" {
pub unsafe fn EVP_PKEY_get_octet_string_param(
pkey: *const EVP_PKEY,
key_name: *const c_char,
buf: *mut c_uchar,
max_buf_sz: usize,
out_sz: *mut usize,
) -> c_int;
}
unsafe extern "C" {
pub unsafe fn EVP_PKEY_set1_encoded_public_key(
pkey: *mut EVP_PKEY,
pub_: *const c_uchar,
publen: usize,
) -> c_int;
}
unsafe extern "C" {
pub unsafe fn EVP_PKEY_paramgen_init(ctx: *mut EVP_PKEY_CTX) -> c_int;
}
unsafe extern "C" {
pub unsafe fn EVP_PKEY_paramgen(ctx: *mut EVP_PKEY_CTX, ppkey: *mut *mut EVP_PKEY) -> c_int;
}