1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
#![cfg_attr(not(feature = "std"), no_std)]
#![doc = include_str!("../README.md")]
#![warn(missing_docs)]

use crate::aegis_128l::Aegis128L;

use hkdf::HkdfExtract;
use sha2::Sha256;
pub use subtle;
use subtle::ConstantTimeEq;

mod aegis_128l;
mod intrinsics;

#[cfg(feature = "docs")]
#[doc = include_str!("../design.md")]
pub mod design {}

#[cfg(feature = "docs")]
#[doc = include_str!("../perf.md")]
pub mod perf {}

/// The length of an authentication tag in bytes.
pub const TAG_LEN: usize = 16;

/// A stateful object providing fine-grained symmetric-key cryptographic services like hashing,
/// message authentication codes, pseudo-random functions, authenticated encryption, and more.
#[derive(Debug, Clone)]
pub struct Protocol {
    transcript: HkdfExtract<Sha256>,
}

impl Protocol {
    /// Creates a new protocol with the given domain.
    #[inline]
    pub fn new(domain: &str) -> Protocol {
        // Initialize a protocol with an empty transcript.
        let mut protocol = Protocol { transcript: HkdfExtract::new(None) };

        // Append the Init op header to the transcript with the domain as the label.
        //
        //   0x01 || domain || right_encode(|domain|)
        protocol.op_header(OpCode::Init, domain);

        protocol
    }

    /// Mixes the given label and slice into the protocol state.
    #[inline]
    pub fn mix(&mut self, label: &str, input: &[u8]) {
        // Append a Mix op header with the label to the transcript.
        //
        //   0x02 || label || right_encode(|label|)
        self.op_header(OpCode::Mix, label);

        // Append the input to the transcript with right-encoded length.
        //
        //   input || right_encode(|input|)
        self.transcript.input_ikm(input);
        self.transcript.input_ikm(right_encode(&mut [0u8; 9], input.len() as u64 * 8));
    }

    /// Moves the protocol into a [`std::io::Write`] implementation, mixing all written data in a
    /// single operation and passing all writes to `inner`.
    ///
    /// Use [`MixWriter::into_inner`] to finish the operation and recover the protocol and `inner`.
    #[inline]
    #[cfg(feature = "std")]
    pub fn mix_writer<W: std::io::Write>(mut self, label: &str, inner: W) -> MixWriter<W> {
        // Append a Mix op header with the label to the transcript.
        self.op_header(OpCode::Mix, label);

        // Move the protocol to a MixWriter.
        MixWriter { protocol: self, inner, len: 0 }
    }

    /// Derives output from the protocol's current state and fills the given slice with it.
    ///
    /// The output is dependent on the protocol's prior transcript, the label, and the length of
    /// `out`.
    #[inline]
    pub fn derive(&mut self, label: &str, out: &mut [u8]) {
        // Append a Derive op header with the label to the transcript.
        //
        //   0x03 || label || right_encode(|label|)
        self.op_header(OpCode::Derive, label);

        // Perform a Mix operation with the output length.
        self.mix("len", right_encode(&mut [0u8; 9], out.len() as u64 * 8));

        // Derive a PRK via HKDF-Extract(kdk, transcript).
        let (_, prk) = self.transcript.clone().finalize();

        // Use HKDF-Expand and the PRK to derive a new KDK and the requested output.
        let mut kdk = [0u8; 32];
        prk.expand(b"kdk", &mut kdk).expect("should expand KDK");
        prk.expand(b"output", out).expect("should expand output");

        // Clear the transcript and prepare for HKDF-Extract(kdk', transcript).
        self.transcript = HkdfExtract::new(Some(&kdk));
    }

    /// Derives output from the protocol's current state and returns it as an `N`-byte array.
    #[inline]
    pub fn derive_array<const N: usize>(&mut self, label: &str) -> [u8; N] {
        let mut out = [0u8; N];
        self.derive(label, &mut out);
        out
    }

    /// Encrypts the given slice in place.
    #[inline]
    pub fn encrypt(&mut self, label: &str, in_out: &mut [u8]) {
        // Append a Crypt op header with the label to the transcript.
        //
        //   0x04 || label || right_encode(|label|)
        self.op_header(OpCode::Crypt, label);

        // Perform a Mix operation with the plaintext length.
        self.mix("len", right_encode(&mut [0u8; 9], in_out.len() as u64 * 8));

        // Derive an AEGIS-128L key and nonce.
        let kn = self.derive_array::<32>("key");
        let (k, n) = kn.split_at(16);
        let mut aegis = Aegis128L::new(
            k.try_into().expect("should be 16 bytes"),
            n.try_into().expect("should be 16 bytes"),
        );

        // Encrypt the plaintext.
        aegis.encrypt(in_out);

        // Finalize the AEGIS-128L tags.
        let (_, tag256) = aegis.finalize();

        // Perform a Mix operation with the 256-bit AEGIS-128L tag.
        self.mix("tag", &tag256);
    }

