use crate::algs;
use crate::common;
use crate::errors::{CoseError, CoseResult, CoseResultWithRet};
use cbor::{decoder::DecodeError, types::Type, Config, Decoder, Encoder};
use openssl::bn::BigNum;
use openssl::rsa::Rsa;
use std::io::Cursor;
use std::str::from_utf8;
pub(crate) const ECDH_KTY: [i32; 2] = [OKP, EC2];
pub const D: i32 = -4;
pub const Y: i32 = -3;
pub const X: i32 = -2;
pub const CRV_K: i32 = -1;
pub const KTY: i32 = 1;
pub const KID: i32 = 2;
pub const ALG: i32 = 3;
pub const KEY_OPS: i32 = 4;
pub const BASE_IV: i32 = 5;
pub const N: i32 = -1;
pub const E: i32 = -2;
pub const RSA_D: i32 = -3;
pub const P: i32 = -4;
pub const Q: i32 = -5;
pub const DP: i32 = -6;
pub const DQ: i32 = -7;
pub const QINV: i32 = -8;
pub const OTHER: i32 = -9;
pub const RI: i32 = -10;
pub const DI: i32 = -11;
pub const TI: i32 = -12;
pub const OKP: i32 = 1;
pub const EC2: i32 = 2;
pub const RSA: i32 = 3;
pub const SYMMETRIC: i32 = 4;
pub const RESERVED: i32 = 0;
pub(crate) const KTY_ALL: [i32; 5] = [RESERVED, OKP, EC2, RSA, SYMMETRIC];
pub(crate) const KTY_NAMES: [&str; 5] = ["Reserved", "OKP", "EC2", "RSA", "Symmetric"];
pub const KEY_OPS_SIGN: i32 = 1;
pub const KEY_OPS_VERIFY: i32 = 2;
pub const KEY_OPS_ENCRYPT: i32 = 3;
pub const KEY_OPS_DECRYPT: i32 = 4;
pub const KEY_OPS_WRAP: i32 = 5;
pub const KEY_OPS_UNWRAP: i32 = 6;
pub const KEY_OPS_DERIVE: i32 = 7;
pub const KEY_OPS_DERIVE_BITS: i32 = 8;
pub const KEY_OPS_MAC: i32 = 9;
pub const KEY_OPS_MAC_VERIFY: i32 = 10;
pub(crate) const KEY_OPS_ALL: [i32; 10] = [
    KEY_OPS_SIGN,
    KEY_OPS_VERIFY,
    KEY_OPS_ENCRYPT,
    KEY_OPS_DECRYPT,
    KEY_OPS_WRAP,
    KEY_OPS_UNWRAP,
    KEY_OPS_DERIVE,
    KEY_OPS_DERIVE_BITS,
    KEY_OPS_MAC,
    KEY_OPS_MAC_VERIFY,
];
pub(crate) const KEY_OPS_NAMES: [&str; 10] = [
    "sign",
    "verify",
    "encrypt",
    "decrypt",
    "wrap key",
    "unwrap key",
    "derive key",
    "derive bits",
    "MAC create",
    "MAC verify",
];
pub const P_256: i32 = 1;
pub const SECP256K1: i32 = 8;
pub const P_384: i32 = 2;
pub const P_521: i32 = 3;
pub const X25519: i32 = 4;
pub const X448: i32 = 5;
pub const ED25519: i32 = 6;
pub const ED448: i32 = 7;
pub(crate) const CURVES_ALL: [i32; 8] =
    [P_256, P_384, P_521, X25519, X448, ED25519, ED448, SECP256K1];
