1use std::fmt::{Display, Formatter};
2
3pub enum Sm3Error {
4 ErrorMsgLen,
5}
6
7impl std::fmt::Debug for Sm3Error {
8 fn fmt(&self, f: &mut Formatter<'_>) -> ::std::fmt::Result {
9 write!(f, "{}", self)
10 }
11}
12
13impl From<Sm3Error> for &str {
14 fn from(e: Sm3Error) -> Self {
15 match e {
16 Sm3Error::ErrorMsgLen => "SM3 Pad error: error msg len",
17 }
18 }
19}
20
21impl Display for Sm3Error {
22 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
23 let err_msg = match self {
24 Sm3Error::ErrorMsgLen => "SM3 Pad error: error msg len",
25 };
26 write!(f, "{}", err_msg)
27 }
28}
29
30pub(crate) const T00: u32 = 0x79cc4519;
32
33pub(crate) const T16: u32 = 0x7a879d8a;
35
36pub(crate) static IV: [u32; 8] = [
37 0x7380166f, 0x4914b2b9, 0x172442d7, 0xda8a0600, 0xa96f30bc, 0x163138aa, 0xe38dee4d, 0xb0fb0e4e,
38];
39
40fn p0(x: u32) -> u32 {
42 x ^ x.rotate_left(9) ^ x.rotate_left(17)
43}
44
45fn p1(x: u32) -> u32 {
47 x ^ x.rotate_left(15) ^ x.rotate_left(23)
48}
49
50fn ff(x: u32, y: u32, z: u32, j: u32) -> u32 {
51 if j <= 15 {
52 return x ^ y ^ z;
53 } else if j >= 16 && j <= 63 {
54 return (x & y) | (x & z) | (y & z);
55 }
56 0
57}
58
59fn gg(x: u32, y: u32, z: u32, j: u32) -> u32 {
60 if j <= 15 {
61 return x ^ y ^ z;
62 } else if j >= 16 && j <= 63 {
63 return (x & y) | (!x & z);
64 }
65 0
66}
67
68fn t(j: usize) -> u32 {
69 if j <= 15 {
70 return T00;
71 } else if j >= 16 && j <= 63 {
72 return T16;
73 }
74 0
75}
76
77pub fn sm3_hash(msg: &[u8]) -> [u8; 32] {
89 let msg = pad(msg).unwrap();
90 let len = msg.len();
91 let mut b_i: [u8; 64] = [0; 64];
92 let mut count_group: usize = 0;
93 let mut v_i = IV;
94 while count_group * 64 != len {
95 for i in (count_group * 64)..(count_group * 64 + 64) {
96 b_i[i - count_group * 64] = msg[i];
97 }
98 cf(&mut v_i, b_i);
99 count_group += 1;
100 }
101 let mut output: [u8; 32] = [0; 32];
102 for i in 0..8 {
103 output[i * 4] = (v_i[i] >> 24) as u8;
104 output[i * 4 + 1] = (v_i[i] >> 16) as u8;
105 output[i * 4 + 2] = (v_i[i] >> 8) as u8;
106 output[i * 4 + 3] = v_i[i] as u8;
107 }
108 output
109}
110
111fn cf(v_i: &mut [u32; 8], b_i: [u8; 64]) {
112 let mut w: [u32; 68] = [0; 68];
114 let mut w1: [u32; 64] = [0; 64];
115
116 let mut j = 0;
118 while j <= 15 {
119 w[j] = u32::from(b_i[j * 4]) << 24
120 | u32::from(b_i[j * 4 + 1]) << 16
121 | u32::from(b_i[j * 4 + 2]) << 8
122 | u32::from(b_i[j * 4 + 3]);
123 j += 1;
124 }
125
126 j = 16;
128 while j <= 67 {
129 w[j] = p1(w[j - 16] ^ w[j - 9] ^ w[j - 3].rotate_left(15))
130 ^ w[j - 13].rotate_left(7)
131 ^ w[j - 6];
132 j += 1;
133 }
134
135 j = 0;
137 while j <= 63 {
138 w1[j] = w[j] ^ w[j + 4];
139 j += 1;
140 }
141
142 let mut a = v_i[0];
143 let mut b = v_i[1];
144 let mut c = v_i[2];
145 let mut d = v_i[3];
146 let mut e = v_i[4];
147 let mut f = v_i[5];
148 let mut g = v_i[6];
149 let mut h = v_i[7];
150
151 for j in 0..64 {
152 let ss1 = (a
153 .rotate_left(12)
154 .wrapping_add(e)
155 .wrapping_add(t(j).rotate_left(j as u32)))
156 .rotate_left(7);
157 let ss2 = ss1 ^ (a.rotate_left(12));
158 let tt1 = ff(a, b, c, j as u32)
159 .wrapping_add(d)
160 .wrapping_add(ss2)
161 .wrapping_add(w1[j]);
162 let tt2 = gg(e, f, g, j as u32)
163 .wrapping_add(h)
164 .wrapping_add(ss1)
165 .wrapping_add(w[j]);
166 d = c;
167 c = b.rotate_left(9);
168 b = a;
169 a = tt1;
170 h = g;
171 g = f.rotate_left(19);
172 f = e;
173 e = p0(tt2);
174 }
175 v_i[0] ^= a;
176 v_i[1] ^= b;
177 v_i[2] ^= c;
178 v_i[3] ^= d;
179 v_i[4] ^= e;
180 v_i[5] ^= f;
181 v_i[6] ^= g;
182 v_i[7] ^= h;
183}
184
185fn pad(msg: &[u8]) -> Result<Vec<u8>, Sm3Error> {
186 let bit_length = (msg.len() << 3) as u64;
187 let mut msg = msg.to_vec();
188 msg.push(0x80);
189 let blocksize = 64;
190 while msg.len() % blocksize != 56 {
191 msg.push(0x00);
192 }
193 msg.push((bit_length >> 56 & 0xff) as u8);
194 msg.push((bit_length >> 48 & 0xff) as u8);
195 msg.push((bit_length >> 40 & 0xff) as u8);
196 msg.push((bit_length >> 32 & 0xff) as u8);
197 msg.push((bit_length >> 24 & 0xff) as u8);
198 msg.push((bit_length >> 16 & 0xff) as u8);
199 msg.push((bit_length >> 8 & 0xff) as u8);
200 msg.push((bit_length & 0xff) as u8);
201 if msg.len() % 64 != 0 {
202 return Err(Sm3Error::ErrorMsgLen);
203 }
204 Ok(msg)
205}
206
207#[cfg(test)]
208mod test {
209 use crate::sm3::sm3_hash;
210
211 #[test]
212 fn test_hash_1() {
213 let hash = sm3_hash(b"abc");
214 let r = hex::encode(hash);
215 assert_eq!("66c7f0f462eeedd9d1f2d46bdc10e4e24167c4875cf2f7a2297da02b8f4ba8e0", r);
216 }
217
218 #[test]
219 fn test_hash_2() {
220 let hash = sm3_hash(b"abcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd");
221 let r = hex::encode(hash);
222 assert_eq!("debe9ff92275b8a138604889c18e5a4d6fdb70e5387e5765293dcba39c0c5732", r);
223 }
224}