    /// Decrypts the given slice in place.
    #[inline]
    pub fn decrypt(&mut self, label: &str, in_out: &mut [u8]) {
        // Append a Crypt op header with the label to the transcript.
        //
        //   0x04 || label || right_encode(|label|)
        self.op_header(OpCode::Crypt, label);

        // Perform a Mix operation with the plaintext length.
        self.mix("len", right_encode(&mut [0u8; 9], in_out.len() as u64 * 8));

        // Derive an AEGIS-128L key and nonce.
        let kn = self.derive_array::<32>("key");
        let (k, n) = kn.split_at(16);
        let mut aegis = Aegis128L::new(
            k.try_into().expect("should be 16 bytes"),
            n.try_into().expect("should be 16 bytes"),
        );

        // Decrypt the ciphertext.
        aegis.decrypt(in_out);

        // Finalize the AEGIS-128L tags.
        let (_, tag256) = aegis.finalize();

        // Perform a Mix operation with the 256-bit AEGIS-128L tag.
        self.mix("tag", &tag256);
    }

    /// Seals the given mutable slice in place.
    ///
    /// The last [`TAG_LEN`] bytes of the slice will be overwritten with the authentication tag.
    #[inline]
    pub fn seal(&mut self, label: &str, in_out: &mut [u8]) {
        // Split the buffer into plaintext and tag.
        let (in_out, tag128_out) = in_out.split_at_mut(in_out.len() - TAG_LEN);

        // Append an AuthCrypt op header with the label to the transcript.
        //
        //   0x05 || label || right_encode(|label|)
        self.op_header(OpCode::AuthCrypt, label);

        // Perform a Mix operation with the plaintext length.
        self.mix("len", right_encode(&mut [0u8; 9], in_out.len() as u64 * 8));

        // Derive an AEGIS-128L key and nonce.
        let kn = self.derive_array::<32>("key");
        let (k, n) = kn.split_at(16);
        let mut aegis = Aegis128L::new(
            k.try_into().expect("should be 16 bytes"),
            n.try_into().expect("should be 16 bytes"),
        );

        // Encrypt the plaintext.
        aegis.encrypt(in_out);

        // Finalize the AEGIS-128L tags.
        let (tag128, tag256) = aegis.finalize();

        // Append the 128-bit AEGIS-128L tag to the ciphertext.
        tag128_out.copy_from_slice(&tag128);

        // Perform a Mix operation with the 256-bit AEGIS-128L tag.
        self.mix("tag", &tag256);
    }

