mod compress;
pub mod hkdf;
use compress::{compress, IV};
pub const DIGEST_LEN: usize = 32;
#[derive(Clone)]
pub struct Sm3Hasher {
state: [u32; 8],
buffer: [u8; 64],
buf_len: usize,
bit_len: u64,
}
impl Sm3Hasher {
pub fn new() -> Self {
Self {
state: IV,
buffer: [0u8; 64],
buf_len: 0,
bit_len: 0,
}
}
pub fn digest(data: &[u8]) -> [u8; DIGEST_LEN] {
let mut h = Self::new();
h.update(data);
h.finalize()
}
pub fn update(&mut self, data: &[u8]) {
let mut remaining = data;
if self.buf_len > 0 {
let need = 64 - self.buf_len;
let take = need.min(remaining.len());
self.buffer[self.buf_len..self.buf_len + take].copy_from_slice(&remaining[..take]);
self.buf_len += take;
remaining = &remaining[take..];
if self.buf_len == 64 {
let block: &[u8; 64] = self.buffer[..].try_into().unwrap();
compress(&mut self.state, block);
self.bit_len = self.bit_len.wrapping_add(512);
self.buf_len = 0;
}
}
while remaining.len() >= 64 {
let block: &[u8; 64] = remaining[..64].try_into().unwrap();
compress(&mut self.state, block);
self.bit_len = self.bit_len.wrapping_add(512);
remaining = &remaining[64..];
}
if !remaining.is_empty() {
self.buffer[..remaining.len()].copy_from_slice(remaining);
self.buf_len = remaining.len();
}
}
pub fn finalize(mut self) -> [u8; DIGEST_LEN] {
Self::finalize_inner(&mut self)
}
pub fn finalize_reset(&mut self) -> [u8; DIGEST_LEN] {
let out = Self::finalize_inner(self);
self.reset();
out
}
pub fn reset(&mut self) {
self.state = IV;
self.buffer = [0u8; 64];
self.buf_len = 0;
self.bit_len = 0;
}
fn finalize_inner(h: &mut Self) -> [u8; DIGEST_LEN] {
let total_bits = h.bit_len.wrapping_add((h.buf_len as u64) * 8);
h.buffer[h.buf_len] = 0x80;
h.buf_len += 1;
if h.buf_len > 56 {
for i in h.buf_len..64 {
h.buffer[i] = 0;
}
compress(&mut h.state, &h.buffer);
h.buffer = [0u8; 64];
} else {
for i in h.buf_len..56 {
h.buffer[i] = 0;
}
}
h.buffer[56..64].copy_from_slice(&total_bits.to_be_bytes());
compress(&mut h.state, &h.buffer);
let mut out = [0u8; 32];
for (i, &v) in h.state.iter().enumerate() {
out[i * 4..i * 4 + 4].copy_from_slice(&v.to_be_bytes());
}
out
}
}
impl Default for Sm3Hasher {
fn default() -> Self {
Self::new()
}
}
pub fn hmac_sm3(key: &[u8], data: &[u8]) -> [u8; DIGEST_LEN] {
use zeroize::Zeroize;
let mut k_pad = [0u8; 64];
if key.len() > 64 {
let h = Sm3Hasher::digest(key);
k_pad[..32].copy_from_slice(&h);
} else {
k_pad[..key.len()].copy_from_slice(key);
}
let mut ipad = [0u8; 64];
let mut opad = [0u8; 64];
for i in 0..64 {
ipad[i] = k_pad[i] ^ 0x36;
opad[i] = k_pad[i] ^ 0x5C;
}
let mut inner = Sm3Hasher::new();
inner.update(&ipad);
inner.update(data);
let inner_hash = inner.finalize();
let mut outer = Sm3Hasher::new();
outer.update(&opad);
outer.update(&inner_hash);
let result = outer.finalize();
k_pad.zeroize();
ipad.zeroize();
opad.zeroize();
result
}
#[derive(Clone)]
pub struct HmacSm3 {
inner: Sm3Hasher,
opad_key: [u8; 64],
}
impl HmacSm3 {
pub fn new(key: &[u8]) -> Self {
use zeroize::Zeroize;
let mut k_pad = [0u8; 64];
if key.len() > 64 {
let h = Sm3Hasher::digest(key);
k_pad[..32].copy_from_slice(&h);
} else {
k_pad[..key.len()].copy_from_slice(key);
}
let mut ipad_key = [0u8; 64];
let mut opad_key = [0u8; 64];
for i in 0..64 {
ipad_key[i] = k_pad[i] ^ 0x36;
opad_key[i] = k_pad[i] ^ 0x5C;
}
k_pad.zeroize();
let mut inner = Sm3Hasher::new();
inner.update(&ipad_key);
ipad_key.zeroize();
Self { inner, opad_key }
}
pub fn update(&mut self, data: &[u8]) {
self.inner.update(data);
}
pub fn finalize(self) -> [u8; DIGEST_LEN] {
use zeroize::Zeroize;
let inner_hash = self.inner.finalize();
let mut opad_key = self.opad_key;
let mut outer = Sm3Hasher::new();
outer.update(&opad_key);
outer.update(&inner_hash);
let result = outer.finalize();
opad_key.zeroize();
result
}
}
impl zeroize::Zeroize for HmacSm3 {
fn zeroize(&mut self) {
self.opad_key.zeroize();
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "alloc")]
extern crate alloc;
use super::*;
#[test]
fn test_sm3_vector_abc() {
let digest = Sm3Hasher::digest(b"abc");
let expected =
hex_literal("66c7f0f462eeedd9d1f2d46bdc10e4e24167c4875cf2f7a2297da02b8f4ba8e0");
assert_eq!(digest, expected, "SM3(\"abc\") 测试向量不匹配");
}
#[test]
fn test_sm3_vector_64bytes() {
let msg = b"abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd";
let digest = Sm3Hasher::digest(msg);
let expected =
hex_literal("debe9ff92275b8a138604889c18e5a4d6fdb70e5387e5765293dcba39c0c5732");
assert_eq!(digest, expected, "SM3(64字节) 测试向量不匹配");
}
#[test]
fn test_sm3_streaming_equals_onceshot() {
let data = b"hello world this is a test message for streaming";
let once = Sm3Hasher::digest(data);
let mut h = Sm3Hasher::new();
for chunk in data.chunks(7) {
h.update(chunk);
}
let streamed = h.finalize();
assert_eq!(once, streamed, "流式哈希与一次性哈希结果不一致");
}
#[test]
fn test_sm3_empty() {
let digest = Sm3Hasher::digest(b"");
let expected =
hex_literal("1ab21d8355cfa17f8e61194831e81a8f22bec8c728fefb747ed035eb5082aa2b");
assert_eq!(digest, expected, "SM3(\"\") 测试向量不匹配");
}
#[test]
fn test_hmac_sm3_basic() {
let key = b"test-key";
let data = b"test-message";
let mac1 = hmac_sm3(key, data);
let mac2 = hmac_sm3(key, data);
assert_eq!(mac1, mac2, "HMAC-SM3 应为确定性函数");
assert_eq!(mac1.len(), 32);
}
#[test]
fn test_hmac_sm3_long_key() {
let long_key = [0x42u8; 100];
let data = b"data";
let mac = hmac_sm3(&long_key, data);
assert_eq!(mac.len(), 32);
}
#[test]
fn test_reset_equals_new() {
let mut h = Sm3Hasher::new();
h.update(b"some data");
h.reset();
let digest_after_reset = h.finalize();
let digest_fresh = Sm3Hasher::digest(b"");
assert_eq!(digest_after_reset, digest_fresh);
}
#[test]
fn test_finalize_reset_correctness() {
let mut h = Sm3Hasher::new();
h.update(b"abc");
let d1 = h.finalize_reset();
assert_eq!(d1, Sm3Hasher::digest(b"abc"));
let d2 = h.finalize();
assert_eq!(d2, Sm3Hasher::digest(b""));
}
#[test]
fn test_finalize_reset_repeatable() {
let mut h = Sm3Hasher::new();
h.update(b"test");
let d1 = h.finalize_reset();
h.update(b"test");
let d2 = h.finalize_reset();
assert_eq!(d1, d2);
}
#[test]
fn test_hmac_sm3_streaming_equals_oneshot() {
let key = b"streaming-key";
let parts: &[&[u8]] = &[b"hello", b" ", b"world"];
let mut all = alloc::vec![];
for p in parts {
all.extend_from_slice(p);
}
let expected = hmac_sm3(key, &all);
let mut h = HmacSm3::new(key);
for p in parts {
h.update(p);
}
let got = h.finalize();
assert_eq!(expected, got);
}
fn hex_literal(s: &str) -> [u8; 32] {
let mut out = [0u8; 32];
let b = s.as_bytes();
for i in 0..32 {
let hi = match b[i * 2] {
c @ b'0'..=b'9' => c - b'0',
c @ b'a'..=b'f' => c - b'a' + 10,
c @ b'A'..=b'F' => c - b'A' + 10,
_ => panic!("invalid hex"),
};
let lo = match b[i * 2 + 1] {
c @ b'0'..=b'9' => c - b'0',
c @ b'a'..=b'f' => c - b'a' + 10,
c @ b'A'..=b'F' => c - b'A' + 10,
_ => panic!("invalid hex"),
};
out[i] = hi << 4 | lo;
}
out
}
}