gmcrypto-core 0.3.0

Constant-time-designed pure-Rust SM2/SM3 primitives (no_std + alloc) with an in-CI dudect timing-leak regression harness
Documentation
//! SM4 in CBC mode with PKCS#7 padding (GB/T 32907-2016 + RFC 5652 §6.3).
//!
//! # IV contract
//!
//! Per **NIST SP 800-38A Appendix C**, CBC IVs must be **unpredictable** —
//! generated by an FIPS-approved RBG or equivalent CSPRNG, never reused
//! under the same key. "Unique per key" is the **CTR**-mode rule and is
//! **insufficient** for CBC (predictable IVs leak via chosen-plaintext
//! attacks; see e.g. BEAST). Caller-supplied: this module does not
//! generate IVs internally — pull from `OsRng` or equivalent at the call
//! site.
//!
//! # Padding-oracle / authenticity caveat
//!
//! Raw CBC is **unauthenticated**. A network-attached attacker who can
//! distinguish "decrypt succeeded" from "decrypt failed" via timing or
//! side channels can mount a padding-oracle attack on the plaintext.
//! Callers needing integrity **MUST** pair CBC with HMAC-SM3 (W3) in
//! encrypt-then-MAC: serialize the IV plus ciphertext, compute the MAC
//! over `(IV || ciphertext)`, send `IV || ciphertext || tag`, verify
//! the MAC before invoking `decrypt`.
//!
//! [`decrypt`]'s PKCS#7 strip is implemented via a [`subtle`]-style
//! constant-time scan over the final block; the amount of work is
//! independent of `pad_len`'s value. The final `Option<Vec<u8>>` does
//! signal validity (one bit), which is unavoidable for this primitive
//! shape.
//!
//! # API
//!
//! Single-shot top-level functions, no streaming `Update`-style trait
//! in v0.2:
//!
//! - [`encrypt`] returns ciphertext (always `len + 1..=16` bytes — PKCS#7
//!   appends a full block of `0x10` when the plaintext is already a
//!   multiple of 16, per RFC 5652 §6.3).
//! - [`decrypt`] returns `Option<Vec<u8>>`. `None` collapses every
//!   failure mode (length not multiple of 16, malformed padding,
//!   inconsistent padding bytes) — no distinguishing variants per the
//!   project's failure-mode invariant.
//!
//! Streaming `BlockMode` / `Pad` trait wiring lands in v0.3 alongside
//! the broader trait generalization.

use crate::sm4::cipher::{BLOCK_SIZE, KEY_SIZE, Sm4Cipher};
use alloc::vec::Vec;
use subtle::{ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater};

/// Encrypt `plaintext` in CBC mode with PKCS#7 padding.
///
/// Output length is `plaintext.len().next_multiple_of(16) + (16 if
/// already a multiple of 16, else 0)` — i.e. always strictly greater
/// than `plaintext.len()`, never equal.
///
/// `iv` must be **unpredictable** (CSPRNG-derived) and must not be
/// reused under the same `key`. See module-doc.
///
/// # Panics
///
/// Never. The internal `try_into().expect("chunk is 16 bytes")` is
/// infallible by `chunks_exact_mut(16)`'s post-condition.
#[must_use]
#[allow(clippy::missing_panics_doc)]
pub fn encrypt(key: &[u8; KEY_SIZE], iv: &[u8; BLOCK_SIZE], plaintext: &[u8]) -> Vec<u8> {
    // PKCS#7: append `pad_len` copies of `pad_len`, where pad_len is 1..=16
    // chosen so the total length is a multiple of 16. When the input is
    // already aligned, pad_len is 16 (a full extra block).
    let pad_len = BLOCK_SIZE - (plaintext.len() % BLOCK_SIZE);
    let mut buf = Vec::with_capacity(plaintext.len() + pad_len);
    buf.extend_from_slice(plaintext);
    #[allow(clippy::cast_possible_truncation)]
    buf.resize(buf.len() + pad_len, pad_len as u8);

    let cipher = Sm4Cipher::new(key);
    let mut prev: [u8; BLOCK_SIZE] = *iv;
    for chunk in buf.chunks_exact_mut(BLOCK_SIZE) {
        let block: &mut [u8; BLOCK_SIZE] = chunk.try_into().expect("chunk is 16 bytes");
        // XOR with previous ciphertext (or IV for the first block).
        for i in 0..BLOCK_SIZE {
            block[i] ^= prev[i];
        }
        cipher.encrypt_block(block);
        prev = *block;
    }
    buf
}

