#[cfg(not(feature = "std"))]
use crate::prelude::*;
use crate::error::Unspecified;
use crate::fips::indicator_check;
use crate::ptr::{ConstPointer, LcPtr};
use crate::wolfcrypt_rs::{
CMAC_CTX_new, CMAC_Final, CMAC_Init, CMAC_Update, EVP_aes_128_cbc, EVP_aes_192_cbc,
EVP_aes_256_cbc, CMAC_CTX, EVP_CIPHER,
};
use crate::{constant_time, rand};
use core::mem::MaybeUninit;
use core::ptr::null_mut;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum AlgorithmId {
Aes128,
Aes192,
Aes256,
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct Algorithm {
id: AlgorithmId,
key_len: usize,
tag_len: usize,
}
impl Algorithm {
#[inline]
#[must_use]
pub fn key_len(&self) -> usize {
self.key_len
}
#[inline]
#[must_use]
pub fn tag_len(&self) -> usize {
self.tag_len
}
}
impl AlgorithmId {
fn evp_cipher(&self) -> ConstPointer<'_, EVP_CIPHER> {
unsafe {
ConstPointer::new_static(match self {
AlgorithmId::Aes128 => EVP_aes_128_cbc(),
AlgorithmId::Aes192 => EVP_aes_192_cbc(),
AlgorithmId::Aes256 => EVP_aes_256_cbc(),
})
.unwrap()
}
}
}
pub const AES_128: Algorithm = Algorithm {
id: AlgorithmId::Aes128,
key_len: 16,
tag_len: 16,
};
pub const AES_192: Algorithm = Algorithm {
id: AlgorithmId::Aes192,
key_len: 24,
tag_len: 16,
};
pub const AES_256: Algorithm = Algorithm {
id: AlgorithmId::Aes256,
key_len: 32,
tag_len: 16,
};
const MAX_CMAC_TAG_LEN: usize = 16;
#[derive(Clone, Copy, Debug)]
pub struct Tag {
bytes: [u8; MAX_CMAC_TAG_LEN],
len: usize,
}
impl AsRef<[u8]> for Tag {
#[inline]
fn as_ref(&self) -> &[u8] {
&self.bytes[..self.len]
}
}
pub struct Key {
algorithm: Algorithm,
ctx: LcPtr<CMAC_CTX>,
key_bytes: Vec<u8>,
}
impl Clone for Key {
fn clone(&self) -> Self {
Key::new(self.algorithm, &self.key_bytes)
.expect("CMAC Key clone failed: re-initialization should succeed")
}
}
unsafe impl Send for Key {}
unsafe impl Sync for Key {}
#[allow(clippy::missing_fields_in_debug)]
impl core::fmt::Debug for Key {
fn fmt(&self, f: &mut core::fmt::Formatter) -> Result<(), core::fmt::Error> {
f.debug_struct("Key")
.field("algorithm", &self.algorithm)
.finish()
}
}
impl Key {
pub fn generate(algorithm: Algorithm) -> Result<Self, Unspecified> {
let mut key_bytes = vec![0u8; algorithm.key_len()];
rand::fill(&mut key_bytes)?;
Self::new(algorithm, &key_bytes)
}
pub fn new(algorithm: Algorithm, key_value: &[u8]) -> Result<Self, Unspecified> {
if key_value.len() != algorithm.key_len() {
return Err(Unspecified);
}
let mut ctx = LcPtr::new(unsafe { CMAC_CTX_new() })?;
unsafe {
let cipher = algorithm.id.evp_cipher();
if 1 != CMAC_Init(
ctx.as_mut_ptr(),
key_value.as_ptr().cast(),
key_value.len(),
cipher.as_const_ptr(),
null_mut(),
) {
return Err(Unspecified);
}
}
Ok(Self {
algorithm,
ctx,
key_bytes: key_value.to_vec(),
})
}
#[inline]
#[must_use]
pub fn algorithm(&self) -> Algorithm {
self.algorithm
}
}
pub struct Context {
key: Key,
accumulated_data: Vec<u8>,
}
impl Clone for Context {
fn clone(&self) -> Self {
let mut cloned = Self {
key: self.key.clone(),
accumulated_data: self.accumulated_data.clone(),
};
if !cloned.accumulated_data.is_empty() {
unsafe {
CMAC_Update(
cloned.key.ctx.as_mut_ptr(),
cloned.accumulated_data.as_ptr(),
cloned.accumulated_data.len(),
);
}
}
cloned
}
}
unsafe impl Send for Context {}
impl core::fmt::Debug for Context {
fn fmt(&self, f: &mut core::fmt::Formatter) -> Result<(), core::fmt::Error> {
f.debug_struct("Context")
.field("algorithm", &self.key.algorithm)
.finish()
}
}
impl Context {
#[inline]
#[must_use]
pub fn with_key(key: &Key) -> Self {
Self {
key: key.clone(),
accumulated_data: Vec::new(),
}
}
pub fn update(&mut self, data: &[u8]) -> Result<(), Unspecified> {
unsafe {
if 1 != CMAC_Update(self.key.ctx.as_mut_ptr(), data.as_ptr(), data.len()) {
return Err(Unspecified);
}
}
self.accumulated_data.extend_from_slice(data);
Ok(())
}
pub fn sign(mut self) -> Result<Tag, Unspecified> {
let mut output = [0u8; MAX_CMAC_TAG_LEN];
let output_len = {
let result = internal_sign(&mut self, &mut output)?;
result.len()
};
Ok(Tag {
bytes: output,
len: output_len,
})
}
#[inline]
pub fn verify(mut self, tag: &[u8]) -> Result<(), Unspecified> {
let mut output = [0u8; MAX_CMAC_TAG_LEN];
let output_len = {
let result = internal_sign(&mut self, &mut output)?;
result.len()
};
constant_time::verify_slices_are_equal(&output[0..output_len], tag)
}
}
pub(crate) fn internal_sign<'in_out>(
ctx: &mut Context,
output: &'in_out mut [u8],
) -> Result<&'in_out mut [u8], Unspecified> {
let mut out_len = MaybeUninit::<usize>::uninit();
if 1 != indicator_check!(unsafe {
CMAC_Final(
ctx.key.ctx.as_mut_ptr(),
output.as_mut_ptr(),
out_len.as_mut_ptr(),
)
}) {
return Err(Unspecified);
}
let actual_len = unsafe { out_len.assume_init() };
debug_assert!(
actual_len <= MAX_CMAC_TAG_LEN,
"CMAC tag length {actual_len} exceeds maximum {MAX_CMAC_TAG_LEN}"
);
if actual_len != ctx.key.algorithm.tag_len() {
return Err(Unspecified);
}
Ok(&mut output[0..actual_len])
}
#[inline]
pub fn sign(key: &Key, data: &[u8]) -> Result<Tag, Unspecified> {
let mut ctx = Context::with_key(key);
ctx.update(data)?;
ctx.sign()
}
#[inline]
pub fn sign_to_buffer<'out>(
key: &Key,
data: &[u8],
output: &'out mut [u8],
) -> Result<&'out mut [u8], Unspecified> {
if output.len() < key.algorithm().tag_len() {
return Err(Unspecified);
}
let mut ctx = Context::with_key(key);
ctx.update(data)?;
internal_sign(&mut ctx, output)
}
#[inline]
pub fn verify(key: &Key, data: &[u8], tag: &[u8]) -> Result<(), Unspecified> {
let mut output = [0u8; MAX_CMAC_TAG_LEN];
let output_len = {
let result = sign_to_buffer(key, data, &mut output)?;
result.len()
};
constant_time::verify_slices_are_equal(&output[0..output_len], tag)
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "fips")]
mod fips;
#[test]
fn cmac_basic_test() {
for &algorithm in &[AES_128, AES_192, AES_256] {
let key = Key::generate(algorithm).unwrap();
let data = b"hello, world";
let tag = sign(&key, data).unwrap();
assert!(verify(&key, data, tag.as_ref()).is_ok());
assert!(verify(&key, b"hello, worle", tag.as_ref()).is_err());
}
}
#[test]
pub fn cmac_signing_key_coverage() {
const HELLO_WORLD_GOOD: &[u8] = b"hello, world";
const HELLO_WORLD_BAD: &[u8] = b"hello, worle";
for algorithm in &[AES_128, AES_192, AES_256] {
let key = Key::generate(*algorithm).unwrap();
let tag = sign(&key, HELLO_WORLD_GOOD).unwrap();
println!("{key:?}");
assert!(verify(&key, HELLO_WORLD_GOOD, tag.as_ref()).is_ok());
assert!(verify(&key, HELLO_WORLD_BAD, tag.as_ref()).is_err());
}
}
#[test]
fn cmac_coverage() {
assert_ne!(AES_128, AES_256);
assert_ne!(AES_192, AES_256);
for &alg in &[AES_128, AES_192, AES_256] {
let key_bytes = vec![0u8; alg.key_len()];
let key = Key::new(alg, &key_bytes).unwrap();
let mut ctx = Context::with_key(&key);
ctx.update(b"hello, world").unwrap();
let ctx_clone = ctx.clone();
let orig_tag = ctx.sign().unwrap();
let clone_tag = ctx_clone.sign().unwrap();
assert_eq!(orig_tag.as_ref(), clone_tag.as_ref());
assert_eq!(orig_tag.clone().as_ref(), clone_tag.as_ref());
}
}
#[test]
fn cmac_context_test() {
let key = Key::generate(AES_192).unwrap();
let mut ctx = Context::with_key(&key);
ctx.update(b"hello").unwrap();
ctx.update(b", ").unwrap();
ctx.update(b"world").unwrap();
let tag1 = ctx.sign().unwrap();
let tag2 = sign(&key, b"hello, world").unwrap();
assert_eq!(tag1.as_ref(), tag2.as_ref());
}
#[test]
fn cmac_multi_part_test() {
let parts = ["hello", ", ", "world"];
for &algorithm in &[AES_128, AES_256] {
let key = Key::generate(algorithm).unwrap();
let mut ctx = Context::with_key(&key);
for part in &parts {
ctx.update(part.as_bytes()).unwrap();
}
let tag = ctx.sign().unwrap();
let mut msg = Vec::<u8>::new();
for part in &parts {
msg.extend(part.as_bytes());
}
assert!(verify(&key, &msg, tag.as_ref()).is_ok());
}
}
#[test]
fn cmac_key_new_test() {
let key_128 = [0u8; 16];
let key_192 = [0u8; 24];
let key_256 = [0u8; 32];
let k1 = Key::new(AES_128, &key_128).unwrap();
let k2 = Key::new(AES_192, &key_192).unwrap();
let k3 = Key::new(AES_256, &key_256).unwrap();
let data = b"test message";
let _ = sign(&k1, data).unwrap();
let _ = sign(&k2, data).unwrap();
let _ = sign(&k3, data).unwrap();
}
#[test]
fn cmac_key_new_wrong_length_test() {
let key_256 = [0u8; 32];
assert!(Key::new(AES_128, &key_256).is_err());
}
#[test]
fn cmac_algorithm_properties() {
assert_eq!(AES_128.key_len(), 16);
assert_eq!(AES_128.tag_len(), 16);
assert_eq!(AES_192.key_len(), 24);
assert_eq!(AES_192.tag_len(), 16);
assert_eq!(AES_256.key_len(), 32);
assert_eq!(AES_256.tag_len(), 16);
}
#[test]
fn cmac_empty_data() {
let key = Key::generate(AES_128).unwrap();
let tag = sign(&key, b"").unwrap();
assert!(verify(&key, b"", tag.as_ref()).is_ok());
let ctx = Context::with_key(&key);
let tag2 = ctx.sign().unwrap();
assert_eq!(tag.as_ref(), tag2.as_ref());
}
#[test]
fn cmac_sign_to_buffer_test() {
for &algorithm in &[AES_128, AES_192, AES_256] {
let key = Key::generate(algorithm).unwrap();
let data = b"hello, world";
let mut output = vec![0u8; algorithm.tag_len()];
let result = sign_to_buffer(&key, data, &mut output).unwrap();
assert_eq!(result.len(), algorithm.tag_len());
let tag = sign(&key, data).unwrap();
assert_eq!(result, tag.as_ref());
let mut large_output = vec![0u8; algorithm.tag_len() + 10];
let result2 = sign_to_buffer(&key, data, &mut large_output).unwrap();
assert_eq!(result2.len(), algorithm.tag_len());
assert_eq!(result2, tag.as_ref());
}
}
#[test]
fn cmac_sign_to_buffer_too_small_test() {
let key = Key::generate(AES_128).unwrap();
let data = b"hello";
let mut small_buffer = vec![0u8; AES_128.tag_len() - 1];
assert!(sign_to_buffer(&key, data, &mut small_buffer).is_err());
let mut empty_buffer = vec![];
assert!(sign_to_buffer(&key, data, &mut empty_buffer).is_err());
}
#[test]
fn cmac_context_verify_test() {
for &algorithm in &[AES_128, AES_192, AES_256] {
let key = Key::generate(algorithm).unwrap();
let data = b"hello, world";
let tag = sign(&key, data).unwrap();
let mut ctx = Context::with_key(&key);
ctx.update(data).unwrap();
assert!(ctx.verify(tag.as_ref()).is_ok());
let mut ctx2 = Context::with_key(&key);
ctx2.update(data).unwrap();
let wrong_tag = vec![0u8; algorithm.tag_len()];
assert!(ctx2.verify(&wrong_tag).is_err());
let mut ctx3 = Context::with_key(&key);
ctx3.update(b"wrong data").unwrap();
assert!(ctx3.verify(tag.as_ref()).is_err());
}
}
#[test]
fn cmac_context_verify_multipart_test() {
let key = Key::generate(AES_256).unwrap();
let parts = ["hello", ", ", "world"];
let mut full_msg = Vec::new();
for part in &parts {
full_msg.extend_from_slice(part.as_bytes());
}
let tag = sign(&key, &full_msg).unwrap();
let mut ctx = Context::with_key(&key);
for part in &parts {
ctx.update(part.as_bytes()).unwrap();
}
assert!(ctx.verify(tag.as_ref()).is_ok());
let mut ctx2 = Context::with_key(&key);
ctx2.update(parts[0].as_bytes()).unwrap();
ctx2.update(parts[1].as_bytes()).unwrap();
assert!(ctx2.verify(tag.as_ref()).is_err());
}
#[test]
fn cmac_rfc4493_known_answer() {
let key_bytes: [u8; 16] = [
0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf,
0x4f, 0x3c,
];
let key = Key::new(AES_128, &key_bytes).unwrap();
let tag = sign(&key, b"").unwrap();
assert_eq!(
tag.as_ref(),
&[
0xbb, 0x1d, 0x69, 0x29, 0xe9, 0x59, 0x37, 0x28, 0x7f, 0xa3, 0x7d, 0x12, 0x9b, 0x75,
0x67, 0x46
]
);
let msg2: [u8; 16] = [
0x6b, 0xc1, 0xbe, 0xe2, 0x2e, 0x40, 0x9f, 0x96, 0xe9, 0x3d, 0x7e, 0x11, 0x73, 0x93,
0x17, 0x2a,
];
let tag2 = sign(&key, &msg2).unwrap();
assert_eq!(
tag2.as_ref(),
&[
0x07, 0x0a, 0x16, 0xb4, 0x6b, 0x4d, 0x41, 0x44, 0xf7, 0x9b, 0xdd, 0x9d, 0xd0, 0x4a,
0x28, 0x7c
]
);
let msg3: [u8; 40] = [
0x6b, 0xc1, 0xbe, 0xe2, 0x2e, 0x40, 0x9f, 0x96, 0xe9, 0x3d, 0x7e, 0x11, 0x73, 0x93,
0x17, 0x2a, 0xae, 0x2d, 0x8a, 0x57, 0x1e, 0x03, 0xac, 0x9c, 0x9e, 0xb7, 0x6f, 0xac,
0x45, 0xaf, 0x8e, 0x51, 0x30, 0xc8, 0x1c, 0x46, 0xa3, 0x5c, 0xe4, 0x11,
];
let tag3 = sign(&key, &msg3).unwrap();
assert_eq!(
tag3.as_ref(),
&[
0xdf, 0xa6, 0x67, 0x47, 0xde, 0x9a, 0xe6, 0x30, 0x30, 0xca, 0x32, 0x61, 0x14, 0x97,
0xc8, 0x27
]
);
let msg4: [u8; 64] = [
0x6b, 0xc1, 0xbe, 0xe2, 0x2e, 0x40, 0x9f, 0x96, 0xe9, 0x3d, 0x7e, 0x11, 0x73, 0x93,
0x17, 0x2a, 0xae, 0x2d, 0x8a, 0x57, 0x1e, 0x03, 0xac, 0x9c, 0x9e, 0xb7, 0x6f, 0xac,
0x45, 0xaf, 0x8e, 0x51, 0x30, 0xc8, 0x1c, 0x46, 0xa3, 0x5c, 0xe4, 0x11, 0xe5, 0xfb,
0xc1, 0x19, 0x1a, 0x0a, 0x52, 0xef, 0xf6, 0x9f, 0x24, 0x45, 0xdf, 0x4f, 0x9b, 0x17,
0xad, 0x2b, 0x41, 0x7b, 0xe6, 0x6c, 0x37, 0x10,
];
let tag4 = sign(&key, &msg4).unwrap();
assert_eq!(
tag4.as_ref(),
&[
0x51, 0xf0, 0xbe, 0xbf, 0x7e, 0x3b, 0x9d, 0x92, 0xfc, 0x49, 0x74, 0x17, 0x79, 0x36,
0x3c, 0xfe
]
);
}
#[test]
fn cmac_verify_rejects_corrupted_tag() {
let key_bytes: [u8; 16] = [
0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf,
0x4f, 0x3c,
];
let key = Key::new(AES_128, &key_bytes).unwrap();
let msg = b"hello, world";
let tag = sign(&key, msg).unwrap();
let mut bad_tag = tag.as_ref().to_vec();
bad_tag[0] ^= 0x01;
assert!(verify(&key, msg, &bad_tag).is_err());
assert!(verify(&key, msg, &tag.as_ref()[..15]).is_err());
assert!(verify(&key, msg, &[]).is_err());
}
}