use crate::error::ErrorStack;
use native_ossl_sys as sys;
use std::ffi::CStr;
use std::sync::Arc;
#[derive(Debug)]
pub struct DigestAlg {
ptr: *mut sys::EVP_MD,
lib_ctx: Option<Arc<crate::lib_ctx::LibCtx>>,
}
impl DigestAlg {
pub fn fetch(name: &CStr, props: Option<&CStr>) -> Result<Self, ErrorStack> {
let props_ptr = props.map_or(std::ptr::null(), CStr::as_ptr);
let ptr = unsafe { sys::EVP_MD_fetch(std::ptr::null_mut(), name.as_ptr(), props_ptr) };
if ptr.is_null() {
return Err(ErrorStack::drain());
}
Ok(DigestAlg { ptr, lib_ctx: None })
}
pub fn fetch_in(
ctx: &Arc<crate::lib_ctx::LibCtx>,
name: &CStr,
props: Option<&CStr>,
) -> Result<Self, ErrorStack> {
let props_ptr = props.map_or(std::ptr::null(), CStr::as_ptr);
let ptr = unsafe { sys::EVP_MD_fetch(ctx.as_ptr(), name.as_ptr(), props_ptr) };
if ptr.is_null() {
return Err(ErrorStack::drain());
}
Ok(DigestAlg {
ptr,
lib_ctx: Some(Arc::clone(ctx)),
})
}
#[must_use]
pub fn output_len(&self) -> usize {
usize::try_from(unsafe { sys::EVP_MD_get_size(self.ptr) }).unwrap_or(0)
}
#[must_use]
pub fn block_size(&self) -> usize {
usize::try_from(unsafe { sys::EVP_MD_get_block_size(self.ptr) }).unwrap_or(0)
}
#[must_use]
pub fn nid(&self) -> i32 {
unsafe { sys::EVP_MD_get_type(self.ptr) }
}
#[must_use]
pub fn as_ptr(&self) -> *const sys::EVP_MD {
self.ptr
}
}
impl Clone for DigestAlg {
fn clone(&self) -> Self {
unsafe { sys::EVP_MD_up_ref(self.ptr) };
DigestAlg {
ptr: self.ptr,
lib_ctx: self.lib_ctx.clone(),
}
}
}
impl Drop for DigestAlg {
fn drop(&mut self) {
unsafe { sys::EVP_MD_free(self.ptr) };
}
}
unsafe impl Send for DigestAlg {}
unsafe impl Sync for DigestAlg {}
#[derive(Debug)]
pub struct DigestCtx {
ptr: *mut sys::EVP_MD_CTX,
}
impl DigestCtx {
pub fn update(&mut self, data: &[u8]) -> Result<(), ErrorStack> {
crate::ossl_call!(sys::EVP_DigestUpdate(
self.ptr,
data.as_ptr().cast(),
data.len()
))
}
pub fn finish(&mut self, out: &mut [u8]) -> Result<usize, ErrorStack> {
let mut len: u32 = 0;
crate::ossl_call!(sys::EVP_DigestFinal_ex(
self.ptr,
out.as_mut_ptr(),
std::ptr::addr_of_mut!(len)
))?;
Ok(usize::try_from(len).unwrap_or(0))
}
pub fn finish_xof(&mut self, out: &mut [u8]) -> Result<(), ErrorStack> {
crate::ossl_call!(sys::EVP_DigestFinalXOF(
self.ptr,
out.as_mut_ptr(),
out.len()
))
}
pub fn fork(&self) -> Result<DigestCtx, ErrorStack> {
let new_ctx = unsafe { sys::EVP_MD_CTX_new() };
if new_ctx.is_null() {
return Err(ErrorStack::drain());
}
crate::ossl_call!(sys::EVP_MD_CTX_copy_ex(new_ctx, self.ptr))?;
Ok(DigestCtx { ptr: new_ctx })
}
#[cfg(ossl_v400)]
pub fn serialize_size(&self) -> Result<usize, ErrorStack> {
let mut outlen: usize = 0;
crate::ossl_call!(sys::EVP_MD_CTX_serialize(
self.ptr,
std::ptr::null_mut(),
std::ptr::addr_of_mut!(outlen)
))?;
Ok(outlen)
}
#[cfg(ossl_v400)]
pub fn serialize(&self, out: &mut [u8]) -> Result<usize, ErrorStack> {
let mut outlen: usize = out.len();
crate::ossl_call!(sys::EVP_MD_CTX_serialize(
self.ptr,
out.as_mut_ptr(),
std::ptr::addr_of_mut!(outlen)
))?;
Ok(outlen)
}
#[cfg(ossl_v400)]
pub fn deserialize(&mut self, data: &[u8]) -> Result<(), ErrorStack> {
crate::ossl_call!(sys::EVP_MD_CTX_deserialize(
self.ptr,
data.as_ptr(),
data.len()
))
}
pub fn new_empty() -> Result<Self, ErrorStack> {
let ptr = unsafe { sys::EVP_MD_CTX_new() };
if ptr.is_null() {
return Err(ErrorStack::drain());
}
Ok(DigestCtx { ptr })
}
pub fn reinit(
&mut self,
alg: &DigestAlg,
params: Option<&crate::params::Params<'_>>,
) -> Result<(), ErrorStack> {
crate::ossl_call!(sys::EVP_DigestInit_ex2(
self.ptr,
alg.ptr,
params.map_or(std::ptr::null(), super::params::Params::as_ptr),
))
}
pub unsafe fn from_ptr(ptr: *mut sys::EVP_MD_CTX) -> Self {
DigestCtx { ptr }
}
#[must_use]
pub fn as_ptr(&self) -> *mut sys::EVP_MD_CTX {
self.ptr
}
}
impl Drop for DigestCtx {
fn drop(&mut self) {
unsafe { sys::EVP_MD_CTX_free(self.ptr) };
}
}
unsafe impl Send for DigestCtx {}
unsafe impl Sync for DigestCtx {}
impl DigestAlg {
pub fn new_context(&self) -> Result<DigestCtx, ErrorStack> {
let ctx_ptr = unsafe { sys::EVP_MD_CTX_new() };
if ctx_ptr.is_null() {
return Err(ErrorStack::drain());
}
crate::ossl_call!(sys::EVP_DigestInit_ex2(ctx_ptr, self.ptr, std::ptr::null())).map_err(
|e| {
unsafe { sys::EVP_MD_CTX_free(ctx_ptr) };
e
},
)?;
Ok(DigestCtx { ptr: ctx_ptr })
}
pub fn digest(&self, data: &[u8], out: &mut [u8]) -> Result<usize, ErrorStack> {
let mut len: u32 = 0;
crate::ossl_call!(sys::EVP_Digest(
data.as_ptr().cast(),
data.len(),
out.as_mut_ptr(),
std::ptr::addr_of_mut!(len),
self.ptr,
std::ptr::null_mut()
))?;
Ok(usize::try_from(len).unwrap_or(0))
}
pub fn digest_to_vec(&self, data: &[u8]) -> Result<Vec<u8>, ErrorStack> {
let mut out = vec![0u8; self.output_len()];
let len = self.digest(data, &mut out)?;
out.truncate(len);
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fetch_sha256_properties() {
let alg = DigestAlg::fetch(c"SHA2-256", None).unwrap();
assert_eq!(alg.output_len(), 32);
assert_eq!(alg.block_size(), 64);
}
#[test]
fn fetch_nonexistent_fails() {
assert!(DigestAlg::fetch(c"NONEXISTENT_DIGEST_XYZ", None).is_err());
}
#[test]
fn clone_then_drop_both() {
let alg = DigestAlg::fetch(c"SHA2-256", None).unwrap();
let alg2 = alg.clone();
drop(alg);
drop(alg2);
}
#[test]
fn sha256_known_answer() {
let alg = DigestAlg::fetch(c"SHA2-256", None).unwrap();
let mut ctx = alg.new_context().unwrap();
ctx.update(b"abc").unwrap();
let mut out = [0u8; 32];
let n = ctx.finish(&mut out).unwrap();
assert_eq!(n, 32);
assert_eq!(
hex::encode(out),
"ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
);
}
#[test]
fn sha256_oneshot() {
let alg = DigestAlg::fetch(c"SHA2-256", None).unwrap();
let got = alg.digest_to_vec(b"abc").unwrap();
let expected =
hex::decode("ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
.unwrap();
assert_eq!(got, expected);
}
#[test]
fn fork_mid_stream() {
let alg = DigestAlg::fetch(c"SHA2-256", None).unwrap();
let mut ctx = alg.new_context().unwrap();
ctx.update(b"common prefix").unwrap();
let mut fork = ctx.fork().unwrap();
ctx.update(b" A").unwrap();
fork.update(b" B").unwrap();
let mut out_a = [0u8; 32];
let mut out_b = [0u8; 32];
ctx.finish(&mut out_a).unwrap();
fork.finish(&mut out_b).unwrap();
assert_ne!(out_a, out_b);
}
#[cfg(ossl_v400)]
#[test]
fn serialize_deserialize_roundtrip() {
let alg = DigestAlg::fetch(c"SHA2-256", None).unwrap();
let mut ctx_a = alg.new_context().unwrap();
ctx_a.update(b"hello").unwrap();
let size = ctx_a.serialize_size().unwrap();
assert!(size > 0, "serialized state must be non-empty");
let mut state = vec![0u8; size];
let written = ctx_a.serialize(&mut state).unwrap();
assert_eq!(written, size, "serialize wrote unexpected byte count");
ctx_a.update(b" world").unwrap();
let mut out_a = [0u8; 32];
ctx_a.finish(&mut out_a).unwrap();
let mut ctx_b = alg.new_context().unwrap();
ctx_b.deserialize(&state).unwrap();
ctx_b.update(b" world").unwrap();
let mut out_b = [0u8; 32];
ctx_b.finish(&mut out_b).unwrap();
assert_eq!(out_a, out_b, "restored context produced different digest");
}
#[cfg(ossl_v400)]
#[test]
fn serialize_different_suffix_differs() {
let alg = DigestAlg::fetch(c"SHA2-256", None).unwrap();
let mut ctx = alg.new_context().unwrap();
ctx.update(b"hello").unwrap();
let size = ctx.serialize_size().unwrap();
let mut state = vec![0u8; size];
ctx.serialize(&mut state).unwrap();
let mut ctx_a = alg.new_context().unwrap();
ctx_a.deserialize(&state).unwrap();
ctx_a.update(b" world").unwrap();
let mut out_a = [0u8; 32];
ctx_a.finish(&mut out_a).unwrap();
let mut ctx_b = alg.new_context().unwrap();
ctx_b.deserialize(&state).unwrap();
ctx_b.update(b" WORLD").unwrap();
let mut out_b = [0u8; 32];
ctx_b.finish(&mut out_b).unwrap();
assert_ne!(
out_a, out_b,
"different suffixes must produce different digests"
);
}
}