/// Decrypt `ciphertext` in CBC mode and strip PKCS#7 padding.
///
/// Returns `None` on **any** failure mode (length not a multiple of 16,
/// empty input, malformed padding, inconsistent padding bytes). No
/// distinguishing variants per the project's failure-mode invariant.
///
/// Raw CBC is unauthenticated — see the module-doc's padding-oracle
/// caveat. Wrap this with HMAC-SM3 in encrypt-then-MAC before exposing
/// the result to a network attacker.
///
/// # Panics
///
/// Never. Same `chunks_exact_mut(16)` infallibility as [`encrypt`].
#[must_use]
#[allow(clippy::missing_panics_doc)]
pub fn decrypt(key: &[u8; KEY_SIZE], iv: &[u8; BLOCK_SIZE], ciphertext: &[u8]) -> Option<Vec<u8>> {
    // Length is public; reject obviously-malformed inputs.
    if ciphertext.is_empty() || ciphertext.len() % BLOCK_SIZE != 0 {
        return None;
    }

    let mut buf = ciphertext.to_vec();
    let cipher = Sm4Cipher::new(key);

    let mut prev: [u8; BLOCK_SIZE] = *iv;
    for chunk in buf.chunks_exact_mut(BLOCK_SIZE) {
        let block: &mut [u8; BLOCK_SIZE] = chunk.try_into().expect("chunk is 16 bytes");
        // Save the current ciphertext block before decryption — it
        // becomes the next round's `prev`.
        let saved = *block;
        cipher.decrypt_block(block);
        for i in 0..BLOCK_SIZE {
            block[i] ^= prev[i];
        }
        prev = saved;
    }

    // PKCS#7 strip. Constant-time scan over the final block; final
    // `Option<()>` is the one unavoidable bit of leakage.
    strip_pkcs7_ct(&mut buf).map(|()| buf)
}

