cryptape_sm/sm4/
cipher_mode.rs1use super::cipher::Sm4Cipher;
16
17pub enum CipherMode {
18 Cfb,
19 Ofb,
20 Ctr,
21}
22
23pub struct SM4CipherMode {
24 cipher: Sm4Cipher,
25 mode: CipherMode,
26}
27
28fn block_xor(a: &[u8], b: &[u8]) -> [u8; 16] {
29 let mut out: [u8; 16] = [0; 16];
30 for i in 0..16 {
31 out[i] = a[i] ^ b[i];
32 }
33 out
34}
35
36fn block_add_one(a: &mut [u8]) {
37 let mut t;
38 let mut carry = 1;
39
40 for i in 0..16 {
41 t = i32::from(a[15 - i]) + carry;
42 if t == 256 {
43 t = 0;
44 carry = 1;
45 } else {
46 carry = 0
47 }
48 a[15 - i] = t as u8;
49 }
50}
51
52impl SM4CipherMode {
53 pub fn new(key: &[u8], mode: CipherMode) -> SM4CipherMode {
54 let cipher = Sm4Cipher::new(key);
55 SM4CipherMode { cipher, mode }
56 }
57
58 pub fn encrypt(&self, data: &[u8], iv: &[u8]) -> Vec<u8> {
59 if iv.len() != 16 {
60 panic!("the iv of sm4 must be 16-byte long");
61 }
62 match self.mode {
63 CipherMode::Cfb => self.cfb_encrypt(data, iv),
64 CipherMode::Ofb => self.ofb_encrypt(data, iv),
65 CipherMode::Ctr => self.ctr_encrypt(data, iv),
66 }
67 }
68
69 pub fn decrypt(&self, data: &[u8], iv: &[u8]) -> Vec<u8> {
70 if iv.len() != 16 {
71 panic!("the iv of sm4 must be 16-byte long");
72 }
73 match self.mode {
74 CipherMode::Cfb => self.cfb_decrypt(data, iv),
75 CipherMode::Ofb => self.ofb_encrypt(data, iv),
76 CipherMode::Ctr => self.ctr_encrypt(data, iv),
77 }
78 }
79
80 fn cfb_encrypt(&self, data: &[u8], iv: &[u8]) -> Vec<u8> {
81 let block_num = data.len() / 16;
82 let tail_len = data.len() - block_num * 16;
83
84 let mut out: Vec<u8> = Vec::new();
85 let mut vec_buf: Vec<u8> = vec![0; 16];
86 vec_buf.clone_from_slice(iv);
87
88 for i in 0..block_num {
90 let enc = self.cipher.encrypt(&vec_buf[..]);
91 let ct = block_xor(&enc, &data[i * 16..i * 16 + 16]);
92 for i in ct.iter() {
93 out.push(*i);
94 }
95 vec_buf.clone_from_slice(&ct);
96 }
97
98 let enc = self.cipher.encrypt(&vec_buf[..]);
100 for i in 0..tail_len {
101 let b = data[block_num * 16 + i] ^ enc[i];
102 out.push(b);
103 }
104 out
105 }
106
107 fn cfb_decrypt(&self, data: &[u8], iv: &[u8]) -> Vec<u8> {
108 let block_num = data.len() / 16;
109 let tail_len = data.len() - block_num * 16;
110
111 let mut out: Vec<u8> = Vec::new();
112 let mut vec_buf: Vec<u8> = vec![0; 16];
113 vec_buf.clone_from_slice(iv);
114
115 for i in 0..block_num {
117 let enc = self.cipher.encrypt(&vec_buf[..]);
118 let ct = &data[i * 16..i * 16 + 16];
119 let pt = block_xor(&enc, ct);
120 for i in pt.iter() {
121 out.push(*i);
122 }
123 vec_buf.clone_from_slice(ct);
124 }
125
126 let enc = self.cipher.encrypt(&vec_buf[..]);
128 for i in 0..tail_len {
129 let b = data[block_num * 16 + i] ^ enc[i];
130 out.push(b);
131 }
132 out
133 }
134
135 fn ofb_encrypt(&self, data: &[u8], iv: &[u8]) -> Vec<u8> {
136 let block_num = data.len() / 16;
137 let tail_len = data.len() - block_num * 16;
138
139 let mut out: Vec<u8> = Vec::new();
140 let mut vec_buf: Vec<u8> = vec![0; 16];
141 vec_buf.clone_from_slice(iv);
142
143 for i in 0..block_num {
145 let enc = self.cipher.encrypt(&vec_buf[..]);
146 let ct = block_xor(&enc, &data[i * 16..i * 16 + 16]);
147 for i in ct.iter() {
148 out.push(*i);
149 }
150 vec_buf.clone_from_slice(&enc);
151 }
152
153 let enc = self.cipher.encrypt(&vec_buf[..]);
155 for i in 0..tail_len {
156 let b = data[block_num * 16 + i] ^ enc[i];
157 out.push(b);
158 }
159 out
160 }
161
162 fn ctr_encrypt(&self, data: &[u8], iv: &[u8]) -> Vec<u8> {
163 let mut vec_buf: Vec<u8> = vec![0; 16];
164 vec_buf.resize(16, 0);
165 vec_buf.clone_from_slice(iv);
166 let mut out: Vec<u8> = Vec::new();
167
168 let block_num = data.len() / 16;
169 let tail_len = data.len() - block_num * 16;
170
171 for i in 0..block_num {
173 let enc = self.cipher.encrypt(&vec_buf[..]);
174 let ct = block_xor(&enc, &data[i * 16..i * 16 + 16]);
175 for i in ct.iter() {
176 out.push(*i);
177 }
178 block_add_one(&mut vec_buf[..]);
179 }
180
181 let enc = self.cipher.encrypt(&vec_buf[..]);
183 for i in 0..tail_len {
184 let b = data[block_num * 16 + i] ^ enc[i];
185 out.push(b);
186 }
187 out
188 }
189}
190
191#[cfg(test)]
197mod tests {
198 use super::*;
199
200 use rand::os::OsRng;
201 use rand::Rng;
202
203 fn rand_block() -> [u8; 16] {
204 let mut rng = OsRng::new().unwrap();
205 let mut block: [u8; 16] = [0; 16];
206 rng.fill_bytes(&mut block[..]);
207 block
208 }
209
210 fn rand_data(len: usize) -> Vec<u8> {
211 let mut rng = OsRng::new().unwrap();
212 let mut dat: Vec<u8> = Vec::new();
213 dat.resize(len, 0);
214 rng.fill_bytes(&mut dat[..]);
215 dat
216 }
217
218 #[test]
219 fn test_driver() {
220 test_ciphermode(CipherMode::Ctr);
221 test_ciphermode(CipherMode::Cfb);
222 test_ciphermode(CipherMode::Ofb);
223 }
224
225 fn test_ciphermode(mode: CipherMode) {
226 let key = rand_block();
227 let iv = rand_block();
228
229 let cmode = SM4CipherMode::new(&key, mode);
230
231 let pt = rand_data(10);
232 let ct = cmode.encrypt(&pt[..], &iv);
233 let new_pt = cmode.decrypt(&ct[..], &iv);
234 assert_eq!(pt, new_pt);
235
236 let pt = rand_data(100);
237 let ct = cmode.encrypt(&pt[..], &iv);
238 let new_pt = cmode.decrypt(&ct[..], &iv);
239 assert_eq!(pt, new_pt);
240
241 let pt = rand_data(1000);
242 let ct = cmode.encrypt(&pt[..], &iv);
243 let new_pt = cmode.decrypt(&ct[..], &iv);
244 assert_eq!(pt, new_pt);
245 }
246}