gmcrypto_core/sm4/ctr_streaming.rs
1//! Streaming SM4-CTR cipher (v0.7 W3).
2//!
3//! Stateful counterpart to [`super::mode_ctr`]'s single-shot
4//! [`encrypt`] / [`decrypt`]: usable when the plaintext isn't
5//! available all at once (network framing, file streaming, etc).
6//!
7//! See [`super::mode_ctr`]'s module docstring for the counter
8//! contract, the no-authenticity caveat, and the BE counter-encoding
9//! semantics — the same rules apply here. CTR is its own inverse, so
10//! [`Sm4CtrCipher`] serves both directions; there's no separate
11//! `Sm4CtrEncryptor` / `Sm4CtrDecryptor` split (compare with CBC's
12//! [`super::cbc_streaming`] which needs separate types for the
13//! padding-bookkeeping asymmetry).
14//!
15//! [`encrypt`]: super::mode_ctr::encrypt
16//! [`decrypt`]: super::mode_ctr::decrypt
17
18use alloc::vec::Vec;
19
20use super::cipher::{BLOCK_SIZE, KEY_SIZE, Sm4Cipher};
21
22/// Streaming SM4-CTR cipher.
23///
24/// Construct with [`Sm4CtrCipher::new`], feed input/output buffer
25/// pairs through [`update`], drop or [`finalize`] when done. Single
26/// struct serves both encrypt and decrypt (CTR is symmetric).
27///
28/// State machine:
29///
30/// - `counter`: next 128-bit BE counter to evaluate.
31/// - `leftover`: the most recently evaluated keystream block.
32/// - `leftover_pos`: in `0..=16`. Bytes `[leftover_pos..16]` of
33/// `leftover` are unconsumed keystream from a previous partial
34/// call; on next [`update`] they're consumed first before the
35/// counter advances.
36///
37/// Internal invariant: when `leftover_pos == 16`, no carried-over
38/// keystream; the next byte requires a fresh `encrypt_block` of
39/// `counter`. Initial state is `leftover_pos = 16` (no leftover).
40///
41/// [`update`]: Sm4CtrCipher::update
42/// [`finalize`]: Sm4CtrCipher::finalize
43pub struct Sm4CtrCipher {
44 cipher: Sm4Cipher,
45 counter: [u8; BLOCK_SIZE],
46 leftover: [u8; BLOCK_SIZE],
47 leftover_pos: usize,
48}
49
50impl Sm4CtrCipher {
51 /// Construct from a 16-byte key and a 16-byte initial counter.
52 /// Counter is treated as a 128-bit BE integer; per-block
53 /// keystream is `SM4_E(key, counter + i)` for `i = 0..N-1`.
54 #[must_use]
55 pub fn new(key: &[u8; KEY_SIZE], counter: &[u8; BLOCK_SIZE]) -> Self {
56 Self {
57 cipher: Sm4Cipher::new(key),
58 counter: *counter,
59 leftover: [0u8; BLOCK_SIZE],
60 // Signals "no carried-over keystream" — the first `update`
61 // will generate a fresh block from `counter`.
62 leftover_pos: BLOCK_SIZE,
63 }
64 }
65
66 /// Consume `input` and write `input.len()` bytes of output. The
67 /// output buffer must be at least as long as the input; only the
68 /// leading `input.len()` bytes are written.
69 ///
70 /// # Panics
71 ///
72 /// Panics if `output.len() < input.len()`.
73 pub fn update(&mut self, input: &[u8], output: &mut [u8]) {
74 assert!(
75 output.len() >= input.len(),
76 "Sm4CtrCipher::update: output buffer too short ({} < {})",
77 output.len(),
78 input.len(),
79 );
80 let mut i = 0usize;
81
82 // Step 1: drain any carried-over keystream from a previous
83 // call. Bytes [self.leftover_pos..16] of self.leftover are
84 // unconsumed and apply to the leading bytes of `input`.
85 while i < input.len() && self.leftover_pos < BLOCK_SIZE {
86 output[i] = input[i] ^ self.leftover[self.leftover_pos];
87 self.leftover_pos += 1;
88 i += 1;
89 }
90
91 // Step 2: process the largest aligned run of full keystream
92 // blocks via the SIMD batch path in `Sm4Cipher::encrypt_blocks`
93 // (v0.7 W1 — fans through `sbox_x32` / `sbox_x16` under
94 // `sm4-bitsliced-simd`).
95 let remaining = input.len() - i;
96 let full_blocks = remaining / BLOCK_SIZE;
97 if full_blocks > 0 {
98 let mut keystream: Vec<[u8; BLOCK_SIZE]> = (0..full_blocks)
99 .map(|j| counter_add(&self.counter, j as u128))
100 .collect();
101 self.cipher.encrypt_blocks(&mut keystream);
102
103 for (b, ks) in keystream.iter().enumerate() {
104 let off = i + b * BLOCK_SIZE;
105 for lane in 0..BLOCK_SIZE {
106 output[off + lane] = input[off + lane] ^ ks[lane];
107 }
108 }
109 self.counter = counter_add(&self.counter, full_blocks as u128);
110 i += full_blocks * BLOCK_SIZE;
111 }
112
113 // Step 3: tail — strictly less than one block remains.
114 // Generate one keystream block from the current counter,
115 // consume what's needed, save the rest as leftover for the
116 // next `update` call.
117 if i < input.len() {
118 self.leftover = self.counter;
119 self.cipher.encrypt_block(&mut self.leftover);
120 self.counter = counter_add(&self.counter, 1);
121 self.leftover_pos = 0;
122 while i < input.len() {
123 output[i] = input[i] ^ self.leftover[self.leftover_pos];
124 self.leftover_pos += 1;
125 i += 1;
126 }
127 }
128 }
129
130 /// Finalize and drop the cipher. CTR has no padding to flush and
131 /// no authenticity bits to emit, so this is a stateless drop.
132 /// Provided for symmetry with [`super::cbc_streaming`] and so
133 /// the call site reads intuitively.
134 pub fn finalize(self) {
135 // Drop self.
136 }
137}
138
139/// Treat `counter` as a 128-bit big-endian integer; return
140/// `(counter + offset) mod 2^128`, BE-encoded.
141const fn counter_add(counter: &[u8; BLOCK_SIZE], offset: u128) -> [u8; BLOCK_SIZE] {
142 let n = u128::from_be_bytes(*counter);
143 n.wrapping_add(offset).to_be_bytes()
144}
145
146#[cfg(test)]
147mod tests {
148 use alloc::vec;
149
150 use super::*;
151 use crate::sm4::mode_ctr;
152
153 const KEY: [u8; 16] = [
154 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32,
155 0x10,
156 ];
157 const COUNTER: [u8; 16] = [0x42u8; 16];
158
159 #[allow(clippy::cast_possible_truncation)]
160 fn make_plaintext(len: usize) -> Vec<u8> {
161 (0..len)
162 .map(|i| {
163 let s = (i as u32).wrapping_mul(0x9E37_79B9);
164 (s ^ (s >> 17)) as u8
165 })
166 .collect()
167 }
168
169 /// Single-shot through the streaming API matches the single-shot
170 /// top-level function from `mode_ctr`.
171 #[test]
172 fn streaming_single_call_matches_single_shot() {
173 for len in 0..=33 {
174 let plaintext = make_plaintext(len);
175 let mut out = vec![0u8; len];
176 let mut cipher = Sm4CtrCipher::new(&KEY, &COUNTER);
177 cipher.update(&plaintext, &mut out);
178 cipher.finalize();
179
180 let expected = mode_ctr::encrypt(&KEY, &COUNTER, &plaintext);
181 assert_eq!(out, expected, "single-call divergence at length {len}");
182 }
183 }
184
185 /// Chunked `update` at every chunk size 1..=17 produces
186 /// byte-identical output to a single-shot call. Exercises the
187 /// leftover-keystream-buffer state machine across every
188 /// possible carry position.
189 #[test]
190 fn chunked_update_sweep_matches_single_shot() {
191 let total = 64;
192 let plaintext = make_plaintext(total);
193 let reference = mode_ctr::encrypt(&KEY, &COUNTER, &plaintext);
194
195 for chunk_size in 1..=17 {
196 let mut cipher = Sm4CtrCipher::new(&KEY, &COUNTER);
197 let mut out = vec![0u8; total];
198 let mut written = 0;
199 while written < total {
200 let take = chunk_size.min(total - written);
201 cipher.update(
202 &plaintext[written..written + take],
203 &mut out[written..written + take],
204 );
205 written += take;
206 }
207 cipher.finalize();
208 assert_eq!(
209 out, reference,
210 "chunked update divergence at chunk_size {chunk_size}",
211 );
212 }
213 }
214
215 /// Round-trip through the streaming API: encrypt then decrypt
216 /// recovers plaintext at every length 0..=33.
217 #[test]
218 fn streaming_round_trip_is_identity() {
219 for len in 0..=33 {
220 let plaintext = make_plaintext(len);
221 let mut ciphertext = vec![0u8; len];
222 let mut enc = Sm4CtrCipher::new(&KEY, &COUNTER);
223 enc.update(&plaintext, &mut ciphertext);
224 enc.finalize();
225
226 let mut recovered = vec![0u8; len];
227 let mut dec = Sm4CtrCipher::new(&KEY, &COUNTER);
228 dec.update(&ciphertext, &mut recovered);
229 dec.finalize();
230
231 assert_eq!(recovered, plaintext, "streaming round-trip at length {len}");
232 }
233 }
234
235 /// Empty input through `update` is a no-op.
236 #[test]
237 fn empty_update() {
238 let mut cipher = Sm4CtrCipher::new(&KEY, &COUNTER);
239 let mut out = [];
240 cipher.update(&[], &mut out);
241 cipher.finalize();
242 }
243}