1#![forbid(unsafe_code)]
15#![allow(non_snake_case)]
17
18use core::{
19 fmt::{self, Debug, Display},
20 marker::PhantomData,
21 num::NonZeroU16,
22 result::Result,
23};
24
25use aranya_buggy::{bug, Bug, BugExt};
26use generic_array::ArrayLength;
27use subtle::{Choice, ConstantTimeEq};
28
29use crate::{
30 aead::{Aead, IndCca2, KeyData, Nonce, OpenError, SealError},
31 csprng::Csprng,
32 import::{ExportError, Import, ImportError},
33 kdf::{Context, Expand, Kdf, KdfError, Prk},
34 kem::{Kem, KemError},
35 AlgId,
36};
37
38macro_rules! i2osp {
40 ($v:expr) => {
41 $v.to_be_bytes()
42 };
43 ($v:expr, $n:ty) => {{
44 let src = $v.to_be_bytes();
45 let mut dst = generic_array::GenericArray::<u8, $n>::default();
46 let idx = dst.len().abs_diff(src.len());
52 if dst.len() >= src.len() {
53 dst[idx..].copy_from_slice(&src);
54 } else {
55 dst.copy_from_slice(&src[idx..]);
56 }
57 dst
58 }};
59}
60
61#[cfg_attr(test, derive(Debug))]
63pub enum Mode<'a, T> {
64 Base,
66 Psk(Psk<'a>),
70 Auth(T),
74 AuthPsk(T, Psk<'a>),
76}
77
78impl<T> Display for Mode<'_, T> {
79 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 match self {
81 Self::Base => write!(f, "mode_base"),
82 Self::Psk(_) => write!(f, "mode_psk"),
83 Self::Auth(_) => write!(f, "mode_auth"),
84 Self::AuthPsk(_, _) => write!(f, "mode_auth_psk"),
85 }
86 }
87}
88
89impl<'a, T> Mode<'a, T> {
90 const DEFAULT_PSK: Psk<'static> = Psk {
93 psk: &[],
94 psk_id: &[],
95 };
96
97 pub const fn as_ref(&self) -> Mode<'_, &T> {
99 match *self {
100 Self::Base => Mode::Base,
101 Self::Psk(psk) => Mode::Psk(psk),
102 Self::Auth(ref k) => Mode::Auth(k),
103 Self::AuthPsk(ref k, psk) => Mode::AuthPsk(k, psk),
104 }
105 }
106
107 fn psk(&self) -> &Psk<'a> {
108 match self {
109 Mode::Psk(psk) => psk,
110 Mode::AuthPsk(_, psk) => psk,
111 _ => &Self::DEFAULT_PSK,
112 }
113 }
114
115 const fn id(&self) -> u8 {
116 match self {
117 Self::Base => 0x00,
118 Self::Psk(_) => 0x01,
119 Self::Auth(_) => 0x02,
120 Self::AuthPsk(_, _) => 0x03,
121 }
122 }
123}
124
125#[derive(Copy, Clone, Debug, Eq, PartialEq)]
127pub struct InvalidPsk;
128
129impl Display for InvalidPsk {
130 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131 f.write_str("invalid pre-shared key: PSK or PSK ID are empty")
132 }
133}
134
135impl core::error::Error for InvalidPsk {}
136
137#[cfg_attr(test, derive(Debug))]
139#[derive(Copy, Clone)]
140pub struct Psk<'a> {
141 psk: &'a [u8],
143 psk_id: &'a [u8],
145}
146
147impl<'a> Psk<'a> {
148 pub fn new(psk: &'a [u8], psk_id: &'a [u8]) -> Result<Self, InvalidPsk> {
150 if psk.is_empty() || psk_id.is_empty() {
152 Err(InvalidPsk)
153 } else {
154 Ok(Self { psk, psk_id })
155 }
156 }
157}
158
159impl ConstantTimeEq for Psk<'_> {
160 fn ct_eq(&self, other: &Self) -> Choice {
161 self.psk.ct_eq(other.psk) & self.psk_id.ct_eq(other.psk_id)
162 }
163}
164
165#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, AlgId)]
169pub enum KemId {
170 #[alg_id(0x0010)]
172 DhKemP256HkdfSha256,
173 #[alg_id(0x0011)]
175 DhKemP384HkdfSha384,
176 #[alg_id(0x0012)]
178 DhKemP521HkdfSha512,
179 #[alg_id(0x0013)]
181 DhKemCp256HkdfSha256,
182 #[alg_id(0x0014)]
184 DhKemCp384HkdfSha384,
185 #[alg_id(0x0015)]
187 DhKemCp521HkdfSha512,
188 #[alg_id(0x0016)]
190 DhKemSecp256k1HkdfSha256,
191 #[alg_id(0x0020)]
193 DhKemX25519HkdfSha256,
194 #[alg_id(0x0021)]
196 DhKemX448HkdfSha512,
197 #[alg_id(0x0030)]
199 X25519Kyber768Draft00,
200 #[alg_id(Other)]
204 Other(NonZeroU16),
205}
206
207impl Display for KemId {
208 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209 match self {
210 Self::DhKemP256HkdfSha256 => write!(f, "DHKEM(P-256, HKDF-SHA256)"),
211 Self::DhKemP384HkdfSha384 => write!(f, "DHKEM(P-384, HKDF-SHA384)"),
212 Self::DhKemP521HkdfSha512 => write!(f, "DHKEM(P-521, HKDF-SHA512)"),
213 Self::DhKemCp256HkdfSha256 => write!(f, "DHKEM(CP-256, HKDF-SHA256)"),
214 Self::DhKemCp384HkdfSha384 => write!(f, "DHKEM(CP-384, HKDF-SHA384)"),
215 Self::DhKemCp521HkdfSha512 => write!(f, "DHKEM(CP-521, HKDF-SHA512)"),
216 Self::DhKemSecp256k1HkdfSha256 => write!(f, "DHKEM(secp256k1, HKDF-SHA256)"),
217 Self::DhKemX25519HkdfSha256 => write!(f, "DHKEM(X25519, HKDF-SHA256)"),
218 Self::DhKemX448HkdfSha512 => write!(f, "DHKEM(X448, HKDF-SHA512)"),
219 Self::X25519Kyber768Draft00 => write!(f, "X25519Kyber768Draft00"),
220 Self::Other(id) => write!(f, "Kem({:#02x})", id),
221 }
222 }
223}
224
225#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, AlgId)]
229pub enum KdfId {
230 #[alg_id(0x0001)]
232 HkdfSha256,
233 #[alg_id(0x0002)]
235 HkdfSha384,
236 #[alg_id(0x0003)]
238 HkdfSha512,
239 #[alg_id(Other)]
243 Other(NonZeroU16),
244}
245
246impl Display for KdfId {
247 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248 match self {
249 Self::HkdfSha256 => write!(f, "HkdfSha256"),
250 Self::HkdfSha384 => write!(f, "HkdfSha384"),
251 Self::HkdfSha512 => write!(f, "HkdfSha512"),
252 Self::Other(id) => write!(f, "Kdf({:#02x})", id),
253 }
254 }
255}
256
257#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, AlgId)]
261pub enum AeadId {
262 #[alg_id(0x0001)]
264 Aes128Gcm,
265 #[alg_id(0x0002)]
267 Aes256Gcm,
268 #[alg_id(0x0003)]
270 ChaCha20Poly1305,
271 #[alg_id(0xfffd)]
275 Cmt1Aes256Gcm,
276 #[alg_id(0xfffe)]
280 Cmt4Aes256Gcm,
281 #[alg_id(Other)]
285 Other(NonZeroU16),
286 #[alg_id(0xffff)]
288 ExportOnly,
289}
290
291impl Display for AeadId {
292 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293 match self {
294 Self::Aes128Gcm => write!(f, "Aes128Gcm"),
295 Self::Aes256Gcm => write!(f, "Aes256Gcm"),
296 Self::ChaCha20Poly1305 => write!(f, "ChaCha20Poly1305"),
297 Self::Cmt1Aes256Gcm => write!(f, "Cmt1Aes256Gcm"),
298 Self::Cmt4Aes256Gcm => write!(f, "Cmt4Aes256Gcm"),
299 Self::Other(id) => write!(f, "Aead({:#02x})", id),
300 Self::ExportOnly => write!(f, "ExportOnly"),
301 }
302 }
303}
304
305#[derive(Debug, Eq, PartialEq)]
307pub enum HpkeError {
308 Seal(SealError),
310 Open(OpenError),
312 Kdf(KdfError),
314 Kem(KemError),
316 Import(ImportError),
318 Export(ExportError),
320 MessageLimitReached,
323 Bug(Bug),
325}
326
327impl Display for HpkeError {
328 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329 match self {
330 Self::Seal(err) => write!(f, "{}", err),
331 Self::Open(err) => write!(f, "{}", err),
332 Self::Kdf(err) => write!(f, "{}", err),
333 Self::Kem(err) => write!(f, "{}", err),
334 Self::Import(err) => write!(f, "{}", err),
335 Self::Export(err) => write!(f, "{}", err),
336 Self::MessageLimitReached => write!(f, "message limit reached"),
337 Self::Bug(err) => write!(f, "{err}"),
338 }
339 }
340}
341
342impl core::error::Error for HpkeError {
343 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
344 match self {
345 Self::Seal(err) => Some(err),
346 Self::Open(err) => Some(err),
347 Self::Kdf(err) => Some(err),
348 Self::Kem(err) => Some(err),
349 Self::Import(err) => Some(err),
350 Self::Export(err) => Some(err),
351 Self::MessageLimitReached => None,
352 Self::Bug(err) => Some(err),
353 }
354 }
355}
356
357impl From<SealError> for HpkeError {
358 fn from(err: SealError) -> Self {
359 Self::Seal(err)
360 }
361}
362
363impl From<OpenError> for HpkeError {
364 fn from(err: OpenError) -> Self {
365 Self::Open(err)
366 }
367}
368
369impl From<KdfError> for HpkeError {
370 fn from(err: KdfError) -> Self {
371 Self::Kdf(err)
372 }
373}
374
375impl From<KemError> for HpkeError {
376 fn from(err: KemError) -> Self {
377 Self::Kem(err)
378 }
379}
380
381impl From<ImportError> for HpkeError {
382 fn from(err: ImportError) -> Self {
383 Self::Import(err)
384 }
385}
386
387impl From<ExportError> for HpkeError {
388 fn from(err: ExportError) -> Self {
389 Self::Export(err)
390 }
391}
392
393impl From<Bug> for HpkeError {
394 fn from(err: Bug) -> Self {
395 Self::Bug(err)
396 }
397}
398
399impl From<MessageLimitReached> for HpkeError {
400 fn from(_err: MessageLimitReached) -> Self {
401 Self::MessageLimitReached
402 }
403}
404
405pub struct Hpke<K, F, A> {
409 _kem: PhantomData<K>,
410 _kdf: PhantomData<F>,
411 _aead: PhantomData<A>,
412}
413
414impl<K: Kem, F: Kdf, A: Aead + IndCca2> Hpke<K, F, A> {
415 #[allow(clippy::type_complexity)]
424 pub fn setup_send<R: Csprng>(
425 rng: &mut R,
426 mode: Mode<'_, &K::DecapKey>,
427 pkR: &K::EncapKey,
428 info: &[u8],
429 ) -> Result<(K::Encap, SendCtx<K, F, A>), HpkeError> {
430 let (shared_secret, enc) = match mode {
431 Mode::Auth(skS) | Mode::AuthPsk(skS, _) => K::auth_encap::<R>(rng, pkR, skS)?,
432 Mode::Base | Mode::Psk(_) => K::encap::<R>(rng, pkR)?,
433 };
434 let ctx = Self::key_schedule(mode, &shared_secret, info)?;
435 Ok((enc, ctx.into_send_ctx()))
436 }
437
438 #[allow(clippy::type_complexity)]
456 pub fn setup_send_deterministically(
457 mode: Mode<'_, &K::DecapKey>,
458 pkR: &K::EncapKey,
459 info: &[u8],
460 skE: K::DecapKey,
461 ) -> Result<(K::Encap, SendCtx<K, F, A>), HpkeError> {
462 let (shared_secret, enc) = match mode {
463 Mode::Auth(skS) | Mode::AuthPsk(skS, _) => {
464 K::auth_encap_deterministically(pkR, skS, skE)?
465 }
466 Mode::Base | Mode::Psk(_) => K::encap_deterministically(pkR, skE)?,
467 };
468 let ctx = Self::key_schedule(mode, &shared_secret, info)?;
469 Ok((enc, ctx.into_send_ctx()))
470 }
471
472 pub fn setup_recv(
478 mode: Mode<'_, &K::EncapKey>,
479 enc: &K::Encap,
480 skR: &K::DecapKey,
481 info: &[u8],
482 ) -> Result<RecvCtx<K, F, A>, HpkeError> {
483 let shared_secret = match mode {
484 Mode::Auth(pkS) | Mode::AuthPsk(pkS, _) => K::auth_decap(enc, skR, pkS)?,
485 Mode::Base | Mode::Psk(_) => K::decap(enc, skR)?,
486 };
487 let ctx = Self::key_schedule(mode, &shared_secret, info)?;
488 Ok(ctx.into_recv_ctx())
489 }
490
491 #[rustfmt::skip]
502 const HPKE_SUITE_ID: [u8; 10] = [
503 b'H',
504 b'P',
505 b'K',
506 b'E',
507 i2osp!(K::ID)[0], i2osp!(K::ID)[1],
508 i2osp!(F::ID)[0], i2osp!(F::ID)[1],
509 i2osp!(A::ID)[0], i2osp!(A::ID)[1],
510 ];
511
512 fn key_schedule<T>(
513 mode: Mode<'_, T>,
514 shared_secret: &K::Secret,
515 info: &[u8],
516 ) -> Result<Schedule<K, F, A>, HpkeError> {
517 let Psk { psk, psk_id } = mode.psk();
518
519 let psk_id_hash = Self::labeled_extract(b"", "psk_id_hash", psk_id);
521
522 let info_hash = Self::labeled_extract(b"", "info_hash", info);
524
525 let ks_ctx = [&[mode.id()], psk_id_hash.as_bytes(), info_hash.as_bytes()];
527
528 let secret = Self::labeled_extract(shared_secret.as_ref(), "secret", psk);
530
531 let key = Self::labeled_expand(&secret, "key", &ks_ctx)?;
533
534 let base_nonce = Self::labeled_expand(&secret, "base_nonce", &ks_ctx)?;
537
538 let exporter_secret = Self::labeled_expand(&secret, "exp", &ks_ctx)?;
541
542 Ok(Schedule {
543 key,
544 base_nonce,
545 exporter_secret,
546 _kem: PhantomData,
547 })
548 }
549
550 const HPKE_CTX: Context = Context {
551 domain: "HPKE-v1",
552 suite_ids: &Self::HPKE_SUITE_ID,
553 };
554
555 fn labeled_extract(salt: &[u8], label: &'static str, ikm: &[u8]) -> Prk<F::PrkSize> {
557 Self::HPKE_CTX.labeled_extract::<F>(salt, label, ikm)
561 }
562
563 fn labeled_expand<T: Expand>(
565 prk: &Prk<F::PrkSize>,
566 label: &'static str,
567 info: &[&[u8]],
568 ) -> Result<T, KdfError> {
569 let key = Self::HPKE_CTX.labeled_expand::<F, T>(prk, label, info)?;
574 Ok(key)
575 }
576
577 fn labeled_expand_into(
579 out: &mut [u8],
580 prk: &Prk<F::PrkSize>,
581 label: &'static str,
582 info: &[&[u8]],
583 ) -> Result<(), KdfError> {
584 Self::HPKE_CTX.labeled_expand_into::<F>(out, prk, label, info)
589 }
590}
591
592struct Schedule<K: Kem, F: Kdf, A: Aead + IndCca2> {
593 key: KeyData<A>,
594 base_nonce: Nonce<A::NonceSize>,
595 exporter_secret: Prk<F::PrkSize>,
596 _kem: PhantomData<K>,
597}
598
599impl<K: Kem, F: Kdf, A: Aead + IndCca2> Schedule<K, F, A> {
600 fn into_send_ctx(self) -> SendCtx<K, F, A> {
601 SendCtx {
602 seal: Either::Right((self.key, self.base_nonce)),
603 export: ExportCtx::new(self.exporter_secret),
604 }
605 }
606
607 fn into_recv_ctx(self) -> RecvCtx<K, F, A> {
608 RecvCtx {
609 open: Either::Right((self.key, self.base_nonce)),
610 export: ExportCtx::new(self.exporter_secret),
611 }
612 }
613}
614
615enum Either<L, R> {
617 Left(L),
618 Right(R),
619}
620
621impl<L, R> Either<L, R> {
622 fn get_or_insert_left<F, E>(&mut self, f: F) -> Result<&mut L, E>
623 where
624 F: FnOnce(&R) -> Result<L, E>,
625 E: From<Bug>,
626 {
627 match self {
628 Self::Left(left) => Ok(left),
629 Self::Right(right) => {
630 *self = Self::Left(f(right)?);
631 match self {
632 Self::Left(left) => Ok(left),
633 Self::Right(_) => bug!("we just assigned `Self::Left`"),
634 }
635 }
636 }
637 }
638}
639
640type RawKey<A> = (KeyData<A>, Nonce<<A as Aead>::NonceSize>);
641
642pub struct SendCtx<K: Kem, F: Kdf, A: Aead + IndCca2> {
645 seal: Either<SealCtx<A>, RawKey<A>>,
646 export: ExportCtx<K, F, A>,
647}
648
649impl<K: Kem, F: Kdf, A: Aead + IndCca2> SendCtx<K, F, A> {
650 pub const OVERHEAD: usize = SealCtx::<A>::OVERHEAD;
652
653 #[doc(hidden)]
655 pub fn into_raw_parts(self) -> Option<(KeyData<A>, Nonce<A::NonceSize>)> {
656 match self.seal {
657 Either::Left(_) => None,
658 Either::Right((key, base_nonce)) => Some((key, base_nonce)),
659 }
660 }
661
662 fn seal_ctx(&mut self) -> Result<&mut SealCtx<A>, ImportError> {
663 self.seal
664 .get_or_insert_left(|(key, nonce)| SealCtx::new(key, nonce, Seq::ZERO))
665 }
666
667 pub fn seal(
674 &mut self,
675 dst: &mut [u8],
676 plaintext: &[u8],
677 additional_data: &[u8],
678 ) -> Result<Seq, HpkeError> {
679 self.seal_ctx()?.seal(dst, plaintext, additional_data)
680 }
681
682 pub fn seal_in_place(
685 &mut self,
686 data: impl AsMut<[u8]>,
687 tag: &mut [u8],
688 additional_data: &[u8],
689 ) -> Result<Seq, HpkeError> {
690 self.seal_ctx()?.seal_in_place(data, tag, additional_data)
691 }
692
693 pub fn export<T>(&self, context: &[u8]) -> Result<T, KdfError>
695 where
696 T: Expand,
697 {
698 self.export.export(context)
699 }
700
701 pub fn export_into(&self, out: &mut [u8], context: &[u8]) -> Result<(), KdfError> {
704 self.export.export_into(out, context)
705 }
706}
707
708pub struct SealCtx<A: Aead + IndCca2> {
713 aead: A,
714 base_nonce: Nonce<A::NonceSize>,
715 seq: Seq,
717}
718
719impl<A: Aead + IndCca2> SealCtx<A> {
720 pub const OVERHEAD: usize = A::OVERHEAD;
722
723 #[doc(hidden)]
725 pub fn new(
726 key: &KeyData<A>,
727 base_nonce: &Nonce<A::NonceSize>,
728 seq: Seq,
729 ) -> Result<Self, ImportError> {
730 let key = A::Key::import(key.as_bytes())?;
731 Ok(Self {
732 aead: A::new(&key),
733 base_nonce: base_nonce.clone(),
734 seq,
735 })
736 }
737
738 fn compute_nonce(&self) -> Result<Nonce<A::NonceSize>, MessageLimitReached> {
739 self.seq.compute_nonce::<A::NonceSize>(&self.base_nonce)
740 }
741
742 fn increment_seq(&mut self) -> Result<Seq, Bug> {
743 self.seq.increment::<A::NonceSize>()
744 }
745
746 pub fn seal(
753 &mut self,
754 dst: &mut [u8],
755 plaintext: &[u8],
756 additional_data: &[u8],
757 ) -> Result<Seq, HpkeError> {
758 let nonce = self.compute_nonce()?;
759 self.aead.seal(dst, &nonce, plaintext, additional_data)?;
760 let prev = self.increment_seq()?;
761 Ok(prev)
762 }
763
764 pub fn seal_in_place(
767 &mut self,
768 mut data: impl AsMut<[u8]>,
769 tag: &mut [u8],
770 additional_data: &[u8],
771 ) -> Result<Seq, HpkeError> {
772 let nonce = self.compute_nonce()?;
773 self.aead
774 .seal_in_place(&nonce, data.as_mut(), tag, additional_data)?;
775 let prev = self.increment_seq()?;
776 Ok(prev)
777 }
778
779 pub fn seq(&self) -> Seq {
781 self.seq
782 }
783}
784
785pub struct RecvCtx<K: Kem, F: Kdf, A: Aead + IndCca2> {
788 open: Either<OpenCtx<A>, RawKey<A>>,
789 export: ExportCtx<K, F, A>,
790}
791
792impl<K: Kem, F: Kdf, A: Aead + IndCca2> RecvCtx<K, F, A> {
793 pub const OVERHEAD: usize = OpenCtx::<A>::OVERHEAD;
795
796 #[doc(hidden)]
798 pub fn into_raw_parts(self) -> Option<(KeyData<A>, Nonce<A::NonceSize>)> {
799 match self.open {
800 Either::Left(_) => None,
801 Either::Right((key, base_nonce)) => Some((key, base_nonce)),
802 }
803 }
804
805 fn open_ctx(&mut self) -> Result<&mut OpenCtx<A>, ImportError> {
806 self.open
807 .get_or_insert_left(|(key, nonce)| OpenCtx::new(key, nonce, Seq::ZERO))
808 }
809
810 pub fn open(
817 &mut self,
818 dst: &mut [u8],
819 ciphertext: &[u8],
820 additional_data: &[u8],
821 ) -> Result<(), HpkeError> {
822 self.open_ctx()?.open(dst, ciphertext, additional_data)
823 }
824
825 pub fn open_at(
832 &mut self,
833 dst: &mut [u8],
834 ciphertext: &[u8],
835 additional_data: &[u8],
836 seq: Seq,
837 ) -> Result<(), HpkeError> {
838 self.open_ctx()?
839 .open_at(dst, ciphertext, additional_data, seq)
840 }
841
842 pub fn open_in_place(
844 &mut self,
845 data: impl AsMut<[u8]>,
846 tag: &[u8],
847 additional_data: &[u8],
848 ) -> Result<(), HpkeError> {
849 self.open_ctx()?.open_in_place(data, tag, additional_data)
850 }
851
852 pub fn open_in_place_at(
855 &mut self,
856 data: impl AsMut<[u8]>,
857 tag: &[u8],
858 additional_data: &[u8],
859 seq: Seq,
860 ) -> Result<(), HpkeError> {
861 self.open_ctx()?
862 .open_in_place_at(data, tag, additional_data, seq)
863 }
864
865 pub fn export<T>(&self, context: &[u8]) -> Result<T, KdfError>
867 where
868 T: Expand,
869 {
870 self.export.export(context)
871 }
872
873 pub fn export_into(&self, out: &mut [u8], context: &[u8]) -> Result<(), KdfError> {
876 self.export.export_into(out, context)
877 }
878}
879
880pub struct OpenCtx<A: Aead + IndCca2> {
885 aead: A,
886 base_nonce: Nonce<A::NonceSize>,
887 seq: Seq,
889}
890
891impl<A: Aead + IndCca2> OpenCtx<A> {
892 pub const OVERHEAD: usize = A::OVERHEAD;
894
895 #[doc(hidden)]
897 pub fn new(
898 key: &KeyData<A>,
899 base_nonce: &Nonce<A::NonceSize>,
900 seq: Seq,
901 ) -> Result<Self, ImportError> {
902 let key = A::Key::import(key.as_bytes())?;
903 Ok(Self {
904 aead: A::new(&key),
905 base_nonce: base_nonce.clone(),
906 seq,
907 })
908 }
909
910 fn increment_seq(&mut self) -> Result<Seq, Bug> {
911 self.seq.increment::<A::NonceSize>()
912 }
913
914 pub fn open(
920 &mut self,
921 dst: &mut [u8],
922 ciphertext: &[u8],
923 additional_data: &[u8],
924 ) -> Result<(), HpkeError> {
925 self.open_at(dst, ciphertext, additional_data, self.seq)?;
926 self.increment_seq()?;
927 Ok(())
928 }
929
930 pub fn open_at(
937 &self,
938 dst: &mut [u8],
939 ciphertext: &[u8],
940 additional_data: &[u8],
941 seq: Seq,
942 ) -> Result<(), HpkeError> {
943 let nonce = seq.compute_nonce::<A::NonceSize>(&self.base_nonce)?;
944 self.aead.open(dst, &nonce, ciphertext, additional_data)?;
945 Ok(())
946 }
947
948 pub fn open_in_place(
950 &mut self,
951 mut data: impl AsMut<[u8]>,
952 tag: &[u8],
953 additional_data: &[u8],
954 ) -> Result<(), HpkeError> {
955 self.open_in_place_at(data.as_mut(), tag, additional_data, self.seq)?;
956 self.increment_seq()?;
957 Ok(())
958 }
959
960 pub fn open_in_place_at(
963 &self,
964 mut data: impl AsMut<[u8]>,
965 tag: &[u8],
966 additional_data: &[u8],
967 seq: Seq,
968 ) -> Result<(), HpkeError> {
969 let nonce = seq.compute_nonce::<A::NonceSize>(&self.base_nonce)?;
970 self.aead
971 .open_in_place(&nonce, data.as_mut(), tag, additional_data)?;
972 Ok(())
973 }
974}
975
976#[derive(Copy, Clone, Debug, Eq, PartialEq)]
978pub struct MessageLimitReached;
979
980impl Display for MessageLimitReached {
981 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
982 f.write_str("message limit reached")
983 }
984}
985
986impl core::error::Error for MessageLimitReached {}
987
988#[derive(Copy, Clone, Debug, Default, Hash, Eq, PartialEq, Ord, PartialOrd)]
990pub struct Seq {
991 seq: u64,
1004}
1005
1006impl Seq {
1007 pub const ZERO: Self = Self::new(0);
1009
1010 #[inline]
1012 pub const fn new(seq: u64) -> Self {
1013 Self { seq }
1014 }
1015
1016 #[inline]
1018 pub const fn to_u64(self) -> u64 {
1019 self.seq
1020 }
1021
1022 #[doc(hidden)]
1026 pub const fn max<N: ArrayLength>() -> u64 {
1027 let shift = 8usize.saturating_mul(N::USIZE);
1029 match 1u64.checked_shl(shift as u32) {
1030 Some(v) => v.saturating_sub(1),
1031 None => u64::MAX,
1032 }
1033 }
1034
1035 fn increment<N: ArrayLength>(&mut self) -> Result<Self, Bug> {
1038 if self.seq >= Self::max::<N>() {
1041 bug!("`Seq::increment` called after limit reached");
1044 }
1045 let prev = self.seq;
1047 self.seq = prev
1048 .checked_add(1)
1049 .assume("`Seq` overflow should be impossible")?;
1050 Ok(Self { seq: prev })
1051 }
1052
1053 fn compute_nonce<N: ArrayLength>(
1055 self,
1056 base_nonce: &Nonce<N>,
1057 ) -> Result<Nonce<N>, MessageLimitReached> {
1058 if self.seq >= Self::max::<N>() {
1059 Err(MessageLimitReached)
1060 } else {
1061 let seq_bytes = i2osp!(self.seq, N);
1063 Ok(base_nonce ^ &Nonce::from_bytes(seq_bytes))
1065 }
1066 }
1067}
1068
1069impl Display for Seq {
1070 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1071 write!(f, "{}", self.seq)
1072 }
1073}
1074
1075struct ExportCtx<K: Kem, F: Kdf, A: Aead + IndCca2> {
1076 exporter_secret: Prk<F::PrkSize>,
1077 _etc: PhantomData<(K, A)>,
1078}
1079
1080impl<K: Kem, F: Kdf, A: Aead + IndCca2> ExportCtx<K, F, A> {
1081 fn new(exporter_secret: Prk<F::PrkSize>) -> Self {
1082 Self {
1083 exporter_secret,
1084 _etc: PhantomData,
1085 }
1086 }
1087
1088 fn export<T>(&self, context: &[u8]) -> Result<T, KdfError>
1090 where
1091 T: Expand,
1092 {
1093 Hpke::<K, F, A>::labeled_expand(&self.exporter_secret, "sec", &[context])
1097 }
1098
1099 fn export_into(&self, out: &mut [u8], context: &[u8]) -> Result<(), KdfError> {
1101 Hpke::<K, F, A>::labeled_expand_into(out, &self.exporter_secret, "sec", &[context])
1105 }
1106}
1107
1108#[cfg(test)]
1109mod tests {
1110 #![allow(clippy::panic)]
1111
1112 use std::{collections::HashSet, ops::RangeInclusive};
1113
1114 use postcard::experimental::max_size::MaxSize;
1115 use typenum::{U1, U2};
1116
1117 use super::*;
1118
1119 #[test]
1122 fn test_seq_compute_nonce() {
1123 let base = Nonce::<U1>::try_from_slice(&[0xfe]).expect("should be able to create nonce");
1124 let cases = [
1125 (0, Ok(&[0xfe])),
1126 (1, Ok(&[0xff])),
1127 (2, Ok(&[0xfc])),
1128 (4, Ok(&[0xfa])),
1129 (254, Ok(&[0x00])),
1130 (255, Err(MessageLimitReached)),
1131 (256, Err(MessageLimitReached)),
1132 (257, Err(MessageLimitReached)),
1133 (u64::MAX, Err(MessageLimitReached)),
1134 ];
1135 for (input, output) in cases {
1136 let got = Seq::new(input).compute_nonce::<U1>(&base);
1137 let want = output.map(|s| Nonce::try_from_slice(s).expect("unable to create nonce"));
1138 assert_eq!(got, want, "seq = {input}");
1139 }
1140 }
1141
1142 #[test]
1144 fn test_seq_unique_nonce() {
1145 let base =
1146 Nonce::<U2>::try_from_slice(&[0xfe, 0xfe]).expect("should be able to create nonce");
1147 let mut seen = HashSet::new();
1148 for v in 0..u16::MAX {
1149 let got = Seq::new(u64::from(v))
1150 .compute_nonce::<U2>(&base)
1151 .expect("unable to create nonce");
1152 assert!(seen.insert(got), "duplicate nonce: {got:?}");
1153 }
1154 }
1155
1156 #[test]
1157 fn test_invalid_psk() {
1158 let err = Psk::new(&[], &[]).expect_err("should get `InvalidPsk`");
1159 assert_eq!(err, InvalidPsk);
1160 }
1161
1162 #[test]
1163 fn test_psk_ct_eq() {
1164 let cases = [
1165 (true, ("abc", "123"), ("abc", "123")),
1166 (false, ("a", "b"), ("a", "x")),
1167 (false, ("a", "b"), ("x", "b")),
1168 (false, ("a", "b"), ("c", "d")),
1169 ];
1170 for (pass, lhs, rhs) in cases {
1171 let lhs = Psk::new(lhs.0.as_bytes(), lhs.1.as_bytes()).expect("should not fail");
1172 let rhs = Psk::new(rhs.0.as_bytes(), rhs.1.as_bytes()).expect("should not fail");
1173 assert_eq!(pass, bool::from(lhs.ct_eq(&rhs)));
1174 }
1175 }
1176
1177 #[test]
1179 fn test_aead_id() {
1180 let unassigned = 0x0004..=0xFFFE - 2;
1182 for id in unassigned {
1183 let want = AeadId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
1184 let encoded = postcard::to_vec::<_, { u16::POSTCARD_MAX_SIZE }>(&id)
1185 .expect("should be able to encode `u16`");
1186 let got: AeadId = postcard::from_bytes(&encoded).unwrap_or_else(|err| {
1187 panic!("should be able to decode unassigned `AeadId` {id}: {err}")
1188 });
1189 assert_eq!(got, want);
1190 }
1191 }
1192
1193 #[test]
1195 fn test_aead_id_json() {
1196 let unassigned = 0x0004..=0xFFFE - 2;
1198 for id in unassigned {
1199 let want = AeadId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
1200 let encoded = serde_json::to_string(&id).expect("should be able to encode `u16`");
1201 let got: AeadId = serde_json::from_str(&encoded).unwrap_or_else(|err| {
1202 panic!("should be able to decode unassigned `AeadId` {id}: {err}")
1203 });
1204 assert_eq!(got, want);
1205 }
1206 }
1207
1208 #[test]
1210 fn test_kdf_id() {
1211 let unassigned = 0x0004..=0xFFFF;
1212 for id in unassigned {
1213 let want = KdfId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
1214 let encoded = postcard::to_vec::<_, { u16::POSTCARD_MAX_SIZE }>(&id)
1215 .expect("should be able to encode `u16`");
1216 let got: KdfId = postcard::from_bytes(&encoded).unwrap_or_else(|err| {
1217 panic!("should be able to decode unassigned `KdfId` {id}: {err}")
1218 });
1219 assert_eq!(got, want);
1220 }
1221 }
1222
1223 #[test]
1225 fn test_kem_id() {
1226 let unassigned: [RangeInclusive<u16>; 3] =
1227 [0x0001..=0x000F, 0x0022..=0x002F, 0x0031..=0xFFFF];
1228 for id in unassigned.into_iter().flatten() {
1229 let want = KemId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
1230 let encoded = postcard::to_vec::<_, { u16::POSTCARD_MAX_SIZE }>(&id)
1231 .expect("should be able to encode `u16`");
1232 let got: KemId = postcard::from_bytes(&encoded).unwrap_or_else(|err| {
1233 panic!("should be able to decode unassigned `KemId` {id}: {err}")
1234 });
1235 assert_eq!(got, want);
1236 }
1237 }
1238}