use std::ptr;
use foreign_types::{ForeignType, ForeignTypeRef};
use openssl_macros::corresponds;
use crate::error::ErrorStack;
use crate::ffi;
use crate::{cvt, cvt_p};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct Algorithm(*const ffi::EVP_AEAD);
impl Algorithm {
#[must_use]
pub const unsafe fn from_ptr(ptr: *const ffi::EVP_AEAD) -> Self {
Self(ptr)
}
#[corresponds(EVP_aead_aes_128_gcm)]
#[must_use]
pub fn aes_128_gcm() -> Self {
unsafe { Self(ffi::EVP_aead_aes_128_gcm()) }
}
#[corresponds(EVP_aead_aes_256_gcm)]
#[must_use]
pub fn aes_256_gcm() -> Self {
unsafe { Self(ffi::EVP_aead_aes_256_gcm()) }
}
#[corresponds(EVP_aead_chacha20_poly1305)]
#[must_use]
pub fn chacha20_poly1305() -> Self {
unsafe { Self(ffi::EVP_aead_chacha20_poly1305()) }
}
#[corresponds(EVP_aead_xchacha20_poly1305)]
#[must_use]
pub fn xchacha20_poly1305() -> Self {
unsafe { Self(ffi::EVP_aead_xchacha20_poly1305()) }
}
#[corresponds(EVP_AEAD_key_length)]
#[allow(clippy::trivially_copy_pass_by_ref)]
#[must_use]
pub fn key_length(&self) -> usize {
unsafe { ffi::EVP_AEAD_key_length(self.0) }
}
#[corresponds(EVP_AEAD_max_overhead)]
#[allow(clippy::trivially_copy_pass_by_ref)]
#[must_use]
pub fn max_overhead(&self) -> usize {
unsafe { ffi::EVP_AEAD_max_overhead(self.0) }
}
#[corresponds(EVP_AEAD_max_tag_len)]
#[allow(clippy::trivially_copy_pass_by_ref)]
#[must_use]
pub fn max_tag_len(&self) -> usize {
unsafe { ffi::EVP_AEAD_max_tag_len(self.0) }
}
#[corresponds(EVP_AEAD_nonce_length)]
#[allow(clippy::trivially_copy_pass_by_ref)]
#[must_use]
pub fn nonce_len(&self) -> usize {
unsafe { ffi::EVP_AEAD_nonce_length(self.0) }
}
#[allow(clippy::trivially_copy_pass_by_ref)]
#[must_use]
pub const fn as_ptr(&self) -> *const ffi::EVP_AEAD {
self.0
}
}
unsafe impl Send for Algorithm {}
unsafe impl Sync for Algorithm {}
foreign_type_and_impl_send_sync! {
type CType = ffi::EVP_AEAD_CTX;
fn drop = ffi::EVP_AEAD_CTX_free;
pub struct AeadCtx;
}
impl AeadCtx {
#[corresponds(EVP_AEAD_CTX_new)]
pub fn new(algorithm: &Algorithm, key: &[u8], tag_len: usize) -> Result<Self, ErrorStack> {
ffi::init();
if key.len() != algorithm.key_length() {
return Err(ErrorStack::internal_error_str("invalid key size"));
}
unsafe {
cvt_p(ffi::EVP_AEAD_CTX_new(
algorithm.as_ptr(),
key.as_ptr(),
key.len(),
tag_len,
))
.map(|ptr| AeadCtx::from_ptr(ptr))
}
}
pub fn new_default_tag(algorithm: &Algorithm, key: &[u8]) -> Result<Self, ErrorStack> {
Self::new(algorithm, key, ffi::EVP_AEAD_DEFAULT_TAG_LENGTH as usize)
}
}
impl AeadCtxRef {
#[corresponds(EVP_AEAD_CTX_tag_len)]
pub fn tag_len(&self, in_len: usize, extra_in_len: usize) -> Result<usize, ErrorStack> {
let mut out_tag_len: usize = 0;
unsafe {
cvt(ffi::EVP_AEAD_CTX_tag_len(
self.as_ptr(),
&mut out_tag_len,
in_len,
extra_in_len,
))?;
}
Ok(out_tag_len)
}
#[corresponds(EVP_AEAD_CTX_seal_scatter)]
pub fn seal_scatter<'a>(
&self,
nonce: &[u8],
in_out: &mut [u8],
out_tag: &'a mut [u8],
extra_in: Option<&[u8]>,
associated_data: &[u8],
) -> Result<&'a mut [u8], ErrorStack> {
let (extra_in_ptr, extra_in_len) = extra_in
.map(|buf| (buf.as_ptr(), buf.len()))
.unwrap_or((ptr::null(), 0));
let mut out_tag_len = out_tag.len();
unsafe {
cvt(ffi::EVP_AEAD_CTX_seal_scatter(
self.as_ptr(),
in_out.as_mut_ptr(),
out_tag.as_mut_ptr(),
&mut out_tag_len,
out_tag.len(),
nonce.as_ptr(),
nonce.len(),
in_out.as_ptr(),
in_out.len(),
extra_in_ptr,
extra_in_len,
associated_data.as_ptr(),
associated_data.len(),
))?;
}
Ok(&mut out_tag[..out_tag_len])
}
#[corresponds(EVP_AEAD_CTX_open_gather)]
pub fn open_gather(
&self,
nonce: &[u8],
in_out: &mut [u8],
in_tag: &[u8],
associated_data: &[u8],
) -> Result<(), ErrorStack> {
unsafe {
cvt(ffi::EVP_AEAD_CTX_open_gather(
self.as_ptr(),
in_out.as_mut_ptr(),
nonce.as_ptr(),
nonce.len(),
in_out.as_ptr(),
in_out.len(),
in_tag.as_ptr(),
in_tag.len(),
associated_data.as_ptr(),
associated_data.len(),
))
}
}
pub fn seal_in_place<'a>(
&self,
nonce: &[u8],
buffer: &mut [u8],
tag: &'a mut [u8],
associated_data: &[u8],
) -> Result<&'a mut [u8], ErrorStack> {
self.seal_scatter(nonce, buffer, tag, None, associated_data)
}
pub fn open_in_place(
&self,
nonce: &[u8],
buffer: &mut [u8],
tag: &[u8],
associated_data: &[u8],
) -> Result<(), ErrorStack> {
self.open_gather(nonce, buffer, tag, associated_data)
}
}
#[cfg(test)]
mod tests {
use super::{AeadCtx, Algorithm};
#[test]
fn in_out() {
let algorithm = Algorithm::aes_128_gcm();
let ctx = AeadCtx::new_default_tag(&algorithm, &[0u8; 16]).unwrap();
let nonce = [0u8; 12];
let associated_data = b"this is authenticated";
let mut buffer = b"ABCDE".to_vec();
let mut tag = [0u8; 16];
ctx.seal_in_place(&nonce, buffer.as_mut_slice(), &mut tag, associated_data)
.unwrap();
ctx.open_in_place(&nonce, buffer.as_mut_slice(), &tag, associated_data)
.unwrap();
assert_eq!(b"ABCDE", buffer.as_slice());
}
#[test]
fn xchacha_in_out() {
let algorithm = Algorithm::xchacha20_poly1305();
let ctx = AeadCtx::new_default_tag(&algorithm, &[0u8; 32]).unwrap();
let nonce = [0u8; 24];
let associated_data = b"xchacha";
let mut buffer = b"payload".to_vec();
let mut tag = [0u8; 16];
let tag_written = ctx
.seal_in_place(&nonce, buffer.as_mut_slice(), &mut tag, associated_data)
.unwrap();
let tag_len = tag_written.len();
ctx.open_in_place(
&nonce,
buffer.as_mut_slice(),
&tag[..tag_len],
associated_data,
)
.unwrap();
assert_eq!(b"payload", buffer.as_slice());
}
#[test]
fn seal_scatter_with_extra_in() {
let algorithm = Algorithm::chacha20_poly1305();
let ctx = AeadCtx::new(&algorithm, &[7u8; 32], algorithm.max_tag_len()).unwrap();
let nonce = [1u8; 12];
let aad = b"frame-header";
let mut main = b"hello".to_vec();
let extra = b" world";
let mut detached = vec![0u8; extra.len() + algorithm.max_overhead()];
let detached_written = ctx
.seal_scatter(
&nonce,
main.as_mut_slice(),
detached.as_mut_slice(),
Some(extra),
aad,
)
.unwrap();
let extra_ct_len = extra.len();
let tag = &detached_written[extra_ct_len..];
let mut full_ciphertext = main;
full_ciphertext.extend_from_slice(&detached_written[..extra_ct_len]);
ctx.open_gather(&nonce, full_ciphertext.as_mut_slice(), tag, aad)
.unwrap();
assert_eq!(full_ciphertext.as_slice(), b"hello world");
}
#[test]
fn new_rejects_invalid_key_length() {
let result = AeadCtx::new_default_tag(&Algorithm::aes_128_gcm(), &[0u8; 15]);
assert!(result.is_err());
}
#[test]
fn tag_len_returns_expected_value() {
let algorithm = Algorithm::aes_128_gcm();
let ctx = AeadCtx::new_default_tag(&algorithm, &[0u8; 16]).unwrap();
let tag_len = ctx.tag_len(0, 0).unwrap();
assert_eq!(tag_len, algorithm.max_overhead());
}
#[test]
fn seal_rejects_invalid_nonce_length() {
let algorithm = Algorithm::chacha20_poly1305();
let ctx = AeadCtx::new_default_tag(&algorithm, &[0u8; 32]).unwrap();
let mut payload = [0u8; 8];
let mut tag = [0u8; 16];
let result = ctx.seal_in_place(&[0u8; 11], &mut payload, &mut tag, b"");
assert!(result.is_err());
}
#[test]
fn seal_rejects_insufficient_tag_buffer() {
let algorithm = Algorithm::aes_128_gcm();
let ctx = AeadCtx::new_default_tag(&algorithm, &[0u8; 16]).unwrap();
let mut payload = [0u8; 8];
let mut short_tag = [0u8; 8];
let result = ctx.seal_in_place(&[0u8; 12], &mut payload, &mut short_tag, b"");
assert!(result.is_err());
}
#[test]
fn open_rejects_invalid_nonce_length() {
let algorithm = Algorithm::chacha20_poly1305();
let ctx = AeadCtx::new_default_tag(&algorithm, &[0u8; 32]).unwrap();
let mut payload = [0u8; 8];
let tag = [0u8; 16];
let result = ctx.open_in_place(&[0u8; 11], &mut payload, &tag, b"");
assert!(result.is_err());
}
}