use crate::algs;
use crate::common;
use crate::errors::{CoseError, CoseResult, CoseResultWithRet};
use cbor::{decoder::DecodeError, types::Type, Config, Decoder, Encoder};
use openssl::bn::BigNum;
use openssl::rsa::Rsa;
use std::io::Cursor;
use std::str::from_utf8;
pub(crate) const ECDH_KTY: [i32; 2] = [OKP, EC2];
pub const D: i32 = -4;
pub const Y: i32 = -3;
pub const X: i32 = -2;
pub const CRV_K: i32 = -1;
pub const KTY: i32 = 1;
pub const KID: i32 = 2;
pub const ALG: i32 = 3;
pub const KEY_OPS: i32 = 4;
pub const BASE_IV: i32 = 5;
pub const N: i32 = -1;
pub const E: i32 = -2;
pub const RSA_D: i32 = -3;
pub const P: i32 = -4;
pub const Q: i32 = -5;
pub const DP: i32 = -6;
pub const DQ: i32 = -7;
pub const QINV: i32 = -8;
pub const OTHER: i32 = -9;
pub const RI: i32 = -10;
pub const DI: i32 = -11;
pub const TI: i32 = -12;
pub const OKP: i32 = 1;
pub const EC2: i32 = 2;
pub const RSA: i32 = 3;
pub const SYMMETRIC: i32 = 4;
pub const RESERVED: i32 = 0;
pub(crate) const KTY_ALL: [i32; 5] = [RESERVED, OKP, EC2, RSA, SYMMETRIC];
pub(crate) const KTY_NAMES: [&str; 5] = ["Reserved", "OKP", "EC2", "RSA", "Symmetric"];
pub const KEY_OPS_SIGN: i32 = 1;
pub const KEY_OPS_VERIFY: i32 = 2;
pub const KEY_OPS_ENCRYPT: i32 = 3;
pub const KEY_OPS_DECRYPT: i32 = 4;
pub const KEY_OPS_WRAP: i32 = 5;
pub const KEY_OPS_UNWRAP: i32 = 6;
pub const KEY_OPS_DERIVE: i32 = 7;
pub const KEY_OPS_DERIVE_BITS: i32 = 8;
pub const KEY_OPS_MAC: i32 = 9;
pub const KEY_OPS_MAC_VERIFY: i32 = 10;
pub(crate) const KEY_OPS_ALL: [i32; 10] = [
KEY_OPS_SIGN,
KEY_OPS_VERIFY,
KEY_OPS_ENCRYPT,
KEY_OPS_DECRYPT,
KEY_OPS_WRAP,
KEY_OPS_UNWRAP,
KEY_OPS_DERIVE,
KEY_OPS_DERIVE_BITS,
KEY_OPS_MAC,
KEY_OPS_MAC_VERIFY,
];
pub(crate) const KEY_OPS_NAMES: [&str; 10] = [
"sign",
"verify",
"encrypt",
"decrypt",
"wrap key",
"unwrap key",
"derive key",
"derive bits",
"MAC create",
"MAC verify",
];
pub const P_256: i32 = 1;
pub const SECP256K1: i32 = 8;
pub const P_384: i32 = 2;
pub const P_521: i32 = 3;
pub const X25519: i32 = 4;
pub const X448: i32 = 5;
pub const ED25519: i32 = 6;
pub const ED448: i32 = 7;
pub(crate) const CURVES_ALL: [i32; 8] =
[P_256, P_384, P_521, X25519, X448, ED25519, ED448, SECP256K1];
pub(crate) const EC2_CRVS: [i32; 3] = [P_256, P_384, P_521];
pub(crate) const CURVES_NAMES: [&str; 8] = [
"P-256",
"P-384",
"P-521",
"X25519",
"X448",
"Ed25519",
"Ed448",
"secp256k1",
];
#[derive(Clone)]
pub struct CoseKey {
pub bytes: Vec<u8>,
used: Vec<i32>,
pub kty: Option<i32>,
pub base_iv: Option<Vec<u8>>,
pub key_ops: Vec<i32>,
pub alg: Option<i32>,
pub x: Option<Vec<u8>>,
pub y: Option<Vec<u8>>,
pub d: Option<Vec<u8>>,
pub k: Option<Vec<u8>>,
pub kid: Option<Vec<u8>>,
pub crv: Option<i32>,
pub n: Option<Vec<u8>>,
pub e: Option<Vec<u8>>,
pub rsa_d: Option<Vec<u8>>,
pub p: Option<Vec<u8>>,
pub q: Option<Vec<u8>>,
pub dp: Option<Vec<u8>>,
pub dq: Option<Vec<u8>>,
pub qinv: Option<Vec<u8>>,
pub other: Option<Vec<Vec<u8>>>,
pub ri: Option<Vec<u8>>,
pub di: Option<Vec<u8>>,
pub ti: Option<Vec<u8>>,
}
impl CoseKey {
pub fn new() -> CoseKey {
CoseKey {
bytes: Vec::new(),
used: Vec::new(),
key_ops: Vec::new(),
base_iv: None,
kty: None,
alg: None,
x: None,
y: None,
d: None,
k: None,
kid: None,
crv: None,
n: None,
e: None,
rsa_d: None,
p: None,
q: None,
dp: None,
dq: None,
qinv: None,
other: None,
ri: None,
di: None,
ti: None,
}
}
fn reg_label(&mut self, label: i32) {
self.used.retain(|&x| x != label);
self.used.push(label);
}
pub(crate) fn remove_label(&mut self, label: i32) {
self.used.retain(|&x| x != label);
}
pub fn kty(&mut self, kty: i32) {
self.reg_label(KTY);
self.kty = Some(kty);
}
pub fn unset_alg(&mut self) {
self.remove_label(ALG);
self.alg = None;
}
pub fn kid(&mut self, kid: Vec<u8>) {
self.reg_label(KID);
self.kid = Some(kid);
}
pub fn alg(&mut self, alg: i32) {
self.reg_label(ALG);
self.alg = Some(alg);
}
pub fn key_ops(&mut self, key_ops: Vec<i32>) {
self.reg_label(KEY_OPS);
self.key_ops = key_ops;
}
pub fn base_iv(&mut self, base_iv: Vec<u8>) {
self.reg_label(BASE_IV);
self.base_iv = Some(base_iv);
}
pub fn crv(&mut self, crv: i32) {
self.reg_label(CRV_K);
self.crv = Some(crv);
}
pub fn x(&mut self, x: Vec<u8>) {
self.reg_label(X);
self.x = Some(x);
}
pub fn y(&mut self, y: Vec<u8>) {
self.reg_label(Y);
self.y = Some(y);
}
pub fn d(&mut self, d: Vec<u8>) {
self.reg_label(D);
self.d = Some(d);
}
pub fn k(&mut self, k: Vec<u8>) {
self.reg_label(CRV_K);
self.k = Some(k);
}
pub fn n(&mut self, n: Vec<u8>) {
self.reg_label(N);
self.n = Some(n);
}
pub fn e(&mut self, e: Vec<u8>) {
self.reg_label(E);
self.e = Some(e);
}
pub fn rsa_d(&mut self, rsa_d: Vec<u8>) {
self.reg_label(RSA_D);
self.rsa_d = Some(rsa_d);
}
pub fn p(&mut self, p: Vec<u8>) {
self.reg_label(P);
self.p = Some(p);
}
pub fn q(&mut self, q: Vec<u8>) {
self.reg_label(Q);
self.q = Some(q);
}
pub fn dp(&mut self, dp: Vec<u8>) {
self.reg_label(DP);
self.dp = Some(dp);
}
pub fn dq(&mut self, dq: Vec<u8>) {
self.reg_label(DQ);
self.dq = Some(dq);
}
pub fn qinv(&mut self, qinv: Vec<u8>) {
self.reg_label(QINV);
self.qinv = Some(qinv);
}
pub fn other(&mut self, other: Vec<Vec<u8>>) {
self.reg_label(OTHER);
self.other = Some(other);
}
pub fn ri(&mut self, ri: Vec<u8>) {
self.reg_label(RI);
self.ri = Some(ri);
}
pub fn di(&mut self, di: Vec<u8>) {
self.reg_label(DI);
self.di = Some(di);
}
pub fn ti(&mut self, ti: Vec<u8>) {
self.reg_label(TI);
self.ti = Some(ti);
}
pub(crate) fn verify_curve(&self) -> CoseResult {
let kty = self.kty.ok_or(CoseError::MissingKTY())?;
if kty == SYMMETRIC || kty == RSA {
return Ok(());
}
let crv = self.crv.ok_or(CoseError::MissingCRV())?;
if kty == OKP && [ED25519, ED448, X25519, X448].contains(&crv) {
Ok(())
} else if kty == EC2 && EC2_CRVS.contains(&crv) {
Ok(())
} else if self.alg.ok_or(CoseError::MissingAlg())? == algs::ES256K && crv == SECP256K1 {
Ok(())
} else {
Err(CoseError::InvalidCRV())
}
}
pub(crate) fn verify_kty(&self) -> CoseResult {
if !KTY_ALL.contains(&self.kty.ok_or(CoseError::MissingKTY())?) {
return Err(CoseError::InvalidKTY());
}
self.verify_curve()?;
Ok(())
}
pub fn encode(&mut self) -> CoseResult {
let mut e = Encoder::new(Vec::new());
if self.alg != None {
self.verify_kty()?;
} else {
self.verify_curve()?;
}
self.encode_key(&mut e)?;
self.bytes = e.into_writer().to_vec();
Ok(())
}
pub(crate) fn encode_key(&self, e: &mut Encoder<Vec<u8>>) -> CoseResult {
let kty = self.kty.ok_or(CoseError::MissingKTY())?;
let key_ops_len = self.key_ops.len();
if key_ops_len > 0 {
if kty == EC2 || kty == OKP {
if self.key_ops.contains(&KEY_OPS_VERIFY)
|| self.key_ops.contains(&KEY_OPS_DERIVE)
|| self.key_ops.contains(&KEY_OPS_DERIVE_BITS)
{
if self.x == None {
return Err(CoseError::MissingX());
} else if self.crv == None {
return Err(CoseError::MissingCRV());
}
}
if self.key_ops.contains(&KEY_OPS_SIGN) {
if self.d == None {
return Err(CoseError::MissingD());
} else if self.crv == None {
return Err(CoseError::MissingCRV());
}
}
} else if kty == SYMMETRIC {
if self.key_ops.contains(&KEY_OPS_ENCRYPT)
|| self.key_ops.contains(&KEY_OPS_MAC_VERIFY)
|| self.key_ops.contains(&KEY_OPS_MAC)
|| self.key_ops.contains(&KEY_OPS_DECRYPT)
|| self.key_ops.contains(&KEY_OPS_UNWRAP)
|| self.key_ops.contains(&KEY_OPS_WRAP)
{
if self.x != None {
return Err(CoseError::InvalidX());
} else if self.y != None {
return Err(CoseError::InvalidY());
} else if self.d != None {
return Err(CoseError::InvalidD());
}
if self.k == None {
return Err(CoseError::MissingK());
}
}
}
}
e.object(self.used.len())?;
for i in &self.used {
e.i32(*i)?;
if *i == KTY {
e.i32(kty)?;
} else if *i == KEY_OPS {
e.array(self.key_ops.len())?;
for x in &self.key_ops {
e.i32(*x)?;
}
} else if *i == CRV_K {
if self.crv != None {
e.i32(self.crv.ok_or(CoseError::MissingCRV())?)?;
} else {
e.bytes(&self.k.as_ref().ok_or(CoseError::MissingK())?)?;
}
} else if *i == KID {
e.bytes(&self.kid.as_ref().ok_or(CoseError::MissingKID())?)?;
} else if *i == ALG {
e.i32(self.alg.ok_or(CoseError::MissingAlg())?)?
} else if *i == BASE_IV {
e.bytes(&self.base_iv.as_ref().ok_or(CoseError::MissingBaseIV())?)?
} else if *i == X {
e.bytes(&self.x.as_ref().ok_or(CoseError::MissingX())?)?
} else if *i == Y {
e.bytes(&self.y.as_ref().ok_or(CoseError::MissingY())?)?
} else if *i == D {
e.bytes(&self.d.as_ref().ok_or(CoseError::MissingD())?)?
} else if *i == N {
e.bytes(&self.n.as_ref().ok_or(CoseError::MissingN())?)?
} else if *i == E {
e.bytes(&self.e.as_ref().ok_or(CoseError::MissingE())?)?
} else if *i == RSA_D {
e.bytes(&self.rsa_d.as_ref().ok_or(CoseError::MissingRsaD())?)?
} else if *i == P {
e.bytes(&self.p.as_ref().ok_or(CoseError::MissingP())?)?
} else if *i == Q {
e.bytes(&self.q.as_ref().ok_or(CoseError::MissingQ())?)?
} else if *i == DP {
e.bytes(&self.dp.as_ref().ok_or(CoseError::MissingDP())?)?
} else if *i == DQ {
e.bytes(&self.dq.as_ref().ok_or(CoseError::MissingDQ())?)?
} else if *i == QINV {
e.bytes(&self.qinv.as_ref().ok_or(CoseError::MissingQINV())?)?
} else if *i == OTHER {
let other = self.other.as_ref().ok_or(CoseError::MissingOther())?;
e.array(other.len())?;
for i in other {
e.bytes(i)?
}
} else if *i == RI {
e.bytes(&self.ri.as_ref().ok_or(CoseError::MissingRI())?)?
} else if *i == DI {
e.bytes(&self.di.as_ref().ok_or(CoseError::MissingDI())?)?
} else if *i == TI {
e.bytes(&self.ti.as_ref().ok_or(CoseError::MissingTI())?)?
} else {
return Err(CoseError::InvalidLabel(*i));
}
}
Ok(())
}
pub fn decode(&mut self) -> CoseResult {
let input = Cursor::new(self.bytes.clone());
let mut d = Decoder::new(Config::default(), input);
self.decode_key(&mut d)?;
if self.alg != None {
self.verify_kty()?;
} else {
self.verify_curve()?;
}
Ok(())
}
pub(crate) fn decode_key(&mut self, d: &mut Decoder<Cursor<Vec<u8>>>) -> CoseResult {
let mut label: i32;
let mut labels_found = Vec::new();
self.used = Vec::new();
for _ in 0..d.object()? {
label = d.i32()?;
if !labels_found.contains(&label) {
labels_found.push(label);
} else {
return Err(CoseError::DuplicateLabel(label));
}
if label == KTY {
let type_info = d.kernel().typeinfo()?;
if type_info.0 == Type::Text {
self.kty = Some(common::get_kty_id(
from_utf8(&d.kernel().raw_data(type_info.1, common::MAX_BYTES)?)
.unwrap()
.to_string(),
)?);
} else if common::CBOR_NUMBER_TYPES.contains(&type_info.0) {
self.kty = Some(d.kernel().i32(&type_info)?);
} else {
return Err(CoseError::InvalidCoseStructure());
}
self.used.push(label);
} else if label == ALG {
let type_info = d.kernel().typeinfo()?;
if type_info.0 == Type::Text {
self.alg = Some(common::get_alg_id(
from_utf8(&d.kernel().raw_data(type_info.1, common::MAX_BYTES)?)
.unwrap()
.to_string(),
)?);
} else if common::CBOR_NUMBER_TYPES.contains(&type_info.0) {
self.alg = Some(d.kernel().i32(&type_info)?);
} else {
return Err(CoseError::InvalidCoseStructure());
}
self.used.push(label);
} else if label == KID {
self.kid = Some(d.bytes()?);
self.used.push(label);
} else if label == KEY_OPS {
let mut key_ops = Vec::new();
for _i in 0..d.array()? {
let type_info = d.kernel().typeinfo()?;
if type_info.0 == Type::Text {
key_ops.push(common::get_key_op_id(
from_utf8(&d.kernel().raw_data(type_info.1, common::MAX_BYTES)?)
.unwrap()
.to_string(),
)?);
} else if common::CBOR_NUMBER_TYPES.contains(&type_info.0) {
key_ops.push(d.kernel().i32(&type_info)?);
} else {
return Err(CoseError::InvalidCoseStructure());
}
}
self.key_ops = key_ops;
self.used.push(label);
} else if label == BASE_IV {
self.base_iv = Some(d.bytes()?);
self.used.push(label);
} else if label == CRV_K {
let type_info = d.kernel().typeinfo()?;
if type_info.0 == Type::Text {
self.crv = Some(common::get_crv_id(
from_utf8(&d.kernel().raw_data(type_info.1, common::MAX_BYTES)?)
.unwrap()
.to_string(),
)?);
} else if common::CBOR_NUMBER_TYPES.contains(&type_info.0) {
self.crv = Some(d.kernel().i32(&type_info)?);
} else if type_info.0 == Type::Bytes {
self.k = Some(d.kernel().raw_data(type_info.1, common::MAX_BYTES)?);
} else {
return Err(CoseError::InvalidCoseStructure());
}
self.used.push(label);
} else if label == X {
self.x = Some(d.bytes()?);
self.used.push(label);
} else if label == Y {
self.y = match d.bytes() {
Ok(value) => {
self.used.push(label);
Some(value)
}
Err(ref err) => match err {
DecodeError::UnexpectedType { datatype, info: _ } => {
if *datatype == Type::Bool {
None
} else {
return Err(CoseError::InvalidCoseStructure());
}
}
_ => {
return Err(CoseError::InvalidCoseStructure());
}
},
};
} else if label == D {
self.d = Some(d.bytes()?);
self.used.push(label);
} else if label == N {
self.n = Some(d.bytes()?);
self.used.push(label);
} else if label == E {
self.e = Some(d.bytes()?);
self.used.push(label);
} else if label == RSA_D {
self.rsa_d = Some(d.bytes()?);
self.used.push(label);
} else if label == P {
self.p = Some(d.bytes()?);
self.used.push(label);
} else if label == Q {
self.q = Some(d.bytes()?);
self.used.push(label);
} else if label == DP {
self.dp = Some(d.bytes()?);
self.used.push(label);
} else if label == DQ {
self.dq = Some(d.bytes()?);
self.used.push(label);
} else if label == QINV {
self.qinv = Some(d.bytes()?);
self.used.push(label);
} else if label == OTHER {
let mut other = Vec::new();
for _ in 0..d.array()? {
other.push(d.bytes()?);
}
self.other = Some(other);
self.used.push(label);
} else if label == RI {
self.ri = Some(d.bytes()?);
self.used.push(label);
} else if label == DI {
self.di = Some(d.bytes()?);
self.used.push(label);
} else if label == TI {
self.ti = Some(d.bytes()?);
self.used.push(label);
} else {
return Err(CoseError::InvalidLabel(label));
}
}
Ok(())
}
pub(crate) fn get_s_key(&self) -> CoseResultWithRet<Vec<u8>> {
let kty = self.kty.ok_or(CoseError::MissingKTY())?;
if kty == EC2 || kty == OKP {
let d = self.d.as_ref().ok_or(CoseError::MissingD())?.to_vec();
if d.len() <= 0 {
return Err(CoseError::MissingD());
}
Ok(d)
} else if kty == RSA {
Ok(Rsa::from_private_components(
BigNum::from_slice(self.n.as_ref().ok_or(CoseError::MissingN())?)?,
BigNum::from_slice(self.e.as_ref().ok_or(CoseError::MissingE())?)?,
BigNum::from_slice(self.rsa_d.as_ref().ok_or(CoseError::MissingRsaD())?)?,
BigNum::from_slice(self.p.as_ref().ok_or(CoseError::MissingP())?)?,
BigNum::from_slice(self.q.as_ref().ok_or(CoseError::MissingQ())?)?,
BigNum::from_slice(self.dp.as_ref().ok_or(CoseError::MissingDP())?)?,
BigNum::from_slice(self.dq.as_ref().ok_or(CoseError::MissingDQ())?)?,
BigNum::from_slice(self.qinv.as_ref().ok_or(CoseError::MissingQINV())?)?,
)?
.private_key_to_der()?)
} else if kty == SYMMETRIC {
let k = self.k.as_ref().ok_or(CoseError::MissingK())?.to_vec();
if k.len() <= 0 {
return Err(CoseError::MissingK());
}
Ok(k)
} else {
Err(CoseError::InvalidKTY())
}
}
pub(crate) fn get_pub_key(&self) -> CoseResultWithRet<Vec<u8>> {
let kty = self.kty.ok_or(CoseError::MissingKTY())?;
if kty == EC2 || kty == OKP {
let mut x = self.x.as_ref().ok_or(CoseError::MissingX())?.to_vec();
if x.len() <= 0 {
return Err(CoseError::MissingX());
}
let mut pub_key;
if kty == EC2 {
if self.y != None && self.y.as_ref().unwrap().len() > 0 {
let mut y = self.y.as_ref().unwrap().to_vec();
pub_key = vec![4];
pub_key.append(&mut x);
pub_key.append(&mut y);
} else {
pub_key = vec![3];
pub_key.append(&mut x);
}
} else {
pub_key = x;
}
Ok(pub_key)
} else if kty == RSA {
Ok(Rsa::from_public_components(
BigNum::from_slice(self.n.as_ref().ok_or(CoseError::MissingN())?)?,
BigNum::from_slice(self.e.as_ref().ok_or(CoseError::MissingE())?)?,
)?
.public_key_to_der()?)
} else {
Err(CoseError::InvalidKTY())
}
}
}
pub struct CoseKeySet {
pub cose_keys: Vec<CoseKey>,
pub bytes: Vec<u8>,
}
impl CoseKeySet {
pub fn new() -> CoseKeySet {
CoseKeySet {
cose_keys: Vec::new(),
bytes: Vec::new(),
}
}
pub fn add_key(&mut self, key: CoseKey) {
self.cose_keys.push(key);
}
pub fn encode(&mut self) -> CoseResult {
let mut e = Encoder::new(Vec::new());
let len = self.cose_keys.len();
if len > 0 {
e.array(len)?;
for i in 0..len {
self.cose_keys[i].encode_key(&mut e)?;
}
self.bytes = e.into_writer().to_vec();
Ok(())
} else {
Err(CoseError::MissingKey())
}
}
pub fn decode(&mut self) -> CoseResult {
let input = Cursor::new(self.bytes.clone());
let mut d = Decoder::new(Config::default(), input);
let len = d.array()?;
if len > 0 {
for _ in 0..len {
let mut cose_key = CoseKey::new();
match cose_key.decode_key(&mut d) {
Ok(_v) => self.cose_keys.push(cose_key),
Err(_e) => (),
}
}
Ok(())
} else {
Err(CoseError::MissingKey())
}
}
pub fn get_key(&self, kid: &Vec<u8>) -> CoseResultWithRet<CoseKey> {
for i in 0..self.cose_keys.len() {
if self.cose_keys[i]
.kid
.as_ref()
.ok_or(CoseError::MissingKID())?
== kid
{
return Ok(self.cose_keys[i].clone());
}
}
Err(CoseError::MissingKey())
}
}