use crate::error::ErrorStack;
use native_ossl_sys as sys;
use std::ffi::CStr;
use std::marker::PhantomData;
use std::ptr;
pub struct Params<'a> {
ptr: *mut sys::OSSL_PARAM,
_lifetime: PhantomData<&'a [u8]>,
}
impl Params<'_> {
#[must_use]
pub fn as_ptr(&self) -> *const sys::OSSL_PARAM {
self.ptr
}
}
impl Drop for Params<'_> {
fn drop(&mut self) {
unsafe { sys::OSSL_PARAM_free(self.ptr) };
}
}
unsafe impl Send for Params<'_> {}
unsafe impl Sync for Params<'_> {}
pub struct ParamBuilder<'a> {
ptr: *mut sys::OSSL_PARAM_BLD,
bns: Vec<*mut sys::BIGNUM>,
_lifetime: PhantomData<&'a [u8]>,
}
unsafe impl Send for ParamBuilder<'_> {}
impl<'a> ParamBuilder<'a> {
pub fn new() -> Result<Self, ErrorStack> {
let ptr = unsafe { sys::OSSL_PARAM_BLD_new() };
if ptr.is_null() {
return Err(ErrorStack::drain());
}
Ok(ParamBuilder {
ptr,
bns: Vec::new(),
_lifetime: PhantomData,
})
}
pub fn push_int(self, key: &CStr, val: i32) -> Result<Self, ErrorStack> {
let rc = unsafe { sys::OSSL_PARAM_BLD_push_int(self.ptr, key.as_ptr(), val) };
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(self)
}
pub fn push_uint(self, key: &CStr, val: u32) -> Result<Self, ErrorStack> {
let rc = unsafe { sys::OSSL_PARAM_BLD_push_uint(self.ptr, key.as_ptr(), val) };
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(self)
}
pub fn push_uint64(self, key: &CStr, val: u64) -> Result<Self, ErrorStack> {
let rc = unsafe { sys::OSSL_PARAM_BLD_push_uint64(self.ptr, key.as_ptr(), val) };
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(self)
}
pub fn push_size(self, key: &CStr, val: usize) -> Result<Self, ErrorStack> {
let rc = unsafe { sys::OSSL_PARAM_BLD_push_size_t(self.ptr, key.as_ptr(), val) };
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(self)
}
pub fn push_octet_slice(self, key: &CStr, val: &[u8]) -> Result<Self, ErrorStack> {
let rc = unsafe {
sys::OSSL_PARAM_BLD_push_octet_string(
self.ptr,
key.as_ptr(),
val.as_ptr().cast(),
val.len(),
)
};
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(self)
}
pub fn push_octet_ptr<'b>(
self,
key: &CStr,
val: &'b [u8],
) -> Result<ParamBuilder<'b>, ErrorStack>
where
'a: 'b,
{
let rc = unsafe {
sys::OSSL_PARAM_BLD_push_octet_ptr(
self.ptr,
key.as_ptr(),
val.as_ptr() as *mut std::os::raw::c_void,
val.len(),
)
};
if rc != 1 {
return Err(ErrorStack::drain());
}
let md = std::mem::ManuallyDrop::new(self);
Ok(ParamBuilder {
ptr: md.ptr,
bns: unsafe { ptr::read(&raw const md.bns) },
_lifetime: PhantomData,
})
}
pub fn push_utf8_string(self, key: &CStr, val: &CStr) -> Result<Self, ErrorStack> {
let rc = unsafe {
sys::OSSL_PARAM_BLD_push_utf8_string(
self.ptr,
key.as_ptr(),
val.as_ptr(),
0, )
};
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(self)
}
pub fn push_utf8_ptr(self, key: &CStr, val: &'static CStr) -> Result<Self, ErrorStack> {
let rc = unsafe {
sys::OSSL_PARAM_BLD_push_utf8_string(self.ptr, key.as_ptr(), val.as_ptr(), 0)
};
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(self)
}
pub fn push_bn(mut self, key: &CStr, bigendian_bytes: &[u8]) -> Result<Self, ErrorStack> {
let bn = unsafe {
sys::BN_bin2bn(
bigendian_bytes.as_ptr(),
i32::try_from(bigendian_bytes.len()).expect("BN too large"),
ptr::null_mut(),
)
};
if bn.is_null() {
return Err(ErrorStack::drain());
}
let rc = unsafe { sys::OSSL_PARAM_BLD_push_BN(self.ptr, key.as_ptr(), bn) };
if rc != 1 {
unsafe { sys::BN_free(bn) };
return Err(ErrorStack::drain());
}
self.bns.push(bn);
Ok(self)
}
pub fn build(self) -> Result<Params<'a>, ErrorStack> {
let builder_ptr = self.ptr;
let bns = unsafe { ptr::read(&raw const self.bns) };
std::mem::forget(self);
let param_ptr = unsafe { sys::OSSL_PARAM_BLD_to_param(builder_ptr) };
unsafe { sys::OSSL_PARAM_BLD_free(builder_ptr) };
for bn in bns {
unsafe { sys::BN_free(bn) };
}
if param_ptr.is_null() {
return Err(ErrorStack::drain());
}
Ok(Params {
ptr: param_ptr,
_lifetime: PhantomData,
})
}
}
impl Drop for ParamBuilder<'_> {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { sys::OSSL_PARAM_BLD_free(self.ptr) };
}
for bn in self.bns.drain(..) {
unsafe { sys::BN_free(bn) };
}
}
}
struct Bn(*mut sys::BIGNUM);
impl Bn {
fn to_bigendian_vec(&self) -> Vec<u8> {
let nbits = unsafe { sys::BN_num_bits(self.0) };
let nbytes = usize::try_from(nbits).unwrap_or(0).div_ceil(8);
if nbytes == 0 {
return Vec::new();
}
let mut out = vec![0u8; nbytes];
unsafe { sys::BN_bn2bin(self.0, out.as_mut_ptr()) };
out
}
}
impl Drop for Bn {
fn drop(&mut self) {
unsafe { sys::BN_free(self.0) };
}
}
impl Params<'_> {
#[must_use]
pub unsafe fn from_owned_ptr(ptr: *mut sys::OSSL_PARAM) -> Params<'static> {
Params {
ptr,
_lifetime: PhantomData,
}
}
#[must_use]
pub fn as_mut_ptr(&mut self) -> *mut sys::OSSL_PARAM {
self.ptr
}
#[must_use]
pub fn has_param(&self, key: &CStr) -> bool {
!unsafe { sys::OSSL_PARAM_locate(self.ptr, key.as_ptr()) }.is_null()
}
pub fn get_int(&self, key: &CStr) -> Result<i32, ErrorStack> {
let elem = unsafe { sys::OSSL_PARAM_locate(self.ptr, key.as_ptr()) };
if elem.is_null() {
return Err(ErrorStack::drain());
}
let mut val: std::os::raw::c_int = 0;
crate::ossl_call!(sys::OSSL_PARAM_get_int(elem, std::ptr::addr_of_mut!(val)))?;
Ok(val)
}
pub fn get_uint(&self, key: &CStr) -> Result<u32, ErrorStack> {
let elem = unsafe { sys::OSSL_PARAM_locate(self.ptr, key.as_ptr()) };
if elem.is_null() {
return Err(ErrorStack::drain());
}
let mut val: std::os::raw::c_uint = 0;
crate::ossl_call!(sys::OSSL_PARAM_get_uint(elem, std::ptr::addr_of_mut!(val)))?;
Ok(val)
}
pub fn get_size_t(&self, key: &CStr) -> Result<usize, ErrorStack> {
let elem = unsafe { sys::OSSL_PARAM_locate(self.ptr, key.as_ptr()) };
if elem.is_null() {
return Err(ErrorStack::drain());
}
let mut val: usize = 0;
crate::ossl_call!(sys::OSSL_PARAM_get_size_t(
elem,
std::ptr::addr_of_mut!(val)
))?;
Ok(val)
}
pub fn get_i64(&self, key: &CStr) -> Result<i64, ErrorStack> {
let elem = unsafe { sys::OSSL_PARAM_locate(self.ptr, key.as_ptr()) };
if elem.is_null() {
return Err(ErrorStack::drain());
}
let mut val: i64 = 0;
crate::ossl_call!(sys::OSSL_PARAM_get_int64(elem, std::ptr::addr_of_mut!(val)))?;
Ok(val)
}
pub fn get_u64(&self, key: &CStr) -> Result<u64, ErrorStack> {
let elem = unsafe { sys::OSSL_PARAM_locate(self.ptr, key.as_ptr()) };
if elem.is_null() {
return Err(ErrorStack::drain());
}
let mut val: u64 = 0;
crate::ossl_call!(sys::OSSL_PARAM_get_uint64(
elem,
std::ptr::addr_of_mut!(val)
))?;
Ok(val)
}
pub fn get_bn(&self, key: &CStr) -> Result<Vec<u8>, ErrorStack> {
let elem = unsafe { sys::OSSL_PARAM_locate(self.ptr, key.as_ptr()) };
if elem.is_null() {
return Err(ErrorStack::drain());
}
let mut bn_ptr: *mut sys::BIGNUM = ptr::null_mut();
crate::ossl_call!(sys::OSSL_PARAM_get_BN(elem, std::ptr::addr_of_mut!(bn_ptr)))?;
let bn = Bn(bn_ptr);
Ok(bn.to_bigendian_vec())
}
pub fn get_octet_string(&self, key: &CStr) -> Result<&[u8], ErrorStack> {
let elem = unsafe { sys::OSSL_PARAM_locate(self.ptr, key.as_ptr()) };
if elem.is_null() {
return Err(ErrorStack::drain());
}
let mut p: *const std::os::raw::c_void = ptr::null();
let mut len: usize = 0;
crate::ossl_call!(sys::OSSL_PARAM_get_octet_string_ptr(
elem,
std::ptr::addr_of_mut!(p),
std::ptr::addr_of_mut!(len),
))?;
Ok(unsafe { std::slice::from_raw_parts(p.cast::<u8>(), len) })
}
pub fn get_utf8_string(&self, key: &CStr) -> Result<&CStr, ErrorStack> {
let elem = unsafe { sys::OSSL_PARAM_locate(self.ptr, key.as_ptr()) };
if elem.is_null() {
return Err(ErrorStack::drain());
}
let mut p: *const std::os::raw::c_char = ptr::null();
crate::ossl_call!(sys::OSSL_PARAM_get_utf8_string_ptr(
elem,
std::ptr::addr_of_mut!(p),
))?;
Ok(unsafe { CStr::from_ptr(p) })
}
}
#[must_use]
pub(crate) fn null_params() -> *const sys::OSSL_PARAM {
ptr::null()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn int_round_trip() {
let params = ParamBuilder::new()
.unwrap()
.push_int(c"mykey", 42)
.unwrap()
.build()
.unwrap();
let elem = unsafe { sys::OSSL_PARAM_locate(params.as_ptr().cast_mut(), c"mykey".as_ptr()) };
assert!(!elem.is_null(), "OSSL_PARAM_locate failed");
let mut out: i32 = 0;
let rc = unsafe { sys::OSSL_PARAM_get_int(elem, std::ptr::addr_of_mut!(out)) };
assert_eq!(rc, 1, "OSSL_PARAM_get_int failed");
assert_eq!(out, 42);
}
#[test]
fn octet_slice_round_trip() {
let data = b"hello world";
let params = ParamBuilder::new()
.unwrap()
.push_octet_slice(c"blob", data)
.unwrap()
.build()
.unwrap();
let elem = unsafe { sys::OSSL_PARAM_locate(params.as_ptr().cast_mut(), c"blob".as_ptr()) };
assert!(!elem.is_null());
let mut p: *const std::os::raw::c_void = ptr::null();
let mut len: usize = 0;
let rc = unsafe {
sys::OSSL_PARAM_get_octet_string_ptr(
elem,
std::ptr::addr_of_mut!(p),
std::ptr::addr_of_mut!(len),
)
};
assert_eq!(rc, 1, "OSSL_PARAM_get_octet_string_ptr failed");
assert_eq!(len, data.len());
let got = unsafe { std::slice::from_raw_parts(p.cast::<u8>(), len) };
assert_eq!(got, data);
}
#[test]
fn utf8_string_round_trip() {
let params = ParamBuilder::new()
.unwrap()
.push_utf8_string(c"alg", c"SHA-256")
.unwrap()
.build()
.unwrap();
let elem = unsafe { sys::OSSL_PARAM_locate(params.as_ptr().cast_mut(), c"alg".as_ptr()) };
assert!(!elem.is_null());
let mut out: *const std::os::raw::c_char = ptr::null();
let rc = unsafe { sys::OSSL_PARAM_get_utf8_string_ptr(elem, std::ptr::addr_of_mut!(out)) };
assert_eq!(rc, 1, "OSSL_PARAM_get_utf8_string_ptr failed");
let got = unsafe { CStr::from_ptr(out) };
assert_eq!(got.to_bytes(), b"SHA-256");
}
}