aranya_crypto_core/
hpke.rs

1//! Hybrid Public Key Encryption per [RFC 9180].
2//!
3//! ## Notation
4//!
5//! - `sk`: a private key; shorthand for "*S*ecret *K*ey"
6//! - `pk`: a public key; shorthand for "*P*ublic *K*ey"
7//! - `skR`, `pkR`: a receiver's secret or public key
8//! - `skS`, `pkS`: a sender's secret or public key
9//! - `skE`, `pkE`: an ephemeral secret or public key
10//! - `encap`, `decap`: see [Encapsulate](#Encapsulate).
11//!
12//! [RFC 9180]: https://www.rfc-editor.org/rfc/rfc9180.html
13
14#![forbid(unsafe_code)]
15// We use the same variable names used in the HPKE RFC.
16#![allow(non_snake_case)]
17
18use core::{
19    fmt::{self, Debug, Display},
20    marker::PhantomData,
21    num::NonZeroU16,
22    result::Result,
23};
24
25use aranya_buggy::{bug, Bug, BugExt};
26use generic_array::ArrayLength;
27use subtle::{Choice, ConstantTimeEq};
28
29use crate::{
30    aead::{Aead, IndCca2, KeyData, Nonce, OpenError, SealError},
31    csprng::Csprng,
32    import::{ExportError, Import, ImportError},
33    kdf::{Context, Expand, Kdf, KdfError, Prk},
34    kem::{Kem, KemError},
35    AlgId,
36};
37
38/// Converts `v` to a big-endian byte array.
39macro_rules! i2osp {
40    ($v:expr) => {
41        $v.to_be_bytes()
42    };
43    ($v:expr, $n:ty) => {{
44        let src = $v.to_be_bytes();
45        let mut dst = generic_array::GenericArray::<u8, $n>::default();
46        // Copy `src` into `dst`, padding with zeros on the
47        // left.
48        //
49        // NB: the compiler knows how to optimize this. Don't
50        // rewrite it without verifying the assembly.
51        let idx = dst.len().abs_diff(src.len());
52        if dst.len() >= src.len() {
53            dst[idx..].copy_from_slice(&src);
54        } else {
55            dst.copy_from_slice(&src[idx..]);
56        }
57        dst
58    }};
59}
60
61/// An HPKE operation mode.
62#[cfg_attr(test, derive(Debug))]
63pub enum Mode<'a, T> {
64    /// The most basic operation mode.
65    Base,
66    /// Extends the base mode by allowing the recipient to
67    /// authenticate that the sender possessed a particular
68    /// pre-shared key.
69    Psk(Psk<'a>),
70    /// Extends the base mode by allowing the recipient to
71    /// authenticate that the sender possessed a particular
72    /// private key.
73    Auth(T),
74    /// A combination of [`Mode::Auth`] and [`Mode::Psk`].
75    AuthPsk(T, Psk<'a>),
76}
77
78impl<T> Display for Mode<'_, T> {
79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        match self {
81            Self::Base => write!(f, "mode_base"),
82            Self::Psk(_) => write!(f, "mode_psk"),
83            Self::Auth(_) => write!(f, "mode_auth"),
84            Self::AuthPsk(_, _) => write!(f, "mode_auth_psk"),
85        }
86    }
87}
88
89impl<'a, T> Mode<'a, T> {
90    // The default `psk` and `psk_id` are empty strings. See
91    // section 5.1.
92    const DEFAULT_PSK: Psk<'static> = Psk {
93        psk: &[],
94        psk_id: &[],
95    };
96
97    /// Converts from `Mode<'_, T>` to `Mode<'_, &T>`.
98    pub const fn as_ref(&self) -> Mode<'_, &T> {
99        match *self {
100            Self::Base => Mode::Base,
101            Self::Psk(psk) => Mode::Psk(psk),
102            Self::Auth(ref k) => Mode::Auth(k),
103            Self::AuthPsk(ref k, psk) => Mode::AuthPsk(k, psk),
104        }
105    }
106
107    fn psk(&self) -> &Psk<'a> {
108        match self {
109            Mode::Psk(psk) => psk,
110            Mode::AuthPsk(_, psk) => psk,
111            _ => &Self::DEFAULT_PSK,
112        }
113    }
114
115    const fn id(&self) -> u8 {
116        match self {
117            Self::Base => 0x00,
118            Self::Psk(_) => 0x01,
119            Self::Auth(_) => 0x02,
120            Self::AuthPsk(_, _) => 0x03,
121        }
122    }
123}
124
125/// The PSK or its ID are empty.
126#[derive(Copy, Clone, Debug, Eq, PartialEq)]
127pub struct InvalidPsk;
128
129impl Display for InvalidPsk {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        f.write_str("invalid pre-shared key: PSK or PSK ID are empty")
132    }
133}
134
135impl core::error::Error for InvalidPsk {}
136
137/// A pre-shared key and its ID.
138#[cfg_attr(test, derive(Debug))]
139#[derive(Copy, Clone)]
140pub struct Psk<'a> {
141    /// The pre-shared key.
142    psk: &'a [u8],
143    // The pre-shared key's ID.
144    psk_id: &'a [u8],
145}
146
147impl<'a> Psk<'a> {
148    /// Creates a [`Psk`] from a pre-shared key and its ID.
149    pub fn new(psk: &'a [u8], psk_id: &'a [u8]) -> Result<Self, InvalidPsk> {
150        // See Section 5.1, `VerifyPSKInputs`.
151        if psk.is_empty() || psk_id.is_empty() {
152            Err(InvalidPsk)
153        } else {
154            Ok(Self { psk, psk_id })
155        }
156    }
157}
158
159impl ConstantTimeEq for Psk<'_> {
160    fn ct_eq(&self, other: &Self) -> Choice {
161        self.psk.ct_eq(other.psk) & self.psk_id.ct_eq(other.psk_id)
162    }
163}
164
165/// KEM algorithm identifiers per [IANA].
166///
167/// [IANA]: https://www.iana.org/assignments/hpke/hpke.xhtml
168#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, AlgId)]
169pub enum KemId {
170    /// DHKEM(P-256, HKDF-SHA256).
171    #[alg_id(0x0010)]
172    DhKemP256HkdfSha256,
173    /// DHKEM(P-384, HKDF-SHA384).
174    #[alg_id(0x0011)]
175    DhKemP384HkdfSha384,
176    /// DHKEM(P-521, HKDF-SHA512).
177    #[alg_id(0x0012)]
178    DhKemP521HkdfSha512,
179    /// DHKEM(CP-256, HKDF-SHA256)
180    #[alg_id(0x0013)]
181    DhKemCp256HkdfSha256,
182    /// DHKEM(CP-384, HKDF-SHA384)
183    #[alg_id(0x0014)]
184    DhKemCp384HkdfSha384,
185    /// DHKEM(CP-521, HKDF-SHA512)
186    #[alg_id(0x0015)]
187    DhKemCp521HkdfSha512,
188    /// DHKEM(secp256k1, HKDF-SHA256)
189    #[alg_id(0x0016)]
190    DhKemSecp256k1HkdfSha256,
191    /// DHKEM(X25519, HKDF-SHA256).
192    #[alg_id(0x0020)]
193    DhKemX25519HkdfSha256,
194    /// DHKEM(X448, HKDF-SHA512).
195    #[alg_id(0x0021)]
196    DhKemX448HkdfSha512,
197    /// X25519Kyber768Draft00
198    #[alg_id(0x0030)]
199    X25519Kyber768Draft00,
200    /// Some other KEM.
201    ///
202    /// Non-zero since 0x0000 is marked as 'reserved'.
203    #[alg_id(Other)]
204    Other(NonZeroU16),
205}
206
207impl Display for KemId {
208    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209        match self {
210            Self::DhKemP256HkdfSha256 => write!(f, "DHKEM(P-256, HKDF-SHA256)"),
211            Self::DhKemP384HkdfSha384 => write!(f, "DHKEM(P-384, HKDF-SHA384)"),
212            Self::DhKemP521HkdfSha512 => write!(f, "DHKEM(P-521, HKDF-SHA512)"),
213            Self::DhKemCp256HkdfSha256 => write!(f, "DHKEM(CP-256, HKDF-SHA256)"),
214            Self::DhKemCp384HkdfSha384 => write!(f, "DHKEM(CP-384, HKDF-SHA384)"),
215            Self::DhKemCp521HkdfSha512 => write!(f, "DHKEM(CP-521, HKDF-SHA512)"),
216            Self::DhKemSecp256k1HkdfSha256 => write!(f, "DHKEM(secp256k1, HKDF-SHA256)"),
217            Self::DhKemX25519HkdfSha256 => write!(f, "DHKEM(X25519, HKDF-SHA256)"),
218            Self::DhKemX448HkdfSha512 => write!(f, "DHKEM(X448, HKDF-SHA512)"),
219            Self::X25519Kyber768Draft00 => write!(f, "X25519Kyber768Draft00"),
220            Self::Other(id) => write!(f, "Kem({:#02x})", id),
221        }
222    }
223}
224
225/// KDF algorithm identifiers per [IANA].
226///
227/// [IANA]: https://www.iana.org/assignments/hpke/hpke.xhtml
228#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, AlgId)]
229pub enum KdfId {
230    /// HKDF-SHA256.
231    #[alg_id(0x0001)]
232    HkdfSha256,
233    /// HKDF-SHA384.
234    #[alg_id(0x0002)]
235    HkdfSha384,
236    /// HKDF-SHA512.
237    #[alg_id(0x0003)]
238    HkdfSha512,
239    /// Some other KDF.
240    ///
241    /// Non-zero since 0x0000 is marked as 'reserved'.
242    #[alg_id(Other)]
243    Other(NonZeroU16),
244}
245
246impl Display for KdfId {
247    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248        match self {
249            Self::HkdfSha256 => write!(f, "HkdfSha256"),
250            Self::HkdfSha384 => write!(f, "HkdfSha384"),
251            Self::HkdfSha512 => write!(f, "HkdfSha512"),
252            Self::Other(id) => write!(f, "Kdf({:#02x})", id),
253        }
254    }
255}
256
257/// AEAD algorithm identifiers per [IANA].
258///
259/// [IANA]: https://www.iana.org/assignments/hpke/hpke.xhtml
260#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq, AlgId)]
261pub enum AeadId {
262    /// AES-128-GCM.
263    #[alg_id(0x0001)]
264    Aes128Gcm,
265    /// AES-256-GCM.
266    #[alg_id(0x0002)]
267    Aes256Gcm,
268    /// ChaCha20Poly1305.
269    #[alg_id(0x0003)]
270    ChaCha20Poly1305,
271    /// CMT-1 AES-256-GCM.
272    ///
273    /// Not an official RFC ID.
274    #[alg_id(0xfffd)]
275    Cmt1Aes256Gcm,
276    /// CMT-4 AES-256-GCM.
277    ///
278    /// Not an official RFC ID.
279    #[alg_id(0xfffe)]
280    Cmt4Aes256Gcm,
281    /// Some other AEAD.
282    ///
283    /// Non-zero since 0x0000 is marked as 'reserved'.
284    #[alg_id(Other)]
285    Other(NonZeroU16),
286    /// Export-only AEAD.
287    #[alg_id(0xffff)]
288    ExportOnly,
289}
290
291impl Display for AeadId {
292    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293        match self {
294            Self::Aes128Gcm => write!(f, "Aes128Gcm"),
295            Self::Aes256Gcm => write!(f, "Aes256Gcm"),
296            Self::ChaCha20Poly1305 => write!(f, "ChaCha20Poly1305"),
297            Self::Cmt1Aes256Gcm => write!(f, "Cmt1Aes256Gcm"),
298            Self::Cmt4Aes256Gcm => write!(f, "Cmt4Aes256Gcm"),
299            Self::Other(id) => write!(f, "Aead({:#02x})", id),
300            Self::ExportOnly => write!(f, "ExportOnly"),
301        }
302    }
303}
304
305/// An error from an [`Hpke`].
306#[derive(Debug, Eq, PartialEq)]
307pub enum HpkeError {
308    /// An AEAD seal operation failed.
309    Seal(SealError),
310    /// An AEAD open operation failed.
311    Open(OpenError),
312    /// A KDF operation failed.
313    Kdf(KdfError),
314    /// A KEM operation failed.
315    Kem(KemError),
316    /// A key could not be imported.
317    Import(ImportError),
318    /// A key could not be exported.
319    Export(ExportError),
320    /// The encryption context has been used to send the maximum
321    /// number of messages.
322    MessageLimitReached,
323    /// An internal bug was discovered.
324    Bug(Bug),
325}
326
327impl Display for HpkeError {
328    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329        match self {
330            Self::Seal(err) => write!(f, "{}", err),
331            Self::Open(err) => write!(f, "{}", err),
332            Self::Kdf(err) => write!(f, "{}", err),
333            Self::Kem(err) => write!(f, "{}", err),
334            Self::Import(err) => write!(f, "{}", err),
335            Self::Export(err) => write!(f, "{}", err),
336            Self::MessageLimitReached => write!(f, "message limit reached"),
337            Self::Bug(err) => write!(f, "{err}"),
338        }
339    }
340}
341
342impl core::error::Error for HpkeError {
343    fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
344        match self {
345            Self::Seal(err) => Some(err),
346            Self::Open(err) => Some(err),
347            Self::Kdf(err) => Some(err),
348            Self::Kem(err) => Some(err),
349            Self::Import(err) => Some(err),
350            Self::Export(err) => Some(err),
351            Self::MessageLimitReached => None,
352            Self::Bug(err) => Some(err),
353        }
354    }
355}
356
357impl From<SealError> for HpkeError {
358    fn from(err: SealError) -> Self {
359        Self::Seal(err)
360    }
361}
362
363impl From<OpenError> for HpkeError {
364    fn from(err: OpenError) -> Self {
365        Self::Open(err)
366    }
367}
368
369impl From<KdfError> for HpkeError {
370    fn from(err: KdfError) -> Self {
371        Self::Kdf(err)
372    }
373}
374
375impl From<KemError> for HpkeError {
376    fn from(err: KemError) -> Self {
377        Self::Kem(err)
378    }
379}
380
381impl From<ImportError> for HpkeError {
382    fn from(err: ImportError) -> Self {
383        Self::Import(err)
384    }
385}
386
387impl From<ExportError> for HpkeError {
388    fn from(err: ExportError) -> Self {
389        Self::Export(err)
390    }
391}
392
393impl From<Bug> for HpkeError {
394    fn from(err: Bug) -> Self {
395        Self::Bug(err)
396    }
397}
398
399impl From<MessageLimitReached> for HpkeError {
400    fn from(_err: MessageLimitReached) -> Self {
401        Self::MessageLimitReached
402    }
403}
404
405/// Hybrid Public Key Encryption (HPKE) per [RFC 9180].
406///
407/// [RFC 9180]: <https://www.rfc-editor.org/rfc/rfc9180.html>
408pub struct Hpke<K, F, A> {
409    _kem: PhantomData<K>,
410    _kdf: PhantomData<F>,
411    _aead: PhantomData<A>,
412}
413
414impl<K: Kem, F: Kdf, A: Aead + IndCca2> Hpke<K, F, A> {
415    /// Creates a randomized encryption context for encrypting
416    /// messages for the receiver, `pkR`.
417    ///
418    /// It returns the encryption context and an encapsulated
419    /// symmetric key which can be used by the receiver to
420    /// decrypt messages.
421    ///
422    /// The `info` parameter provides contextual binding.
423    #[allow(clippy::type_complexity)]
424    pub fn setup_send<R: Csprng>(
425        rng: &mut R,
426        mode: Mode<'_, &K::DecapKey>,
427        pkR: &K::EncapKey,
428        info: &[u8],
429    ) -> Result<(K::Encap, SendCtx<K, F, A>), HpkeError> {
430        let (shared_secret, enc) = match mode {
431            Mode::Auth(skS) | Mode::AuthPsk(skS, _) => K::auth_encap::<R>(rng, pkR, skS)?,
432            Mode::Base | Mode::Psk(_) => K::encap::<R>(rng, pkR)?,
433        };
434        let ctx = Self::key_schedule(mode, &shared_secret, info)?;
435        Ok((enc, ctx.into_send_ctx()))
436    }
437
438    /// Deterministically creates an encryption context for
439    /// encrypting messages for the receiver, `pkR`.
440    ///
441    /// It returns the encryption context and an encapsulated
442    /// symmetric key which can be used by the receiver to
443    /// decrypt messages.
444    ///
445    /// The `info` parameter provides contextual binding.
446    ///
447    /// # Warning
448    ///
449    /// The security of this function relies on choosing the
450    /// correct value for `skE`. It is a catastrophic error if
451    /// you do not ensure all of the following properties:
452    ///
453    /// - it must be cryptographically secure
454    /// - it must never be reused
455    #[allow(clippy::type_complexity)]
456    pub fn setup_send_deterministically(
457        mode: Mode<'_, &K::DecapKey>,
458        pkR: &K::EncapKey,
459        info: &[u8],
460        skE: K::DecapKey,
461    ) -> Result<(K::Encap, SendCtx<K, F, A>), HpkeError> {
462        let (shared_secret, enc) = match mode {
463            Mode::Auth(skS) | Mode::AuthPsk(skS, _) => {
464                K::auth_encap_deterministically(pkR, skS, skE)?
465            }
466            Mode::Base | Mode::Psk(_) => K::encap_deterministically(pkR, skE)?,
467        };
468        let ctx = Self::key_schedule(mode, &shared_secret, info)?;
469        Ok((enc, ctx.into_send_ctx()))
470    }
471
472    /// Creates an encryption context that can decrypt messages
473    /// from a particular sender (the creator of `enc`).
474    ///
475    /// The `mode` and `info` parameters must be the same
476    /// parameters used by the sender.
477    pub fn setup_recv(
478        mode: Mode<'_, &K::EncapKey>,
479        enc: &K::Encap,
480        skR: &K::DecapKey,
481        info: &[u8],
482    ) -> Result<RecvCtx<K, F, A>, HpkeError> {
483        let shared_secret = match mode {
484            Mode::Auth(pkS) | Mode::AuthPsk(pkS, _) => K::auth_decap(enc, skR, pkS)?,
485            Mode::Base | Mode::Psk(_) => K::decap(enc, skR)?,
486        };
487        let ctx = Self::key_schedule(mode, &shared_secret, info)?;
488        Ok(ctx.into_recv_ctx())
489    }
490
491    /// The "HPKE" suite ID.
492    ///
493    /// ```text
494    /// suite_id = concat(
495    ///     "HPKE",
496    ///     I2OSP(kem_id, 2),
497    ///     I2OSP(kdf_id, 2),
498    ///     I2OSP(aead_id, 2),
499    /// )
500    /// ```
501    #[rustfmt::skip]
502    const HPKE_SUITE_ID: [u8; 10] = [
503        b'H',
504        b'P',
505        b'K',
506        b'E',
507        i2osp!(K::ID)[0], i2osp!(K::ID)[1],
508        i2osp!(F::ID)[0], i2osp!(F::ID)[1],
509        i2osp!(A::ID)[0], i2osp!(A::ID)[1],
510    ];
511
512    fn key_schedule<T>(
513        mode: Mode<'_, T>,
514        shared_secret: &K::Secret,
515        info: &[u8],
516    ) -> Result<Schedule<K, F, A>, HpkeError> {
517        let Psk { psk, psk_id } = mode.psk();
518
519        //  psk_id_hash = LabeledExtract("", "psk_id_hash", psk_id)
520        let psk_id_hash = Self::labeled_extract(b"", "psk_id_hash", psk_id);
521
522        //  info_hash = LabeledExtract("", "info_hash", info)
523        let info_hash = Self::labeled_extract(b"", "info_hash", info);
524
525        //  key_schedule_context = concat(mode, psk_id_hash, info_hash)
526        let ks_ctx = [&[mode.id()], psk_id_hash.as_bytes(), info_hash.as_bytes()];
527
528        //  secret = LabeledExtract(shared_secret, "secret", psk)
529        let secret = Self::labeled_extract(shared_secret.as_ref(), "secret", psk);
530
531        // key = LabeledExpand(secret, "key", key_schedule_context, Nk)
532        let key = Self::labeled_expand(&secret, "key", &ks_ctx)?;
533
534        // base_nonce = LabeledExpand(secret, "base_nonce",
535        //                      key_schedule_context, Nn)
536        let base_nonce = Self::labeled_expand(&secret, "base_nonce", &ks_ctx)?;
537
538        // exporter_secret = LabeledExpand(secret, "exp",
539        //                           key_schedule_context, Nh)
540        let exporter_secret = Self::labeled_expand(&secret, "exp", &ks_ctx)?;
541
542        Ok(Schedule {
543            key,
544            base_nonce,
545            exporter_secret,
546            _kem: PhantomData,
547        })
548    }
549
550    const HPKE_CTX: Context = Context {
551        domain: "HPKE-v1",
552        suite_ids: &Self::HPKE_SUITE_ID,
553    };
554
555    /// Performs `LabeledExtract`.
556    fn labeled_extract(salt: &[u8], label: &'static str, ikm: &[u8]) -> Prk<F::PrkSize> {
557        // def LabeledExtract(salt, label, ikm):
558        //     labeled_ikm = concat("HPKE-v1", suite_id, label, ikm)
559        //     return Extract(salt, labeled_ikm)
560        Self::HPKE_CTX.labeled_extract::<F>(salt, label, ikm)
561    }
562
563    /// Performs `LabeledExpand`.
564    fn labeled_expand<T: Expand>(
565        prk: &Prk<F::PrkSize>,
566        label: &'static str,
567        info: &[&[u8]],
568    ) -> Result<T, KdfError> {
569        // def LabeledExpand(prk, label, info, L):
570        //     labeled_info = concat(I2OSP(L, 2), "HPKE-v1", suite_id,
571        //                 label, info)
572        //     return Expand(prk, labeled_info, L)
573        let key = Self::HPKE_CTX.labeled_expand::<F, T>(prk, label, info)?;
574        Ok(key)
575    }
576
577    /// Performs `LabeledExpand`.
578    fn labeled_expand_into(
579        out: &mut [u8],
580        prk: &Prk<F::PrkSize>,
581        label: &'static str,
582        info: &[&[u8]],
583    ) -> Result<(), KdfError> {
584        // def LabeledExpand(prk, label, info, L):
585        //     labeled_info = concat(I2OSP(L, 2), "HPKE-v1", suite_id,
586        //                 label, info)
587        //     return Expand(prk, labeled_info, L)
588        Self::HPKE_CTX.labeled_expand_into::<F>(out, prk, label, info)
589    }
590}
591
592struct Schedule<K: Kem, F: Kdf, A: Aead + IndCca2> {
593    key: KeyData<A>,
594    base_nonce: Nonce<A::NonceSize>,
595    exporter_secret: Prk<F::PrkSize>,
596    _kem: PhantomData<K>,
597}
598
599impl<K: Kem, F: Kdf, A: Aead + IndCca2> Schedule<K, F, A> {
600    fn into_send_ctx(self) -> SendCtx<K, F, A> {
601        SendCtx {
602            seal: Either::Right((self.key, self.base_nonce)),
603            export: ExportCtx::new(self.exporter_secret),
604        }
605    }
606
607    fn into_recv_ctx(self) -> RecvCtx<K, F, A> {
608        RecvCtx {
609            open: Either::Right((self.key, self.base_nonce)),
610            export: ExportCtx::new(self.exporter_secret),
611        }
612    }
613}
614
615/// Either `L` or `R`.
616enum Either<L, R> {
617    Left(L),
618    Right(R),
619}
620
621impl<L, R> Either<L, R> {
622    fn get_or_insert_left<F, E>(&mut self, f: F) -> Result<&mut L, E>
623    where
624        F: FnOnce(&R) -> Result<L, E>,
625        E: From<Bug>,
626    {
627        match self {
628            Self::Left(left) => Ok(left),
629            Self::Right(right) => {
630                *self = Self::Left(f(right)?);
631                match self {
632                    Self::Left(left) => Ok(left),
633                    Self::Right(_) => bug!("we just assigned `Self::Left`"),
634                }
635            }
636        }
637    }
638}
639
640type RawKey<A> = (KeyData<A>, Nonce<<A as Aead>::NonceSize>);
641
642/// An encryption context that encrypts messages for a particular
643/// recipient.
644pub struct SendCtx<K: Kem, F: Kdf, A: Aead + IndCca2> {
645    seal: Either<SealCtx<A>, RawKey<A>>,
646    export: ExportCtx<K, F, A>,
647}
648
649impl<K: Kem, F: Kdf, A: Aead + IndCca2> SendCtx<K, F, A> {
650    /// The size in bytes of the overhead added to the plaintext.
651    pub const OVERHEAD: usize = SealCtx::<A>::OVERHEAD;
652
653    // Exposed for `aranya-crypto`, do not use.
654    #[doc(hidden)]
655    pub fn into_raw_parts(self) -> Option<(KeyData<A>, Nonce<A::NonceSize>)> {
656        match self.seal {
657            Either::Left(_) => None,
658            Either::Right((key, base_nonce)) => Some((key, base_nonce)),
659        }
660    }
661
662    fn seal_ctx(&mut self) -> Result<&mut SealCtx<A>, ImportError> {
663        self.seal
664            .get_or_insert_left(|(key, nonce)| SealCtx::new(key, nonce, Seq::ZERO))
665    }
666
667    /// Encrypts and authenticates `plaintext`, returning the
668    /// sequence number.
669    ///
670    /// The resulting ciphertext is written to `dst`, which must
671    /// be at least `plaintext.len()` + [`OVERHEAD`][Self::OVERHEAD]
672    /// bytes long.
673    pub fn seal(
674        &mut self,
675        dst: &mut [u8],
676        plaintext: &[u8],
677        additional_data: &[u8],
678    ) -> Result<Seq, HpkeError> {
679        self.seal_ctx()?.seal(dst, plaintext, additional_data)
680    }
681
682    /// Encrypts and authenticates `data` in-place, returning the
683    /// sequence number.
684    pub fn seal_in_place(
685        &mut self,
686        data: impl AsMut<[u8]>,
687        tag: &mut [u8],
688        additional_data: &[u8],
689    ) -> Result<Seq, HpkeError> {
690        self.seal_ctx()?.seal_in_place(data, tag, additional_data)
691    }
692
693    /// Exports a secret from the encryption context.
694    pub fn export<T>(&self, context: &[u8]) -> Result<T, KdfError>
695    where
696        T: Expand,
697    {
698        self.export.export(context)
699    }
700
701    /// Exports a secret from the encryption context, writing it
702    /// to `out`.
703    pub fn export_into(&self, out: &mut [u8], context: &[u8]) -> Result<(), KdfError> {
704        self.export.export_into(out, context)
705    }
706}
707
708/// An encryption context that can only encrypt messages for
709/// a particular recipient.
710///
711/// Unlike [`SendCtx`], it cannot export secrets.
712pub struct SealCtx<A: Aead + IndCca2> {
713    aead: A,
714    base_nonce: Nonce<A::NonceSize>,
715    /// Incremented after each call to `seal`.
716    seq: Seq,
717}
718
719impl<A: Aead + IndCca2> SealCtx<A> {
720    /// The size in bytes of the overhead added to the plaintext.
721    pub const OVERHEAD: usize = A::OVERHEAD;
722
723    // Exported for `aranya-crypto`. Do not use.
724    #[doc(hidden)]
725    pub fn new(
726        key: &KeyData<A>,
727        base_nonce: &Nonce<A::NonceSize>,
728        seq: Seq,
729    ) -> Result<Self, ImportError> {
730        let key = A::Key::import(key.as_bytes())?;
731        Ok(Self {
732            aead: A::new(&key),
733            base_nonce: base_nonce.clone(),
734            seq,
735        })
736    }
737
738    fn compute_nonce(&self) -> Result<Nonce<A::NonceSize>, MessageLimitReached> {
739        self.seq.compute_nonce::<A::NonceSize>(&self.base_nonce)
740    }
741
742    fn increment_seq(&mut self) -> Result<Seq, Bug> {
743        self.seq.increment::<A::NonceSize>()
744    }
745
746    /// Encrypts and authenticates `plaintext`, returning the
747    /// sequence number.
748    ///
749    /// The resulting ciphertext is written to `dst`, which must
750    /// be at least `plaintext.len()` + [`OVERHEAD`][Self::OVERHEAD]
751    /// bytes long.
752    pub fn seal(
753        &mut self,
754        dst: &mut [u8],
755        plaintext: &[u8],
756        additional_data: &[u8],
757    ) -> Result<Seq, HpkeError> {
758        let nonce = self.compute_nonce()?;
759        self.aead.seal(dst, &nonce, plaintext, additional_data)?;
760        let prev = self.increment_seq()?;
761        Ok(prev)
762    }
763
764    /// Encrypts and authenticates `data` in place, returning the
765    /// sequence number.
766    pub fn seal_in_place(
767        &mut self,
768        mut data: impl AsMut<[u8]>,
769        tag: &mut [u8],
770        additional_data: &[u8],
771    ) -> Result<Seq, HpkeError> {
772        let nonce = self.compute_nonce()?;
773        self.aead
774            .seal_in_place(&nonce, data.as_mut(), tag, additional_data)?;
775        let prev = self.increment_seq()?;
776        Ok(prev)
777    }
778
779    /// Returns the current sequence number.
780    pub fn seq(&self) -> Seq {
781        self.seq
782    }
783}
784
785/// An encryption context that decrypts messages from
786/// a particular sender.
787pub struct RecvCtx<K: Kem, F: Kdf, A: Aead + IndCca2> {
788    open: Either<OpenCtx<A>, RawKey<A>>,
789    export: ExportCtx<K, F, A>,
790}
791
792impl<K: Kem, F: Kdf, A: Aead + IndCca2> RecvCtx<K, F, A> {
793    /// The size in bytes of the overhead added to the plaintext.
794    pub const OVERHEAD: usize = OpenCtx::<A>::OVERHEAD;
795
796    // Exposed for `aranya-crypto`, do not use.
797    #[doc(hidden)]
798    pub fn into_raw_parts(self) -> Option<(KeyData<A>, Nonce<A::NonceSize>)> {
799        match self.open {
800            Either::Left(_) => None,
801            Either::Right((key, base_nonce)) => Some((key, base_nonce)),
802        }
803    }
804
805    fn open_ctx(&mut self) -> Result<&mut OpenCtx<A>, ImportError> {
806        self.open
807            .get_or_insert_left(|(key, nonce)| OpenCtx::new(key, nonce, Seq::ZERO))
808    }
809
810    /// Decrypts and authenticates `ciphertext` using the
811    /// internal sequence number.
812    ///
813    /// The resulting plaintext is written to `dst`, which must
814    /// must be at least `ciphertext.len()` - [`Self::OVERHEAD`]
815    /// bytes long.
816    pub fn open(
817        &mut self,
818        dst: &mut [u8],
819        ciphertext: &[u8],
820        additional_data: &[u8],
821    ) -> Result<(), HpkeError> {
822        self.open_ctx()?.open(dst, ciphertext, additional_data)
823    }
824
825    /// Decrypts and authenticates `ciphertext` at a particular
826    /// sequence number.
827    ///
828    /// The resulting plaintext is written to `dst`, which must
829    /// must be at least `ciphertext.len()` - [`Self::OVERHEAD`]
830    /// bytes long.
831    pub fn open_at(
832        &mut self,
833        dst: &mut [u8],
834        ciphertext: &[u8],
835        additional_data: &[u8],
836        seq: Seq,
837    ) -> Result<(), HpkeError> {
838        self.open_ctx()?
839            .open_at(dst, ciphertext, additional_data, seq)
840    }
841
842    /// Decrypts and authenticates `ciphertext`.
843    pub fn open_in_place(
844        &mut self,
845        data: impl AsMut<[u8]>,
846        tag: &[u8],
847        additional_data: &[u8],
848    ) -> Result<(), HpkeError> {
849        self.open_ctx()?.open_in_place(data, tag, additional_data)
850    }
851
852    /// Decrypts and authenticates `ciphertext` at a particular
853    /// sequence number.
854    pub fn open_in_place_at(
855        &mut self,
856        data: impl AsMut<[u8]>,
857        tag: &[u8],
858        additional_data: &[u8],
859        seq: Seq,
860    ) -> Result<(), HpkeError> {
861        self.open_ctx()?
862            .open_in_place_at(data, tag, additional_data, seq)
863    }
864
865    /// Exports a secret from the encryption context.
866    pub fn export<T>(&self, context: &[u8]) -> Result<T, KdfError>
867    where
868        T: Expand,
869    {
870        self.export.export(context)
871    }
872
873    /// Exports a secret from the encryption context, writing it
874    /// to `out`.
875    pub fn export_into(&self, out: &mut [u8], context: &[u8]) -> Result<(), KdfError> {
876        self.export.export_into(out, context)
877    }
878}
879
880/// An encryption context that can only decrypt messages from
881/// a particular sender.
882///
883/// Unlike [`RecvCtx`], it cannot export secrets.
884pub struct OpenCtx<A: Aead + IndCca2> {
885    aead: A,
886    base_nonce: Nonce<A::NonceSize>,
887    /// Incremented after each call to `open`.
888    seq: Seq,
889}
890
891impl<A: Aead + IndCca2> OpenCtx<A> {
892    /// The size in bytes of the overhead added to the plaintext.
893    pub const OVERHEAD: usize = A::OVERHEAD;
894
895    // Exported for `aranya-crypto`. Do not use.
896    #[doc(hidden)]
897    pub fn new(
898        key: &KeyData<A>,
899        base_nonce: &Nonce<A::NonceSize>,
900        seq: Seq,
901    ) -> Result<Self, ImportError> {
902        let key = A::Key::import(key.as_bytes())?;
903        Ok(Self {
904            aead: A::new(&key),
905            base_nonce: base_nonce.clone(),
906            seq,
907        })
908    }
909
910    fn increment_seq(&mut self) -> Result<Seq, Bug> {
911        self.seq.increment::<A::NonceSize>()
912    }
913
914    /// Decrypts and authenticates `ciphertext`.
915    ///
916    /// The resulting plaintext is written to `dst`, which must
917    /// must be at least `ciphertext.len()` - [`Self::OVERHEAD`]
918    /// bytes long.
919    pub fn open(
920        &mut self,
921        dst: &mut [u8],
922        ciphertext: &[u8],
923        additional_data: &[u8],
924    ) -> Result<(), HpkeError> {
925        self.open_at(dst, ciphertext, additional_data, self.seq)?;
926        self.increment_seq()?;
927        Ok(())
928    }
929
930    /// Decrypts and authenticates `ciphertext` at a particular
931    /// sequence number.
932    ///
933    /// The resulting plaintext is written to `dst`, which must
934    /// must be at least `ciphertext.len()` - [`Self::OVERHEAD`]
935    /// bytes long.
936    pub fn open_at(
937        &self,
938        dst: &mut [u8],
939        ciphertext: &[u8],
940        additional_data: &[u8],
941        seq: Seq,
942    ) -> Result<(), HpkeError> {
943        let nonce = seq.compute_nonce::<A::NonceSize>(&self.base_nonce)?;
944        self.aead.open(dst, &nonce, ciphertext, additional_data)?;
945        Ok(())
946    }
947
948    /// Decrypts and authenticates `ciphertext`.
949    pub fn open_in_place(
950        &mut self,
951        mut data: impl AsMut<[u8]>,
952        tag: &[u8],
953        additional_data: &[u8],
954    ) -> Result<(), HpkeError> {
955        self.open_in_place_at(data.as_mut(), tag, additional_data, self.seq)?;
956        self.increment_seq()?;
957        Ok(())
958    }
959
960    /// Decrypts and authenticates `ciphertext` at a particular
961    /// sequence number.
962    pub fn open_in_place_at(
963        &self,
964        mut data: impl AsMut<[u8]>,
965        tag: &[u8],
966        additional_data: &[u8],
967        seq: Seq,
968    ) -> Result<(), HpkeError> {
969        let nonce = seq.compute_nonce::<A::NonceSize>(&self.base_nonce)?;
970        self.aead
971            .open_in_place(&nonce, data.as_mut(), tag, additional_data)?;
972        Ok(())
973    }
974}
975
976/// HPKE's message limit has been reached.
977#[derive(Copy, Clone, Debug, Eq, PartialEq)]
978pub struct MessageLimitReached;
979
980impl Display for MessageLimitReached {
981    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
982        f.write_str("message limit reached")
983    }
984}
985
986impl core::error::Error for MessageLimitReached {}
987
988/// Sequence numbers ensure nonce uniqueness.
989#[derive(Copy, Clone, Debug, Default, Hash, Eq, PartialEq, Ord, PartialOrd)]
990pub struct Seq {
991    /// The sequence number.
992    ///
993    /// It's encoded as a big-endian integer (I2OSP) and XORed
994    /// with the `base_nonce`.
995    ///
996    /// This should be the size of the nonce, but it's
997    /// vanishingly unlikely that we'll ever overflow. Since
998    /// encryption contexts ([`SealCtx`], etc.) can only be used
999    /// serially, we can only overflow if the user actually
1000    /// performs 2^64-1 operations. At an impossible one
1001    /// nanosecond per encryption, this will take upward of 500
1002    /// years.
1003    seq: u64,
1004}
1005
1006impl Seq {
1007    /// The zero value of a `Seq`.
1008    pub const ZERO: Self = Self::new(0);
1009
1010    /// Creates a sequence number.
1011    #[inline]
1012    pub const fn new(seq: u64) -> Self {
1013        Self { seq }
1014    }
1015
1016    /// Converts itself to a `u64`.
1017    #[inline]
1018    pub const fn to_u64(self) -> u64 {
1019        self.seq
1020    }
1021
1022    /// Returns the maximum allowed sequence number.
1023    ///
1024    /// Exported for `aranya-crypto`. Do not use.
1025    #[doc(hidden)]
1026    pub const fn max<N: ArrayLength>() -> u64 {
1027        // 1<<(8*N) - 1
1028        let shift = 8usize.saturating_mul(N::USIZE);
1029        match 1u64.checked_shl(shift as u32) {
1030            Some(v) => v.saturating_sub(1),
1031            None => u64::MAX,
1032        }
1033    }
1034
1035    /// Increments the sequence by one and returns the *previous*
1036    /// sequence number.
1037    fn increment<N: ArrayLength>(&mut self) -> Result<Self, Bug> {
1038        // if self.seq >= (1 << (8*Nn)) - 1:
1039        //     raise MessageLimitReachedError
1040        if self.seq >= Self::max::<N>() {
1041            // We only call `Seq::increment` after computing the
1042            // nonce, which requires `seq < Self::max`.
1043            bug!("`Seq::increment` called after limit reached");
1044        }
1045        // self.seq += 1
1046        let prev = self.seq;
1047        self.seq = prev
1048            .checked_add(1)
1049            .assume("`Seq` overflow should be impossible")?;
1050        Ok(Self { seq: prev })
1051    }
1052
1053    /// Computes the per-message nonce.
1054    fn compute_nonce<N: ArrayLength>(
1055        self,
1056        base_nonce: &Nonce<N>,
1057    ) -> Result<Nonce<N>, MessageLimitReached> {
1058        if self.seq >= Self::max::<N>() {
1059            Err(MessageLimitReached)
1060        } else {
1061            //  seq_bytes = I2OSP(seq, Nn)
1062            let seq_bytes = i2osp!(self.seq, N);
1063            // xor(self.base_nonce, seq_bytes)
1064            Ok(base_nonce ^ &Nonce::from_bytes(seq_bytes))
1065        }
1066    }
1067}
1068
1069impl Display for Seq {
1070    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1071        write!(f, "{}", self.seq)
1072    }
1073}
1074
1075struct ExportCtx<K: Kem, F: Kdf, A: Aead + IndCca2> {
1076    exporter_secret: Prk<F::PrkSize>,
1077    _etc: PhantomData<(K, A)>,
1078}
1079
1080impl<K: Kem, F: Kdf, A: Aead + IndCca2> ExportCtx<K, F, A> {
1081    fn new(exporter_secret: Prk<F::PrkSize>) -> Self {
1082        Self {
1083            exporter_secret,
1084            _etc: PhantomData,
1085        }
1086    }
1087
1088    /// Exports a secret from the context.
1089    fn export<T>(&self, context: &[u8]) -> Result<T, KdfError>
1090    where
1091        T: Expand,
1092    {
1093        // def Context.Export(exporter_context, L):
1094        //   return LabeledExpand(self.exporter_secret, "sec",
1095        //                        exporter_context, L)
1096        Hpke::<K, F, A>::labeled_expand(&self.exporter_secret, "sec", &[context])
1097    }
1098
1099    /// Exports a secret from the context, writing it to `out`.
1100    fn export_into(&self, out: &mut [u8], context: &[u8]) -> Result<(), KdfError> {
1101        // def Context.Export(exporter_context, L):
1102        //   return LabeledExpand(self.exporter_secret, "sec",
1103        //                        exporter_context, L)
1104        Hpke::<K, F, A>::labeled_expand_into(out, &self.exporter_secret, "sec", &[context])
1105    }
1106}
1107
1108#[cfg(test)]
1109mod tests {
1110    #![allow(clippy::panic)]
1111
1112    use std::{collections::HashSet, ops::RangeInclusive};
1113
1114    use postcard::experimental::max_size::MaxSize;
1115    use typenum::{U1, U2};
1116
1117    use super::*;
1118
1119    /// Tests that [`Seq::compute_nonce`] generates correct
1120    /// nonces.
1121    #[test]
1122    fn test_seq_compute_nonce() {
1123        let base = Nonce::<U1>::try_from_slice(&[0xfe]).expect("should be able to create nonce");
1124        let cases = [
1125            (0, Ok(&[0xfe])),
1126            (1, Ok(&[0xff])),
1127            (2, Ok(&[0xfc])),
1128            (4, Ok(&[0xfa])),
1129            (254, Ok(&[0x00])),
1130            (255, Err(MessageLimitReached)),
1131            (256, Err(MessageLimitReached)),
1132            (257, Err(MessageLimitReached)),
1133            (u64::MAX, Err(MessageLimitReached)),
1134        ];
1135        for (input, output) in cases {
1136            let got = Seq::new(input).compute_nonce::<U1>(&base);
1137            let want = output.map(|s| Nonce::try_from_slice(s).expect("unable to create nonce"));
1138            assert_eq!(got, want, "seq = {input}");
1139        }
1140    }
1141
1142    /// Tests that all nonces are unique.
1143    #[test]
1144    fn test_seq_unique_nonce() {
1145        let base =
1146            Nonce::<U2>::try_from_slice(&[0xfe, 0xfe]).expect("should be able to create nonce");
1147        let mut seen = HashSet::new();
1148        for v in 0..u16::MAX {
1149            let got = Seq::new(u64::from(v))
1150                .compute_nonce::<U2>(&base)
1151                .expect("unable to create nonce");
1152            assert!(seen.insert(got), "duplicate nonce: {got:?}");
1153        }
1154    }
1155
1156    #[test]
1157    fn test_invalid_psk() {
1158        let err = Psk::new(&[], &[]).expect_err("should get `InvalidPsk`");
1159        assert_eq!(err, InvalidPsk);
1160    }
1161
1162    #[test]
1163    fn test_psk_ct_eq() {
1164        let cases = [
1165            (true, ("abc", "123"), ("abc", "123")),
1166            (false, ("a", "b"), ("a", "x")),
1167            (false, ("a", "b"), ("x", "b")),
1168            (false, ("a", "b"), ("c", "d")),
1169        ];
1170        for (pass, lhs, rhs) in cases {
1171            let lhs = Psk::new(lhs.0.as_bytes(), lhs.1.as_bytes()).expect("should not fail");
1172            let rhs = Psk::new(rhs.0.as_bytes(), rhs.1.as_bytes()).expect("should not fail");
1173            assert_eq!(pass, bool::from(lhs.ct_eq(&rhs)));
1174        }
1175    }
1176
1177    /// Tests that [`AeadId`] is assigned correctly.
1178    #[test]
1179    fn test_aead_id() {
1180        // NB: we include two unofficial IDs.
1181        let unassigned = 0x0004..=0xFFFE - 2;
1182        for id in unassigned {
1183            let want = AeadId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
1184            let encoded = postcard::to_vec::<_, { u16::POSTCARD_MAX_SIZE }>(&id)
1185                .expect("should be able to encode `u16`");
1186            let got: AeadId = postcard::from_bytes(&encoded).unwrap_or_else(|err| {
1187                panic!("should be able to decode unassigned `AeadId` {id}: {err}")
1188            });
1189            assert_eq!(got, want);
1190        }
1191    }
1192
1193    /// Tests that [`AeadId`] can be serialized and deserialized via [`serde_json`].
1194    #[test]
1195    fn test_aead_id_json() {
1196        // NB: we include two unofficial IDs.
1197        let unassigned = 0x0004..=0xFFFE - 2;
1198        for id in unassigned {
1199            let want = AeadId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
1200            let encoded = serde_json::to_string(&id).expect("should be able to encode `u16`");
1201            let got: AeadId = serde_json::from_str(&encoded).unwrap_or_else(|err| {
1202                panic!("should be able to decode unassigned `AeadId` {id}: {err}")
1203            });
1204            assert_eq!(got, want);
1205        }
1206    }
1207
1208    /// Tests that [`KdfId`] is assigned correctly.
1209    #[test]
1210    fn test_kdf_id() {
1211        let unassigned = 0x0004..=0xFFFF;
1212        for id in unassigned {
1213            let want = KdfId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
1214            let encoded = postcard::to_vec::<_, { u16::POSTCARD_MAX_SIZE }>(&id)
1215                .expect("should be able to encode `u16`");
1216            let got: KdfId = postcard::from_bytes(&encoded).unwrap_or_else(|err| {
1217                panic!("should be able to decode unassigned `KdfId` {id}: {err}")
1218            });
1219            assert_eq!(got, want);
1220        }
1221    }
1222
1223    /// Tests that [`KemId`] is assigned correctly.
1224    #[test]
1225    fn test_kem_id() {
1226        let unassigned: [RangeInclusive<u16>; 3] =
1227            [0x0001..=0x000F, 0x0022..=0x002F, 0x0031..=0xFFFF];
1228        for id in unassigned.into_iter().flatten() {
1229            let want = KemId::Other(NonZeroU16::new(id).expect("`id` should be non-zero"));
1230            let encoded = postcard::to_vec::<_, { u16::POSTCARD_MAX_SIZE }>(&id)
1231                .expect("should be able to encode `u16`");
1232            let got: KemId = postcard::from_bytes(&encoded).unwrap_or_else(|err| {
1233                panic!("should be able to decode unassigned `KemId` {id}: {err}")
1234            });
1235            assert_eq!(got, want);
1236        }
1237    }
1238}