Skip to main content

gmcrypto_core/sm4/
ctr_streaming.rs

1//! Streaming SM4-CTR cipher (v0.7 W3).
2//!
3//! Stateful counterpart to [`super::mode_ctr`]'s single-shot
4//! [`encrypt`] / [`decrypt`]: usable when the plaintext isn't
5//! available all at once (network framing, file streaming, etc).
6//!
7//! See [`super::mode_ctr`]'s module docstring for the counter
8//! contract, the no-authenticity caveat, and the BE counter-encoding
9//! semantics — the same rules apply here. CTR is its own inverse, so
10//! [`Sm4CtrCipher`] serves both directions; there's no separate
11//! `Sm4CtrEncryptor` / `Sm4CtrDecryptor` split (compare with CBC's
12//! [`super::cbc_streaming`] which needs separate types for the
13//! padding-bookkeeping asymmetry).
14//!
15//! [`encrypt`]: super::mode_ctr::encrypt
16//! [`decrypt`]: super::mode_ctr::decrypt
17
18use alloc::vec::Vec;
19
20use super::cipher::{BLOCK_SIZE, KEY_SIZE, Sm4Cipher};
21
22/// Streaming SM4-CTR cipher.
23///
24/// Construct with [`Sm4CtrCipher::new`], feed input/output buffer
25/// pairs through [`update`], drop or [`finalize`] when done. Single
26/// struct serves both encrypt and decrypt (CTR is symmetric).
27///
28/// State machine:
29///
30/// - `counter`: next 128-bit BE counter to evaluate.
31/// - `leftover`: the most recently evaluated keystream block.
32/// - `leftover_pos`: in `0..=16`. Bytes `[leftover_pos..16]` of
33///   `leftover` are unconsumed keystream from a previous partial
34///   call; on next [`update`] they're consumed first before the
35///   counter advances.
36///
37/// Internal invariant: when `leftover_pos == 16`, no carried-over
38/// keystream; the next byte requires a fresh `encrypt_block` of
39/// `counter`. Initial state is `leftover_pos = 16` (no leftover).
40///
41/// [`update`]: Sm4CtrCipher::update
42/// [`finalize`]: Sm4CtrCipher::finalize
43pub struct Sm4CtrCipher {
44    cipher: Sm4Cipher,
45    counter: [u8; BLOCK_SIZE],
46    leftover: [u8; BLOCK_SIZE],
47    leftover_pos: usize,
48}
49
50impl Sm4CtrCipher {
51    /// Construct from a 16-byte key and a 16-byte initial counter.
52    /// Counter is treated as a 128-bit BE integer; per-block
53    /// keystream is `SM4_E(key, counter + i)` for `i = 0..N-1`.
54    #[must_use]
55    pub fn new(key: &[u8; KEY_SIZE], counter: &[u8; BLOCK_SIZE]) -> Self {
56        Self {
57            cipher: Sm4Cipher::new(key),
58            counter: *counter,
59            leftover: [0u8; BLOCK_SIZE],
60            // Signals "no carried-over keystream" — the first `update`
61            // will generate a fresh block from `counter`.
62            leftover_pos: BLOCK_SIZE,
63        }
64    }
65
66    /// Consume `input` and write `input.len()` bytes of output. The
67    /// output buffer must be at least as long as the input; only the
68    /// leading `input.len()` bytes are written.
69    ///
70    /// # Panics
71    ///
72    /// Panics if `output.len() < input.len()`.
73    pub fn update(&mut self, input: &[u8], output: &mut [u8]) {
74        assert!(
75            output.len() >= input.len(),
76            "Sm4CtrCipher::update: output buffer too short ({} < {})",
77            output.len(),
78            input.len(),
79        );
80        let mut i = 0usize;
81
82        // Step 1: drain any carried-over keystream from a previous
83        // call. Bytes [self.leftover_pos..16] of self.leftover are
84        // unconsumed and apply to the leading bytes of `input`.
85        while i < input.len() && self.leftover_pos < BLOCK_SIZE {
86            output[i] = input[i] ^ self.leftover[self.leftover_pos];
87            self.leftover_pos += 1;
88            i += 1;
89        }
90
91        // Step 2: process the largest aligned run of full keystream
92        // blocks via the SIMD batch path in `Sm4Cipher::encrypt_blocks`
93        // (v0.7 W1 — fans through `sbox_x32` / `sbox_x16` under
94        // `sm4-bitsliced-simd`).
95        let remaining = input.len() - i;
96        let full_blocks = remaining / BLOCK_SIZE;
97        if full_blocks > 0 {
98            let mut keystream: Vec<[u8; BLOCK_SIZE]> = (0..full_blocks)
99                .map(|j| counter_add(&self.counter, j as u128))
100                .collect();
101            self.cipher.encrypt_blocks(&mut keystream);
102
103            for (b, ks) in keystream.iter().enumerate() {
104                let off = i + b * BLOCK_SIZE;
105                for lane in 0..BLOCK_SIZE {
106                    output[off + lane] = input[off + lane] ^ ks[lane];
107                }
108            }
109            self.counter = counter_add(&self.counter, full_blocks as u128);
110            i += full_blocks * BLOCK_SIZE;
111        }
112
113        // Step 3: tail — strictly less than one block remains.
114        // Generate one keystream block from the current counter,
115        // consume what's needed, save the rest as leftover for the
116        // next `update` call.
117        if i < input.len() {
118            self.leftover = self.counter;
119            self.cipher.encrypt_block(&mut self.leftover);
120            self.counter = counter_add(&self.counter, 1);
121            self.leftover_pos = 0;
122            while i < input.len() {
123                output[i] = input[i] ^ self.leftover[self.leftover_pos];
124                self.leftover_pos += 1;
125                i += 1;
126            }
127        }
128    }
129
130    /// Finalize and drop the cipher. CTR has no padding to flush and
131    /// no authenticity bits to emit, so this is a stateless drop.
132    /// Provided for symmetry with [`super::cbc_streaming`] and so
133    /// the call site reads intuitively.
134    pub fn finalize(self) {
135        // Drop self.
136    }
137}
138
139/// Treat `counter` as a 128-bit big-endian integer; return
140/// `(counter + offset) mod 2^128`, BE-encoded.
141const fn counter_add(counter: &[u8; BLOCK_SIZE], offset: u128) -> [u8; BLOCK_SIZE] {
142    let n = u128::from_be_bytes(*counter);
143    n.wrapping_add(offset).to_be_bytes()
144}
145
146#[cfg(test)]
147mod tests {
148    use alloc::vec;
149
150    use super::*;
151    use crate::sm4::mode_ctr;
152
153    const KEY: [u8; 16] = [
154        0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32,
155        0x10,
156    ];
157    const COUNTER: [u8; 16] = [0x42u8; 16];
158
159    #[allow(clippy::cast_possible_truncation)]
160    fn make_plaintext(len: usize) -> Vec<u8> {
161        (0..len)
162            .map(|i| {
163                let s = (i as u32).wrapping_mul(0x9E37_79B9);
164                (s ^ (s >> 17)) as u8
165            })
166            .collect()
167    }
168
169    /// Single-shot through the streaming API matches the single-shot
170    /// top-level function from `mode_ctr`.
171    #[test]
172    fn streaming_single_call_matches_single_shot() {
173        for len in 0..=33 {
174            let plaintext = make_plaintext(len);
175            let mut out = vec![0u8; len];
176            let mut cipher = Sm4CtrCipher::new(&KEY, &COUNTER);
177            cipher.update(&plaintext, &mut out);
178            cipher.finalize();
179
180            let expected = mode_ctr::encrypt(&KEY, &COUNTER, &plaintext);
181            assert_eq!(out, expected, "single-call divergence at length {len}");
182        }
183    }
184
185    /// Chunked `update` at every chunk size 1..=17 produces
186    /// byte-identical output to a single-shot call. Exercises the
187    /// leftover-keystream-buffer state machine across every
188    /// possible carry position.
189    #[test]
190    fn chunked_update_sweep_matches_single_shot() {
191        let total = 64;
192        let plaintext = make_plaintext(total);
193        let reference = mode_ctr::encrypt(&KEY, &COUNTER, &plaintext);
194
195        for chunk_size in 1..=17 {
196            let mut cipher = Sm4CtrCipher::new(&KEY, &COUNTER);
197            let mut out = vec![0u8; total];
198            let mut written = 0;
199            while written < total {
200                let take = chunk_size.min(total - written);
201                cipher.update(
202                    &plaintext[written..written + take],
203                    &mut out[written..written + take],
204                );
205                written += take;
206            }
207            cipher.finalize();
208            assert_eq!(
209                out, reference,
210                "chunked update divergence at chunk_size {chunk_size}",
211            );
212        }
213    }
214
215    /// Round-trip through the streaming API: encrypt then decrypt
216    /// recovers plaintext at every length 0..=33.
217    #[test]
218    fn streaming_round_trip_is_identity() {
219        for len in 0..=33 {
220            let plaintext = make_plaintext(len);
221            let mut ciphertext = vec![0u8; len];
222            let mut enc = Sm4CtrCipher::new(&KEY, &COUNTER);
223            enc.update(&plaintext, &mut ciphertext);
224            enc.finalize();
225
226            let mut recovered = vec![0u8; len];
227            let mut dec = Sm4CtrCipher::new(&KEY, &COUNTER);
228            dec.update(&ciphertext, &mut recovered);
229            dec.finalize();
230
231            assert_eq!(recovered, plaintext, "streaming round-trip at length {len}");
232        }
233    }
234
235    /// Empty input through `update` is a no-op.
236    #[test]
237    fn empty_update() {
238        let mut cipher = Sm4CtrCipher::new(&KEY, &COUNTER);
239        let mut out = [];
240        cipher.update(&[], &mut out);
241        cipher.finalize();
242    }
243}