Skip to main content

lib_q_duplex_aead/
crypto.rs

1//! Duplex AEAD encrypt/decrypt over byte slices.
2
3use core::fmt;
4
5use subtle::ConstantTimeEq;
6use zeroize::Zeroize;
7
8#[cfg(feature = "alloc")]
9extern crate alloc;
10
11#[cfg(feature = "alloc")]
12use lib_q_core::DecryptSemanticOutcome;
13#[cfg(feature = "alloc")]
14use zeroize::Zeroizing;
15
16use crate::params::{
17    KEY_BYTES,
18    NONCE_BYTES,
19    PLEN,
20    RATE_BYTES,
21    TAG_BYTES,
22};
23use crate::state::{
24    absorb_all,
25    duplex_decrypt_chunk,
26    duplex_encrypt_chunk,
27    init_key_nonce,
28    tag_from_state,
29};
30
31/// Encrypt/decrypt failed: buffer too small, length overflow, or (decrypt) authentication failure.
32#[derive(Clone, Copy, PartialEq, Eq)]
33pub struct DuplexCryptoError;
34
35impl fmt::Debug for DuplexCryptoError {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        f.write_str("DuplexCryptoError")
38    }
39}
40
41/// Shared duplex decrypt: one duplex walk over the ciphertext body writes plaintext into
42/// `out[..body_len]`, derives the tag from the final sponge state, and returns whether the tag
43/// was valid (`subtle::ConstantTimeEq`). The walk always runs to completion regardless of tag
44/// validity (timing discipline). A second full body pass would only duplicate `f1600` work: the
45/// tag is already fixed by this single trajectory (the inverse of [`encrypt`]).
46///
47/// Returns `Err` if `ct_in` is shorter than `TAG_BYTES` or `out` is shorter than the body length.
48pub(crate) fn decrypt_core(
49    key: &[u8; KEY_BYTES],
50    nonce: &[u8; NONCE_BYTES],
51    ad: &[u8],
52    ct_in: &[u8],
53    out: &mut [u8],
54) -> Result<bool, DuplexCryptoError> {
55    if ct_in.len() < TAG_BYTES {
56        return Err(DuplexCryptoError);
57    }
58    let body_len = ct_in.len() - TAG_BYTES;
59    if out.len() < body_len {
60        return Err(DuplexCryptoError);
61    }
62    let ct_body = &ct_in[..body_len];
63    let tag_recv = &ct_in[body_len..body_len + TAG_BYTES];
64
65    let mut state = [0u64; PLEN];
66    init_key_nonce(&mut state, key, nonce);
67    absorb_all(&mut state, ad);
68
69    let pt = &mut out[..body_len];
70    let mut off = 0usize;
71    while off + RATE_BYTES <= body_len {
72        duplex_decrypt_chunk(
73            &mut state,
74            &ct_body[off..off + RATE_BYTES],
75            &mut pt[off..off + RATE_BYTES],
76        );
77        off += RATE_BYTES;
78    }
79    if off < body_len {
80        duplex_decrypt_chunk(&mut state, &ct_body[off..], &mut pt[off..]);
81    }
82
83    let tag_calc = tag_from_state(&state);
84    let tag_recv_arr: [u8; TAG_BYTES] = tag_recv.try_into().map_err(|_| DuplexCryptoError)?;
85    let tag_ok = tag_calc.ct_eq(&tag_recv_arr).unwrap_u8() == 1;
86
87    state.zeroize();
88
89    Ok(tag_ok)
90}
91
92/// Encrypt: `ciphertext` is `pt.len() + TAG_BYTES`; `ct` must hold at least that.
93pub fn encrypt(
94    key: &[u8; KEY_BYTES],
95    nonce: &[u8; NONCE_BYTES],
96    ad: &[u8],
97    pt: &[u8],
98    out: &mut [u8],
99) -> Result<(), DuplexCryptoError> {
100    let total = pt.len().checked_add(TAG_BYTES).ok_or(DuplexCryptoError)?;
101    if out.len() < total {
102        return Err(DuplexCryptoError);
103    }
104    let mut state = [0u64; PLEN];
105    init_key_nonce(&mut state, key, nonce);
106    absorb_all(&mut state, ad);
107
108    let ct = &mut out[..pt.len()];
109    let mut off = 0usize;
110    while off + RATE_BYTES <= pt.len() {
111        duplex_encrypt_chunk(
112            &mut state,
113            &pt[off..off + RATE_BYTES],
114            &mut ct[off..off + RATE_BYTES],
115        );
116        off += RATE_BYTES;
117    }
118    if off < pt.len() {
119        duplex_encrypt_chunk(&mut state, &pt[off..], &mut ct[off..]);
120    }
121
122    let tag = tag_from_state(&state);
123    out[pt.len()..pt.len() + TAG_BYTES].copy_from_slice(&tag);
124    state.zeroize();
125    Ok(())
126}
127
128/// Decrypt `ct_in` (ciphertext including tag) in constant time.
129///
130/// On success, plaintext is written to `out` (length `ct_in.len() - TAG_BYTES`).
131/// On authentication failure, zeroes `out[..body_len]` and returns `Err`.
132/// The duplex body walk always runs to completion regardless of tag validity (timing discipline).
133pub fn decrypt(
134    key: &[u8; KEY_BYTES],
135    nonce: &[u8; NONCE_BYTES],
136    ad: &[u8],
137    ct_in: &[u8],
138    out: &mut [u8],
139) -> Result<(), DuplexCryptoError> {
140    if ct_in.len() < TAG_BYTES {
141        return Err(DuplexCryptoError);
142    }
143    let body_len = ct_in.len() - TAG_BYTES;
144    let tag_ok = decrypt_core(key, nonce, ad, ct_in, out)?;
145    if tag_ok {
146        Ok(())
147    } else {
148        out[..body_len].zeroize();
149        Err(DuplexCryptoError)
150    }
151}
152
153/// Layer B semantic decrypt: single shared [`decrypt_core`] (one duplex walk over the body).
154#[cfg(feature = "alloc")]
155pub(crate) fn decrypt_semantic_outcome(
156    key: &[u8; KEY_BYTES],
157    nonce: &[u8; NONCE_BYTES],
158    ad: &[u8],
159    ct_in: &[u8],
160) -> Result<DecryptSemanticOutcome, DuplexCryptoError> {
161    if ct_in.len() < TAG_BYTES {
162        return Err(DuplexCryptoError);
163    }
164    let body_len = ct_in.len() - TAG_BYTES;
165    let mut pt = vec![0u8; body_len];
166    let tag_ok = decrypt_core(key, nonce, ad, ct_in, &mut pt)?;
167    if tag_ok {
168        Ok(DecryptSemanticOutcome::Success(Zeroizing::new(pt)))
169    } else {
170        pt.zeroize();
171        Ok(DecryptSemanticOutcome::AuthenticationFailed)
172    }
173}