    /// Opens the given mutable slice in place. Returns the plaintext slice of `in_out` if the input
    /// was authenticated. The last [`TAG_LEN`] bytes of the slice will be unmodified.
    #[inline]
    #[must_use]
    pub fn open<'ct>(&mut self, label: &str, in_out: &'ct mut [u8]) -> Option<&'ct [u8]> {
        // Split the buffer into ciphertext and tag.
        let (in_out, tag128_in) = in_out.split_at_mut(in_out.len() - TAG_LEN);

        // Append an AuthCrypt op header with the label to the transcript.
        //
        //   0x05 || label || right_encode(|label|)
        self.op_header(OpCode::AuthCrypt, label);

        // Perform a Mix operation with the plaintext length.
        self.mix("len", right_encode(&mut [0u8; 9], in_out.len() as u64 * 8));

        // Derive an AEGIS-128L key and nonce.
        let kn = self.derive_array::<32>("key");
        let (k, n) = kn.split_at(16);
        let mut aegis = Aegis128L::new(
            k.try_into().expect("should be 16 bytes"),
            n.try_into().expect("should be 16 bytes"),
        );

        // Decrypt the ciphertext.
        aegis.decrypt(in_out);

        // Finalize the AEGIS-128L tags.
        let (tag128, tag256) = aegis.finalize();

        // Perform a Mix operation with the 256-bit AEGIS-128L tag.
        self.mix("tag", &tag256);

        // Check the tag against the counterfactual tag in constant time.
        if tag128_in.ct_eq(&tag128).into() {
            // If the tag is verified, then the ciphertext is authentic. Return the slice of the
            // input which contains the plaintext.
            Some(in_out)
        } else {
            // Otherwise, the ciphertext is inauthentic and we zero out the inauthentic plaintext to
            // avoid bugs where the caller forgets to check the return value of this function and
            // discloses inauthentic plaintext.
            in_out.fill(0);
            None
        }
    }

    /// Clones the protocol and mixes `secrets` plus 64 random bytes into the clone. Passes the
    /// clone to `f` and if `f` returns `Some(R)`, returns `R`. Iterates until a value is returned.
    #[cfg(feature = "hedge")]
    #[must_use]
    pub fn hedge<R>(
        &self,
        mut rng: impl rand_core::CryptoRngCore,
        secrets: &[impl AsRef<[u8]>],
        max_tries: usize,
        f: impl Fn(&mut Self) -> Option<R>,
    ) -> R {
        for _ in 0..max_tries {
            // Clone the protocol's state.
            let mut clone = self.clone();

            // Mix each secret into the clone.
            for s in secrets {
                clone.mix("secret", s.as_ref());
            }

            // Mix a random value into the clone.
            let mut r = [0u8; 64];
            rng.fill_bytes(&mut r);
            clone.mix("nonce", &r);

            // Call the given function with the clone and return if the function was successful.
            if let Some(r) = f(&mut clone) {
                return r;
            }
        }

        unreachable!("unable to hedge a valid value in {} tries", max_tries);
    }

    /// Appends an operation header with an optional label to the protocol transcript.
    #[inline]
    fn op_header(&mut self, op_code: OpCode, label: &str) {
        // Append the operation code and label to the transcript:
        //
        //   op_code || label || right_encode(|label|)
        self.transcript.input_ikm(&[op_code as u8]);
        self.transcript.input_ikm(label.as_bytes());
        self.transcript.input_ikm(right_encode(&mut [0u8; 9], label.len() as u64 * 8));
    }
}

/// All Lockstitch operation types.
#[derive(Debug, Clone, Copy)]
enum OpCode {
    /// Initialize a protocol with a domain separation string.
    Init = 0x01,
    /// Mix a labeled input into the protocol transcript.
    Mix = 0x02,
    /// Derive a labeled output from the protocol transcript.
    Derive = 0x03,
    /// Encrypt or decrypt a labeled input using the protocol transcript as a key.
    Crypt = 0x04,
    /// Seal or open a labeled input using the protocol transcript as a key.
    AuthCrypt = 0x05,
}

/// A [`std::io::Write`] implementation which combines all written data into a single `Mix`
/// operation and passes all writes to an inner writer.
#[cfg(feature = "std")]
#[derive(Debug)]
pub struct MixWriter<W> {
    protocol: Protocol,
    inner: W,
    len: u64,
}

#[cfg(feature = "std")]
impl<W: std::io::Write> MixWriter<W> {
    /// Finishes the `Mix` operation and returns the inner [`Protocol`] and writer.
    #[inline]
    pub fn into_inner(mut self) -> (Protocol, W) {
        // Append the right-encoded length to the transcript.
        self.protocol.transcript.input_ikm(right_encode(&mut [0u8; 9], self.len * 8));
        (self.protocol, self.inner)
    }
}

#[cfg(feature = "std")]
impl<W: std::io::Write> std::io::Write for MixWriter<W> {
    #[inline]
    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
        // Track the written length.
        self.len += buf.len() as u64;
        // Append the written slice to the protocol transcript.
        self.protocol.transcript.input_ikm(buf);
        // Pass the slice to the inner writer and return the result.
        self.inner.write(buf)
    }

    #[inline]
    fn flush(&mut self) -> std::io::Result<()> {
        self.inner.flush()
    }
}

/// Encodes a value using [NIST SP 800-185][]'s `right_encode`.
///
/// [NIST SP 800-185]: https://www.nist.gov/publications/sha-3-derived-functions-cshake-kmac-tuplehash-and-parallelhash
#[inline]
fn right_encode(buf: &mut [u8; 9], value: u64) -> &[u8] {
    let len = buf.len();
    buf[..len - 1].copy_from_slice(&value.to_be_bytes());
    let n = (len - 1 - value.leading_zeros() as usize / 8).max(1);
    buf[len - 1] = n as u8;
    &buf[len - n - 1..]
}

#[cfg(all(test, feature = "std"))]
mod tests {
    use std::io::{self, Cursor};

    use expect_test::expect;

    use super::*;

