use core::convert::identity;
use core::fmt::{Debug, Display};
use ciborium::value::Value;
use coset::iana::Algorithm;
use coset::{iana, CoseKey, CoseKeyBuilder, Header, Label, ProtectedHeader};
use rand::{CryptoRng, Error, RngCore};
#[cfg(not(feature = "std"))]
use {
alloc::string::{String, ToString},
alloc::vec,
alloc::vec::Vec,
};
use crate::common::cbor_map::ToCborMap;
use crate::error::{AccessTokenError, CoseCipherError, MultipleCoseError};
use crate::token::{CoseCipher, MultipleEncryptCipher, MultipleSignCipher};
use crate::{CoseEncryptCipher, CoseMacCipher, CoseSignCipher};
fn get_symmetric_key_value(key: &CoseKey) -> Vec<u8> {
const K_LABEL: i64 = iana::SymmetricKeyParameter::K as i64;
key.params
.iter()
.find(|x| matches!(x.0, Label::Int(K_LABEL)))
.and_then(|x| match x {
(_, Value::Bytes(x)) => Some(x),
_ => None,
})
.expect("Key value must be present!")
.clone()
}
pub(crate) fn expect_ser_de<T>(
value: T,
transform_value: Option<fn(T) -> T>,
expected_hex: &str,
) -> Result<(), String>
where
T: ToCborMap + Clone + Debug + PartialEq,
{
let copy = value.clone();
let mut result = Vec::new();
value
.serialize_into(&mut result)
.map_err(|x| x.to_string())?;
#[cfg(feature = "std")]
println!("Result: {:?}, Original: {:?}", hex::encode(&result), ©);
assert_eq!(
&result,
&hex::decode(expected_hex).map_err(|x| x.to_string())?
);
let decoded = T::deserialize_from(result.as_slice()).map_err(|x| x.to_string());
if let Ok(decoded_value) = decoded {
let decoded_value = transform_value.unwrap_or(identity)(decoded_value);
assert_eq!(copy, decoded_value);
Ok(())
} else if let Err(e) = decoded {
Err(e)
} else {
Err("Invalid value: Not a CBOR map!".to_string())
}
}
#[derive(Copy, Clone)]
pub(crate) struct FakeCrypto {}
impl CoseCipher for FakeCrypto {
type Error = String;
fn set_headers<RNG: RngCore + CryptoRng>(
key: &CoseKey,
unprotected_header: &mut Header,
protected_header: &mut Header,
_rng: RNG,
) -> Result<(), CoseCipherError<Self::Error>> {
if let Some(label) = unprotected_header
.rest
.iter()
.find(|x| x.0 == Label::Int(47))
{
return Err(CoseCipherError::existing_header_label(&label.0));
}
if protected_header.alg != None {
return Err(CoseCipherError::existing_header("alg"));
}
if !protected_header.key_id.is_empty() {
return Err(CoseCipherError::existing_header("kid"));
}
unprotected_header.rest.push((Label::Int(47), Value::Null));
protected_header.alg = Some(coset::Algorithm::Assigned(Algorithm::Direct));
protected_header.key_id = key.key_id.clone();
Ok(())
}
}
impl CoseEncryptCipher for FakeCrypto {
fn encrypt(
key: &CoseKey,
plaintext: &[u8],
aad: &[u8],
_protected_header: &Header,
_unprotected_header: &Header,
) -> Vec<u8> {
let mut result: Vec<u8> = get_symmetric_key_value(key);
result.append(&mut aad.to_vec());
result.append(&mut plaintext.to_vec());
result
}
fn decrypt(
key: &CoseKey,
ciphertext: &[u8],
aad: &[u8],
_unprotected_header: &Header,
protected_header: &ProtectedHeader,
) -> Result<Vec<u8>, CoseCipherError<Self::Error>> {
if &key.key_id != &protected_header.header.key_id {
return Err(CoseCipherError::DecryptionFailure);
}
let key_value = get_symmetric_key_value(key);
if ciphertext.len() < (aad.len() + key_value.len()) {
return Err(CoseCipherError::Other(
"Encrypted data has invalid length!".to_string(),
));
}
let mut result: Vec<u8> = ciphertext.to_vec();
let plaintext = result.split_off(aad.len() + key_value.len());
let aad_result = result.split_off(key_value.len());
if aad == aad_result && key_value == result.as_slice() {
Ok(plaintext)
} else {
Err(CoseCipherError::DecryptionFailure)
}
}
}
impl CoseSignCipher for FakeCrypto {
fn sign(
key: &CoseKey,
target: &[u8],
_unprotected_header: &Header,
_protected_header: &Header,
) -> Vec<u8> {
let mut signature = target.to_vec();
signature.append(&mut get_symmetric_key_value(key));
signature
}
fn verify(
key: &CoseKey,
signature: &[u8],
signed_data: &[u8],
unprotected_header: &Header,
protected_header: &ProtectedHeader,
_unprotected_signature_header: Option<&Header>,
protected_signature_header: Option<&ProtectedHeader>,
) -> Result<(), CoseCipherError<Self::Error>> {
let matching_kid = if let Some(protected) = protected_signature_header {
protected.header.key_id == key.key_id
} else {
protected_header.header.key_id == key.key_id
};
let signed_again = Self::sign(
key,
signed_data,
unprotected_header,
&protected_header.header,
);
if matching_kid && signed_again == signature {
Ok(())
} else {
Err(CoseCipherError::VerificationFailure)
}
}
}
impl CoseMacCipher for FakeCrypto {
fn compute(
key: &CoseKey,
target: &[u8],
_unprotected_header: &Header,
_protected_header: &Header,
) -> Vec<u8> {
let mut tag = target.to_vec();
tag.append(&mut get_symmetric_key_value(key));
tag
}
fn verify(
key: &CoseKey,
tag: &[u8],
maced_data: &[u8],
unprotected_header: &Header,
protected_header: &ProtectedHeader,
) -> Result<(), CoseCipherError<Self::Error>> {
if protected_header.header.key_id == key.key_id
&& tag
== Self::compute(
key,
maced_data,
unprotected_header,
&protected_header.header,
)
{
Ok(())
} else {
Err(CoseCipherError::VerificationFailure)
}
}
}
impl MultipleEncryptCipher for FakeCrypto {
fn generate_cek<RNG: RngCore + CryptoRng>(rng: &mut RNG) -> CoseKey {
let mut key = [0; 5];
let mut kid = [0; 2];
rng.fill_bytes(&mut key);
rng.fill_bytes(&mut kid);
CoseKeyBuilder::new_symmetric_key(key.to_vec())
.key_id(kid.to_vec())
.build()
}
}
impl MultipleSignCipher for FakeCrypto {}
#[derive(Clone, Copy)]
pub(crate) struct FakeRng;
impl RngCore for FakeRng {
fn next_u32(&mut self) -> u32 {
0
}
fn next_u64(&mut self) -> u64 {
0
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
dest.fill(0);
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
dest.fill(0);
Ok(())
}
}
impl CryptoRng for FakeRng {}
impl<C, K> From<CoseCipherError<MultipleCoseError<C, K>>> for CoseCipherError<String>
where
C: Display,
K: Display,
{
fn from(x: CoseCipherError<MultipleCoseError<C, K>>) -> Self {
match x {
CoseCipherError::HeaderAlreadySet {
existing_header_name,
} => CoseCipherError::HeaderAlreadySet {
existing_header_name,
},
CoseCipherError::VerificationFailure => CoseCipherError::VerificationFailure,
CoseCipherError::DecryptionFailure => CoseCipherError::DecryptionFailure,
CoseCipherError::Other(x) => CoseCipherError::Other(x.to_string()),
}
}
}
impl<C, K> From<AccessTokenError<MultipleCoseError<C, K>>> for AccessTokenError<String>
where
C: Display,
K: Display,
{
fn from(x: AccessTokenError<MultipleCoseError<C, K>>) -> Self {
match x {
AccessTokenError::CoseError(x) => AccessTokenError::CoseError(x),
AccessTokenError::CoseCipherError(x) => {
AccessTokenError::CoseCipherError(CoseCipherError::from(x))
}
AccessTokenError::UnknownCoseStructure => AccessTokenError::UnknownCoseStructure,
AccessTokenError::NoMatchingRecipient => AccessTokenError::NoMatchingRecipient,
AccessTokenError::MultipleMatchingRecipients => {
AccessTokenError::MultipleMatchingRecipients
}
}
}
}