1use std::{
2 convert::TryFrom,
3 io::{BufRead, BufReader, Cursor, Read},
4};
5
6use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
7
8#[cfg(feature = "nss")]
9use crate::nss::{
10 hpke::{generate_key_pair, Config as HpkeConfig, HpkeR},
11 PrivateKey, PublicKey,
12};
13#[cfg(feature = "rust-hpke")]
14use crate::rh::hpke::{
15 derive_key_pair, generate_key_pair, Config as HpkeConfig, HpkeR, PrivateKey, PublicKey,
16};
17use crate::{
18 err::{Error, Res},
19 hpke::{Aead as AeadId, Kdf, Kem},
20 KeyId,
21};
22
23#[derive(Debug, Copy, Clone, PartialEq, Eq)]
25pub struct SymmetricSuite {
26 kdf: Kdf,
27 aead: AeadId,
28}
29
30impl SymmetricSuite {
31 #[must_use]
32 pub const fn new(kdf: Kdf, aead: AeadId) -> Self {
33 Self { kdf, aead }
34 }
35
36 #[must_use]
37 pub fn kdf(self) -> Kdf {
38 self.kdf
39 }
40
41 #[must_use]
42 pub fn aead(self) -> AeadId {
43 self.aead
44 }
45}
46
47#[allow(clippy::module_name_repetitions)]
51#[derive(Debug, Clone)]
52pub struct KeyConfig {
53 pub(crate) key_id: KeyId,
54 pub(crate) kem: Kem,
55 pub(crate) symmetric: Vec<SymmetricSuite>,
56 pub(crate) sk: Option<PrivateKey>,
57 pub(crate) pk: PublicKey,
58}
59
60impl KeyConfig {
61 fn strip_unsupported(symmetric: &mut Vec<SymmetricSuite>, kem: Kem) {
62 symmetric.retain(|s| HpkeConfig::new(kem, s.kdf(), s.aead()).supported());
63 }
64
65 pub fn new(key_id: u8, kem: Kem, mut symmetric: Vec<SymmetricSuite>) -> Res<Self> {
69 Self::strip_unsupported(&mut symmetric, kem);
70 assert!(!symmetric.is_empty());
71 let (sk, pk) = generate_key_pair(kem)?;
72 Ok(Self {
73 key_id,
74 kem,
75 symmetric,
76 sk: Some(sk),
77 pk,
78 })
79 }
80
81 #[allow(unused)]
87 pub fn derive(
88 key_id: u8,
89 kem: Kem,
90 mut symmetric: Vec<SymmetricSuite>,
91 ikm: &[u8],
92 ) -> Res<Self> {
93 #[cfg(feature = "rust-hpke")]
94 {
95 Self::strip_unsupported(&mut symmetric, kem);
96 assert!(!symmetric.is_empty());
97 let (sk, pk) = derive_key_pair(kem, ikm)?;
98 return Ok(Self {
99 key_id,
100 kem,
101 symmetric,
102 sk: Some(sk),
103 pk,
104 });
105 }
106 Err(Error::Unsupported)
107 }
108
109 pub fn encode_list(list: &[impl AsRef<Self>]) -> Res<Vec<u8>> {
120 let mut buf = Vec::new();
121 for c in list {
122 let offset = buf.len();
123 buf.write_u16::<NetworkEndian>(0)?;
124 c.as_ref().write(&mut buf)?;
125 let len = buf.len() - offset - 2;
126 buf[offset] = u8::try_from(len >> 8)?;
127 buf[offset + 1] = u8::try_from(len & 0xff).unwrap();
128 }
129 Ok(buf)
130 }
131
132 fn write(&self, buf: &mut Vec<u8>) -> Res<()> {
133 buf.write_u8(self.key_id)?;
134 buf.write_u16::<NetworkEndian>(u16::from(self.kem))?;
135 let pk_buf = self.pk.key_data()?;
136 buf.extend_from_slice(&pk_buf);
137 buf.write_u16::<NetworkEndian>((self.symmetric.len() * 4).try_into()?)?;
138 for s in &self.symmetric {
139 buf.write_u16::<NetworkEndian>(u16::from(s.kdf()))?;
140 buf.write_u16::<NetworkEndian>(u16::from(s.aead()))?;
141 }
142 Ok(())
143 }
144
145 pub fn encode(&self) -> Res<Vec<u8>> {
168 let mut buf = Vec::new();
169 self.write(&mut buf)?;
170 Ok(buf)
171 }
172
173 pub fn decode(encoded_config: &[u8]) -> Res<Self> {
176 let end_position = u64::try_from(encoded_config.len())?;
177 let mut r = Cursor::new(encoded_config);
178 let key_id = r.read_u8()?;
179 let kem = Kem::try_from(r.read_u16::<NetworkEndian>()?)?;
180
181 let kem_config = HpkeConfig::new(kem, Kdf::HkdfSha256, AeadId::Aes128Gcm);
183 if !kem_config.supported() {
184 return Err(Error::Unsupported);
185 }
186 let mut pk_buf = vec![0; kem_config.kem().n_pk()];
187 r.read_exact(&mut pk_buf)?;
188
189 let sym_len = r.read_u16::<NetworkEndian>()?;
190 let mut sym = vec![0; usize::from(sym_len)];
191 r.read_exact(&mut sym)?;
192 if sym.is_empty() || (sym.len() % 4 != 0) {
193 return Err(Error::Format);
194 }
195 let sym_count = sym.len() / 4;
196 let mut sym_r = BufReader::new(&sym[..]);
197 let mut symmetric = Vec::with_capacity(sym_count);
198 for _ in 0..sym_count {
199 let kdf = Kdf::try_from(sym_r.read_u16::<NetworkEndian>()?)?;
200 let aead = AeadId::try_from(sym_r.read_u16::<NetworkEndian>()?)?;
201 symmetric.push(SymmetricSuite::new(kdf, aead));
202 }
203
204 if r.position() != end_position {
206 return Err(Error::Format);
207 }
208
209 Self::strip_unsupported(&mut symmetric, kem);
210 let pk = HpkeR::decode_public_key(kem_config.kem(), &pk_buf)?;
211
212 Ok(Self {
213 key_id,
214 kem,
215 symmetric,
216 sk: None,
217 pk,
218 })
219 }
220
221 pub fn decode_list(encoded_list: &[u8]) -> Res<Vec<Self>> {
225 let end_position = u64::try_from(encoded_list.len())?;
226 let mut r = Cursor::new(encoded_list);
227 let mut configs = Vec::new();
228 loop {
229 if r.position() == end_position {
230 break;
231 }
232 let len = usize::from(r.read_u16::<NetworkEndian>()?);
233 let buf = r.fill_buf()?;
234 if len > buf.len() {
235 return Err(Error::Truncated);
236 }
237 let res = Self::decode(&buf[..len]);
238 r.consume(len);
239 match res {
240 Ok(config) => configs.push(config),
241 Err(Error::Unsupported) => {}
242 Err(e) => return Err(e),
243 }
244 }
245 Ok(configs)
246 }
247
248 pub fn select(&self, sym: SymmetricSuite) -> Res<HpkeConfig> {
253 if self.symmetric.contains(&sym) {
254 let config = HpkeConfig::new(self.kem, sym.kdf(), sym.aead());
255 Ok(config)
256 } else {
257 Err(Error::Unsupported)
258 }
259 }
260
261 #[allow(clippy::similar_names)] pub(crate) fn decode_hpke_config(&self, r: &mut Cursor<&[u8]>) -> Res<HpkeConfig> {
263 let key_id = r.read_u8()?;
264 if key_id != self.key_id {
265 return Err(Error::KeyId);
266 }
267 let kem_id = Kem::try_from(r.read_u16::<NetworkEndian>()?)?;
268 if kem_id != self.kem {
269 return Err(Error::InvalidKem);
270 }
271 let kdf_id = Kdf::try_from(r.read_u16::<NetworkEndian>()?)?;
272 let aead_id = AeadId::try_from(r.read_u16::<NetworkEndian>()?)?;
273 let hpke_config = HpkeConfig::new(self.kem, kdf_id, aead_id);
274 Ok(hpke_config)
275 }
276}
277
278impl AsRef<Self> for KeyConfig {
279 fn as_ref(&self) -> &Self {
280 self
281 }
282}
283
284#[cfg(test)]
285mod test {
286 use std::iter::zip;
287
288 use crate::{
289 hpke::{Aead, Kdf, Kem},
290 init, Error, KeyConfig, KeyId, SymmetricSuite,
291 };
292
293 const KEY_ID: KeyId = 1;
294 const KEM: Kem = Kem::X25519Sha256;
295 const SYMMETRIC: &[SymmetricSuite] = &[
296 SymmetricSuite::new(Kdf::HkdfSha256, Aead::Aes128Gcm),
297 SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305),
298 ];
299
300 #[test]
301 fn encode_decode_config_list() {
302 const COUNT: usize = 3;
303 init();
304
305 let mut configs = Vec::with_capacity(COUNT);
306 configs.resize_with(COUNT, || {
307 KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap()
308 });
309
310 let buf = KeyConfig::encode_list(&configs).unwrap();
311 let decoded_list = KeyConfig::decode_list(&buf).unwrap();
312 for (original, decoded) in zip(&configs, &decoded_list) {
313 assert_eq!(decoded.key_id, original.key_id);
314 assert_eq!(decoded.kem, original.kem);
315 assert_eq!(
316 decoded.pk.key_data().unwrap(),
317 original.pk.key_data().unwrap()
318 );
319 assert!(decoded.sk.is_none());
320 assert!(original.sk.is_some());
321 }
322
323 assert!(KeyConfig::decode_list(&buf[..buf.len() - 3]).is_err());
325 }
326
327 #[test]
328 fn empty_config_list() {
329 let list = KeyConfig::decode_list(&[]).unwrap();
330 assert!(list.is_empty());
331
332 let list = KeyConfig::decode_list(&[0, 3, 0, 0, 0]).unwrap();
337 assert!(list.is_empty());
338 }
339
340 #[test]
341 fn bad_config_list_length() {
342 init();
343
344 let res = KeyConfig::decode_list(&[0]);
346 assert!(matches!(res, Err(Error::Io(_))));
347 }
348
349 #[test]
350 fn decode_bad_config() {
351 init();
352
353 let mut x25519 = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC))
354 .unwrap()
355 .encode()
356 .unwrap();
357 {
358 let trunc = |n: usize| KeyConfig::decode(&x25519[..n]);
360
361 assert!(matches!(trunc(2), Err(Error::Io(_))));
363 assert!(matches!(trunc(4), Err(Error::Io(_))));
365 assert!(matches!(trunc(36), Err(Error::Io(_))));
367 assert!(matches!(trunc(38), Err(Error::Io(_))));
369 }
370
371 x25519.push(0);
373 assert!(matches!(KeyConfig::decode(&x25519), Err(Error::Format)));
374 }
375
376 #[test]
378 fn truncate_kdf_aead_list() {
379 init();
380
381 let mut x25519 = KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC))
382 .unwrap()
383 .encode()
384 .unwrap();
385 x25519.truncate(38);
386 assert_eq!(usize::from(x25519[36]), SYMMETRIC.len() * 4);
387 x25519[36] = 1;
388 assert!(matches!(KeyConfig::decode(&x25519), Err(Error::Format)));
389 }
390}