use alloc::format;
use alloc::vec::Vec;
pub const BUILTIN_CRYPTO_PLUGIN: &str = "DDS:Crypto:AES_GCM_GMAC";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum CryptoTransformKind {
None = 0,
Aes128Gmac = 1,
Aes128Gcm = 2,
Aes256Gmac = 3,
Aes256Gcm = 4,
}
impl CryptoTransformKind {
#[must_use]
pub const fn to_be_bytes(self) -> [u8; 4] {
(self as u32).to_be_bytes()
}
pub fn from_be_bytes(bytes: [u8; 4]) -> Result<Self, &'static str> {
match u32::from_be_bytes(bytes) {
0 => Ok(Self::None),
1 => Ok(Self::Aes128Gmac),
2 => Ok(Self::Aes128Gcm),
3 => Ok(Self::Aes256Gmac),
4 => Ok(Self::Aes256Gcm),
_ => Err("unknown CryptoTransformKind"),
}
}
#[must_use]
pub const fn encrypts(self) -> bool {
matches!(self, Self::Aes128Gcm | Self::Aes256Gcm)
}
#[must_use]
pub const fn tag_size(self) -> usize {
match self {
Self::None => 0,
_ => 16,
}
}
#[must_use]
pub const fn key_size(self) -> usize {
match self {
Self::None => 0,
Self::Aes128Gmac | Self::Aes128Gcm => 16,
Self::Aes256Gmac | Self::Aes256Gcm => 32,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CryptoTransformIdentifier {
pub kind: CryptoTransformKind,
pub key_id: [u8; 4],
}
impl CryptoTransformIdentifier {
#[must_use]
pub fn new(kind: CryptoTransformKind, key_id: [u8; 4]) -> Self {
Self { kind, key_id }
}
#[must_use]
pub fn to_bytes(&self) -> [u8; 8] {
let mut out = [0u8; 8];
out[0..4].copy_from_slice(&self.kind.to_be_bytes());
out[4..8].copy_from_slice(&self.key_id);
out
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, &'static str> {
if bytes.len() != 8 {
return Err("CryptoTransformIdentifier needs 8 bytes");
}
let mut k = [0u8; 4];
k.copy_from_slice(&bytes[0..4]);
let kind = CryptoTransformKind::from_be_bytes(k)?;
let mut key_id = [0u8; 4];
key_id.copy_from_slice(&bytes[4..8]);
Ok(Self { kind, key_id })
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CryptoHeader {
pub transformation_id: CryptoTransformIdentifier,
pub session_id: [u8; 4],
pub init_vector_suffix: [u8; 8],
}
impl CryptoHeader {
pub const WIRE_SIZE: usize = 20;
#[must_use]
pub fn to_bytes(&self) -> [u8; Self::WIRE_SIZE] {
let mut out = [0u8; Self::WIRE_SIZE];
out[0..8].copy_from_slice(&self.transformation_id.to_bytes());
out[8..12].copy_from_slice(&self.session_id);
out[12..20].copy_from_slice(&self.init_vector_suffix);
out
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, &'static str> {
if bytes.len() < Self::WIRE_SIZE {
return Err("CryptoHeader needs 20 bytes");
}
let transformation_id = CryptoTransformIdentifier::from_bytes(&bytes[0..8])?;
let mut session_id = [0u8; 4];
session_id.copy_from_slice(&bytes[8..12]);
let mut iv_suffix = [0u8; 8];
iv_suffix.copy_from_slice(&bytes[12..20]);
Ok(Self {
transformation_id,
session_id,
init_vector_suffix: iv_suffix,
})
}
#[must_use]
pub fn full_iv(&self) -> [u8; 12] {
let mut iv = [0u8; 12];
iv[0..4].copy_from_slice(&self.session_id);
iv[4..12].copy_from_slice(&self.init_vector_suffix);
iv
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CryptoFooter {
pub common_mac: [u8; 16],
pub receiver_specific_macs: Vec<([u8; 4], [u8; 16])>,
}
impl CryptoFooter {
#[must_use]
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(20 + self.receiver_specific_macs.len() * 20);
out.extend_from_slice(&self.common_mac);
let n = self.receiver_specific_macs.len() as u32;
out.extend_from_slice(&n.to_be_bytes());
for (key_id, mac) in &self.receiver_specific_macs {
out.extend_from_slice(key_id);
out.extend_from_slice(mac);
}
out
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, &'static str> {
if bytes.len() < 20 {
return Err("CryptoFooter needs >= 20 bytes");
}
let mut common_mac = [0u8; 16];
common_mac.copy_from_slice(&bytes[0..16]);
let n_buf: [u8; 4] = bytes[16..20].try_into().map_err(|_| "footer count")?;
let n = u32::from_be_bytes(n_buf) as usize;
let mut pos = 20;
let mut receiver_specific_macs = Vec::with_capacity(n);
for _ in 0..n {
if bytes.len() < pos + 20 {
return Err("receiver-specific MAC truncated");
}
let mut key_id = [0u8; 4];
key_id.copy_from_slice(&bytes[pos..pos + 4]);
let mut mac = [0u8; 16];
mac.copy_from_slice(&bytes[pos + 4..pos + 20]);
receiver_specific_macs.push((key_id, mac));
pos += 20;
}
Ok(Self {
common_mac,
receiver_specific_macs,
})
}
}
pub fn negotiate_transform(
remote_kinds: &[CryptoTransformKind],
) -> Result<CryptoTransformKind, alloc::string::String> {
let pref = [
CryptoTransformKind::Aes256Gcm,
CryptoTransformKind::Aes128Gcm,
CryptoTransformKind::Aes256Gmac,
CryptoTransformKind::Aes128Gmac,
];
for p in pref {
if remote_kinds.contains(&p) {
return Ok(p);
}
}
Err(format!(
"no common crypto transform with peer (peer-offered: {remote_kinds:?})"
))
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn kind_round_trip_all_variants() {
for k in [
CryptoTransformKind::None,
CryptoTransformKind::Aes128Gmac,
CryptoTransformKind::Aes128Gcm,
CryptoTransformKind::Aes256Gmac,
CryptoTransformKind::Aes256Gcm,
] {
assert_eq!(
CryptoTransformKind::from_be_bytes(k.to_be_bytes()).unwrap(),
k
);
}
}
#[test]
fn unknown_kind_rejected() {
assert!(CryptoTransformKind::from_be_bytes([0, 0, 0, 99]).is_err());
}
#[test]
fn key_sizes_match_spec() {
assert_eq!(CryptoTransformKind::Aes128Gcm.key_size(), 16);
assert_eq!(CryptoTransformKind::Aes256Gcm.key_size(), 32);
assert_eq!(CryptoTransformKind::None.key_size(), 0);
}
#[test]
fn tag_size_is_16_for_all_aead_variants() {
for k in [
CryptoTransformKind::Aes128Gmac,
CryptoTransformKind::Aes128Gcm,
CryptoTransformKind::Aes256Gmac,
CryptoTransformKind::Aes256Gcm,
] {
assert_eq!(k.tag_size(), 16);
}
}
#[test]
fn encrypts_only_for_gcm() {
assert!(CryptoTransformKind::Aes128Gcm.encrypts());
assert!(CryptoTransformKind::Aes256Gcm.encrypts());
assert!(!CryptoTransformKind::Aes128Gmac.encrypts());
assert!(!CryptoTransformKind::None.encrypts());
}
#[test]
fn transform_identifier_round_trip() {
let id = CryptoTransformIdentifier::new(
CryptoTransformKind::Aes256Gcm,
[0xCA, 0xFE, 0xBA, 0xBE],
);
let bytes = id.to_bytes();
assert_eq!(bytes.len(), 8);
let back = CryptoTransformIdentifier::from_bytes(&bytes).unwrap();
assert_eq!(back, id);
}
#[test]
fn header_round_trip_with_full_iv() {
let h = CryptoHeader {
transformation_id: CryptoTransformIdentifier::new(
CryptoTransformKind::Aes256Gcm,
[1, 2, 3, 4],
),
session_id: [10, 20, 30, 40],
init_vector_suffix: [50, 60, 70, 80, 90, 100, 110, 120],
};
let bytes = h.to_bytes();
assert_eq!(bytes.len(), CryptoHeader::WIRE_SIZE);
let back = CryptoHeader::from_bytes(&bytes).unwrap();
assert_eq!(back, h);
assert_eq!(
back.full_iv(),
[10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]
);
}
#[test]
fn header_short_buffer_rejected() {
assert!(CryptoHeader::from_bytes(&[0; 10]).is_err());
}
#[test]
fn footer_round_trip_no_receivers() {
let f = CryptoFooter {
common_mac: [0xAA; 16],
receiver_specific_macs: alloc::vec![],
};
let bytes = f.to_bytes();
let back = CryptoFooter::from_bytes(&bytes).unwrap();
assert_eq!(back, f);
}
#[test]
fn footer_round_trip_with_receivers() {
let f = CryptoFooter {
common_mac: [0xAA; 16],
receiver_specific_macs: alloc::vec![
([1, 2, 3, 4], [0xBB; 16]),
([5, 6, 7, 8], [0xCC; 16]),
],
};
let bytes = f.to_bytes();
let back = CryptoFooter::from_bytes(&bytes).unwrap();
assert_eq!(back, f);
}
#[test]
fn footer_short_buffer_rejected() {
assert!(CryptoFooter::from_bytes(&[0; 10]).is_err());
}
#[test]
fn negotiate_picks_strongest_common() {
let r = negotiate_transform(&[
CryptoTransformKind::Aes128Gmac,
CryptoTransformKind::Aes256Gcm,
CryptoTransformKind::Aes128Gcm,
])
.unwrap();
assert_eq!(r, CryptoTransformKind::Aes256Gcm);
}
#[test]
fn negotiate_falls_back_to_gmac_if_no_gcm() {
let r = negotiate_transform(&[CryptoTransformKind::Aes128Gmac]).unwrap();
assert_eq!(r, CryptoTransformKind::Aes128Gmac);
}
#[test]
fn negotiate_fails_with_no_overlap() {
assert!(negotiate_transform(&[CryptoTransformKind::None]).is_err());
assert!(negotiate_transform(&[]).is_err());
}
#[test]
fn builtin_plugin_id_matches_spec() {
assert_eq!(BUILTIN_CRYPTO_PLUGIN, "DDS:Crypto:AES_GCM_GMAC");
}
}