pub(crate) const EC2_CRVS: [i32; 3] = [P_256, P_384, P_521];
pub(crate) const CURVES_NAMES: [&str; 8] = [
    "P-256",
    "P-384",
    "P-521",
    "X25519",
    "X448",
    "Ed25519",
    "Ed448",
    "secp256k1",
];
#[derive(Clone)]
pub struct CoseKey {
        pub bytes: Vec<u8>,
    used: Vec<i32>,
        pub kty: Option<i32>,
        pub base_iv: Option<Vec<u8>>,
        pub key_ops: Vec<i32>,
        pub alg: Option<i32>,
        pub x: Option<Vec<u8>>,
        pub y: Option<Vec<u8>>,
        pub d: Option<Vec<u8>>,
        pub k: Option<Vec<u8>>,
        pub kid: Option<Vec<u8>>,
        pub crv: Option<i32>,
    pub n: Option<Vec<u8>>,
    pub e: Option<Vec<u8>>,
    pub rsa_d: Option<Vec<u8>>,
    pub p: Option<Vec<u8>>,
    pub q: Option<Vec<u8>>,
    pub dp: Option<Vec<u8>>,
    pub dq: Option<Vec<u8>>,
    pub qinv: Option<Vec<u8>>,
    pub other: Option<Vec<Vec<u8>>>,
    pub ri: Option<Vec<u8>>,
    pub di: Option<Vec<u8>>,
    pub ti: Option<Vec<u8>>,
}
impl CoseKey {
        pub fn new() -> CoseKey {
        CoseKey {
            bytes: Vec::new(),
            used: Vec::new(),
            key_ops: Vec::new(),
            base_iv: None,
            kty: None,
            alg: None,
            x: None,
            y: None,
            d: None,
            k: None,
            kid: None,
            crv: None,
            n: None,
            e: None,
            rsa_d: None,
            p: None,
            q: None,
            dp: None,
            dq: None,
            qinv: None,
            other: None,
            ri: None,
            di: None,
            ti: None,
        }
    }
    fn reg_label(&mut self, label: i32) {
        self.used.retain(|&x| x != label);
        self.used.push(label);
    }
    pub(crate) fn remove_label(&mut self, label: i32) {
        self.used.retain(|&x| x != label);
    }
        pub fn kty(&mut self, kty: i32) {
        self.reg_label(KTY);
        self.kty = Some(kty);
    }
        pub fn unset_alg(&mut self) {
        self.remove_label(ALG);
        self.alg = None;
    }
        pub fn kid(&mut self, kid: Vec<u8>) {
        self.reg_label(KID);
        self.kid = Some(kid);
    }
        pub fn alg(&mut self, alg: i32) {
        self.reg_label(ALG);
        self.alg = Some(alg);
    }
        pub fn key_ops(&mut self, key_ops: Vec<i32>) {
        self.reg_label(KEY_OPS);
        self.key_ops = key_ops;
    }
        pub fn base_iv(&mut self, base_iv: Vec<u8>) {
        self.reg_label(BASE_IV);
        self.base_iv = Some(base_iv);
    }
        pub fn crv(&mut self, crv: i32) {
        self.reg_label(CRV_K);
        self.crv = Some(crv);
    }
        pub fn x(&mut self, x: Vec<u8>) {
        self.reg_label(X);
        self.x = Some(x);
    }
        pub fn y(&mut self, y: Vec<u8>) {
        self.reg_label(Y);
        self.y = Some(y);
    }
        pub fn d(&mut self, d: Vec<u8>) {
        self.reg_label(D);
        self.d = Some(d);
    }
        pub fn k(&mut self, k: Vec<u8>) {
        self.reg_label(CRV_K);
        self.k = Some(k);
    }
    pub fn n(&mut self, n: Vec<u8>) {
        self.reg_label(N);
        self.n = Some(n);
    }
    pub fn e(&mut self, e: Vec<u8>) {
        self.reg_label(E);
        self.e = Some(e);
    }
    pub fn rsa_d(&mut self, rsa_d: Vec<u8>) {
        self.reg_label(RSA_D);
        self.rsa_d = Some(rsa_d);
    }
    pub fn p(&mut self, p: Vec<u8>) {
        self.reg_label(P);
        self.p = Some(p);
    }
    pub fn q(&mut self, q: Vec<u8>) {
        self.reg_label(Q);
        self.q = Some(q);
    }
    pub fn dp(&mut self, dp: Vec<u8>) {
        self.reg_label(DP);
        self.dp = Some(dp);
    }
    pub fn dq(&mut self, dq: Vec<u8>) {
        self.reg_label(DQ);
        self.dq = Some(dq);
    }
    pub fn qinv(&mut self, qinv: Vec<u8>) {
        self.reg_label(QINV);
        self.qinv = Some(qinv);
    }
    pub fn other(&mut self, other: Vec<Vec<u8>>) {
        self.reg_label(OTHER);
        self.other = Some(other);
    }
    pub fn ri(&mut self, ri: Vec<u8>) {
        self.reg_label(RI);
        self.ri = Some(ri);
    }
    pub fn di(&mut self, di: Vec<u8>) {
        self.reg_label(DI);
        self.di = Some(di);
    }
    pub fn ti(&mut self, ti: Vec<u8>) {
        self.reg_label(TI);
        self.ti = Some(ti);
    }
    pub(crate) fn verify_curve(&self) -> CoseResult {
        let kty = self.kty.ok_or(CoseError::MissingKTY())?;
        if kty == SYMMETRIC || kty == RSA {
            return Ok(());
        }
        let crv = self.crv.ok_or(CoseError::MissingCRV())?;
        if kty == OKP && [ED25519, ED448, X25519, X448].contains(&crv) {
            Ok(())
        } else if kty == EC2 && EC2_CRVS.contains(&crv) {
            Ok(())
        } else if self.alg.ok_or(CoseError::MissingAlg())? == algs::ES256K && crv == SECP256K1 {
            Ok(())
        } else {
            Err(CoseError::InvalidCRV())
        }
    }
    pub(crate) fn verify_kty(&self) -> CoseResult {
        if !KTY_ALL.contains(&self.kty.ok_or(CoseError::MissingKTY())?) {
            return Err(CoseError::InvalidKTY());
        }
        self.verify_curve()?;
        Ok(())
    }
        pub fn encode(&mut self) -> CoseResult {
        let mut e = Encoder::new(Vec::new());
        if self.alg != None {
            self.verify_kty()?;
        } else {
            self.verify_curve()?;
        }
        self.encode_key(&mut e)?;
        self.bytes = e.into_writer().to_vec();
        Ok(())
    }
    pub(crate) fn encode_key(&self, e: &mut Encoder<Vec<u8>>) -> CoseResult {
        let kty = self.kty.ok_or(CoseError::MissingKTY())?;
        let key_ops_len = self.key_ops.len();
        if key_ops_len > 0 {
            if kty == EC2 || kty == OKP {
                if self.key_ops.contains(&KEY_OPS_VERIFY)
                    || self.key_ops.contains(&KEY_OPS_DERIVE)
                    || self.key_ops.contains(&KEY_OPS_DERIVE_BITS)
                {
                    if self.x == None {
                        return Err(CoseError::MissingX());
                    } else if self.crv == None {
                        return Err(CoseError::MissingCRV());
                    }
                }
                if self.key_ops.contains(&KEY_OPS_SIGN) {
                    if self.d == None {
                        return Err(CoseError::MissingD());
                    } else if self.crv == None {
                        return Err(CoseError::MissingCRV());
                    }
                }
            } else if kty == SYMMETRIC {
                if self.key_ops.contains(&KEY_OPS_ENCRYPT)
                    || self.key_ops.contains(&KEY_OPS_MAC_VERIFY)
                    || self.key_ops.contains(&KEY_OPS_MAC)
                    || self.key_ops.contains(&KEY_OPS_DECRYPT)
                    || self.key_ops.contains(&KEY_OPS_UNWRAP)
                    || self.key_ops.contains(&KEY_OPS_WRAP)
                {
                    if self.x != None {
                        return Err(CoseError::InvalidX());
                    } else if self.y != None {
                        return Err(CoseError::InvalidY());
                    } else if self.d != None {
                        return Err(CoseError::InvalidD());
                    }
                    if self.k == None {
                        return Err(CoseError::MissingK());
                    }
                }
            }
        }
        e.object(self.used.len())?;
        for i in &self.used {
            e.i32(*i)?;
            if *i == KTY {
                e.i32(kty)?;
            } else if *i == KEY_OPS {
                e.array(self.key_ops.len())?;
                for x in &self.key_ops {
                    e.i32(*x)?;
                }
            } else if *i == CRV_K {
                if self.crv != None {
                    e.i32(self.crv.ok_or(CoseError::MissingCRV())?)?;
                } else {
                    e.bytes(&self.k.as_ref().ok_or(CoseError::MissingK())?)?;
                }
            } else if *i == KID {
                e.bytes(&self.kid.as_ref().ok_or(CoseError::MissingKID())?)?;
            } else if *i == ALG {
                e.i32(self.alg.ok_or(CoseError::MissingAlg())?)?
            } else if *i == BASE_IV {
                e.bytes(&self.base_iv.as_ref().ok_or(CoseError::MissingBaseIV())?)?
            } else if *i == X {
                e.bytes(&self.x.as_ref().ok_or(CoseError::MissingX())?)?
            } else if *i == Y {
                e.bytes(&self.y.as_ref().ok_or(CoseError::MissingY())?)?
            } else if *i == D {
                e.bytes(&self.d.as_ref().ok_or(CoseError::MissingD())?)?
            } else if *i == N {
                e.bytes(&self.n.as_ref().ok_or(CoseError::MissingN())?)?
            } else if *i == E {
                e.bytes(&self.e.as_ref().ok_or(CoseError::MissingE())?)?
            } else if *i == RSA_D {
                e.bytes(&self.rsa_d.as_ref().ok_or(CoseError::MissingRsaD())?)?
            } else if *i == P {
                e.bytes(&self.p.as_ref().ok_or(CoseError::MissingP())?)?
            } else if *i == Q {
                e.bytes(&self.q.as_ref().ok_or(CoseError::MissingQ())?)?
            } else if *i == DP {
                e.bytes(&self.dp.as_ref().ok_or(CoseError::MissingDP())?)?
            } else if *i == DQ {
                e.bytes(&self.dq.as_ref().ok_or(CoseError::MissingDQ())?)?
            } else if *i == QINV {
                e.bytes(&self.qinv.as_ref().ok_or(CoseError::MissingQINV())?)?
            } else if *i == OTHER {
                let other = self.other.as_ref().ok_or(CoseError::MissingOther())?;
                e.array(other.len())?;
                for i in other {
                    e.bytes(i)?
                }
            } else if *i == RI {
                e.bytes(&self.ri.as_ref().ok_or(CoseError::MissingRI())?)?
            } else if *i == DI {
                e.bytes(&self.di.as_ref().ok_or(CoseError::MissingDI())?)?
            } else if *i == TI {
                e.bytes(&self.ti.as_ref().ok_or(CoseError::MissingTI())?)?
            } else {
                return Err(CoseError::InvalidLabel(*i));
            }
        }
        Ok(())
    }
        pub fn decode(&mut self) -> CoseResult {
        let input = Cursor::new(self.bytes.clone());
        let mut d = Decoder::new(Config::default(), input);
        self.decode_key(&mut d)?;
        if self.alg != None {
            self.verify_kty()?;
        } else {
            self.verify_curve()?;
        }
        Ok(())
    }
    pub(crate) fn decode_key(&mut self, d: &mut Decoder<Cursor<Vec<u8>>>) -> CoseResult {
        let mut label: i32;
        let mut labels_found = Vec::new();
        self.used = Vec::new();
        for _ in 0..d.object()? {
            label = d.i32()?;
            if !labels_found.contains(&label) {
                labels_found.push(label);
            } else {
                return Err(CoseError::DuplicateLabel(label));
            }
            if label == KTY {
                let type_info = d.kernel().typeinfo()?;
                if type_info.0 == Type::Text {
                    self.kty = Some(common::get_kty_id(
                        from_utf8(&d.kernel().raw_data(type_info.1, common::MAX_BYTES)?)
                            .unwrap()
                            .to_string(),
                    )?);
                } else if common::CBOR_NUMBER_TYPES.contains(&type_info.0) {
                    self.kty = Some(d.kernel().i32(&type_info)?);
                } else {
                    return Err(CoseError::InvalidCoseStructure());
                }
                self.used.push(label);
            } else if label == ALG {
                let type_info = d.kernel().typeinfo()?;
                if type_info.0 == Type::Text {
                    self.alg = Some(common::get_alg_id(
                        from_utf8(&d.kernel().raw_data(type_info.1, common::MAX_BYTES)?)
                            .unwrap()
                            .to_string(),
                    )?);
                } else if common::CBOR_NUMBER_TYPES.contains(&type_info.0) {
                    self.alg = Some(d.kernel().i32(&type_info)?);
                } else {
                    return Err(CoseError::InvalidCoseStructure());
                }
                self.used.push(label);
            } else if label == KID {
                self.kid = Some(d.bytes()?);
                self.used.push(label);
            } else if label == KEY_OPS {
                let mut key_ops = Vec::new();
                for _i in 0..d.array()? {
                    let type_info = d.kernel().typeinfo()?;
                    if type_info.0 == Type::Text {
                        key_ops.push(common::get_key_op_id(
                            from_utf8(&d.kernel().raw_data(type_info.1, common::MAX_BYTES)?)
                                .unwrap()
                                .to_string(),
                        )?);
                    } else if common::CBOR_NUMBER_TYPES.contains(&type_info.0) {
                        key_ops.push(d.kernel().i32(&type_info)?);
                    } else {
                        return Err(CoseError::InvalidCoseStructure());
                    }
                }
                self.key_ops = key_ops;
                self.used.push(label);
            } else if label == BASE_IV {
                self.base_iv = Some(d.bytes()?);
                self.used.push(label);
            } else if label == CRV_K {
                let type_info = d.kernel().typeinfo()?;
                if type_info.0 == Type::Text {
                    self.crv = Some(common::get_crv_id(
                        from_utf8(&d.kernel().raw_data(type_info.1, common::MAX_BYTES)?)
                            .unwrap()
                            .to_string(),
                    )?);
                } else if common::CBOR_NUMBER_TYPES.contains(&type_info.0) {
                    self.crv = Some(d.kernel().i32(&type_info)?);
                } else if type_info.0 == Type::Bytes {
                    self.k = Some(d.kernel().raw_data(type_info.1, common::MAX_BYTES)?);
                } else {
                    return Err(CoseError::InvalidCoseStructure());
                }
                self.used.push(label);
            } else if label == X {
                self.x = Some(d.bytes()?);
                self.used.push(label);
            } else if label == Y {
                self.y = match d.bytes() {
                    Ok(value) => {
                        self.used.push(label);
                        Some(value)
                    }
                    Err(ref err) => match err {
                        DecodeError::UnexpectedType { datatype, info: _ } => {
                            if *datatype == Type::Bool {
                                None
                            } else {
                                return Err(CoseError::InvalidCoseStructure());
                            }
                        }
                        _ => {
                            return Err(CoseError::InvalidCoseStructure());
                        }
                    },
                };
            } else if label == D {
                self.d = Some(d.bytes()?);
                self.used.push(label);
            } else if label == N {
                self.n = Some(d.bytes()?);
                self.used.push(label);
            } else if label == E {
                self.e = Some(d.bytes()?);
                self.used.push(label);
            } else if label == RSA_D {
                self.rsa_d = Some(d.bytes()?);
                self.used.push(label);
            } else if label == P {
                self.p = Some(d.bytes()?);
                self.used.push(label);
            } else if label == Q {
                self.q = Some(d.bytes()?);
                self.used.push(label);
            } else if label == DP {
                self.dp = Some(d.bytes()?);
                self.used.push(label);
            } else if label == DQ {
                self.dq = Some(d.bytes()?);
                self.used.push(label);
            } else if label == QINV {
                self.qinv = Some(d.bytes()?);
                self.used.push(label);
            } else if label == OTHER {
                let mut other = Vec::new();
                for _ in 0..d.array()? {
                    other.push(d.bytes()?);
                }
                self.other = Some(other);
                self.used.push(label);
            } else if label == RI {
                self.ri = Some(d.bytes()?);
                self.used.push(label);
            } else if label == DI {
                self.di = Some(d.bytes()?);
                self.used.push(label);
            } else if label == TI {
                self.ti = Some(d.bytes()?);
                self.used.push(label);
            } else {
                return Err(CoseError::InvalidLabel(label));
            }
        }
        Ok(())
    }
    pub(crate) fn get_s_key(&self) -> CoseResultWithRet<Vec<u8>> {
        let kty = self.kty.ok_or(CoseError::MissingKTY())?;
        if kty == EC2 || kty == OKP {
            let d = self.d.as_ref().ok_or(CoseError::MissingD())?.to_vec();
            if d.len() <= 0 {
                return Err(CoseError::MissingD());
            }
            Ok(d)
        } else if kty == RSA {
            Ok(Rsa::from_private_components(
                BigNum::from_slice(self.n.as_ref().ok_or(CoseError::MissingN())?)?,
                BigNum::from_slice(self.e.as_ref().ok_or(CoseError::MissingE())?)?,
                BigNum::from_slice(self.rsa_d.as_ref().ok_or(CoseError::MissingRsaD())?)?,
                BigNum::from_slice(self.p.as_ref().ok_or(CoseError::MissingP())?)?,
                BigNum::from_slice(self.q.as_ref().ok_or(CoseError::MissingQ())?)?,
                BigNum::from_slice(self.dp.as_ref().ok_or(CoseError::MissingDP())?)?,
                BigNum::from_slice(self.dq.as_ref().ok_or(CoseError::MissingDQ())?)?,
                BigNum::from_slice(self.qinv.as_ref().ok_or(CoseError::MissingQINV())?)?,
            )?
            .private_key_to_der()?)
        } else if kty == SYMMETRIC {
            let k = self.k.as_ref().ok_or(CoseError::MissingK())?.to_vec();
            if k.len() <= 0 {
                return Err(CoseError::MissingK());
            }
            Ok(k)
        } else {
            Err(CoseError::InvalidKTY())
        }
    }
    pub(crate) fn get_pub_key(&self) -> CoseResultWithRet<Vec<u8>> {
        let kty = self.kty.ok_or(CoseError::MissingKTY())?;
        if kty == EC2 || kty == OKP {
            let mut x = self.x.as_ref().ok_or(CoseError::MissingX())?.to_vec();
            if x.len() <= 0 {
                return Err(CoseError::MissingX());
            }
            let mut pub_key;
            if kty == EC2 {
                if self.y != None && self.y.as_ref().unwrap().len() > 0 {
                    let mut y = self.y.as_ref().unwrap().to_vec();
                    pub_key = vec![4];
                    pub_key.append(&mut x);
                    pub_key.append(&mut y);
                } else {
                    pub_key = vec![3];
                    pub_key.append(&mut x);
                }
            } else {
                pub_key = x;
            }
            Ok(pub_key)
        } else if kty == RSA {
            Ok(Rsa::from_public_components(
                BigNum::from_slice(self.n.as_ref().ok_or(CoseError::MissingN())?)?,
                BigNum::from_slice(self.e.as_ref().ok_or(CoseError::MissingE())?)?,
            )?
            .public_key_to_der()?)
        } else {
            Err(CoseError::InvalidKTY())
        }
    }
}
pub struct CoseKeySet {
        pub cose_keys: Vec<CoseKey>,
        pub bytes: Vec<u8>,
}
impl CoseKeySet {
        pub fn new() -> CoseKeySet {
        CoseKeySet {
            cose_keys: Vec::new(),
            bytes: Vec::new(),
        }
    }
        pub fn add_key(&mut self, key: CoseKey) {
        self.cose_keys.push(key);
    }
        pub fn encode(&mut self) -> CoseResult {
        let mut e = Encoder::new(Vec::new());
        let len = self.cose_keys.len();
        if len > 0 {
            e.array(len)?;
            for i in 0..len {
                self.cose_keys[i].encode_key(&mut e)?;
            }
            self.bytes = e.into_writer().to_vec();
            Ok(())
        } else {
            Err(CoseError::MissingKey())
        }
    }
                pub fn decode(&mut self) -> CoseResult {
        let input = Cursor::new(self.bytes.clone());
        let mut d = Decoder::new(Config::default(), input);
        let len = d.array()?;
        if len > 0 {
            for _ in 0..len {
                let mut cose_key = CoseKey::new();
                match cose_key.decode_key(&mut d) {
                    Ok(_v) => self.cose_keys.push(cose_key),
                    Err(_e) => (),
                }
            }
            Ok(())
        } else {
            Err(CoseError::MissingKey())
        }
    }
        pub fn get_key(&self, kid: &Vec<u8>) -> CoseResultWithRet<CoseKey> {
        for i in 0..self.cose_keys.len() {
            if self.cose_keys[i]
                .kid
                .as_ref()
                .ok_or(CoseError::MissingKID())?
                == kid
            {
                return Ok(self.cose_keys[i].clone());
            }
        }
        Err(CoseError::MissingKey())
    }
}