bc_components/kyber/
kyber_public_key.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
use anyhow::{Result, Error, anyhow, bail};
use dcbor::prelude::*;
use pqcrypto_kyber::*;
use pqcrypto_traits::kem::{PublicKey, SharedSecret};

use crate::{tags, SymmetricKey};

use super::{Kyber, KyberCiphertext};

#[derive(Clone)]
pub enum KyberPublicKey {
    Kyber512(Box<kyber512::PublicKey>),
    Kyber768(Box<kyber768::PublicKey>),
    Kyber1024(Box<kyber1024::PublicKey>),
}

impl PartialEq for KyberPublicKey {
    fn eq(&self, other: &Self) -> bool {
        self.level() == other.level() && self.as_bytes() == other.as_bytes()
    }
}

impl Eq for KyberPublicKey {}

impl std::hash::Hash for KyberPublicKey {
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
        self.level().hash(state);
        self.as_bytes().hash(state);
    }
}

impl KyberPublicKey {
    pub fn level(&self) -> Kyber {
        match self {
            KyberPublicKey::Kyber512(_) => Kyber::Kyber512,
            KyberPublicKey::Kyber768(_) => Kyber::Kyber768,
            KyberPublicKey::Kyber1024(_) => Kyber::Kyber1024,
        }
    }

    pub fn size(&self) -> usize {
        self.level().public_key_size()
    }

    pub fn as_bytes(&self) -> &[u8] {
        match self {
            KyberPublicKey::Kyber512(pk) => pk.as_ref().as_bytes(),
            KyberPublicKey::Kyber768(pk) => pk.as_ref().as_bytes(),
            KyberPublicKey::Kyber1024(pk) => pk.as_ref().as_bytes(),
        }
    }

    pub fn from_bytes(level: Kyber, bytes: &[u8]) -> Result<Self> {
        match level {
            Kyber::Kyber512 => Ok(KyberPublicKey::Kyber512(Box::new(kyber512::PublicKey::from_bytes(bytes).map_err(|e| anyhow!(e))?))),
            Kyber::Kyber768 => Ok(KyberPublicKey::Kyber768(Box::new(kyber768::PublicKey::from_bytes(bytes).map_err(|e| anyhow!(e))?))),
            Kyber::Kyber1024 => Ok(KyberPublicKey::Kyber1024(Box::new(kyber1024::PublicKey::from_bytes(bytes).map_err(|e| anyhow!(e))?))),
        }
    }

    pub fn encapsulate_new_shared_secret(&self) -> (SymmetricKey, KyberCiphertext) {
        match self {
            KyberPublicKey::Kyber512(pk) => {
                let (ss, ct) = kyber512::encapsulate(pk.as_ref());
                (SymmetricKey::from_data_ref(ss.as_bytes()).unwrap(), KyberCiphertext::Kyber512(ct.into()))
            }
            KyberPublicKey::Kyber768(pk) => {
                let (ss, ct) = kyber768::encapsulate(pk.as_ref());
                (SymmetricKey::from_data_ref(ss.as_bytes()).unwrap(), KyberCiphertext::Kyber768(ct.into()))
            }
            KyberPublicKey::Kyber1024(pk) => {
                let (ss, ct) = kyber1024::encapsulate(pk.as_ref());
                (SymmetricKey::from_data_ref(ss.as_bytes()).unwrap(), KyberCiphertext::Kyber1024(ct.into()))
            }
        }
    }
}

impl std::fmt::Debug for KyberPublicKey {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            KyberPublicKey::Kyber512(_) => f.write_str("Kyber512PublicKey"),
            KyberPublicKey::Kyber768(_) => f.write_str("Kyber768PublicKey"),
            KyberPublicKey::Kyber1024(_) => f.write_str("Kyber1024PublicKey"),
        }
    }
}

impl CBORTagged for KyberPublicKey {
    fn cbor_tags() -> Vec<Tag> {
        tags_for_values(&[tags::TAG_KYBER_PUBLIC_KEY])
    }
}

impl From<KyberPublicKey> for CBOR {
    fn from(value: KyberPublicKey) -> Self {
        value.tagged_cbor()
    }
}

impl CBORTaggedEncodable for KyberPublicKey {
    fn untagged_cbor(&self) -> CBOR {
        vec![
            self.level().into(),
            CBOR::to_byte_string(self.as_bytes())
        ].into()
    }
}

impl TryFrom<CBOR> for KyberPublicKey {
    type Error = Error;

    fn try_from(cbor: CBOR) -> Result<Self, Self::Error> {
        Self::from_tagged_cbor(cbor)
    }
}

impl CBORTaggedDecodable for KyberPublicKey {
    fn from_untagged_cbor(untagged_cbor: CBOR) -> Result<Self> {
        match untagged_cbor.as_case() {
            CBORCase::Array(elements) => {
                if elements.len() != 2 {
                    bail!("KyberPublicKey must have two elements");
                }

                let level = Kyber::try_from(elements[0].clone())?;
                let data = CBOR::try_into_byte_string(elements[1].clone())?;
                KyberPublicKey::from_bytes(level, &data)
            }
            _ => bail!("KyberPublicKey must be an array"),
        }
    }
}