sevenz_rust2/
aes256sha256.rs

1use std::io::{Read, Seek, Write};
2
3#[cfg(feature = "compress")]
4pub use self::enc::*;
5use crate::Password;
6use aes::cipher::{BlockDecryptMut, BlockEncryptMut, KeyIvInit, generic_array::GenericArray};
7use lzma_rust2::CountingWriter;
8use sha2::Digest;
9
10type Aes256CbcDec = cbc::Decryptor<aes::Aes256>;
11
12#[cfg_attr(docsrs, doc(cfg(feature = "aes256")))]
13pub struct Aes256Sha256Decoder<R> {
14    cipher: Cipher,
15    input: R,
16    done: bool,
17    obuffer: Vec<u8>,
18    ostart: usize,
19    ofinish: usize,
20    pos: usize,
21}
22
23impl<R: Read> Aes256Sha256Decoder<R> {
24    pub fn new(input: R, properties: &[u8], password: &[u8]) -> Result<Self, crate::Error> {
25        let cipher = Cipher::from_properties(properties, password)?;
26        Ok(Self {
27            input,
28            cipher,
29            done: false,
30            obuffer: Default::default(),
31            ostart: 0,
32            ofinish: 0,
33            pos: 0,
34        })
35    }
36
37    fn get_more_data(&mut self) -> std::io::Result<usize> {
38        if self.done {
39            Ok(0)
40        } else {
41            self.ofinish = 0;
42            self.ostart = 0;
43            self.obuffer.clear();
44            let mut ibuffer = [0; 512];
45            let readin = self.input.read(&mut ibuffer)?;
46            if readin == 0 {
47                self.done = true;
48                self.ofinish = self.cipher.do_final(&mut self.obuffer)?;
49                Ok(self.ofinish)
50            } else {
51                let n = self
52                    .cipher
53                    .update(&mut ibuffer[..readin], &mut self.obuffer)?;
54                self.ofinish = n;
55                Ok(n)
56            }
57        }
58    }
59}
60
61impl<R: Read> Read for Aes256Sha256Decoder<R> {
62    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
63        if self.ostart >= self.ofinish {
64            let mut n: usize;
65            n = self.get_more_data()?;
66            while n == 0 && !self.done {
67                n = self.get_more_data()?;
68            }
69            if n == 0 {
70                return Ok(0);
71            }
72        }
73
74        if buf.is_empty() {
75            return Ok(0);
76        }
77        let buf_len = self.ofinish - self.ostart;
78        let size = buf_len.min(buf.len());
79        buf[..size].copy_from_slice(&self.obuffer[self.ostart..self.ostart + size]);
80        self.ostart += size;
81        self.pos += size;
82        Ok(size)
83    }
84}
85
86impl<R: Read + Seek> Seek for Aes256Sha256Decoder<R> {
87    fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
88        let len = self.ofinish - self.ostart;
89        match pos {
90            std::io::SeekFrom::Start(p) => {
91                let n = (p as i64 - self.pos as i64).min(len as i64);
92
93                if n < 0 {
94                    Ok(0)
95                } else {
96                    self.ostart += n as usize;
97                    Ok(p)
98                }
99            }
100            std::io::SeekFrom::End(_) => Err(std::io::Error::new(
101                std::io::ErrorKind::Unsupported,
102                "Aes256 decoder unsupport seek from end",
103            )),
104            std::io::SeekFrom::Current(n) => {
105                let n = n.min(len as i64);
106                if n < 0 {
107                    Ok(0)
108                } else {
109                    self.ostart += n as usize;
110                    Ok(self.pos as u64 + n as u64)
111                }
112            }
113        }
114    }
115}
116
117fn get_aes_key(properties: &[u8], password: &[u8]) -> Result<([u8; 32], [u8; 16]), crate::Error> {
118    if properties.len() < 2 {
119        return Err(crate::Error::other("AES256 properties too shart"));
120    }
121    let b0 = properties[0];
122    let num_cycles_power = b0 & 63;
123    let b1 = properties[1];
124    let iv_size = (((b0 >> 6) & 1) + (b1 & 15)) as usize;
125    let salt_size = (((b0 >> 7) & 1) + (b1 >> 4)) as usize;
126    if 2 + salt_size + iv_size > properties.len() {
127        return Err(crate::Error::other("Salt size + IV size too long"));
128    }
129    let mut salt = vec![0u8; salt_size];
130    salt.copy_from_slice(&properties[2..(2 + salt_size)]);
131    let mut iv = [0u8; 16];
132    iv[0..iv_size].copy_from_slice(&properties[(2 + salt_size)..(2 + salt_size + iv_size)]);
133    if password.is_empty() {
134        return Err(crate::Error::PasswordRequired);
135    }
136    let aes_key = if num_cycles_power == 0x3f {
137        let mut aes_key = [0u8; 32];
138        aes_key.copy_from_slice(&salt[..salt_size]);
139        let n = password.len().min(aes_key.len() - salt_size);
140        aes_key[salt_size..n + salt_size].copy_from_slice(&password[0..n]);
141        aes_key
142    } else {
143        let mut sha = sha2::Sha256::default();
144        let mut extra = [0u8; 8];
145        for _ in 0..(1u32 << num_cycles_power) {
146            sha.update(&salt);
147            sha.update(password);
148            sha.update(extra);
149            for item in &mut extra {
150                *item = item.wrapping_add(1);
151                if *item != 0 {
152                    break;
153                }
154            }
155        }
156        sha.finalize().into()
157    };
158    Ok((aes_key, iv))
159}
160
161struct Cipher {
162    dec: Aes256CbcDec,
163    buf: Vec<u8>,
164}
165
166impl Cipher {
167    fn from_properties(properties: &[u8], password: &[u8]) -> Result<Self, crate::Error> {
168        let (aes_key, iv) = get_aes_key(properties, password)?;
169        Ok(Self {
170            dec: Aes256CbcDec::new(&GenericArray::from(aes_key), &iv.into()),
171            buf: Default::default(),
172        })
173    }
174
175    fn update<W: Write>(&mut self, mut data: &mut [u8], mut output: W) -> std::io::Result<usize> {
176        let mut n = 0;
177        if !self.buf.is_empty() {
178            assert!(self.buf.len() < 16);
179            let end = 16 - self.buf.len();
180            self.buf.extend_from_slice(&data[..end]);
181            data = &mut data[end..];
182            let block = GenericArray::from_mut_slice(&mut self.buf);
183            self.dec.decrypt_block_mut(block);
184            let out = block.as_slice();
185            output.write_all(out)?;
186            n += out.len();
187            self.buf.clear();
188        }
189
190        for a in data.chunks_mut(16) {
191            if a.len() < 16 {
192                self.buf.extend_from_slice(a);
193                break;
194            }
195            let block = GenericArray::from_mut_slice(a);
196            self.dec.decrypt_block_mut(block);
197            let out = block.as_slice();
198            output.write_all(out)?;
199            n += out.len();
200        }
201        Ok(n)
202    }
203
204    fn do_final(&mut self, output: &mut Vec<u8>) -> std::io::Result<usize> {
205        if self.buf.is_empty() {
206            output.clear();
207            Ok(0)
208        } else {
209            Err(std::io::Error::new(
210                std::io::ErrorKind::InvalidData,
211                "IllegalBlockSize",
212            ))
213        }
214    }
215}
216#[cfg(feature = "compress")]
217mod enc {
218    type Aes256CbcEnc = cbc::Encryptor<aes::Aes256>;
219    use super::*;
220
221    #[cfg_attr(docsrs, doc(cfg(feature = "aes256")))]
222    pub struct Aes256Sha256Encoder<W> {
223        output: CountingWriter<W>,
224        enc: Aes256CbcEnc,
225        buffer: Vec<u8>,
226        done: bool,
227        write_size: u32,
228    }
229
230    #[cfg_attr(docsrs, doc(cfg(feature = "aes256")))]
231    #[derive(Debug, Clone)]
232    pub struct AesEncoderOptions {
233        pub password: Password,
234        pub iv: [u8; 16],
235        pub salt: [u8; 16],
236        pub num_cycles_power: u8,
237    }
238
239    impl AesEncoderOptions {
240        pub fn new(password: Password) -> Self {
241            let mut iv = [0; 16];
242            getrandom::fill(&mut iv).expect("Can't generate IV");
243
244            let mut salt = [0; 16];
245            getrandom::fill(&mut salt).expect("Can't generate salt");
246
247            Self {
248                password,
249                iv,
250                salt,
251                num_cycles_power: 8,
252            }
253        }
254
255        pub fn properties(&self) -> [u8; 34] {
256            let mut props = [0u8; 34];
257            self.write_properties(&mut props);
258            props
259        }
260
261        #[inline]
262        pub fn write_properties(&self, props: &mut [u8]) {
263            assert!(props.len() >= 34);
264            props[0] = (self.num_cycles_power & 0x3f) | 0xc0;
265            props[1] = 0xff;
266            props[2..18].copy_from_slice(&self.salt);
267            props[18..34].copy_from_slice(&self.iv);
268        }
269    }
270
271    impl<W> Aes256Sha256Encoder<W> {
272        pub fn new(
273            output: CountingWriter<W>,
274            options: &AesEncoderOptions,
275        ) -> Result<Self, crate::Error> {
276            let (key, iv) = get_aes_key(&options.properties(), options.password.as_slice())?;
277
278            Ok(Self {
279                output,
280                enc: Aes256CbcEnc::new(&GenericArray::from(key), &iv.into()),
281                buffer: Default::default(),
282                done: false,
283                write_size: 0,
284            })
285        }
286
287        #[inline(always)]
288        fn write_block(&mut self, block: &mut [u8]) -> std::io::Result<()>
289        where
290            W: Write,
291        {
292            let block2 = GenericArray::from_mut_slice(block);
293            self.enc.encrypt_block_mut(block2);
294            self.output.write_all(block)?;
295            self.write_size += block.len() as u32;
296            Ok(())
297        }
298    }
299
300    impl<W: Write> Write for Aes256Sha256Encoder<W> {
301        fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
302            if self.done && !buf.is_empty() {
303                return Ok(0);
304            }
305            if buf.is_empty() {
306                self.done = true;
307                self.flush()?;
308                return self.output.write(buf);
309            }
310            let len = buf.len();
311            if !self.buffer.is_empty() {
312                assert!(self.buffer.len() < 16);
313                if buf.len() + self.buffer.len() >= 16 {
314                    let buffer = &self.buffer[..];
315                    let end = 16 - buffer.len();
316
317                    let mut block = [0u8; 16];
318                    block[0..buffer.len()].copy_from_slice(buffer);
319                    block[buffer.len()..16].copy_from_slice(&buf[..end]);
320                    self.write_block(&mut block)?;
321                    self.buffer.clear();
322                    buf = &buf[end..];
323                } else {
324                    self.buffer.extend_from_slice(buf);
325                    return Ok(len);
326                }
327            }
328
329            for data in buf.chunks(16) {
330                if data.len() < 16 {
331                    self.buffer.extend_from_slice(data);
332                    break;
333                }
334                let mut block = [0u8; 16];
335                block.copy_from_slice(data);
336                self.write_block(&mut block)?;
337            }
338
339            Ok(len)
340        }
341
342        fn flush(&mut self) -> std::io::Result<()> {
343            if !self.buffer.is_empty() && self.done {
344                assert!(self.buffer.len() < 16);
345                let mut block = [0u8; 16];
346                block[..self.buffer.len()].copy_from_slice(&self.buffer);
347                self.write_block(&mut block)?;
348                self.buffer.clear();
349            }
350            Ok(())
351        }
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[allow(clippy::unused_io_amount)]
360    #[cfg(feature = "compress")]
361    #[test]
362    fn test_aes_codec() {
363        let mut encoded = vec![];
364        let writer = CountingWriter::new(&mut encoded);
365        let pwd: Password = "1234".into();
366        let options = AesEncoderOptions::new(pwd.clone());
367        let mut enc = Aes256Sha256Encoder::new(writer, &options).unwrap();
368        let original = include_bytes!("./aes256sha256.rs");
369        enc.write_all(original).expect("encode data");
370        enc.write(&[]).unwrap();
371
372        let mut encoded_data = &encoded[..];
373        let mut dec =
374            Aes256Sha256Decoder::new(&mut encoded_data, &options.properties(), pwd.as_slice())
375                .unwrap();
376
377        let mut decoded = vec![];
378        std::io::copy(&mut dec, &mut decoded).unwrap();
379        assert_eq!(&decoded[..original.len()], &original[..]);
380    }
381}