/// Strip PKCS#7 padding from the final block in place. Constant-time on
/// the byte values of the final block — same amount of work regardless
/// of `pad_len`.
///
/// The final `Some(())` vs `None` is one bit of leakage; combine with
/// HMAC-SM3 + encrypt-then-MAC (W3) to keep that bit out of an
/// attacker's reach.
fn strip_pkcs7_ct(buf: &mut Vec<u8>) -> Option<()> {
    let n = buf.len();
    // Length is public — these bounds checks don't leak secrets.
    if n == 0 || n % BLOCK_SIZE != 0 {
        return None;
    }

    let last = buf[n - 1];

    // pad_len validity: 1 <= last <= 16. Express as `subtle::Choice`
    // without branching on `last`'s value.
    let pad_nonzero = !last.ct_eq(&0u8);
    #[allow(clippy::cast_possible_truncation)]
    let pad_le_block = !last.ct_gt(&(BLOCK_SIZE as u8));
    let pad_in_range = pad_nonzero & pad_le_block;

    // Scan all 16 bytes of the final block. For each byte whose
    // position-from-end is within the padding region (≤ last), it must
    // equal `last`; XOR-accumulate any disagreement into `acc`.
    let mut acc: u8 = 0;
    for i in 0..BLOCK_SIZE {
        // BLOCK_SIZE is 16 and `i` is in 0..16, so `BLOCK_SIZE - i` is
        // in 1..=16 — fits in u8 without truncation.
        #[allow(clippy::cast_possible_truncation)]
        let pos_from_end = (BLOCK_SIZE - i) as u8; // 16 down to 1
        let byte = buf[n - BLOCK_SIZE + i];
        // in_padding ≡ pos_from_end <= last ≡ !(pos_from_end > last).
        let in_padding = !pos_from_end.ct_gt(&last);
        let diff = byte ^ last;
        let masked = u8::conditional_select(&0u8, &diff, in_padding);
        acc |= masked;
    }
    let acc_zero = acc.ct_eq(&0u8);

    let valid = pad_in_range & acc_zero;
    // The final branch is unavoidable — produces `Some` vs `None`.
    if bool::from(valid) {
        buf.truncate(n - last as usize);
        Some(())
    } else {
        None
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Round-trip across boundary plaintext lengths: empty, 1 byte
    /// (under-block), 15 bytes (one byte short of a block), 16 bytes
    /// (exact block — triggers full-block padding per PKCS#7), 17 bytes
    /// (one over a block), 31 bytes, 32 bytes (two exact blocks),
    /// 100 bytes (mid-multi-block).
    #[test]
    fn round_trip_boundary_lengths() {
        let key: [u8; 16] = [0x42; 16];
        let iv: [u8; 16] = [0x33; 16];

        for len in [0usize, 1, 15, 16, 17, 31, 32, 100] {
            // `len` is at most 100, so `i as u8` doesn't truncate.
            #[allow(clippy::cast_possible_truncation)]
            let plaintext: Vec<u8> = (0..len).map(|i| (i as u8).wrapping_mul(7)).collect();
            let ciphertext = encrypt(&key, &iv, &plaintext);

            // Output is always strictly longer than input (PKCS#7 always
            // appends at least one byte; appends a full block when the
            // input is already aligned).
            assert!(
                ciphertext.len() > plaintext.len(),
                "ciphertext must be longer than plaintext for len={len}"
            );
            assert_eq!(
                ciphertext.len() % BLOCK_SIZE,
                0,
                "ciphertext must be block-aligned"
            );

            let recovered = decrypt(&key, &iv, &ciphertext).expect("decrypt must succeed");
            assert_eq!(recovered, plaintext, "round-trip mismatch at len={len}");
        }
    }

    /// Per RFC 5652 §6.3, an exact-multiple-of-16 plaintext gets a
    /// FULL extra block of `0x10` padding. Empty plaintext + the full
    /// block of `0x10` thus yields a single 16-byte ciphertext block.
    #[test]
    fn empty_plaintext_yields_one_block() {
        let key: [u8; 16] = [0x42; 16];
        let iv: [u8; 16] = [0x33; 16];
        let ciphertext = encrypt(&key, &iv, b"");
        assert_eq!(ciphertext.len(), BLOCK_SIZE, "empty PT → exactly one block");
        let recovered = decrypt(&key, &iv, &ciphertext).expect("decrypt empty");
        assert_eq!(recovered, b"");
    }

    /// Decrypt rejects ciphertext whose length isn't a multiple of 16.
    #[test]
    fn decrypt_rejects_misaligned_length() {
        let key: [u8; 16] = [0x42; 16];
        let iv: [u8; 16] = [0x33; 16];
        // 15 bytes — one short of a block.
        assert!(decrypt(&key, &iv, &[0u8; 15]).is_none());
        // 17 bytes — one past a block.
        assert!(decrypt(&key, &iv, &[0u8; 17]).is_none());
    }

    /// Decrypt rejects empty input.
    #[test]
    fn decrypt_rejects_empty() {
        let key: [u8; 16] = [0x42; 16];
        let iv: [u8; 16] = [0x33; 16];
        assert!(decrypt(&key, &iv, &[]).is_none());
    }

    /// Tampering with the final ciphertext block scrambles the
    /// PKCS#7 padding bytes; decrypt must reject.
    /// (This is NOT an integrity guarantee — earlier-block tampering
    /// generally won't reject because the padding may still be valid
    /// by luck. For real integrity, pair with HMAC-SM3 — see module-doc.)
    #[test]
    fn decrypt_rejects_tampered_final_block() {
        let key: [u8; 16] = [0x42; 16];
        let iv: [u8; 16] = [0x33; 16];
        let plaintext = b"this is a test message that spans multiple blocks";
        let mut ciphertext = encrypt(&key, &iv, plaintext);
        // Flip a bit in the last block — almost-certainly destroys the
        // PKCS#7 padding (the last byte of decrypted plaintext was
        // `pad_len`; flipping it elsewhere produces inconsistent padding).
        let last = ciphertext.len() - 1;
        ciphertext[last] ^= 0x01;
        // Most flips will destroy padding; some will land on a value
        // that happens to produce valid padding by chance (1/256). We
        // pick a flip on the very last byte, which IS the pad_len value
        // — making it a different value, which will then trigger the
        // "all padding bytes match" check failure.
        assert!(
            decrypt(&key, &iv, &ciphertext).is_none(),
            "tampered last byte must break PKCS#7"
        );
    }

    /// `strip_pkcs7_ct` cross-check: a buffer ending in a known-good
    /// PKCS#7 layout strips correctly. (This exercises the function
    /// directly without going through encrypt/decrypt round-trip.)
    #[test]
    fn strip_pkcs7_known_good() {
        // 16-byte buffer ending in 16 copies of 0x10 — valid PKCS#7
        // for an empty plaintext.
        let mut buf = alloc::vec![0x10u8; 16];
        assert!(strip_pkcs7_ct(&mut buf).is_some());
        assert_eq!(buf.len(), 0);

        // 16-byte buffer ending in `01` — valid PKCS#7 stripping one
        // byte off.
        let mut buf = alloc::vec![0u8; 16];
        buf[15] = 0x01;
        assert!(strip_pkcs7_ct(&mut buf).is_some());
        assert_eq!(buf.len(), 15);

        // 16-byte buffer ending in `04 04 04 04` (and other bytes) —
        // valid PKCS#7 stripping 4 bytes.
        let mut buf = alloc::vec![0u8; 16];
        buf[12] = 0x04;
        buf[13] = 0x04;
        buf[14] = 0x04;
        buf[15] = 0x04;
        assert!(strip_pkcs7_ct(&mut buf).is_some());
        assert_eq!(buf.len(), 12);
    }

    /// `strip_pkcs7_ct` rejection cases: `pad_len` 0, `pad_len` > 16,
    /// inconsistent padding bytes.
    #[test]
    fn strip_pkcs7_known_bad() {
        // pad_len == 0.
        let mut buf = alloc::vec![0u8; 16];
        assert!(strip_pkcs7_ct(&mut buf).is_none());

        // pad_len == 17 (out of range).
        let mut buf = alloc::vec![0u8; 16];
        buf[15] = 17;
        assert!(strip_pkcs7_ct(&mut buf).is_none());

        // pad_len = 4 but only 3 of the 4 padding bytes match.
        let mut buf = alloc::vec![0u8; 16];
        buf[12] = 0x04;
        buf[13] = 0xff; // wrong
        buf[14] = 0x04;
        buf[15] = 0x04;
        assert!(strip_pkcs7_ct(&mut buf).is_none());
    }
}