1use crate::{Error, PublicKey, SecretKey};
2use derive_more::{Display, FromStr, TryFrom, TryInto};
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use sha3::{
5 Shake128, Shake256,
6 digest::{ExtendableOutput, HashMarker, Update, XofReader},
7};
8use zeroize::Zeroize;
9
10#[derive(
46 Copy,
47 Clone,
48 Debug,
49 Default,
50 PartialEq,
51 Eq,
52 PartialOrd,
53 Ord,
54 Hash,
55 rkyv::Archive,
56 rkyv::Serialize,
57 rkyv::Deserialize,
58 Display,
59 FromStr,
60 TryFrom,
61 TryInto,
62)]
63#[display("{}")]
64#[try_from(repr)]
65#[repr(u8)]
66#[rkyv(compare(PartialEq), derive(Debug))]
67pub enum Scheme {
68 #[default]
69 #[display("Nist")]
70 Nist = 1,
72 #[display("Small")]
73 Small = 2,
75 #[display("Secure")]
76 Secure = 3,
78}
79
80impl Serialize for Scheme {
81 fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
82 where
83 S: Serializer,
84 {
85 if s.is_human_readable() {
86 s.serialize_str(&self.to_string())
87 } else {
88 s.serialize_u8(self.into())
89 }
90 }
91}
92
93impl<'de> Deserialize<'de> for Scheme {
94 fn deserialize<D>(d: D) -> Result<Self, D::Error>
95 where
96 D: Deserializer<'de>,
97 {
98 if d.is_human_readable() {
99 let s = String::deserialize(d)?;
100 s.parse().map_err(serde::de::Error::custom)
101 } else {
102 let v = u8::deserialize(d)?;
103 v.try_into().map_err(serde::de::Error::custom)
104 }
105 }
106}
107
108impl From<Scheme> for u8 {
109 fn from(scheme: Scheme) -> Self {
110 scheme as u8
111 }
112}
113
114impl From<&Scheme> for u8 {
115 fn from(scheme: &Scheme) -> Self {
116 *scheme as u8
117 }
118}
119
120impl From<Scheme> for oqs::kem::Kem {
121 fn from(scheme: Scheme) -> Self {
122 match scheme {
123 Scheme::Small => {
124 oqs::kem::Kem::new(oqs::kem::Algorithm::ClassicMcEliece348864).expect("Invalid KEM")
125 }
126 Scheme::Secure => {
127 oqs::kem::Kem::new(oqs::kem::Algorithm::FrodoKem1344Aes).expect("Invalid KEM")
128 }
129 Scheme::Nist => oqs::kem::Kem::new(oqs::kem::Algorithm::MlKem768).expect("Invalid KEM"),
130 }
131 }
132}
133
134impl From<&Scheme> for oqs::kem::Kem {
135 fn from(value: &Scheme) -> Self {
136 oqs::kem::Kem::from(*value)
137 }
138}
139
140impl TryFrom<&oqs::kem::Kem> for Scheme {
141 type Error = Error;
142
143 fn try_from(kem: &oqs::kem::Kem) -> Result<Self, Self::Error> {
144 Self::try_from(kem.algorithm())
145 }
146}
147
148impl TryFrom<oqs::kem::Algorithm> for Scheme {
149 type Error = Error;
150
151 fn try_from(alg: oqs::kem::Algorithm) -> Result<Self, Self::Error> {
152 match alg {
153 oqs::kem::Algorithm::ClassicMcEliece348864 => Ok(Scheme::Small),
154 oqs::kem::Algorithm::FrodoKem1344Aes => Ok(Scheme::Secure),
155 oqs::kem::Algorithm::MlKem768 => Ok(Scheme::Nist),
156 _ => Err(Error::InvalidSchemeValue(derive_more::TryFromReprError {
157 input: alg as u8,
158 })),
159 }
160 }
161}
162
163impl Scheme {
164 pub fn key_pair(&self) -> crate::Result<(PublicKey, SecretKey)> {
166 let kem: oqs::kem::Kem = self.into();
167 let (pk, sk) = kem.keypair()?;
168 Ok((pk.into(), sk.into()))
169 }
170 #[cfg(test)]
171 pub(crate) const fn recipient_binary_size(&self) -> usize {
172 match self {
173 Scheme::Small => crate::SCHEME_SMALL_CAPSULE_LENGTH + 41,
174 Scheme::Nist => crate::SCHEME_NIST_CAPSULE_LENGTH + 42,
175 Scheme::Secure => crate::SCHEME_SECURE_CAPSULE_LENGTH + 43,
176 }
177 }
178
179 pub(crate) fn create_kek<B: AsRef<[u8]>>(&self, shared_secret: B) -> aes_kw::KekAes256 {
180 let mut kek = match self {
181 Scheme::Small | Scheme::Nist => self.kdf_kek::<Shake128, B>(shared_secret, 32),
182 Scheme::Secure => self.kdf_kek::<Shake256, B>(shared_secret, 64),
183 };
184
185 let kw = aes_kw::KekAes256::new(
186 &aes_gcm::aes::cipher::generic_array::GenericArray::clone_from_slice(&kek),
187 );
188 kek.zeroize();
189 kw
190 }
191
192 fn kdf_kek<H: ExtendableOutput + Update + Default + HashMarker, B: AsRef<[u8]>>(
193 &self,
194 shared_secret: B,
195 required_length: usize,
196 ) -> [u8; 32] {
197 let mut shaker = H::default();
198 shaker.update(b"pq-envelope");
199 shaker.update(b"key-encryption-key");
200 shaker.update(self.to_string().as_bytes());
201 shaker.update(shared_secret.as_ref());
202 shaker.update(&[32u8]);
203 let mut reader = shaker.finalize_xof();
204 let mut kek = vec![0u8; required_length];
205 reader.read(&mut kek);
206 <[u8; 32]>::try_from(&kek[required_length - 32..required_length])
207 .expect("KDF output length is always >= 32 bytes")
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use rkyv::{access, deserialize, rancor::Error, to_bytes};
215 use rstest::*;
216
217 #[rstest]
218 #[case::small(Scheme::Small, "Small")]
219 #[case::nist(Scheme::Nist, "Nist")]
220 #[case::secure(Scheme::Secure, "Secure")]
221 fn serialization_human_readable(#[case] scheme: Scheme, #[case] value: &str) {
222 let serialized = serde_json::to_string(&scheme).unwrap();
223 assert_eq!(serialized, format!("\"{}\"", value));
224 let deserialized: Scheme = serde_json::from_str(&serialized).unwrap();
225 assert_eq!(scheme, deserialized);
226 }
227
228 #[rstest]
229 #[case::nist(Scheme::Nist, 1u8)]
230 #[case::small(Scheme::Small, 2u8)]
231 #[case::secure(Scheme::Secure, 3u8)]
232 fn serialization_non_human_readable(#[case] scheme: Scheme, #[case] value: u8) {
233 let serialized = postcard::to_stdvec(&scheme).unwrap();
234 assert_eq!(serialized.len(), 1);
235 assert_eq!(serialized[0], value);
236 let deserialized: Scheme = postcard::from_bytes(&serialized).unwrap();
237 assert_eq!(scheme, deserialized);
238 }
239
240 #[rstest]
241 #[case::nist(Scheme::Nist)]
242 #[case::small(Scheme::Small)]
243 #[case::secure(Scheme::Secure)]
244 fn rkyv_tests(#[case] scheme: Scheme) {
245 let serialized = to_bytes::<Error>(&scheme).unwrap();
246 let archive = access::<ArchivedScheme, Error>(&serialized[..]).unwrap();
247 assert_eq!(archive, &scheme);
248 let deserialized = deserialize::<Scheme, Error>(archive).unwrap();
249 assert_eq!(deserialized, scheme);
250 }
251}