use crate::algs;
use crate::cose_struct;
use crate::errors::{CoseError, CoseField, CoseResult, CoseResultWithRet};
use crate::headers::{CoseHeader, COUNTER_SIG};
use crate::keys;
use cbor::{Decoder, Encoder};
use std::io::Cursor;
#[derive(Clone)]
pub struct CoseAgent {
pub header: CoseHeader,
pub payload: Vec<u8>,
pub(crate) ph_bstr: Vec<u8>,
pub pub_key: Vec<u8>,
pub s_key: Vec<u8>,
pub(crate) context: String,
pub(crate) crv: Option<i32>,
pub(crate) key_ops: Vec<i32>,
pub(crate) base_iv: Option<Vec<u8>>,
pub(crate) enc: bool,
}
const KEY_OPS_SKEY: [i32; 8] = [
keys::KEY_OPS_DERIVE_BITS,
keys::KEY_OPS_DERIVE,
keys::KEY_OPS_DECRYPT,
keys::KEY_OPS_ENCRYPT,
keys::KEY_OPS_WRAP,
keys::KEY_OPS_UNWRAP,
keys::KEY_OPS_MAC_VERIFY,
keys::KEY_OPS_MAC,
];
const SIZE: usize = 3;
impl CoseAgent {
pub fn new() -> CoseAgent {
CoseAgent {
header: CoseHeader::new(),
payload: Vec::new(),
ph_bstr: Vec::new(),
pub_key: Vec::new(),
key_ops: Vec::new(),
s_key: Vec::new(),
crv: None,
base_iv: None,
context: String::new(),
enc: false,
}
}
pub fn new_counter_sig() -> CoseAgent {
CoseAgent {
header: CoseHeader::new(),
payload: Vec::new(),
ph_bstr: Vec::new(),
pub_key: Vec::new(),
key_ops: Vec::new(),
s_key: Vec::new(),
crv: None,
base_iv: None,
context: cose_struct::COUNTER_SIGNATURE.to_string(),
enc: false,
}
}
pub fn add_header(&mut self, header: CoseHeader) {
self.header = header;
}
pub fn key(&mut self, key: &keys::CoseKey) -> CoseResult {
let alg = self.header.alg.ok_or(CoseError::Missing(CoseField::Alg))?;
key.verify_kty()?;
if algs::ECDH_ALGS.contains(&alg) {
if !keys::ECDH_KTY.contains(key.kty.as_ref().ok_or(CoseError::Missing(CoseField::Kty))?)
{
return Err(CoseError::Invalid(CoseField::Kty));
}
if key.alg.is_some() && key.alg.unwrap() != alg {
return Err(CoseError::AlgMismatch());
}
} else if (alg != algs::DIRECT
&& !algs::A_KW.contains(&alg)
&& !algs::RSA_OAEP.contains(&alg))
&& key.alg.is_some()
&& key.alg.unwrap() != alg
{
return Err(CoseError::AlgMismatch());
}
if algs::SIGNING_ALGS.contains(&alg) {
if key.key_ops.contains(&keys::KEY_OPS_SIGN) {
self.s_key = key.get_s_key()?;
}
if key.key_ops.contains(&keys::KEY_OPS_VERIFY) {
self.pub_key = key.get_pub_key()?;
}
if key.key_ops.is_empty() {
self.s_key = match key.get_s_key() {
Ok(v) => v,
Err(_) => Vec::new(),
};
self.pub_key = match key.get_pub_key() {
Ok(v) => v,
Err(_) => Vec::new(),
};
}
} else if algs::KEY_DISTRIBUTION_ALGS.contains(&alg) || algs::ENCRYPT_ALGS.contains(&alg) {
if KEY_OPS_SKEY.iter().any(|i| key.key_ops.contains(i)) {
self.s_key = key.get_s_key()?;
}
if key.key_ops.is_empty() {
self.s_key = match key.get_s_key() {
Ok(v) => v,
Err(_) => Vec::new(),
};
}
if (algs::ECDH_ALGS.contains(&alg) || algs::OAEP_ALGS.contains(&alg))
&& key.key_ops.is_empty()
{
self.pub_key = key.get_pub_key()?;
}
}
self.crv = key.crv;
self.base_iv = key.base_iv.clone();
self.key_ops = key.key_ops.clone();
Ok(())
}
pub(crate) fn sign(
&mut self,
content: &Vec<u8>,
external_aad: &Vec<u8>,
body_protected: &Vec<u8>,
) -> CoseResult {
if !self.key_ops.is_empty() && !self.key_ops.contains(&keys::KEY_OPS_SIGN) {
return Err(CoseError::Invalid(CoseField::KeyOp));
}
self.ph_bstr = self.header.get_protected_bstr(false)?;
self.payload = cose_struct::gen_sig(
&self.s_key,
&self.header.alg.ok_or(CoseError::Missing(CoseField::Alg))?,
&self.crv,
&external_aad,
&self.context,
&body_protected,
&self.ph_bstr,
&content,
)?;
Ok(())
}
pub(crate) fn verify(
&self,
content: &Vec<u8>,
external_aad: &Vec<u8>,
body_protected: &Vec<u8>,
) -> CoseResultWithRet<bool> {
if !self.key_ops.is_empty() && !self.key_ops.contains(&keys::KEY_OPS_VERIFY) {
return Err(CoseError::Invalid(CoseField::KeyOp));
}
Ok(cose_struct::verify_sig(
&self.pub_key,
&self.header.alg.ok_or(CoseError::Missing(CoseField::Alg))?,
&self.crv,
&external_aad,
&self.context,
&body_protected,
&self.ph_bstr,
&content,
&self.payload,
)?)
}
pub fn add_signature(&mut self, signature: Vec<u8>) -> CoseResult {
if self.context != cose_struct::COUNTER_SIGNATURE {
return Err(CoseError::InvalidContext(self.context.clone()));
}
self.payload = signature;
Ok(())
}
pub(crate) fn get_sign_content(
&mut self,
content: &Vec<u8>,
external_aad: &Vec<u8>,
body_protected: &Vec<u8>,
) -> CoseResultWithRet<Vec<u8>> {
if self.context != cose_struct::COUNTER_SIGNATURE {
return Err(CoseError::InvalidContext(self.context.clone()));
}
self.ph_bstr = self.header.get_protected_bstr(false)?;
cose_struct::get_to_sign(
&external_aad,
cose_struct::COUNTER_SIGNATURE,
&body_protected,
&self.ph_bstr,
&content,
)
}
pub fn counter_sig(
&self,
external_aad: Option<Vec<u8>>,
counter: &mut CoseAgent,
) -> CoseResult {
if !self.enc {
Err(CoseError::Missing(CoseField::Payload))
} else {
let aead = match external_aad {
None => Vec::new(),
Some(v) => v,
};
counter.sign(&self.payload, &aead, &self.ph_bstr)?;
Ok(())
}
}
pub fn get_to_sign(
&self,
external_aad: Option<Vec<u8>>,
counter: &mut CoseAgent,
) -> CoseResultWithRet<Vec<u8>> {
if !self.enc {
Err(CoseError::Missing(CoseField::Payload))
} else {
let aead = match external_aad {
None => Vec::new(),
Some(v) => v,
};
counter.get_sign_content(&self.payload, &aead, &self.ph_bstr)
}
}
pub fn get_to_verify(
&mut self,
external_aad: Option<Vec<u8>>,
counter: &usize,
) -> CoseResultWithRet<Vec<u8>> {
if !self.enc {
Err(CoseError::Missing(CoseField::Payload))
} else {
let aead = match external_aad {
None => Vec::new(),
Some(v) => v,
};
self.header.counters[*counter].get_sign_content(&self.payload, &aead, &self.ph_bstr)
}
}
pub fn counters_verify(&mut self, external_aad: Option<Vec<u8>>, counter: usize) -> CoseResult {
if !self.enc {
Err(CoseError::Missing(CoseField::Payload))
} else {
let aead = match external_aad {
None => Vec::new(),
Some(v) => v,
};
if self.header.counters[counter].verify(&self.payload, &aead, &self.ph_bstr)? {
Ok(())
} else {
Err(CoseError::Invalid(CoseField::CounterSignature))
}
}
}
pub fn add_counter_sig(&mut self, counter: CoseAgent) -> CoseResult {
if !algs::SIGNING_ALGS.contains(
&counter
.header
.alg
.ok_or(CoseError::Missing(CoseField::Alg))?,
) {
return Err(CoseError::Invalid(CoseField::Alg));
}
if counter.context != cose_struct::COUNTER_SIGNATURE {
return Err(CoseError::InvalidContext(counter.context));
}
if self.header.unprotected.contains(&COUNTER_SIG) {
self.header.counters.push(counter);
Ok(())
} else {
self.header.counters.push(counter);
self.header.remove_label(COUNTER_SIG);
self.header.unprotected.push(COUNTER_SIG);
Ok(())
}
}
pub(crate) fn derive_key(
&mut self,
cek: &Vec<u8>,
size: usize,
sender: bool,
true_alg: &i32,
) -> CoseResultWithRet<Vec<u8>> {
if self.ph_bstr.is_empty() {
self.ph_bstr = self.header.get_protected_bstr(false)?;
}
let alg = self
.header
.alg
.as_ref()
.ok_or(CoseError::Missing(CoseField::Alg))?;
if algs::A_KW.contains(alg) {
if sender {
self.payload = algs::aes_key_wrap(&self.s_key, size, &cek)?;
} else {
return Ok(algs::aes_key_unwrap(&self.s_key, size, &cek)?);
}
return Ok(cek.to_vec());
} else if algs::RSA_OAEP.contains(alg) {
if sender {
self.payload = algs::rsa_oaep_enc(&self.pub_key, &cek, alg)?;
} else {
return Ok(algs::rsa_oaep_dec(&self.s_key, size, &cek, alg)?);
}
return Ok(cek.to_vec());
} else if algs::D_HA.contains(alg) || algs::D_HS.contains(alg) {
let mut kdf_context = cose_struct::gen_kdf(
true_alg,
&self.header.party_u_identity,
&self.header.party_u_nonce,
&self.header.party_u_other,
&self.header.party_v_identity,
&self.header.party_v_nonce,
&self.header.party_v_other,
size as u16 * 8,
&self.ph_bstr,
&self.header.pub_other,
&self.header.priv_info,
)?;
return Ok(algs::hkdf(
size,
&self.s_key,
self.header.salt.as_ref(),
&mut kdf_context,
self.header.alg.unwrap(),
)?);
} else if algs::ECDH_H.contains(alg) || algs::ECDH_A.contains(alg) {
let (receiver_key, sender_key, crv_rec, crv_send);
if sender {
if self.pub_key.is_empty() {
return Err(CoseError::MissingKey());
}
receiver_key = self.pub_key.clone();
if !self.header.x5_private.is_empty() {
sender_key = self.header.x5_private.clone();
crv_send = None;
} else {
sender_key = self.header.ecdh_key.get_s_key()?;
crv_send = Some(self.header.ecdh_key.crv.unwrap());
}
crv_rec = Some(self.crv.unwrap());
} else {
if self.s_key.is_empty() {
return Err(CoseError::MissingKey());
}
if self.header.x5chain_sender.is_some() {
algs::verify_chain(self.header.x5chain_sender.as_ref().unwrap())?;
receiver_key = self.header.x5chain_sender.as_ref().unwrap()[0].clone();
crv_rec = None;
} else {
receiver_key = self.header.ecdh_key.get_pub_key()?;
crv_rec = Some(self.crv.unwrap());
}
sender_key = self.s_key.clone();
crv_send = Some(self.crv.unwrap());
}
let shared = algs::ecdh_derive_key(crv_rec, crv_send, &receiver_key, &sender_key)?;
if algs::ECDH_H.contains(alg) {
let mut kdf_context = cose_struct::gen_kdf(
true_alg,
&self.header.party_u_identity,
&self.header.party_u_nonce,
&self.header.party_u_other,
&self.header.party_v_identity,
&self.header.party_v_nonce,
&self.header.party_v_other,
size as u16 * 8,
&self.ph_bstr,
&self.header.pub_other,
&self.header.priv_info,
)?;
return Ok(algs::hkdf(
size,
&shared,
self.header.salt.as_ref(),
&mut kdf_context,
self.header.alg.unwrap(),
)?);
} else {
let size_akw = algs::get_cek_size(&alg)?;
let alg_akw;
if [algs::ECDH_ES_A128KW, algs::ECDH_SS_A128KW].contains(alg) {
alg_akw = algs::A128KW;
} else if [algs::ECDH_ES_A192KW, algs::ECDH_SS_A192KW].contains(alg) {
alg_akw = algs::A192KW;
} else {
alg_akw = algs::A256KW;
}
let mut kdf_context = cose_struct::gen_kdf(
&alg_akw,
&self.header.party_u_identity,
&self.header.party_u_nonce,
&self.header.party_u_other,
&self.header.party_v_identity,
&self.header.party_v_nonce,
&self.header.party_v_other,
size_akw as u16 * 8,
&self.ph_bstr,
&self.header.pub_other,
&self.header.priv_info,
)?;
let kek = algs::hkdf(
size_akw,
&shared,
self.header.salt.as_ref(),
&mut kdf_context,
self.header.alg.unwrap(),
)?;
if sender {
self.payload = algs::aes_key_wrap(&kek, size, &cek)?;
} else {
return Ok(algs::aes_key_unwrap(&kek, size, &cek)?);
}
return Ok(cek.to_vec());
}
} else {
return Err(CoseError::Invalid(CoseField::Alg));
}
}
pub(crate) fn decode(&mut self, d: &mut Decoder<Cursor<Vec<u8>>>) -> CoseResult {
if !self.ph_bstr.is_empty() {
self.header.decode_protected_bstr(&self.ph_bstr)?;
}
self.header
.decode_unprotected(d, self.context == cose_struct::COUNTER_SIGNATURE)?;
self.payload = d.bytes()?;
self.header.labels_found = Vec::new();
Ok(())
}
pub(crate) fn encode(&mut self, e: &mut Encoder<Vec<u8>>) -> CoseResult {
e.array(SIZE)?;
e.bytes(&self.ph_bstr)?;
self.header.encode_unprotected(e)?;
e.bytes(&self.payload)?;
self.header.labels_found = Vec::new();
Ok(())
}
}