use crate::error::ErrorStack;
use native_ossl_sys as sys;
use std::ffi::CStr;
use std::sync::Arc;
pub struct CipherAlg {
ptr: *mut sys::EVP_CIPHER,
lib_ctx: Option<Arc<crate::lib_ctx::LibCtx>>,
}
impl CipherAlg {
pub fn fetch(name: &CStr, props: Option<&CStr>) -> Result<Self, ErrorStack> {
let props_ptr = props.map_or(std::ptr::null(), CStr::as_ptr);
let ptr = unsafe { sys::EVP_CIPHER_fetch(std::ptr::null_mut(), name.as_ptr(), props_ptr) };
if ptr.is_null() {
return Err(ErrorStack::drain());
}
Ok(CipherAlg { ptr, lib_ctx: None })
}
pub fn fetch_in(
ctx: &Arc<crate::lib_ctx::LibCtx>,
name: &CStr,
props: Option<&CStr>,
) -> Result<Self, ErrorStack> {
let props_ptr = props.map_or(std::ptr::null(), CStr::as_ptr);
let ptr = unsafe { sys::EVP_CIPHER_fetch(ctx.as_ptr(), name.as_ptr(), props_ptr) };
if ptr.is_null() {
return Err(ErrorStack::drain());
}
Ok(CipherAlg {
ptr,
lib_ctx: Some(Arc::clone(ctx)),
})
}
#[must_use]
pub fn key_len(&self) -> usize {
usize::try_from(unsafe { sys::EVP_CIPHER_get_key_length(self.ptr) }).unwrap_or(0)
}
#[must_use]
pub fn iv_len(&self) -> usize {
usize::try_from(unsafe { sys::EVP_CIPHER_get_iv_length(self.ptr) }).unwrap_or(0)
}
#[must_use]
pub fn block_size(&self) -> usize {
usize::try_from(unsafe { sys::EVP_CIPHER_get_block_size(self.ptr) }).unwrap_or(0)
}
#[must_use]
pub fn flags(&self) -> u64 {
unsafe { sys::EVP_CIPHER_get_flags(self.ptr) }
}
#[must_use]
pub fn is_aead(&self) -> bool {
(self.flags() & 0x0020_0000) != 0
}
#[must_use]
pub fn as_ptr(&self) -> *const sys::EVP_CIPHER {
self.ptr
}
}
impl Clone for CipherAlg {
fn clone(&self) -> Self {
unsafe { sys::EVP_CIPHER_up_ref(self.ptr) };
CipherAlg {
ptr: self.ptr,
lib_ctx: self.lib_ctx.clone(),
}
}
}
impl Drop for CipherAlg {
fn drop(&mut self) {
unsafe { sys::EVP_CIPHER_free(self.ptr) };
}
}
unsafe impl Send for CipherAlg {}
unsafe impl Sync for CipherAlg {}
pub struct Encrypt;
pub struct Decrypt;
mod sealed {
pub trait Direction {}
impl Direction for super::Encrypt {}
impl Direction for super::Decrypt {}
}
pub struct CipherCtx<Dir> {
ptr: *mut sys::EVP_CIPHER_CTX,
_dir: std::marker::PhantomData<Dir>,
}
impl<Dir: sealed::Direction> CipherCtx<Dir> {
pub fn update(&mut self, input: &[u8], output: &mut [u8]) -> Result<usize, ErrorStack>
where
Dir: IsEncrypt,
{
unsafe { Dir::do_update(self.ptr, input, output) }
}
pub fn update_to_vec(&mut self, input: &[u8]) -> Result<Vec<u8>, ErrorStack>
where
Dir: IsEncrypt,
{
let block_size =
usize::try_from(unsafe { sys::EVP_CIPHER_CTX_get_block_size(self.ptr) }).unwrap_or(0);
let max = input.len() + block_size;
let mut out = vec![0u8; max];
let n = self.update(input, &mut out)?;
out.truncate(n);
Ok(out)
}
pub fn finalize(&mut self, output: &mut [u8]) -> Result<usize, ErrorStack>
where
Dir: IsEncrypt,
{
unsafe { Dir::do_finalize(self.ptr, output) }
}
pub fn set_params(&mut self, params: &crate::params::Params<'_>) -> Result<(), ErrorStack> {
crate::ossl_call!(sys::EVP_CIPHER_CTX_set_params(self.ptr, params.as_ptr()))
}
#[must_use]
pub fn as_ptr(&self) -> *mut sys::EVP_CIPHER_CTX {
self.ptr
}
}
impl<Dir> Drop for CipherCtx<Dir> {
fn drop(&mut self) {
unsafe { sys::EVP_CIPHER_CTX_free(self.ptr) };
}
}
unsafe impl<Dir: sealed::Direction> Send for CipherCtx<Dir> {}
pub trait IsEncrypt: sealed::Direction {
unsafe fn do_update(
ctx: *mut sys::EVP_CIPHER_CTX,
input: &[u8],
output: &mut [u8],
) -> Result<usize, ErrorStack>;
unsafe fn do_finalize(
ctx: *mut sys::EVP_CIPHER_CTX,
output: &mut [u8],
) -> Result<usize, ErrorStack>;
}
impl IsEncrypt for Encrypt {
unsafe fn do_update(
ctx: *mut sys::EVP_CIPHER_CTX,
input: &[u8],
output: &mut [u8],
) -> Result<usize, ErrorStack> {
let inl = i32::try_from(input.len()).map_err(|_| ErrorStack::drain())?;
let mut outl: i32 = 0;
crate::ossl_call!(sys::EVP_EncryptUpdate(
ctx,
output.as_mut_ptr(),
std::ptr::addr_of_mut!(outl),
input.as_ptr(),
inl
))?;
Ok(usize::try_from(outl).unwrap_or(0))
}
unsafe fn do_finalize(
ctx: *mut sys::EVP_CIPHER_CTX,
output: &mut [u8],
) -> Result<usize, ErrorStack> {
let mut outl: i32 = 0;
crate::ossl_call!(sys::EVP_EncryptFinal_ex(
ctx,
output.as_mut_ptr(),
std::ptr::addr_of_mut!(outl)
))?;
Ok(usize::try_from(outl).unwrap_or(0))
}
}
impl IsEncrypt for Decrypt {
unsafe fn do_update(
ctx: *mut sys::EVP_CIPHER_CTX,
input: &[u8],
output: &mut [u8],
) -> Result<usize, ErrorStack> {
let inl = i32::try_from(input.len()).map_err(|_| ErrorStack::drain())?;
let mut outl: i32 = 0;
crate::ossl_call!(sys::EVP_DecryptUpdate(
ctx,
output.as_mut_ptr(),
std::ptr::addr_of_mut!(outl),
input.as_ptr(),
inl
))?;
Ok(usize::try_from(outl).unwrap_or(0))
}
unsafe fn do_finalize(
ctx: *mut sys::EVP_CIPHER_CTX,
output: &mut [u8],
) -> Result<usize, ErrorStack> {
let mut outl: i32 = 0;
crate::ossl_call!(sys::EVP_DecryptFinal_ex(
ctx,
output.as_mut_ptr(),
std::ptr::addr_of_mut!(outl)
))?;
Ok(usize::try_from(outl).unwrap_or(0))
}
}
impl CipherAlg {
pub fn encrypt(
&self,
key: &[u8],
iv: &[u8],
params: Option<&crate::params::Params<'_>>,
) -> Result<CipherCtx<Encrypt>, ErrorStack> {
let ctx_ptr = unsafe { sys::EVP_CIPHER_CTX_new() };
if ctx_ptr.is_null() {
return Err(ErrorStack::drain());
}
let params_ptr = params.map_or(crate::params::null_params(), crate::params::Params::as_ptr);
crate::ossl_call!(sys::EVP_EncryptInit_ex2(
ctx_ptr,
self.ptr,
key.as_ptr(),
iv.as_ptr(),
params_ptr
))
.map_err(|e| {
unsafe { sys::EVP_CIPHER_CTX_free(ctx_ptr) };
e
})?;
Ok(CipherCtx {
ptr: ctx_ptr,
_dir: std::marker::PhantomData,
})
}
pub fn decrypt(
&self,
key: &[u8],
iv: &[u8],
params: Option<&crate::params::Params<'_>>,
) -> Result<CipherCtx<Decrypt>, ErrorStack> {
let ctx_ptr = unsafe { sys::EVP_CIPHER_CTX_new() };
if ctx_ptr.is_null() {
return Err(ErrorStack::drain());
}
let params_ptr = params.map_or(crate::params::null_params(), crate::params::Params::as_ptr);
crate::ossl_call!(sys::EVP_DecryptInit_ex2(
ctx_ptr,
self.ptr,
key.as_ptr(),
iv.as_ptr(),
params_ptr
))
.map_err(|e| {
unsafe { sys::EVP_CIPHER_CTX_free(ctx_ptr) };
e
})?;
Ok(CipherCtx {
ptr: ctx_ptr,
_dir: std::marker::PhantomData,
})
}
}
pub struct AeadEncryptCtx(CipherCtx<Encrypt>);
impl AeadEncryptCtx {
pub fn new(
alg: &CipherAlg,
key: &[u8],
iv: &[u8],
params: Option<&crate::params::Params<'_>>,
) -> Result<Self, ErrorStack> {
assert!(alg.is_aead(), "CipherAlg is not an AEAD algorithm");
Ok(AeadEncryptCtx(alg.encrypt(key, iv, params)?))
}
pub fn set_aad(&mut self, aad: &[u8]) -> Result<(), ErrorStack> {
let alen = i32::try_from(aad.len()).expect("AAD too large for EVP_EncryptUpdate");
let mut outl: i32 = 0;
crate::ossl_call!(sys::EVP_EncryptUpdate(
self.0.ptr,
std::ptr::null_mut(),
std::ptr::addr_of_mut!(outl),
aad.as_ptr(),
alen
))
}
pub fn update(&mut self, input: &[u8], output: &mut [u8]) -> Result<usize, ErrorStack> {
self.0.update(input, output)
}
pub fn finalize(&mut self, output: &mut [u8]) -> Result<usize, ErrorStack> {
self.0.finalize(output)
}
pub fn tag(&self, tag: &mut [u8]) -> Result<(), ErrorStack> {
let tlen = i32::try_from(tag.len()).expect("tag slice too large");
let rc = unsafe {
sys::EVP_CIPHER_CTX_ctrl(
self.0.ptr,
16, tlen,
tag.as_mut_ptr().cast(),
)
};
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(())
}
}
pub struct AeadDecryptCtx(CipherCtx<Decrypt>);
impl AeadDecryptCtx {
pub fn new(
alg: &CipherAlg,
key: &[u8],
iv: &[u8],
params: Option<&crate::params::Params<'_>>,
) -> Result<Self, ErrorStack> {
assert!(alg.is_aead(), "CipherAlg is not an AEAD algorithm");
Ok(AeadDecryptCtx(alg.decrypt(key, iv, params)?))
}
pub fn set_aad(&mut self, aad: &[u8]) -> Result<(), ErrorStack> {
let alen = i32::try_from(aad.len()).expect("AAD too large for EVP_DecryptUpdate");
let mut outl: i32 = 0;
crate::ossl_call!(sys::EVP_DecryptUpdate(
self.0.ptr,
std::ptr::null_mut(),
std::ptr::addr_of_mut!(outl),
aad.as_ptr(),
alen
))
}
pub fn update(&mut self, input: &[u8], output: &mut [u8]) -> Result<usize, ErrorStack> {
self.0.update(input, output)
}
pub fn set_tag(&mut self, tag: &[u8]) -> Result<(), ErrorStack> {
let tlen = i32::try_from(tag.len()).expect("tag slice too large");
let rc = unsafe {
sys::EVP_CIPHER_CTX_ctrl(
self.0.ptr,
17, tlen,
tag.as_ptr().cast_mut().cast(),
)
};
if rc != 1 {
return Err(ErrorStack::drain());
}
Ok(())
}
pub fn finalize(&mut self, output: &mut [u8]) -> Result<usize, ErrorStack> {
self.0.finalize(output)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fetch_aes_256_gcm_properties() {
let alg = CipherAlg::fetch(c"AES-256-GCM", None).unwrap();
assert_eq!(alg.key_len(), 32);
assert_eq!(alg.iv_len(), 12);
assert_eq!(alg.block_size(), 1);
assert!(alg.is_aead());
}
#[test]
fn fetch_aes_256_cbc_properties() {
let alg = CipherAlg::fetch(c"AES-256-CBC", None).unwrap();
assert_eq!(alg.key_len(), 32);
assert_eq!(alg.iv_len(), 16);
assert_eq!(alg.block_size(), 16);
assert!(!alg.is_aead());
}
#[test]
fn fetch_nonexistent_fails() {
assert!(CipherAlg::fetch(c"NONEXISTENT_CIPHER_XYZ", None).is_err());
}
#[test]
fn clone_then_drop_both() {
let alg = CipherAlg::fetch(c"AES-256-GCM", None).unwrap();
let alg2 = alg.clone();
drop(alg);
drop(alg2);
}
#[test]
fn aes_256_cbc_round_trip() {
let alg = CipherAlg::fetch(c"AES-256-CBC", None).unwrap();
let key = [0x42u8; 32];
let iv = [0x24u8; 16];
let plaintext = b"Hello, cipher world!";
let mut enc = alg.encrypt(&key, &iv, None).unwrap();
let mut ciphertext = vec![0u8; plaintext.len() + alg.block_size()];
let n = enc.update(plaintext, &mut ciphertext).unwrap();
let m = enc.finalize(&mut ciphertext[n..]).unwrap();
ciphertext.truncate(n + m);
let mut dec = alg.decrypt(&key, &iv, None).unwrap();
let mut recovered = vec![0u8; ciphertext.len() + alg.block_size()];
let n2 = dec.update(&ciphertext, &mut recovered).unwrap();
let m2 = dec.finalize(&mut recovered[n2..]).unwrap();
recovered.truncate(n2 + m2);
assert_eq!(recovered, plaintext);
}
#[test]
fn aes_256_gcm_round_trip_and_tag_failure() {
let alg = CipherAlg::fetch(c"AES-256-GCM", None).unwrap();
let key = [0x11u8; 32];
let iv = [0x22u8; 12];
let aad = b"additional data";
let plaintext = b"secret message!";
let mut enc = AeadEncryptCtx::new(&alg, &key, &iv, None).unwrap();
enc.set_aad(aad).unwrap();
let mut ciphertext = vec![0u8; plaintext.len()];
let n = enc.update(plaintext, &mut ciphertext).unwrap();
enc.finalize(&mut ciphertext[n..]).unwrap();
let mut tag = [0u8; 16];
enc.tag(&mut tag).unwrap();
let mut dec = AeadDecryptCtx::new(&alg, &key, &iv, None).unwrap();
dec.set_aad(aad).unwrap();
let mut recovered = vec![0u8; ciphertext.len()];
let n2 = dec.update(&ciphertext, &mut recovered).unwrap();
dec.set_tag(&tag).unwrap();
dec.finalize(&mut recovered[n2..]).unwrap();
assert_eq!(&recovered[..n2], plaintext);
let mut bad_tag = tag;
bad_tag[0] ^= 0xff;
let mut dec2 = AeadDecryptCtx::new(&alg, &key, &iv, None).unwrap();
dec2.set_aad(aad).unwrap();
let mut dummy = vec![0u8; ciphertext.len()];
dec2.update(&ciphertext, &mut dummy).unwrap();
dec2.set_tag(&bad_tag).unwrap();
assert!(dec2.finalize(&mut dummy).is_err());
}
}