pq_envelope/
scheme.rs

1use crate::{Error, PublicKey, SecretKey};
2use derive_more::{Display, FromStr, TryFrom, TryInto};
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use sha3::{
5    Shake128, Shake256,
6    digest::{ExtendableOutput, HashMarker, Update, XofReader},
7};
8use zeroize::Zeroize;
9
10/// The type of [`Scheme`]s supported by this crate.
11///
12/// They are divided into three categories:
13///
14/// * `Small`: Where the focus is on optimizing the size of the envelope. Underneath, it creates an envelope using
15///     - AES-256-GCM as the data encryption algorithm or data encryption key (DEK).
16///     - ClassicMcEliece348864 as the key encapsulation mechanism (KEM) to produce the key encryption key (KEK).
17///     - SHAKE256 as the key derivation function (KDF).
18///     - AES-256-KW to encrypt the DEK with the KEK.
19///     - The resulting envelope is 96 bytes for the KEM ciphertext, 40 bytes for the wrapped DEK
20///       and N bytes for the encrypted data, 16 bytes for the authentication tag.
21/// * `Secure`: Where the focus is on optimizing security. Underneath, it creates an envelope using
22///      - AES-256-GCM as the data encryption algorithm or data encryption key (DEK).
23///      - FrodoKem1344Aes as the key encapsulation mechanism (KEM) to produce the key encryption key (KEK).
24///      - SHAKE256 as the key derivation function (KDF).
25///      - AES-256-KW to encrypt the DEK with the KEK.
26///     - The resulting envelope is 21632 bytes for the KEM ciphertext, 40 bytes for the wrapped DEK
27///       and N bytes for the encrypted data, 16 bytes for the authentication tag.
28/// * `Nist`: Where the focus is on using NIST standardized algorithms. Underneath, it creates an envelope using
29///    - AES-256-GCM as the data encryption algorithm or data encryption key (DEK).
30///    - MLKEM768 as the key encapsulation mechanism (KEM) to produce the key encryption key (KEK).
31///    - SHAKE256 as the key derivation function (KDF).
32///    - AES-256-KW to encrypt the DEK with the KEK.
33///     - The resulting envelope is 1088 bytes for the KEM ciphertext, 40 bytes for the wrapped DEK
34///       and N bytes for the encrypted data, 16 bytes for the authentication tag.
35///
36/// `Nist` has a good balance between size and security,
37/// while using only NIST standardized algorithms.
38/// The key sizes are relatively small and the best performance.
39///
40/// `Small` is suitable for scenarios where envelope size is a critical factor,
41/// however, it requires the largest key sizes.
42///
43/// `Secure` offers the highest security level, but comes with a
44/// significant increase in envelope and key size, and the slowest performance.
45#[derive(
46    Copy,
47    Clone,
48    Debug,
49    Default,
50    PartialEq,
51    Eq,
52    PartialOrd,
53    Ord,
54    Hash,
55    rkyv::Archive,
56    rkyv::Serialize,
57    rkyv::Deserialize,
58    Display,
59    FromStr,
60    TryFrom,
61    TryInto,
62)]
63#[display("{}")]
64#[try_from(repr)]
65#[repr(u8)]
66#[rkyv(compare(PartialEq), derive(Debug))]
67pub enum Scheme {
68    #[default]
69    #[display("Nist")]
70    /// Enveloped using NIST standardized algorithms.
71    Nist = 1,
72    #[display("Small")]
73    /// Enveloped optimized for space.
74    Small = 2,
75    #[display("Secure")]
76    /// Enveloped optimized for security.
77    Secure = 3,
78}
79
80impl Serialize for Scheme {
81    fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
82    where
83        S: Serializer,
84    {
85        if s.is_human_readable() {
86            s.serialize_str(&self.to_string())
87        } else {
88            s.serialize_u8(self.into())
89        }
90    }
91}
92
93impl<'de> Deserialize<'de> for Scheme {
94    fn deserialize<D>(d: D) -> Result<Self, D::Error>
95    where
96        D: Deserializer<'de>,
97    {
98        if d.is_human_readable() {
99            let s = String::deserialize(d)?;
100            s.parse().map_err(serde::de::Error::custom)
101        } else {
102            let v = u8::deserialize(d)?;
103            v.try_into().map_err(serde::de::Error::custom)
104        }
105    }
106}
107
108impl From<Scheme> for u8 {
109    fn from(scheme: Scheme) -> Self {
110        scheme as u8
111    }
112}
113
114impl From<&Scheme> for u8 {
115    fn from(scheme: &Scheme) -> Self {
116        *scheme as u8
117    }
118}
119
120impl From<Scheme> for oqs::kem::Kem {
121    fn from(scheme: Scheme) -> Self {
122        match scheme {
123            Scheme::Small => {
124                oqs::kem::Kem::new(oqs::kem::Algorithm::ClassicMcEliece348864).expect("Invalid KEM")
125            }
126            Scheme::Secure => {
127                oqs::kem::Kem::new(oqs::kem::Algorithm::FrodoKem1344Aes).expect("Invalid KEM")
128            }
129            Scheme::Nist => oqs::kem::Kem::new(oqs::kem::Algorithm::MlKem768).expect("Invalid KEM"),
130        }
131    }
132}
133
134impl From<&Scheme> for oqs::kem::Kem {
135    fn from(value: &Scheme) -> Self {
136        oqs::kem::Kem::from(*value)
137    }
138}
139
140impl TryFrom<&oqs::kem::Kem> for Scheme {
141    type Error = Error;
142
143    fn try_from(kem: &oqs::kem::Kem) -> Result<Self, Self::Error> {
144        Self::try_from(kem.algorithm())
145    }
146}
147
148impl TryFrom<oqs::kem::Algorithm> for Scheme {
149    type Error = Error;
150
151    fn try_from(alg: oqs::kem::Algorithm) -> Result<Self, Self::Error> {
152        match alg {
153            oqs::kem::Algorithm::ClassicMcEliece348864 => Ok(Scheme::Small),
154            oqs::kem::Algorithm::FrodoKem1344Aes => Ok(Scheme::Secure),
155            oqs::kem::Algorithm::MlKem768 => Ok(Scheme::Nist),
156            _ => Err(Error::InvalidSchemeValue(derive_more::TryFromReprError {
157                input: alg as u8,
158            })),
159        }
160    }
161}
162
163impl Scheme {
164    /// Generate a new public/private key pair for the specified scheme.
165    pub fn key_pair(&self) -> crate::Result<(PublicKey, SecretKey)> {
166        let kem: oqs::kem::Kem = self.into();
167        let (pk, sk) = kem.keypair()?;
168        Ok((pk.into(), sk.into()))
169    }
170    #[cfg(test)]
171    pub(crate) const fn recipient_binary_size(&self) -> usize {
172        match self {
173            Scheme::Small => crate::SCHEME_SMALL_CAPSULE_LENGTH + 41,
174            Scheme::Nist => crate::SCHEME_NIST_CAPSULE_LENGTH + 42,
175            Scheme::Secure => crate::SCHEME_SECURE_CAPSULE_LENGTH + 43,
176        }
177    }
178
179    pub(crate) fn create_kek<B: AsRef<[u8]>>(&self, shared_secret: B) -> aes_kw::KekAes256 {
180        let mut kek = match self {
181            Scheme::Small | Scheme::Nist => self.kdf_kek::<Shake128, B>(shared_secret, 32),
182            Scheme::Secure => self.kdf_kek::<Shake256, B>(shared_secret, 64),
183        };
184
185        let kw = aes_kw::KekAes256::new(
186            &aes_gcm::aes::cipher::generic_array::GenericArray::clone_from_slice(&kek),
187        );
188        kek.zeroize();
189        kw
190    }
191
192    fn kdf_kek<H: ExtendableOutput + Update + Default + HashMarker, B: AsRef<[u8]>>(
193        &self,
194        shared_secret: B,
195        required_length: usize,
196    ) -> [u8; 32] {
197        let mut shaker = H::default();
198        shaker.update(b"pq-envelope");
199        shaker.update(b"key-encryption-key");
200        shaker.update(self.to_string().as_bytes());
201        shaker.update(shared_secret.as_ref());
202        shaker.update(&[32u8]);
203        let mut reader = shaker.finalize_xof();
204        let mut kek = vec![0u8; required_length];
205        reader.read(&mut kek);
206        <[u8; 32]>::try_from(&kek[required_length - 32..required_length])
207            .expect("KDF output length is always >= 32 bytes")
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use rkyv::{access, deserialize, rancor::Error, to_bytes};
215    use rstest::*;
216
217    #[rstest]
218    #[case::small(Scheme::Small, "Small")]
219    #[case::nist(Scheme::Nist, "Nist")]
220    #[case::secure(Scheme::Secure, "Secure")]
221    fn serialization_human_readable(#[case] scheme: Scheme, #[case] value: &str) {
222        let serialized = serde_json::to_string(&scheme).unwrap();
223        assert_eq!(serialized, format!("\"{}\"", value));
224        let deserialized: Scheme = serde_json::from_str(&serialized).unwrap();
225        assert_eq!(scheme, deserialized);
226    }
227
228    #[rstest]
229    #[case::nist(Scheme::Nist, 1u8)]
230    #[case::small(Scheme::Small, 2u8)]
231    #[case::secure(Scheme::Secure, 3u8)]
232    fn serialization_non_human_readable(#[case] scheme: Scheme, #[case] value: u8) {
233        let serialized = postcard::to_stdvec(&scheme).unwrap();
234        assert_eq!(serialized.len(), 1);
235        assert_eq!(serialized[0], value);
236        let deserialized: Scheme = postcard::from_bytes(&serialized).unwrap();
237        assert_eq!(scheme, deserialized);
238    }
239
240    #[rstest]
241    #[case::nist(Scheme::Nist)]
242    #[case::small(Scheme::Small)]
243    #[case::secure(Scheme::Secure)]
244    fn rkyv_tests(#[case] scheme: Scheme) {
245        let serialized = to_bytes::<Error>(&scheme).unwrap();
246        let archive = access::<ArchivedScheme, Error>(&serialized[..]).unwrap();
247        assert_eq!(archive, &scheme);
248        let deserialized = deserialize::<Scheme, Error>(archive).unwrap();
249        assert_eq!(deserialized, scheme);
250    }
251}