1#![forbid(unsafe_code)]
15#![allow(non_snake_case)]
17
18use core::{fmt, iter, marker::PhantomData, num::NonZeroU16, result::Result};
19
20use buggy::{bug, Bug, BugExt};
21use generic_array::ArrayLength;
22use subtle::{Choice, ConstantTimeEq};
23use typenum::Unsigned as _;
24
25use crate::{
26 aead::{Aead, IndCca2, KeyData, Nonce, OpenError, SealError},
27 csprng::Csprng,
28 import::{ExportError, Import as _, ImportError},
29 kdf::{Expand, Kdf, KdfError, Prk},
30 kem::{Kem, KemError},
31 keys::RawSecretBytes as _,
32 AlgId,
33};
34
35macro_rules! i2osp {
37 ($v:expr) => {
38 $v.to_be_bytes()
39 };
40 ($v:expr, $n:ty) => {{
41 let src = $v.to_be_bytes();
42 let mut dst = generic_array::GenericArray::<u8, $n>::default();
43 let idx = dst.len().abs_diff(src.len());
49 if dst.len() >= src.len() {
50 dst[idx..].copy_from_slice(&src);
51 } else {
52 dst.copy_from_slice(&src[idx..]);
53 }
54 dst
55 }};
56}
57
58#[derive(Debug)]
60pub enum Mode<'a, T> {
61 Base,
63 Psk(Psk<'a>),
67 Auth(T),
71 AuthPsk(T, Psk<'a>),
73}
74
75impl<T> fmt::Display for Mode<'_, T> {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 match self {
78 Self::Base => write!(f, "mode_base"),
79 Self::Psk(_) => write!(f, "mode_psk"),
80 Self::Auth(_) => write!(f, "mode_auth"),
81 Self::AuthPsk(_, _) => write!(f, "mode_auth_psk"),
82 }
83 }
84}
85
86impl<'a, T> Mode<'a, T> {
87 const DEFAULT_PSK: Psk<'static> = Psk {
90 psk: &[],
91 psk_id: &[],
92 };
93
94 pub const fn as_ref(&self) -> Mode<'_, &T> {
96 match *self {
97 Self::Base => Mode::Base,
98 Self::Psk(psk) => Mode::Psk(psk),
99 Self::Auth(ref k) => Mode::Auth(k),
100 Self::AuthPsk(ref k, psk) => Mode::AuthPsk(k, psk),
101 }
102 }
103
104 fn psk(&self) -> &Psk<'a> {
105 match self {
106 Mode::Psk(psk) => psk,
107 Mode::AuthPsk(_, psk) => psk,
108 _ => &Self::DEFAULT_PSK,
109 }
110 }
111
112 const fn id(&self) -> u8 {
113 match self {
114 Self::Base => 0x00,
115 Self::Psk(_) => 0x01,
116 Self::Auth(_) => 0x02,
117 Self::AuthPsk(_, _) => 0x03,
118 }
119 }
120}
121
122#[derive(Copy, Clone, Debug, Eq, PartialEq)]
124pub struct InvalidPsk;
125
126impl fmt::Display for InvalidPsk {
127 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128 f.write_str("invalid pre-shared key: PSK or PSK ID are empty")
129 }
130}
131
132impl core::error::Error for InvalidPsk {}
133
134#[derive(Copy, Clone)]
136pub struct Psk<'a> {
137 psk: &'a [u8],
139 psk_id: &'a [u8],
141}
142
143impl<'a> Psk<'a> {
144 pub fn new(psk: &'a [u8], psk_id: &'a [u8]) -> Result<Self, InvalidPsk> {
146 if psk.is_empty() || psk_id.is_empty() {
148 Err(InvalidPsk)
149 } else {
150 Ok(Self { psk, psk_id })
151 }
152 }
153}
154
155impl ConstantTimeEq for Psk<'_> {
156 fn ct_eq(&self, other: &Self) -> Choice {
157 self.psk.ct_eq(other.psk) & self.psk_id.ct_eq(other.psk_id)
158 }
159}
160
161impl fmt::Debug for Psk<'_> {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 f.debug_struct("Psk")
164 .field("psk_id", &self.psk_id)
165 .finish_non_exhaustive()
166 }
167}
168
169#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, AlgId)]
173pub enum KemId {
174 #[alg_id(0x0010)]
176 DhKemP256HkdfSha256,
177 #[alg_id(0x0011)]
179 DhKemP384HkdfSha384,
180 #[alg_id(0x0012)]
182 DhKemP521HkdfSha512,
183 #[alg_id(0x0013)]
185 DhKemCp256HkdfSha256,
186 #[alg_id(0x0014)]
188 DhKemCp384HkdfSha384,
189 #[alg_id(0x0015)]
191 DhKemCp521HkdfSha512,
192 #[alg_id(0x0016)]
194 DhKemSecp256k1HkdfSha256,
195 #[alg_id(0x0020)]
197 DhKemX25519HkdfSha256,
198 #[alg_id(0x0021)]
200 DhKemX448HkdfSha512,
201 #[alg_id(0x0030)]
203 X25519Kyber768Draft00,
204 #[alg_id(0x040)]
206 MlKem512,
207 #[alg_id(0x041)]
209 MlKem768,
210 #[alg_id(0x042)]
212 MlKem1024,
213 #[alg_id(0x647a)]
215 XWing,
216 #[alg_id(Other)]
220 Other(NonZeroU16),
221}
222
223impl fmt::Display for KemId {
224 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225 match self {
226 Self::DhKemP256HkdfSha256 => write!(f, "DHKEM(P-256, HKDF-SHA256)"),
227 Self::DhKemP384HkdfSha384 => write!(f, "DHKEM(P-384, HKDF-SHA384)"),
228 Self::DhKemP521HkdfSha512 => write!(f, "DHKEM(P-521, HKDF-SHA512)"),
229 Self::DhKemCp256HkdfSha256 => write!(f, "DHKEM(CP-256, HKDF-SHA256)"),
230 Self::DhKemCp384HkdfSha384 => write!(f, "DHKEM(CP-384, HKDF-SHA384)"),
231 Self::DhKemCp521HkdfSha512 => write!(f, "DHKEM(CP-521, HKDF-SHA512)"),
232 Self::DhKemSecp256k1HkdfSha256 => write!(f, "DHKEM(secp256k1, HKDF-SHA256)"),
233 Self::DhKemX25519HkdfSha256 => write!(f, "DHKEM(X25519, HKDF-SHA256)"),
234 Self::DhKemX448HkdfSha512 => write!(f, "DHKEM(X448, HKDF-SHA512)"),
235 Self::X25519Kyber768Draft00 => write!(f, "X25519Kyber768Draft00"),
236 Self::MlKem512 => write!(f, "ML-KEM-512"),
237 Self::MlKem768 => write!(f, "ML-KEM-768"),
238 Self::MlKem1024 => write!(f, "ML-KEM-1024"),
239 Self::XWing => write!(f, "X-Wing"),
240 Self::Other(id) => write!(f, "Kem({:#02x})", id),
241 }
242 }
243}
244
245#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, AlgId)]
249pub enum KdfId {
250 #[alg_id(0x0001)]
252 HkdfSha256,
253 #[alg_id(0x0002)]
255 HkdfSha384,
256 #[alg_id(0x0003)]
258 HkdfSha512,
259 #[alg_id(Other)]
263 Other(NonZeroU16),
264}
265
266impl fmt::Display for KdfId {
267 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
268 match self {
269 Self::HkdfSha256 => write!(f, "HkdfSha256"),
270 Self::HkdfSha384 => write!(f, "HkdfSha384"),
271 Self::HkdfSha512 => write!(f, "HkdfSha512"),
272 Self::Other(id) => write!(f, "Kdf({:#02x})", id),
273 }
274 }
275}
276
277#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, AlgId)]
281pub enum AeadId {
282 #[alg_id(0x0001)]
284 Aes128Gcm,
285 #[alg_id(0x0002)]
287 Aes256Gcm,
288 #[alg_id(0x0003)]
290 ChaCha20Poly1305,
291 #[alg_id(Other)]
295 Other(NonZeroU16),
296 #[alg_id(0xffff)]
298 ExportOnly,
299}
300
301impl fmt::Display for AeadId {
302 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303 match self {
304 Self::Aes128Gcm => write!(f, "Aes128Gcm"),
305 Self::Aes256Gcm => write!(f, "Aes256Gcm"),
306 Self::ChaCha20Poly1305 => write!(f, "ChaCha20Poly1305"),
307 Self::Other(id) => write!(f, "Aead({:#02x})", id),
308 Self::ExportOnly => write!(f, "ExportOnly"),
309 }
310 }
311}
312
313pub trait HpkeKem: Kem {
315 const ID: KemId;
319}
320
321pub trait HpkeKdf: Kdf {
323 const ID: KdfId;
327}
328
329pub trait HpkeAead: Aead + IndCca2 {
331 const ID: AeadId;
335}
336
337#[derive(Debug, Eq, PartialEq)]
339pub enum HpkeError {
340 Seal(SealError),
342 Open(OpenError),
344 Kdf(KdfError),
346 Kem(KemError),
348 Import(ImportError),
350 Export(ExportError),
352 MessageLimitReached,
355 Bug(Bug),
357}
358
359impl fmt::Display for HpkeError {
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 match self {
362 Self::Seal(err) => write!(f, "{}", err),
363 Self::Open(err) => write!(f, "{}", err),
364 Self::Kdf(err) => write!(f, "{}", err),
365 Self::Kem(err) => write!(f, "{}", err),
366 Self::Import(err) => write!(f, "{}", err),
367 Self::Export(err) => write!(f, "{}", err),
368 Self::MessageLimitReached => write!(f, "message limit reached"),
369 Self::Bug(err) => write!(f, "{err}"),
370 }
371 }
372}
373
374impl core::error::Error for HpkeError {
375 fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
376 match self {
377 Self::Seal(err) => Some(err),
378 Self::Open(err) => Some(err),
379 Self::Kdf(err) => Some(err),
380 Self::Kem(err) => Some(err),
381 Self::Import(err) => Some(err),
382 Self::Export(err) => Some(err),
383 Self::MessageLimitReached => None,
384 Self::Bug(err) => Some(err),
385 }
386 }
387}
388
389impl From<SealError> for HpkeError {
390 fn from(err: SealError) -> Self {
391 Self::Seal(err)
392 }
393}
394
395impl From<OpenError> for HpkeError {
396 fn from(err: OpenError) -> Self {
397 Self::Open(err)
398 }
399}
400
401impl From<KdfError> for HpkeError {
402 fn from(err: KdfError) -> Self {
403 Self::Kdf(err)
404 }
405}
406
407impl From<KemError> for HpkeError {
408 fn from(err: KemError) -> Self {
409 Self::Kem(err)
410 }
411}
412
413impl From<ImportError> for HpkeError {
414 fn from(err: ImportError) -> Self {
415 Self::Import(err)
416 }
417}
418
419impl From<ExportError> for HpkeError {
420 fn from(err: ExportError) -> Self {
421 Self::Export(err)
422 }
423}
424
425impl From<Bug> for HpkeError {
426 fn from(err: Bug) -> Self {
427 Self::Bug(err)
428 }
429}
430
431impl From<MessageLimitReached> for HpkeError {
432 fn from(_err: MessageLimitReached) -> Self {
433 Self::MessageLimitReached
434 }
435}
436
437#[derive(Debug)]
441pub struct Hpke<K, F, A> {
442 _kem: PhantomData<fn() -> K>,
443 _kdf: PhantomData<fn() -> F>,
444 _aead: PhantomData<fn() -> A>,
445}
446
447impl<K, F, A> Hpke<K, F, A>
448where
449 K: HpkeKem,
450 F: HpkeKdf,
451 A: HpkeAead,
452{
453 #[allow(clippy::type_complexity)]
462 pub fn setup_send<'a, R: Csprng>(
463 rng: &mut R,
464 mode: Mode<'_, &K::DecapKey>,
465 pkR: &K::EncapKey,
466 info: impl IntoIterator<Item = &'a [u8]>,
467 ) -> Result<(K::Encap, SendCtx<K, F, A>), HpkeError> {
468 let (shared_secret, enc) = match mode {
469 Mode::Auth(skS) | Mode::AuthPsk(skS, _) => K::auth_encap::<R>(rng, pkR, skS)?,
470 Mode::Base | Mode::Psk(_) => K::encap::<R>(rng, pkR)?,
471 };
472 let ctx = Self::key_schedule(mode, &shared_secret, info)?;
473 Ok((enc, ctx.into_send_ctx()))
474 }
475
476 #[allow(clippy::type_complexity)]
494 pub fn setup_send_deterministically<'a>(
495 mode: Mode<'_, &K::DecapKey>,
496 pkR: &K::EncapKey,
497 info: impl IntoIterator<Item = &'a [u8]>,
498 skE: K::DecapKey,
499 ) -> Result<(K::Encap, SendCtx<K, F, A>), HpkeError> {
500 let (shared_secret, enc) = match mode {
501 Mode::Auth(skS) | Mode::AuthPsk(skS, _) => {
502 K::auth_encap_deterministically(pkR, skS, skE)?
503 }
504 Mode::Base | Mode::Psk(_) => K::encap_deterministically(pkR, skE)?,
505 };
506 let ctx = Self::key_schedule(mode, &shared_secret, info)?;
507 Ok((enc, ctx.into_send_ctx()))
508 }
509
510 pub fn setup_recv<'a>(
516 mode: Mode<'_, &K::EncapKey>,
517 enc: &K::Encap,
518 skR: &K::DecapKey,
519 info: impl IntoIterator<Item = &'a [u8]>,
520 ) -> Result<RecvCtx<K, F, A>, HpkeError> {
521 let shared_secret = match mode {
522 Mode::Auth(pkS) | Mode::AuthPsk(pkS, _) => K::auth_decap(enc, skR, pkS)?,
523 Mode::Base | Mode::Psk(_) => K::decap(enc, skR)?,
524 };
525 let ctx = Self::key_schedule(mode, &shared_secret, info)?;
526 Ok(ctx.into_recv_ctx())
527 }
528
529 #[rustfmt::skip]
540 const HPKE_SUITE_ID: &[u8] = &[
541 b'H',
542 b'P',
543 b'K',
544 b'E',
545 i2osp!(K::ID)[0], i2osp!(K::ID)[1],
546 i2osp!(F::ID)[0], i2osp!(F::ID)[1],
547 i2osp!(A::ID)[0], i2osp!(A::ID)[1],
548 ];
549
550 const DOMAIN: &[u8] = b"HPKE-v1";
553
554 fn key_schedule<'a, T>(
555 mode: Mode<'_, T>,
556 shared_secret: &K::Secret,
557 info: impl IntoIterator<Item = &'a [u8]>,
558 ) -> Result<Schedule<K, F, A>, HpkeError> {
559 let Psk { psk, psk_id } = mode.psk();
560
561 let psk_id_hash = Self::labeled_extract(b"", b"psk_id_hash", iter::once(psk_id).copied());
563
564 let info_hash = Self::labeled_extract(b"", b"info_hash", info);
566
567 let ks_ctx = [&[mode.id()], psk_id_hash.as_bytes(), info_hash.as_bytes()];
569
570 let secret = Self::labeled_extract(
572 shared_secret.raw_secret_bytes(),
573 b"secret",
574 iter::once(psk).copied(),
575 );
576
577 let key = Self::labeled_expand(&secret, b"key", ks_ctx)?;
579
580 let base_nonce = Self::labeled_expand(&secret, b"base_nonce", ks_ctx)?;
583
584 let exporter_secret = Self::labeled_expand(&secret, b"exp", ks_ctx)?;
587
588 Ok(Schedule {
589 key,
590 base_nonce,
591 exporter_secret,
592 _kem: PhantomData,
593 })
594 }
595
596 fn labeled_extract<'a>(
598 salt: &[u8],
599 label: &'static [u8],
600 ikm: impl IntoIterator<Item = &'a [u8]>,
601 ) -> Prk<F::PrkSize> {
602 let labeled_ikm = [Self::DOMAIN, Self::HPKE_SUITE_ID, label]
606 .into_iter()
607 .chain(ikm);
608 F::extract_multi(labeled_ikm, salt)
609 }
610
611 fn labeled_expand<'a, T: Expand>(
613 prk: &Prk<F::PrkSize>,
614 label: &'static [u8],
615 info: impl IntoIterator<Item = &'a [u8], IntoIter: Clone>,
616 ) -> Result<T, KdfError> {
617 let size = T::Size::U16.to_be_bytes();
622 let labeled_info = iter::once(size.as_slice())
623 .chain(iter::once(Self::DOMAIN))
624 .chain(iter::once(Self::HPKE_SUITE_ID))
625 .chain(iter::once(label))
626 .chain(
627 #[allow(clippy::map_identity)]
630 info.into_iter().map(|v| v),
631 );
632 T::expand_multi::<F, _>(prk, labeled_info)
633 }
634
635 fn labeled_expand_into<'a>(
637 out: &mut [u8],
638 prk: &Prk<F::PrkSize>,
639 label: &'static [u8],
640 info: impl IntoIterator<Item = &'a [u8], IntoIter: Clone>,
641 ) -> Result<(), KdfError> {
642 let size = u16::try_from(out.len())
647 .map_err(|_| KdfError::OutputTooLong)?
648 .to_be_bytes();
649 let labeled_info = iter::once(size.as_slice())
650 .chain(iter::once(Self::DOMAIN))
651 .chain(iter::once(Self::HPKE_SUITE_ID))
652 .chain(iter::once(label))
653 .chain(
654 #[allow(clippy::map_identity)]
657 info.into_iter().map(|v| v),
658 );
659 F::expand_multi(out, prk, labeled_info)
660 }
661}
662
663#[derive(Debug)]
664struct Schedule<K, F, A>
665where
666 K: HpkeKem,
667 F: HpkeKdf,
668 A: HpkeAead,
669{
670 key: KeyData<A>,
671 base_nonce: Nonce<A::NonceSize>,
672 exporter_secret: Prk<F::PrkSize>,
673 _kem: PhantomData<fn() -> K>,
674}
675
676impl<K, F, A> Schedule<K, F, A>
677where
678 K: HpkeKem,
679 F: HpkeKdf,
680 A: HpkeAead,
681{
682 fn into_send_ctx(self) -> SendCtx<K, F, A> {
683 SendCtx {
684 seal: Either::Right((self.key, self.base_nonce)),
685 export: ExportCtx::new(self.exporter_secret),
686 }
687 }
688
689 fn into_recv_ctx(self) -> RecvCtx<K, F, A> {
690 RecvCtx {
691 open: Either::Right((self.key, self.base_nonce)),
692 export: ExportCtx::new(self.exporter_secret),
693 }
694 }
695}
696
697#[derive(Debug)]
699enum Either<L, R> {
700 Left(L),
701 Right(R),
702}
703
704impl<L, R> Either<L, R> {
705 fn get_or_insert_left<F, E>(&mut self, f: F) -> Result<&mut L, E>
706 where
707 F: FnOnce(&R) -> Result<L, E>,
708 E: From<Bug>,
709 {
710 match self {
711 Self::Left(left) => Ok(left),
712 Self::Right(right) => {
713 *self = Self::Left(f(right)?);
714 match self {
715 Self::Left(left) => Ok(left),
716 Self::Right(_) => bug!("we just assigned `Self::Left`"),
717 }
718 }
719 }
720 }
721}
722
723type RawKey<A> = (KeyData<A>, Nonce<<A as Aead>::NonceSize>);
724
725pub struct SendCtx<K, F, A>
728where
729 K: HpkeKem,
730 F: HpkeKdf,
731 A: HpkeAead,
732{
733 seal: Either<SealCtx<A>, RawKey<A>>,
734 export: ExportCtx<K, F, A>,
735}
736
737impl<K, F, A> SendCtx<K, F, A>
738where
739 K: HpkeKem,
740 F: HpkeKdf,
741 A: HpkeAead,
742{
743 pub const OVERHEAD: usize = SealCtx::<A>::OVERHEAD;
745
746 #[doc(hidden)]
748 pub fn into_raw_parts(self) -> Option<(KeyData<A>, Nonce<A::NonceSize>)> {
749 match self.seal {
750 Either::Left(_) => None,
751 Either::Right((key, base_nonce)) => Some((key, base_nonce)),
752 }
753 }
754
755 fn seal_ctx(&mut self) -> Result<&mut SealCtx<A>, ImportError> {
756 self.seal
757 .get_or_insert_left(|(key, nonce)| SealCtx::new(key, nonce, Seq::ZERO))
758 }
759
760 pub fn seal(
767 &mut self,
768 dst: &mut [u8],
769 plaintext: &[u8],
770 additional_data: &[u8],
771 ) -> Result<Seq, HpkeError> {
772 self.seal_ctx()?.seal(dst, plaintext, additional_data)
773 }
774
775 pub fn seal_in_place(
778 &mut self,
779 data: impl AsMut<[u8]>,
780 tag: &mut [u8],
781 additional_data: &[u8],
782 ) -> Result<Seq, HpkeError> {
783 self.seal_ctx()?.seal_in_place(data, tag, additional_data)
784 }
785
786 pub fn export<T>(&self, context: &[u8]) -> Result<T, KdfError>
788 where
789 T: Expand,
790 {
791 self.export.export(context)
792 }
793
794 pub fn export_into(&self, out: &mut [u8], context: &[u8]) -> Result<(), KdfError> {
797 self.export.export_into(out, context)
798 }
799}
800
801impl<K, F, A> fmt::Debug for SendCtx<K, F, A>
802where
803 K: HpkeKem,
804 F: HpkeKdf,
805 A: HpkeAead + fmt::Debug,
806{
807 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
808 f.debug_struct("SendCtx")
809 .field("seal", &self.seal)
810 .field("export", &self.export)
811 .finish()
812 }
813}
814
815#[doc(hidden)]
820pub struct SealCtx<A: HpkeAead> {
821 aead: A,
822 base_nonce: Nonce<A::NonceSize>,
823 seq: Seq,
825}
826
827impl<A: HpkeAead> SealCtx<A> {
828 pub const OVERHEAD: usize = A::OVERHEAD;
830
831 #[doc(hidden)]
833 pub fn new(
834 key: &KeyData<A>,
835 base_nonce: &Nonce<A::NonceSize>,
836 seq: Seq,
837 ) -> Result<Self, ImportError> {
838 let key = A::Key::import(key.as_bytes())?;
839 Ok(Self {
840 aead: A::new(&key),
841 base_nonce: base_nonce.clone(),
842 seq,
843 })
844 }
845
846 fn compute_nonce(&self) -> Result<Nonce<A::NonceSize>, MessageLimitReached> {
847 self.seq.compute_nonce::<A::NonceSize>(&self.base_nonce)
848 }
849
850 fn increment_seq(&mut self) -> Result<Seq, Bug> {
851 self.seq.increment::<A::NonceSize>()
852 }
853
854 pub fn seal(
861 &mut self,
862 dst: &mut [u8],
863 plaintext: &[u8],
864 additional_data: &[u8],
865 ) -> Result<Seq, HpkeError> {
866 let nonce = self.compute_nonce()?;
867 self.aead.seal(dst, &nonce, plaintext, additional_data)?;
868 let prev = self.increment_seq()?;
869 Ok(prev)
870 }
871
872 pub fn seal_in_place(
875 &mut self,
876 mut data: impl AsMut<[u8]>,
877 tag: &mut [u8],
878 additional_data: &[u8],
879 ) -> Result<Seq, HpkeError> {
880 let nonce = self.compute_nonce()?;
881 self.aead
882 .seal_in_place(&nonce, data.as_mut(), tag, additional_data)?;
883 let prev = self.increment_seq()?;
884 Ok(prev)
885 }
886
887 pub fn seq(&self) -> Seq {
889 self.seq
890 }
891}
892
893impl<A: HpkeAead + fmt::Debug> fmt::Debug for SealCtx<A> {
894 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
895 f.debug_struct("SealCtx")
896 .field("aead", &self.aead)
897 .field("base_nonce", &self.base_nonce)
898 .field("seq", &self.seq)
899 .finish()
900 }
901}
902
903pub struct RecvCtx<K, F, A>
906where
907 K: HpkeKem,
908 F: HpkeKdf,
909 A: HpkeAead,
910{
911 open: Either<OpenCtx<A>, RawKey<A>>,
912 export: ExportCtx<K, F, A>,
913}
914
915impl<K, F, A> RecvCtx<K, F, A>
916where
917 K: HpkeKem,
918 F: HpkeKdf,
919 A: HpkeAead,
920{
921 pub const OVERHEAD: usize = OpenCtx::<A>::OVERHEAD;
923
924 #[doc(hidden)]
926 pub fn into_raw_parts(self) -> Option<(KeyData<A>, Nonce<A::NonceSize>)> {
927 match self.open {
928 Either::Left(_) => None,
929 Either::Right((key, base_nonce)) => Some((key, base_nonce)),
930 }
931 }
932
933 fn open_ctx(&mut self) -> Result<&mut OpenCtx<A>, ImportError> {
934 self.open
935 .get_or_insert_left(|(key, nonce)| OpenCtx::new(key, nonce, Seq::ZERO))
936 }
937
938 pub fn open(
945 &mut self,
946 dst: &mut [u8],
947 ciphertext: &[u8],
948 additional_data: &[u8],
949 ) -> Result<(), HpkeError> {
950 self.open_ctx()?.open(dst, ciphertext, additional_data)
951 }
952
953 pub fn open_at(
960 &mut self,
961 dst: &mut [u8],
962 ciphertext: &[u8],
963 additional_data: &[u8],
964 seq: Seq,
965 ) -> Result<(), HpkeError> {
966 self.open_ctx()?
967 .open_at(dst, ciphertext, additional_data, seq)
968 }
969
970 pub fn open_in_place(
972 &mut self,
973 data: impl AsMut<[u8]>,
974 tag: &[u8],
975 additional_data: &[u8],
976 ) -> Result<(), HpkeError> {
977 self.open_ctx()?.open_in_place(data, tag, additional_data)
978 }
979
980 pub fn open_in_place_at(
983 &mut self,
984 data: impl AsMut<[u8]>,
985 tag: &[u8],
986 additional_data: &[u8],
987 seq: Seq,
988 ) -> Result<(), HpkeError> {
989 self.open_ctx()?
990 .open_in_place_at(data, tag, additional_data, seq)
991 }
992
993 pub fn export<T>(&self, context: &[u8]) -> Result<T, KdfError>
995 where
996 T: Expand,
997 {
998 self.export.export(context)
999 }
1000
1001 pub fn export_into(&self, out: &mut [u8], context: &[u8]) -> Result<(), KdfError> {
1004 self.export.export_into(out, context)
1005 }
1006}
1007
1008impl<K, F, A> fmt::Debug for RecvCtx<K, F, A>
1009where
1010 K: HpkeKem,
1011 F: HpkeKdf,
1012 A: HpkeAead + fmt::Debug,
1013{
1014 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1015 f.debug_struct("RecvCtx")
1016 .field("open", &self.open)
1017 .field("export", &self.export)
1018 .finish()
1019 }
1020}
1021
1022#[doc(hidden)]
1027pub struct OpenCtx<A: HpkeAead> {
1028 aead: A,
1029 base_nonce: Nonce<A::NonceSize>,
1030 seq: Seq,
1032}
1033
1034impl<A: HpkeAead> OpenCtx<A> {
1035 pub const OVERHEAD: usize = A::OVERHEAD;
1037
1038 #[doc(hidden)]
1040 pub fn new(
1041 key: &KeyData<A>,
1042 base_nonce: &Nonce<A::NonceSize>,
1043 seq: Seq,
1044 ) -> Result<Self, ImportError> {
1045 let key = A::Key::import(key.as_bytes())?;
1046 Ok(Self {
1047 aead: A::new(&key),
1048 base_nonce: base_nonce.clone(),
1049 seq,
1050 })
1051 }
1052
1053 fn increment_seq(&mut self) -> Result<Seq, Bug> {
1054 self.seq.increment::<A::NonceSize>()
1055 }
1056
1057 pub fn open(
1063 &mut self,
1064 dst: &mut [u8],
1065 ciphertext: &[u8],
1066 additional_data: &[u8],
1067 ) -> Result<(), HpkeError> {
1068 self.open_at(dst, ciphertext, additional_data, self.seq)?;
1069 self.increment_seq()?;
1070 Ok(())
1071 }
1072
1073 pub fn open_at(
1080 &self,
1081 dst: &mut [u8],
1082 ciphertext: &[u8],
1083 additional_data: &[u8],
1084 seq: Seq,
1085 ) -> Result<(), HpkeError> {
1086 let nonce = seq.compute_nonce::<A::NonceSize>(&self.base_nonce)?;
1087 self.aead.open(dst, &nonce, ciphertext, additional_data)?;
1088 Ok(())
1089 }
1090
1091 pub fn open_in_place(
1093 &mut self,
1094 mut data: impl AsMut<[u8]>,
1095 tag: &[u8],
1096 additional_data: &[u8],
1097 ) -> Result<(), HpkeError> {
1098 self.open_in_place_at(data.as_mut(), tag, additional_data, self.seq)?;
1099 self.increment_seq()?;
1100 Ok(())
1101 }
1102
1103 pub fn open_in_place_at(
1106 &self,
1107 mut data: impl AsMut<[u8]>,
1108 tag: &[u8],
1109 additional_data: &[u8],
1110 seq: Seq,
1111 ) -> Result<(), HpkeError> {
1112 let nonce = seq.compute_nonce::<A::NonceSize>(&self.base_nonce)?;
1113 self.aead
1114 .open_in_place(&nonce, data.as_mut(), tag, additional_data)?;
1115 Ok(())
1116 }
1117}
1118
1119impl<A: HpkeAead + fmt::Debug> fmt::Debug for OpenCtx<A> {
1120 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1121 f.debug_struct("OpenCtx")
1122 .field("aead", &self.aead)
1123 .field("base_nonce", &self.base_nonce)
1124 .field("seq", &self.seq)
1125 .finish()
1126 }
1127}
1128
1129#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1131pub struct MessageLimitReached;
1132
1133impl fmt::Display for MessageLimitReached {
1134 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1135 f.write_str("message limit reached")
1136 }
1137}
1138
1139impl core::error::Error for MessageLimitReached {}
1140
1141#[derive(Copy, Clone, Debug, Default, Hash, Eq, PartialEq, Ord, PartialOrd)]
1143pub struct Seq {
1144 seq: u64,
1157}
1158
1159impl Seq {
1160 pub const ZERO: Self = Self::new(0);
1162
1163 #[inline]
1165 pub const fn new(seq: u64) -> Self {
1166 Self { seq }
1167 }
1168
1169 #[inline]
1171 pub const fn to_u64(self) -> u64 {
1172 self.seq
1173 }
1174
1175 #[doc(hidden)]
1179 pub const fn max<N: ArrayLength>() -> u64 {
1180 let shift = 8usize.saturating_mul(N::USIZE);
1182 match 1u64.checked_shl(shift as u32) {
1183 Some(v) => v.saturating_sub(1),
1184 None => u64::MAX,
1185 }
1186 }
1187
1188 fn increment<N: ArrayLength>(&mut self) -> Result<Self, Bug> {
1191 if self.seq >= Self::max::<N>() {
1194 bug!("`Seq::increment` called after limit reached");
1197 }
1198 let prev = self.seq;
1200 self.seq = prev
1201 .checked_add(1)
1202 .assume("`Seq` overflow should be impossible")?;
1203 Ok(Self { seq: prev })
1204 }
1205
1206 fn compute_nonce<N: ArrayLength>(
1208 self,
1209 base_nonce: &Nonce<N>,
1210 ) -> Result<Nonce<N>, MessageLimitReached> {
1211 if self.seq >= Self::max::<N>() {
1212 Err(MessageLimitReached)
1213 } else {
1214 let seq_bytes = i2osp!(self.seq, N);
1216 Ok(base_nonce ^ &Nonce::from_bytes(seq_bytes))
1218 }
1219 }
1220}
1221
1222impl fmt::Display for Seq {
1223 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1224 write!(f, "{}", self.seq)
1225 }
1226}
1227
1228struct ExportCtx<K, F, A>
1229where
1230 K: HpkeKem,
1231 F: HpkeKdf,
1232 A: HpkeAead,
1233{
1234 exporter_secret: Prk<F::PrkSize>,
1235 _etc: PhantomData<fn() -> (K, A)>,
1236}
1237
1238impl<K, F, A> ExportCtx<K, F, A>
1239where
1240 K: HpkeKem,
1241 F: HpkeKdf,
1242 A: HpkeAead,
1243{
1244 fn new(exporter_secret: Prk<F::PrkSize>) -> Self {
1245 Self {
1246 exporter_secret,
1247 _etc: PhantomData,
1248 }
1249 }
1250
1251 fn export<T>(&self, context: &[u8]) -> Result<T, KdfError>
1253 where
1254 T: Expand,
1255 {
1256 Hpke::<K, F, A>::labeled_expand(&self.exporter_secret, b"sec", [context])
1260 }
1261
1262 fn export_into(&self, out: &mut [u8], context: &[u8]) -> Result<(), KdfError> {
1264 Hpke::<K, F, A>::labeled_expand_into(out, &self.exporter_secret, b"sec", [context])
1268 }
1269}
1270
1271impl<K, F, A> fmt::Debug for ExportCtx<K, F, A>
1272where
1273 K: HpkeKem,
1274 F: HpkeKdf,
1275 A: HpkeAead,
1276{
1277 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1278 f.debug_struct("ExportCtx").finish_non_exhaustive()
1279 }
1280}
1281
1282#[cfg(test)]
1283mod tests {
1284 #![allow(clippy::panic)]
1285
1286 use std::{collections::HashSet, ops::RangeInclusive};
1287
1288 use typenum::{U1, U2};
1289
1290 use super::*;
1291
1292 #[test]
1295 fn test_seq_compute_nonce() {
1296 let base = Nonce::<U1>::try_from_slice(&[0xfe]).expect("should be able to create nonce");
1297 let cases = [
1298 (0, Ok(&[0xfe])),
1299 (1, Ok(&[0xff])),
1300 (2, Ok(&[0xfc])),
1301 (4, Ok(&[0xfa])),
1302 (254, Ok(&[0x00])),
1303 (255, Err(MessageLimitReached)),
1304 (256, Err(MessageLimitReached)),
1305 (257, Err(MessageLimitReached)),
1306 (u64::MAX, Err(MessageLimitReached)),
1307 ];
1308 for (input, output) in cases {
1309 let got = Seq::new(input).compute_nonce::<U1>(&base);
1310 let want = output.map(|s| Nonce::try_from_slice(s).expect("unable to create nonce"));
1311 assert_eq!(got, want, "seq = {input}");
1312 }
1313 }
1314
1315 #[test]
1317 fn test_seq_unique_nonce() {
1318 let base =
1319 Nonce::<U2>::try_from_slice(&[0xfe, 0xfe]).expect("should be able to create nonce");
1320 let mut seen = HashSet::new();
1321 for v in 0..u16::MAX {
1322 let got = Seq::new(u64::from(v))
1323 .compute_nonce::<U2>(&base)
1324 .expect("unable to create nonce");
1325 assert!(seen.insert(got), "duplicate nonce: {got:?}");
1326 }
1327 }
1328
1329 #[test]
1330 fn test_invalid_psk() {
1331 let err = Psk::new(&[], &[]).expect_err("should get `InvalidPsk`");
1332 assert_eq!(err, InvalidPsk);
1333 }
1334
1335 #[test]
1336 fn test_psk_ct_eq() {
1337 let cases = [
1338 (true, ("abc", "123"), ("abc", "123")),
1339 (false, ("a", "b"), ("a", "x")),
1340 (false, ("a", "b"), ("x", "b")),
1341 (false, ("a", "b"), ("c", "d")),
1342 ];
1343 for (pass, lhs, rhs) in cases {
1344 let lhs = Psk::new(lhs.0.as_bytes(), lhs.1.as_bytes()).expect("should not fail");
1345 let rhs = Psk::new(rhs.0.as_bytes(), rhs.1.as_bytes()).expect("should not fail");
1346 assert_eq!(pass, bool::from(lhs.ct_eq(&rhs)));
1347 }
1348 }
1349
1350 #[test]
1352 fn test_aead_id() {
1353 let unassigned = 0x0004..=0xFFFE;
1354 for id in unassigned {
1355 let want = AeadId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
1356 let encoded = want.to_be_bytes();
1357 let got = AeadId::try_from_be_bytes(encoded).unwrap_or_else(|err| {
1358 panic!("should be able to decode unassigned `AeadId` {id}: {err}")
1359 });
1360 assert_eq!(got, want);
1361 }
1362 }
1363
1364 #[test]
1366 fn test_kdf_id() {
1367 let unassigned = 0x0004..=0xFFFF;
1368 for id in unassigned {
1369 let want = KdfId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
1370 let encoded = want.to_be_bytes();
1371 let got = KdfId::try_from_be_bytes(encoded).unwrap_or_else(|err| {
1372 panic!("should be able to decode unassigned `KdfId` {id}: {err}")
1373 });
1374 assert_eq!(got, want);
1375 }
1376 }
1377
1378 #[test]
1380 fn test_kem_id() {
1381 let unassigned: [RangeInclusive<u16>; 6] = [
1382 0x0001..=0x000F,
1383 0x0017..=0x001F,
1384 0x0022..=0x002F,
1385 0x0031..=0x0039,
1386 0x0043..=0x6479,
1387 0x647b..=0xFFFF,
1388 ];
1389 for id in unassigned.into_iter().flatten() {
1390 let want = KemId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
1391 let encoded = want.to_be_bytes();
1392 let got = KemId::try_from_be_bytes(encoded).unwrap_or_else(|err| {
1393 panic!("should be able to decode unassigned `KemId` {id}: {err}")
1394 });
1395 assert_eq!(got, want);
1396 }
1397 }
1398}