cryptape_sm/sm4/
cipher_mode.rs

1// Copyright 2018 Cryptape Technology LLC.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::cipher::Sm4Cipher;
16
17pub enum CipherMode {
18    Cfb,
19    Ofb,
20    Ctr,
21}
22
23pub struct SM4CipherMode {
24    cipher: Sm4Cipher,
25    mode: CipherMode,
26}
27
28fn block_xor(a: &[u8], b: &[u8]) -> [u8; 16] {
29    let mut out: [u8; 16] = [0; 16];
30    for i in 0..16 {
31        out[i] = a[i] ^ b[i];
32    }
33    out
34}
35
36fn block_add_one(a: &mut [u8]) {
37    let mut t;
38    let mut carry = 1;
39
40    for i in 0..16 {
41        t = i32::from(a[15 - i]) + carry;
42        if t == 256 {
43            t = 0;
44            carry = 1;
45        } else {
46            carry = 0
47        }
48        a[15 - i] = t as u8;
49    }
50}
51
52impl SM4CipherMode {
53    pub fn new(key: &[u8], mode: CipherMode) -> SM4CipherMode {
54        let cipher = Sm4Cipher::new(key);
55        SM4CipherMode { cipher, mode }
56    }
57
58    pub fn encrypt(&self, data: &[u8], iv: &[u8]) -> Vec<u8> {
59        if iv.len() != 16 {
60            panic!("the iv of sm4 must be 16-byte long");
61        }
62        match self.mode {
63            CipherMode::Cfb => self.cfb_encrypt(data, iv),
64            CipherMode::Ofb => self.ofb_encrypt(data, iv),
65            CipherMode::Ctr => self.ctr_encrypt(data, iv),
66        }
67    }
68
69    pub fn decrypt(&self, data: &[u8], iv: &[u8]) -> Vec<u8> {
70        if iv.len() != 16 {
71            panic!("the iv of sm4 must be 16-byte long");
72        }
73        match self.mode {
74            CipherMode::Cfb => self.cfb_decrypt(data, iv),
75            CipherMode::Ofb => self.ofb_encrypt(data, iv),
76            CipherMode::Ctr => self.ctr_encrypt(data, iv),
77        }
78    }
79
80    fn cfb_encrypt(&self, data: &[u8], iv: &[u8]) -> Vec<u8> {
81        let block_num = data.len() / 16;
82        let tail_len = data.len() - block_num * 16;
83
84        let mut out: Vec<u8> = Vec::new();
85        let mut vec_buf: Vec<u8> = vec![0; 16];
86        vec_buf.clone_from_slice(iv);
87
88        // Normal
89        for i in 0..block_num {
90            let enc = self.cipher.encrypt(&vec_buf[..]);
91            let ct = block_xor(&enc, &data[i * 16..i * 16 + 16]);
92            for i in ct.iter() {
93                out.push(*i);
94            }
95            vec_buf.clone_from_slice(&ct);
96        }
97
98        // Last block
99        let enc = self.cipher.encrypt(&vec_buf[..]);
100        for i in 0..tail_len {
101            let b = data[block_num * 16 + i] ^ enc[i];
102            out.push(b);
103        }
104        out
105    }
106
107    fn cfb_decrypt(&self, data: &[u8], iv: &[u8]) -> Vec<u8> {
108        let block_num = data.len() / 16;
109        let tail_len = data.len() - block_num * 16;
110
111        let mut out: Vec<u8> = Vec::new();
112        let mut vec_buf: Vec<u8> = vec![0; 16];
113        vec_buf.clone_from_slice(iv);
114
115        // Normal
116        for i in 0..block_num {
117            let enc = self.cipher.encrypt(&vec_buf[..]);
118            let ct = &data[i * 16..i * 16 + 16];
119            let pt = block_xor(&enc, ct);
120            for i in pt.iter() {
121                out.push(*i);
122            }
123            vec_buf.clone_from_slice(ct);
124        }
125
126        // Last block
127        let enc = self.cipher.encrypt(&vec_buf[..]);
128        for i in 0..tail_len {
129            let b = data[block_num * 16 + i] ^ enc[i];
130            out.push(b);
131        }
132        out
133    }
134
135    fn ofb_encrypt(&self, data: &[u8], iv: &[u8]) -> Vec<u8> {
136        let block_num = data.len() / 16;
137        let tail_len = data.len() - block_num * 16;
138
139        let mut out: Vec<u8> = Vec::new();
140        let mut vec_buf: Vec<u8> = vec![0; 16];
141        vec_buf.clone_from_slice(iv);
142
143        // Normal
144        for i in 0..block_num {
145            let enc = self.cipher.encrypt(&vec_buf[..]);
146            let ct = block_xor(&enc, &data[i * 16..i * 16 + 16]);
147            for i in ct.iter() {
148                out.push(*i);
149            }
150            vec_buf.clone_from_slice(&enc);
151        }
152
153        // Last block
154        let enc = self.cipher.encrypt(&vec_buf[..]);
155        for i in 0..tail_len {
156            let b = data[block_num * 16 + i] ^ enc[i];
157            out.push(b);
158        }
159        out
160    }
161
162    fn ctr_encrypt(&self, data: &[u8], iv: &[u8]) -> Vec<u8> {
163        let mut vec_buf: Vec<u8> = vec![0; 16];
164        vec_buf.resize(16, 0);
165        vec_buf.clone_from_slice(iv);
166        let mut out: Vec<u8> = Vec::new();
167
168        let block_num = data.len() / 16;
169        let tail_len = data.len() - block_num * 16;
170
171        // Normal
172        for i in 0..block_num {
173            let enc = self.cipher.encrypt(&vec_buf[..]);
174            let ct = block_xor(&enc, &data[i * 16..i * 16 + 16]);
175            for i in ct.iter() {
176                out.push(*i);
177            }
178            block_add_one(&mut vec_buf[..]);
179        }
180
181        // Last block
182        let enc = self.cipher.encrypt(&vec_buf[..]);
183        for i in 0..tail_len {
184            let b = data[block_num * 16 + i] ^ enc[i];
185            out.push(b);
186        }
187        out
188    }
189}
190
191// TODO: AEAD in SM4
192// pub struct SM4Gcm;
193
194// Tests below
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    use rand::os::OsRng;
201    use rand::Rng;
202
203    fn rand_block() -> [u8; 16] {
204        let mut rng = OsRng::new().unwrap();
205        let mut block: [u8; 16] = [0; 16];
206        rng.fill_bytes(&mut block[..]);
207        block
208    }
209
210    fn rand_data(len: usize) -> Vec<u8> {
211        let mut rng = OsRng::new().unwrap();
212        let mut dat: Vec<u8> = Vec::new();
213        dat.resize(len, 0);
214        rng.fill_bytes(&mut dat[..]);
215        dat
216    }
217
218    #[test]
219    fn test_driver() {
220        test_ciphermode(CipherMode::Ctr);
221        test_ciphermode(CipherMode::Cfb);
222        test_ciphermode(CipherMode::Ofb);
223    }
224
225    fn test_ciphermode(mode: CipherMode) {
226        let key = rand_block();
227        let iv = rand_block();
228
229        let cmode = SM4CipherMode::new(&key, mode);
230
231        let pt = rand_data(10);
232        let ct = cmode.encrypt(&pt[..], &iv);
233        let new_pt = cmode.decrypt(&ct[..], &iv);
234        assert_eq!(pt, new_pt);
235
236        let pt = rand_data(100);
237        let ct = cmode.encrypt(&pt[..], &iv);
238        let new_pt = cmode.decrypt(&ct[..], &iv);
239        assert_eq!(pt, new_pt);
240
241        let pt = rand_data(1000);
242        let ct = cmode.encrypt(&pt[..], &iv);
243        let new_pt = cmode.decrypt(&ct[..], &iv);
244        assert_eq!(pt, new_pt);
245    }
246}