    #[test]
    fn known_answers() {
        let mut protocol = Protocol::new("com.example.kat");
        protocol.mix("first", b"one");
        protocol.mix("second", b"two");

        expect!["20ea2bf0d8234351"].assert_eq(&hex::encode(protocol.derive_array::<8>("third")));

        let mut plaintext = b"this is an example".to_vec();
        protocol.encrypt("fourth", &mut plaintext);
        expect!["e06289eeea8f938c65ca984eb1c1a9df6557"].assert_eq(&hex::encode(plaintext));

        let plaintext = b"this is an example";
        let mut sealed = vec![0u8; plaintext.len() + TAG_LEN];
        sealed[..plaintext.len()].copy_from_slice(plaintext);
        protocol.seal("fifth", &mut sealed);

        expect!["c5e08d9df027dab5f83c30314c098bd65eb4ac6866dd154802b47b0c4cce5b14ab7a"]
            .assert_eq(&hex::encode(sealed));

        expect!["2ddaec8811f6092a"].assert_eq(&hex::encode(protocol.derive_array::<8>("sixth")));
    }

    #[test]
    fn readers() {
        let mut slices = Protocol::new("com.example.streams");
        slices.mix("first", b"one");
        slices.mix("second", b"two");

        let streams = Protocol::new("com.example.streams");
        let mut streams_write = streams.mix_writer("first", io::sink());
        io::copy(&mut Cursor::new(b"one"), &mut streams_write)
            .expect("cursor reads and sink writes should be infallible");
        let (streams, _) = streams_write.into_inner();

        let mut output = Vec::new();
        let mut streams_write = streams.mix_writer("second", &mut output);
        io::copy(&mut Cursor::new(b"two"), &mut streams_write)
            .expect("cursor reads and sink writes should be infallible");
        let (mut streams, output) = streams_write.into_inner();

        assert_eq!(slices.derive_array::<16>("third"), streams.derive_array::<16>("third"));
        assert_eq!(b"two".as_slice(), output);
    }

    #[test]
    #[cfg(feature = "hedge")]
    fn hedging() {
        let mut hedger = Protocol::new("com.example.hedge");
        hedger.mix("first", b"one");
        let tag = hedger.hedge(rand::thread_rng(), &[b"two"], 10_000, |clone| {
            let tag = clone.derive_array::<16>("tag");
            (tag[0] == 0).then_some(tag)
        });

        assert_eq!(tag[0], 0);
    }

    #[test]
    fn edge_case() {
        let mut sender = Protocol::new("");
        let mut message = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
        sender.encrypt("message", &mut message);
        let tag_s = sender.derive_array::<TAG_LEN>("tag");

        let mut receiver = Protocol::new("");
        receiver.decrypt("message", &mut message);
        let tag_r = receiver.derive_array::<TAG_LEN>("tag");

        assert_eq!(message, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]);
        assert_eq!(tag_s, tag_r);
    }

    #[test]
    fn right_encode_injective() {
        bolero::check!().with_type::<(u64, u64)>().cloned().for_each(|(a, b)| {
            let mut buf_a = [0u8; 9];
            let mut buf_b = [0u8; 9];

            let a_e = right_encode(&mut buf_a, a);
            let b_e = right_encode(&mut buf_b, b);

            if a == b {
                assert_eq!(a_e, b_e);
            } else {
                assert_ne!(a_e, b_e);
            }
        });
    }

    #[test]
    fn encoded_label_injective() {
        bolero::check!().with_type::<(Vec<u8>, Vec<u8>)>().cloned().for_each(|(a, b)| {
            let mut a_e = a.clone();
            a_e.extend_from_slice(right_encode(&mut [0u8; 9], a.len() as u64 * 8));

            let mut b_e = b.clone();
            b_e.extend_from_slice(right_encode(&mut [0u8; 9], b.len() as u64 * 8));

            if a == b {
                assert_eq!(a_e, b_e, "equal labels must have equal encoded forms");
            } else {
                assert_ne!(a_e, b_e, "non-equal labels must have non-equal encoded forms");
            }
        });
    }

    #[test]
    fn right_encode_test_vectors() {
        let mut buf = [0; 9];

        assert_eq!(right_encode(&mut buf, 0), [0, 1]);

        assert_eq!(right_encode(&mut buf, 128), [128, 1]);

        assert_eq!(right_encode(&mut buf, 65536), [1, 0, 0, 3]);

        assert_eq!(right_encode(&mut buf, 4096), [16, 0, 2]);

        assert_eq!(
            right_encode(&mut buf, 18446744073709551615),
            [255, 255, 255, 255, 255, 255, 255, 255, 8]
        );

        assert_eq!(right_encode(&mut buf, 12345), [48, 57, 2]);
    }
}