Skip to main content

crypt_io/stream/
encryptor.rs

1//! Streaming AEAD encryptor.
2
3use alloc::vec::Vec;
4
5use crate::aead::Algorithm;
6use crate::error::{Error, Result};
7
8use super::aead::encrypt_chunk;
9use super::frame::{
10    DEFAULT_CHUNK_SIZE_LOG2, HEADER_LEN, MAX_CHUNK_SIZE_LOG2, MIN_CHUNK_SIZE_LOG2,
11    NONCE_PREFIX_LEN, build_header, build_nonce, chunk_size_from_log2,
12};
13
14/// Streaming AEAD encryptor. Buffers caller-supplied plaintext into
15/// fixed-size chunks, encrypts each chunk with a STREAM-construction
16/// nonce, and emits `ciphertext || tag` per chunk.
17///
18/// Usage is symmetric with the [`super::StreamDecryptor`]:
19///
20/// 1. Construct with [`StreamEncryptor::new`]. The constructor returns
21///    the encryptor and a 24-byte header — write this header to the
22///    output sink first.
23/// 2. Feed plaintext via [`update`](Self::update). The method returns
24///    zero or more encrypted chunks (each `chunk_size + 16` bytes) as
25///    buffer fills are reached.
26/// 3. Call [`finalize`](Self::finalize) to emit any remaining buffered
27///    data as the final chunk. The final chunk is **always** emitted
28///    (even if zero plaintext bytes remain) and is always strictly
29///    smaller than `chunk_size + 16` bytes, so the decryptor can
30///    detect it unambiguously by length.
31///
32/// # Example
33///
34/// ```
35/// # #[cfg(all(feature = "stream", feature = "aead-chacha20"))] {
36/// use crypt_io::stream::{StreamDecryptor, StreamEncryptor};
37/// use crypt_io::Algorithm;
38///
39/// let key = [0u8; 32];
40/// let plaintext = b"the quick brown fox jumps over the lazy dog".repeat(1000);
41///
42/// // ---- Encrypt ----
43/// let (mut enc, header) = StreamEncryptor::new(&key, Algorithm::ChaCha20Poly1305)?;
44/// let mut wire = header.to_vec();
45/// wire.extend(enc.update(&plaintext)?);
46/// wire.extend(enc.finalize()?);
47///
48/// // ---- Decrypt ----
49/// let mut dec = StreamDecryptor::new(&key, &wire[..24])?;
50/// let mut recovered = dec.update(&wire[24..])?;
51/// recovered.extend(dec.finalize()?);
52/// assert_eq!(recovered, plaintext);
53/// # }
54/// # Ok::<(), crypt_io::Error>(())
55/// ```
56#[derive(Debug)]
57pub struct StreamEncryptor {
58    algorithm: Algorithm,
59    key: [u8; 32],
60    nonce_prefix: [u8; NONCE_PREFIX_LEN],
61    aad: [u8; HEADER_LEN],
62    counter: u32,
63    chunk_size: usize,
64    chunk_size_log2: u8,
65    /// Plaintext awaiting chunking. Held capacity is `chunk_size`.
66    buffer: Vec<u8>,
67}
68
69impl StreamEncryptor {
70    /// Construct a new stream encryptor with the default 64 KiB chunk
71    /// size. Returns the encryptor plus the 24-byte header to be
72    /// written to the output sink before any encrypted chunks.
73    ///
74    /// # Errors
75    ///
76    /// - [`Error::InvalidKey`] if `key` is not 32 bytes.
77    /// - [`Error::RandomFailure`] if the OS RNG cannot produce a
78    ///   nonce prefix.
79    pub fn new(key: &[u8], algorithm: Algorithm) -> Result<(Self, [u8; HEADER_LEN])> {
80        Self::new_with_chunk_size(key, algorithm, DEFAULT_CHUNK_SIZE_LOG2)
81    }
82
83    /// Construct with an explicit chunk size. `chunk_size_log2` must
84    /// be in [`MIN_CHUNK_SIZE_LOG2`]`..=`[`MAX_CHUNK_SIZE_LOG2`]
85    /// (10..=24).
86    ///
87    /// # Errors
88    ///
89    /// See [`new`](Self::new), plus
90    /// [`Error::InvalidCiphertext`](crate::Error::InvalidCiphertext)
91    /// on out-of-range chunk size.
92    pub fn new_with_chunk_size(
93        key: &[u8],
94        algorithm: Algorithm,
95        chunk_size_log2: u8,
96    ) -> Result<(Self, [u8; HEADER_LEN])> {
97        check_key(key)?;
98        if !(MIN_CHUNK_SIZE_LOG2..=MAX_CHUNK_SIZE_LOG2).contains(&chunk_size_log2) {
99            return Err(Error::InvalidCiphertext(alloc::format!(
100                "chunk_size_log2 out of range: {chunk_size_log2}"
101            )));
102        }
103
104        let mut nonce_prefix = [0u8; NONCE_PREFIX_LEN];
105        mod_rand::tier3::fill_bytes(&mut nonce_prefix)
106            .map_err(|_| Error::RandomFailure("mod_rand::tier3::fill_bytes"))?;
107
108        let header = build_header(algorithm, chunk_size_log2, &nonce_prefix);
109        let chunk_size = chunk_size_from_log2(chunk_size_log2);
110
111        let mut key_arr = [0u8; 32];
112        key_arr.copy_from_slice(key);
113
114        let enc = Self {
115            algorithm,
116            key: key_arr,
117            nonce_prefix,
118            aad: header,
119            counter: 0,
120            chunk_size,
121            chunk_size_log2,
122            buffer: Vec::with_capacity(chunk_size),
123        };
124        Ok((enc, header))
125    }
126
127    /// Chunk size in bytes used by this encryptor.
128    #[must_use]
129    pub fn chunk_size(&self) -> usize {
130        self.chunk_size
131    }
132
133    /// Log2 of the chunk size, as stored in the header.
134    #[must_use]
135    pub fn chunk_size_log2(&self) -> u8 {
136        self.chunk_size_log2
137    }
138
139    /// Feed plaintext bytes. Returns zero or more complete encrypted
140    /// chunks (each `chunk_size + 16` bytes) concatenated.
141    ///
142    /// # Errors
143    ///
144    /// - [`Error::AuthenticationFailed`] on an upstream AEAD failure
145    ///   (unreachable in practice).
146    pub fn update(&mut self, data: &[u8]) -> Result<Vec<u8>> {
147        if data.is_empty() {
148            return Ok(Vec::new());
149        }
150
151        // Worst case: every byte triggers a chunk boundary. In practice
152        // it's at most `data.len() / chunk_size + 1` chunks.
153        let estimated_chunks = data.len() / self.chunk_size + 1;
154        let mut out = Vec::with_capacity(estimated_chunks * (self.chunk_size + 16));
155
156        // Fill the buffer up to `chunk_size`, emit a non-final chunk,
157        // then repeat with the remainder.
158        let mut cursor = 0usize;
159        while cursor < data.len() {
160            let needed = self.chunk_size - self.buffer.len();
161            let take = needed.min(data.len() - cursor);
162            self.buffer.extend_from_slice(&data[cursor..cursor + take]);
163            cursor += take;
164
165            if self.buffer.len() == self.chunk_size {
166                let nonce = build_nonce(&self.nonce_prefix, self.counter, false);
167                let chunk =
168                    encrypt_chunk(self.algorithm, &self.key, &nonce, &self.buffer, &self.aad)?;
169                out.extend_from_slice(&chunk);
170                self.counter = self.counter.checked_add(1).ok_or(Error::InvalidCiphertext(
171                    alloc::string::String::from("stream chunk counter overflow"),
172                ))?;
173                self.buffer.clear();
174            }
175        }
176
177        Ok(out)
178    }
179
180    /// Flush remaining buffered plaintext as the final chunk. Always
181    /// emits at least 16 bytes (the AEAD tag), so the receiver sees
182    /// an unambiguous "final" frame.
183    ///
184    /// # Errors
185    ///
186    /// Same as [`update`](Self::update).
187    pub fn finalize(mut self) -> Result<Vec<u8>> {
188        // If buffer is exactly chunk_size, emit it as non-final first,
189        // then a 0-byte final chunk. This keeps the invariant
190        // "final chunk is strictly < chunk_size + 16 bytes".
191        let mut out = Vec::with_capacity(self.chunk_size + 16);
192        if self.buffer.len() == self.chunk_size {
193            let nonce = build_nonce(&self.nonce_prefix, self.counter, false);
194            let chunk = encrypt_chunk(self.algorithm, &self.key, &nonce, &self.buffer, &self.aad)?;
195            out.extend_from_slice(&chunk);
196            self.counter = self.counter.checked_add(1).ok_or(Error::InvalidCiphertext(
197                alloc::string::String::from("stream chunk counter overflow"),
198            ))?;
199            self.buffer.clear();
200        }
201
202        let nonce = build_nonce(&self.nonce_prefix, self.counter, true);
203        let final_chunk =
204            encrypt_chunk(self.algorithm, &self.key, &nonce, &self.buffer, &self.aad)?;
205        out.extend_from_slice(&final_chunk);
206        Ok(out)
207    }
208}
209
210fn check_key(key: &[u8]) -> Result<()> {
211    if key.len() == 32 {
212        Ok(())
213    } else {
214        Err(Error::InvalidKey {
215            expected: 32,
216            actual: key.len(),
217        })
218    }
219}