use crate::Result;
use core::marker::PhantomData;
use super::aead::Aead;
use super::domain::DS_WRAP;
#[derive(Debug, Clone, Copy)]
pub struct WrapBinding<'a> {
pub credential_id: &'a [u8],
pub version: u16,
}
impl<'a> WrapBinding<'a> {
pub fn to_canonical_ad(&self) -> Vec<u8> {
let mut ad = Vec::with_capacity(DS_WRAP.len() + self.credential_id.len() + 2);
ad.extend_from_slice(DS_WRAP);
ad.extend_from_slice(self.credential_id);
ad.extend_from_slice(&self.version.to_be_bytes());
ad
}
}
pub trait KeyWrap {
fn wrap(w: &[u8], key_material: &[u8], binding: &WrapBinding<'_>) -> Result<Vec<u8>>;
fn unwrap(w: &[u8], wrapped: &[u8], binding: &WrapBinding<'_>) -> Result<Vec<u8>>;
}
pub struct AeadWrap<A: Aead>(PhantomData<A>);
impl<A: Aead> KeyWrap for AeadWrap<A> {
fn wrap(w: &[u8], key_material: &[u8], binding: &WrapBinding<'_>) -> Result<Vec<u8>> {
A::seal(w, key_material, &binding.to_canonical_ad())
}
fn unwrap(w: &[u8], wrapped: &[u8], binding: &WrapBinding<'_>) -> Result<Vec<u8>> {
A::open(w, wrapped, &binding.to_canonical_ad())
}
}
#[cfg(all(test, feature = "std-primitives"))]
mod tests {
use super::*;
use crate::primitives::aead::ChaCha20Poly1305;
fn binding(cid: &[u8], ver: u16) -> WrapBinding<'_> {
WrapBinding {
credential_id: cid,
version: ver,
}
}
#[test]
fn wrap_unwrap_roundtrip() {
let w = [0x33u8; 32];
let k = [0x77u8; 32];
let b = binding(b"cred-1", 1);
let wrapped = AeadWrap::<ChaCha20Poly1305>::wrap(&w, &k, &b).unwrap();
let unwrapped = AeadWrap::<ChaCha20Poly1305>::unwrap(&w, &wrapped, &b).unwrap();
assert_eq!(unwrapped.as_slice(), &k);
}
#[test]
fn wrong_w_fails() {
let w1 = [0x01u8; 32];
let w2 = [0x02u8; 32];
let k = [0x77u8; 32];
let b = binding(b"cred-1", 1);
let wrapped = AeadWrap::<ChaCha20Poly1305>::wrap(&w1, &k, &b).unwrap();
assert!(AeadWrap::<ChaCha20Poly1305>::unwrap(&w2, &wrapped, &b).is_err());
}
#[test]
fn wrong_cid_fails_unwrap() {
let w = [0x33u8; 32];
let k = [0x77u8; 32];
let b1 = binding(b"cred-A", 1);
let b2 = binding(b"cred-B", 1);
let wrapped = AeadWrap::<ChaCha20Poly1305>::wrap(&w, &k, &b1).unwrap();
assert!(AeadWrap::<ChaCha20Poly1305>::unwrap(&w, &wrapped, &b2).is_err());
}
#[test]
fn wrong_ver_fails_unwrap() {
let w = [0x33u8; 32];
let k = [0x77u8; 32];
let b1 = binding(b"cred-A", 1);
let b2 = binding(b"cred-A", 2);
let wrapped = AeadWrap::<ChaCha20Poly1305>::wrap(&w, &k, &b1).unwrap();
assert!(AeadWrap::<ChaCha20Poly1305>::unwrap(&w, &wrapped, &b2).is_err());
}
}