gm_rs/
sm3.rs

1use std::fmt::{Display, Formatter};
2
3pub enum Sm3Error {
4    ErrorMsgLen,
5}
6
7impl std::fmt::Debug for Sm3Error {
8    fn fmt(&self, f: &mut Formatter<'_>) -> ::std::fmt::Result {
9        write!(f, "{}", self)
10    }
11}
12
13impl From<Sm3Error> for &str {
14    fn from(e: Sm3Error) -> Self {
15        match e {
16            Sm3Error::ErrorMsgLen => "SM3 Pad error: error msg len",
17        }
18    }
19}
20
21impl Display for Sm3Error {
22    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
23        let err_msg = match self {
24            Sm3Error::ErrorMsgLen => "SM3 Pad error: error msg len",
25        };
26        write!(f, "{}", err_msg)
27    }
28}
29
30// 0 ≤ j ≤ 15
31pub(crate) const T00: u32 = 0x79cc4519;
32
33// 16 ≤ j ≤ 63
34pub(crate) const T16: u32 = 0x7a879d8a;
35
36pub(crate) static IV: [u32; 8] = [
37    0x7380166f, 0x4914b2b9, 0x172442d7, 0xda8a0600, 0xa96f30bc, 0x163138aa, 0xe38dee4d, 0xb0fb0e4e,
38];
39
40/// P0(X) = X ⊕ (X ≪ 9) ⊕ (X ≪ 17)
41fn p0(x: u32) -> u32 {
42    x ^ x.rotate_left(9) ^ x.rotate_left(17)
43}
44
45/// P1(X) = X ⊕ (X ≪ 15) ⊕ (X ≪ 23)
46fn p1(x: u32) -> u32 {
47    x ^ x.rotate_left(15) ^ x.rotate_left(23)
48}
49
50fn ff(x: u32, y: u32, z: u32, j: u32) -> u32 {
51    if j <= 15 {
52        return x ^ y ^ z;
53    } else if j >= 16 && j <= 63 {
54        return (x & y) | (x & z) | (y & z);
55    }
56    0
57}
58
59fn gg(x: u32, y: u32, z: u32, j: u32) -> u32 {
60    if j <= 15 {
61        return x ^ y ^ z;
62    } else if j >= 16 && j <= 63 {
63        return (x & y) | (!x & z);
64    }
65    0
66}
67
68fn t(j: usize) -> u32 {
69    if j <= 15 {
70        return T00;
71    } else if j >= 16 && j <= 63 {
72        return T16;
73    }
74    0
75}
76
77/// # Example
78/// ```rust
79/// use crate::gm_rs::sm3::sm3_hash;
80/// fn main(){
81///     let hash = sm3_hash(b"abc");
82///     let r = hex::encode(hash);
83///     assert_eq!("66c7f0f462eeedd9d1f2d46bdc10e4e24167c4875cf2f7a2297da02b8f4ba8e0", r);
84/// }
85///
86/// ```
87///
88pub fn sm3_hash(msg: &[u8]) -> [u8; 32] {
89    let msg = pad(msg).unwrap();
90    let len = msg.len();
91    let mut b_i: [u8; 64] = [0; 64];
92    let mut count_group: usize = 0;
93    let mut v_i = IV;
94    while count_group * 64 != len {
95        for i in (count_group * 64)..(count_group * 64 + 64) {
96            b_i[i - count_group * 64] = msg[i];
97        }
98        cf(&mut v_i, b_i);
99        count_group += 1;
100    }
101    let mut output: [u8; 32] = [0; 32];
102    for i in 0..8 {
103        output[i * 4] = (v_i[i] >> 24) as u8;
104        output[i * 4 + 1] = (v_i[i] >> 16) as u8;
105        output[i * 4 + 2] = (v_i[i] >> 8) as u8;
106        output[i * 4 + 3] = v_i[i] as u8;
107    }
108    output
109}
110
111fn cf(v_i: &mut [u32; 8], b_i: [u8; 64]) {
112    // expend msg
113    let mut w: [u32; 68] = [0; 68];
114    let mut w1: [u32; 64] = [0; 64];
115
116    // a. 将消息分组B(i)划分为16个字W0, W1, · · · , W15。
117    let mut j = 0;
118    while j <= 15 {
119        w[j] = u32::from(b_i[j * 4]) << 24
120            | u32::from(b_i[j * 4 + 1]) << 16
121            | u32::from(b_i[j * 4 + 2]) << 8
122            | u32::from(b_i[j * 4 + 3]);
123        j += 1;
124    }
125
126    // b. Wj ← P1(Wj−16 ⊕ Wj−9 ⊕ (Wj−3 ≪ 15)) ⊕ (Wj−13 ≪ 7) ⊕ Wj−6
127    j = 16;
128    while j <= 67 {
129        w[j] = p1(w[j - 16] ^ w[j - 9] ^ w[j - 3].rotate_left(15))
130            ^ w[j - 13].rotate_left(7)
131            ^ w[j - 6];
132        j += 1;
133    }
134
135    // c. Wj′ = Wj ⊕ Wj+4
136    j = 0;
137    while j <= 63 {
138        w1[j] = w[j] ^ w[j + 4];
139        j += 1;
140    }
141
142    let mut a = v_i[0];
143    let mut b = v_i[1];
144    let mut c = v_i[2];
145    let mut d = v_i[3];
146    let mut e = v_i[4];
147    let mut f = v_i[5];
148    let mut g = v_i[6];
149    let mut h = v_i[7];
150
151    for j in 0..64 {
152        let ss1 = (a
153            .rotate_left(12)
154            .wrapping_add(e)
155            .wrapping_add(t(j).rotate_left(j as u32)))
156        .rotate_left(7);
157        let ss2 = ss1 ^ (a.rotate_left(12));
158        let tt1 = ff(a, b, c, j as u32)
159            .wrapping_add(d)
160            .wrapping_add(ss2)
161            .wrapping_add(w1[j]);
162        let tt2 = gg(e, f, g, j as u32)
163            .wrapping_add(h)
164            .wrapping_add(ss1)
165            .wrapping_add(w[j]);
166        d = c;
167        c = b.rotate_left(9);
168        b = a;
169        a = tt1;
170        h = g;
171        g = f.rotate_left(19);
172        f = e;
173        e = p0(tt2);
174    }
175    v_i[0] ^= a;
176    v_i[1] ^= b;
177    v_i[2] ^= c;
178    v_i[3] ^= d;
179    v_i[4] ^= e;
180    v_i[5] ^= f;
181    v_i[6] ^= g;
182    v_i[7] ^= h;
183}
184
185fn pad(msg: &[u8]) -> Result<Vec<u8>, Sm3Error> {
186    let bit_length = (msg.len() << 3) as u64;
187    let mut msg = msg.to_vec();
188    msg.push(0x80);
189    let blocksize = 64;
190    while msg.len() % blocksize != 56 {
191        msg.push(0x00);
192    }
193    msg.push((bit_length >> 56 & 0xff) as u8);
194    msg.push((bit_length >> 48 & 0xff) as u8);
195    msg.push((bit_length >> 40 & 0xff) as u8);
196    msg.push((bit_length >> 32 & 0xff) as u8);
197    msg.push((bit_length >> 24 & 0xff) as u8);
198    msg.push((bit_length >> 16 & 0xff) as u8);
199    msg.push((bit_length >> 8 & 0xff) as u8);
200    msg.push((bit_length & 0xff) as u8);
201    if msg.len() % 64 != 0 {
202        return Err(Sm3Error::ErrorMsgLen);
203    }
204    Ok(msg)
205}
206
207#[cfg(test)]
208mod test {
209    use crate::sm3::sm3_hash;
210
211    #[test]
212    fn test_hash_1() {
213        let hash = sm3_hash(b"abc");
214        let r = hex::encode(hash);
215        assert_eq!("66c7f0f462eeedd9d1f2d46bdc10e4e24167c4875cf2f7a2297da02b8f4ba8e0", r);
216    }
217
218    #[test]
219    fn test_hash_2() {
220        let hash = sm3_hash(b"abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd");
221        let r = hex::encode(hash);
222        assert_eq!("debe9ff92275b8a138604889c18e5a4d6fdb70e5387e5765293dcba39c0c5732", r);
223    }
224}