1use std::cell::RefCell;
4use std::io::{Read, Seek, Write};
5
6use corecrypto::cipher::Ciphers;
7use corecrypto::header::{HashingAlgorithm, Header, HeaderType, Keyslot};
8use corecrypto::primitives::{Mode, ENCRYPTED_MASTER_KEY_LEN};
9use corecrypto::protected::Protected;
10use corecrypto::stream::EncryptionStreams;
11
12use crate::utils::{gen_master_key, gen_nonce, gen_salt};
13
14#[derive(Debug)]
15pub enum Error {
16    ResetCursorPosition,
17    HashKey,
18    EncryptMasterKey,
19    EncryptFile,
20    WriteHeader,
21    InitializeStreams,
22    InitializeChiphers,
23    CreateAad,
24}
25
26impl std::fmt::Display for Error {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            Error::ResetCursorPosition => f.write_str("Unable to reset cursor position"),
30            Error::HashKey => f.write_str("Cannot hash raw key"),
31            Error::EncryptMasterKey => f.write_str("Cannot encrypt master key"),
32            Error::EncryptFile => f.write_str("Cannot encrypt file"),
33            Error::WriteHeader => f.write_str("Cannot write header"),
34            Error::InitializeStreams => f.write_str("Cannot initialize streams"),
35            Error::InitializeChiphers => f.write_str("Cannot initialize chiphers"),
36            Error::CreateAad => f.write_str("Cannot create AAD"),
37        }
38    }
39}
40
41impl std::error::Error for Error {}
42
43pub struct Request<'a, R, W>
44where
45    R: Read + Seek,
46    W: Write + Seek,
47{
48    pub reader: &'a RefCell<R>,
49    pub writer: &'a RefCell<W>,
50    pub header_writer: Option<&'a RefCell<W>>,
51    pub raw_key: Protected<Vec<u8>>,
52    pub header_type: HeaderType,
54    pub hashing_algorithm: HashingAlgorithm,
55}
56
57pub fn execute<R, W>(req: Request<'_, R, W>) -> Result<(), Error>
58where
59    R: Read + Seek,
60    W: Write + Seek,
61{
62    let salt = gen_salt();
64
65    let key = req
67        .hashing_algorithm
68        .hash(req.raw_key, &salt)
69        .map_err(|_| Error::HashKey)?;
70
71    let cipher = Ciphers::initialize(key, &req.header_type.algorithm)
73        .map_err(|_| Error::InitializeChiphers)?;
74
75    let master_key = gen_master_key();
77
78    let master_key_nonce = gen_nonce(&req.header_type.algorithm, &Mode::MemoryMode);
79
80    let master_key_encrypted = {
82        let encrypted_key = cipher
83            .encrypt(master_key_nonce.as_slice(), master_key.as_slice())
84            .map_err(|_| Error::EncryptMasterKey)?;
85
86        let mut encrypted_key_arr = [0u8; ENCRYPTED_MASTER_KEY_LEN];
87        let len = ENCRYPTED_MASTER_KEY_LEN.min(encrypted_key.len());
88        encrypted_key_arr[..len].copy_from_slice(&encrypted_key[..len]);
89
90        encrypted_key_arr
91    };
92
93    let keyslot = Keyslot {
94        encrypted_key: master_key_encrypted,
95        nonce: master_key_nonce,
96        hash_algorithm: req.hashing_algorithm,
97        salt,
98    };
99
100    let keyslots = vec![keyslot];
101
102    let header_nonce = gen_nonce(&req.header_type.algorithm, &req.header_type.mode);
103    let streams =
104        EncryptionStreams::initialize(master_key, &header_nonce, &req.header_type.algorithm)
105            .map_err(|_| Error::InitializeStreams)?;
106
107    let header = Header {
108        header_type: req.header_type,
109        nonce: header_nonce,
110        salt: None,
111        keyslots: Some(keyslots),
112    };
113
114    req.writer
115        .borrow_mut()
116        .rewind()
117        .map_err(|_| Error::ResetCursorPosition)?;
118
119    match req.header_writer {
120        None => {
121            req.writer
122                .borrow_mut()
123                .write(&header.serialize().map_err(|_| Error::WriteHeader)?)
124                .map_err(|_| Error::WriteHeader)?;
125        }
126        Some(header_writer) => {
127            header_writer
128                .borrow_mut()
129                .rewind()
130                .map_err(|_| Error::ResetCursorPosition)?;
131
132            header_writer
133                .borrow_mut()
134                .write(&header.serialize().map_err(|_| Error::WriteHeader)?)
135                .map_err(|_| Error::WriteHeader)?;
136        }
137    }
138
139    let aad = header.create_aad().map_err(|_| Error::CreateAad)?;
140
141    let mut reader = req.reader.borrow_mut();
142    reader.rewind().map_err(|_| Error::ResetCursorPosition)?;
143
144    let mut writer = req.writer.borrow_mut();
145    streams
146        .encrypt_file(&mut *reader, &mut *writer, &aad)
147        .map_err(|_| Error::EncryptFile)?;
148
149    Ok(())
150}
151
152#[cfg(test)]
155pub mod tests {
156    use std::io::Cursor;
157
158    use corecrypto::header::HeaderVersion;
159    use corecrypto::primitives::Algorithm;
160
161    use super::*;
162
163    pub const PASSWORD: &[u8; 8] = b"12345678";
164
165    pub const V4_ENCRYPTED_CONTENT: [u8; 155] = [
166        222, 4, 14, 1, 12, 1, 58, 206, 16, 183, 233, 128, 23, 223, 81, 30, 214, 132, 32, 104, 51,
167        119, 173, 240, 60, 45, 230, 243, 58, 160, 69, 50, 217, 192, 66, 223, 124, 190, 148, 91, 92,
168        129, 0, 0, 0, 0, 0, 0, 147, 32, 67, 18, 249, 211, 189, 86, 187, 159, 234, 160, 94, 80, 72,
169        68, 231, 114, 132, 105, 164, 177, 26, 217, 46, 168, 97, 110, 34, 27, 13, 16, 14, 111, 3,
170        109, 218, 232, 212, 78, 188, 55, 91, 106, 97, 74, 238, 210, 173, 240, 60, 45, 230, 243, 58,
171        160, 69, 50, 217, 192, 66, 223, 124, 190, 148, 91, 92, 129, 50, 126, 110, 254, 0, 0, 0, 0,
172        0, 0, 0, 0, 14, 110, 105, 217, 74, 171, 173, 103, 11, 136, 119, 98, 145, 17, 70, 84, 144,
173        143, 154, 244, 82, 201, 85, 13, 187, 85, 89,
174    ];
175
176    pub const V5_ENCRYPTED_CONTENT: [u8; 443] = [
177        222, 5, 14, 1, 12, 1, 173, 240, 60, 45, 230, 243, 58, 160, 69, 50, 217, 192, 66, 223, 124,
178        190, 148, 91, 92, 129, 0, 0, 0, 0, 0, 0, 223, 181, 71, 240, 140, 106, 41, 36, 82, 150, 105,
179        215, 159, 108, 234, 246, 25, 19, 65, 206, 177, 146, 15, 174, 209, 129, 82, 2, 62, 76, 129,
180        34, 136, 189, 11, 98, 105, 54, 146, 71, 102, 166, 97, 177, 207, 62, 194, 132, 38, 87, 173,
181        240, 60, 45, 230, 243, 58, 160, 69, 50, 217, 192, 66, 223, 124, 190, 148, 91, 92, 129, 50,
182        126, 110, 254, 58, 206, 16, 183, 233, 128, 23, 223, 81, 30, 214, 132, 32, 104, 51, 119, 0,
183        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
184        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
185        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
186        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
187        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
188        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
189        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
190        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
191        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
192        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 110, 105, 217, 74,
193        171, 173, 103, 11, 136, 119, 172, 145, 72, 239, 74, 217, 63, 245, 222, 31, 164, 139, 146,
194        71, 165, 91,
195    ];
196
197    pub const V5_ENCRYPTED_FULL_DETACHED_CONTENT: [u8; 27] = [
198        14, 110, 105, 217, 74, 171, 173, 103, 11, 136, 119, 172, 145, 72, 239, 74, 217, 63, 245,
199        222, 31, 164, 139, 146, 71, 165, 91,
200    ];
201    pub const V5_ENCRYPTED_DETACHED_CONTENT: [u8; 443] = [
202        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
203        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
204        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
205        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
206        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
207        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
208        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
209        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
210        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
211        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
212        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
213        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
214        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
215        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 110, 105,
216        217, 74, 171, 173, 103, 11, 136, 119, 172, 145, 72, 239, 74, 217, 63, 245, 222, 31, 164,
217        139, 146, 71, 165, 91,
218    ];
219    pub const V5_ENCRYPTED_DETACHED_HEADER: [u8; 416] = [
220        222, 5, 14, 1, 12, 1, 173, 240, 60, 45, 230, 243, 58, 160, 69, 50, 217, 192, 66, 223, 124,
221        190, 148, 91, 92, 129, 0, 0, 0, 0, 0, 0, 223, 181, 71, 240, 140, 106, 41, 36, 82, 150, 105,
222        215, 159, 108, 234, 246, 25, 19, 65, 206, 177, 146, 15, 174, 209, 129, 82, 2, 62, 76, 129,
223        34, 136, 189, 11, 98, 105, 54, 146, 71, 102, 166, 97, 177, 207, 62, 194, 132, 38, 87, 173,
224        240, 60, 45, 230, 243, 58, 160, 69, 50, 217, 192, 66, 223, 124, 190, 148, 91, 92, 129, 50,
225        126, 110, 254, 58, 206, 16, 183, 233, 128, 23, 223, 81, 30, 214, 132, 32, 104, 51, 119, 0,
226        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
227        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
228        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
229        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
230        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
231        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
232        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
233        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
234        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
235        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
236    ];
237
238    #[test]
239    fn should_encrypt_content_with_v4_version() {
240        let mut input_content = b"Hello world";
241        let input_cur = RefCell::new(Cursor::new(&mut input_content));
242
243        let mut output_content = vec![];
244        let output_cur = RefCell::new(Cursor::new(&mut output_content));
245
246        let req = Request {
247            reader: &input_cur,
248            writer: &output_cur,
249            header_writer: None,
250            raw_key: Protected::new(PASSWORD.to_vec()),
251            header_type: HeaderType {
252                version: HeaderVersion::V4,
253                algorithm: Algorithm::XChaCha20Poly1305,
254                mode: Mode::StreamMode,
255            },
256            hashing_algorithm: HashingAlgorithm::Blake3Balloon(4),
257        };
258
259        match execute(req) {
260            Ok(_) => {
261                assert_eq!(output_content, V4_ENCRYPTED_CONTENT.to_vec());
262            }
263            Err(e) => {
264                println!("{e:?}");
265                unreachable!()
266            }
267        }
268    }
269
270    #[test]
271    fn should_encrypt_content_with_v5_version() {
272        let mut input_content = b"Hello world";
273        let input_cur = RefCell::new(Cursor::new(&mut input_content));
274
275        let mut output_content = vec![];
276        let output_cur = RefCell::new(Cursor::new(&mut output_content));
277
278        let req = Request {
279            reader: &input_cur,
280            writer: &output_cur,
281            header_writer: None,
282            raw_key: Protected::new(PASSWORD.to_vec()),
283            header_type: HeaderType {
284                version: HeaderVersion::V5,
285                algorithm: Algorithm::XChaCha20Poly1305,
286                mode: Mode::StreamMode,
287            },
288            hashing_algorithm: HashingAlgorithm::Blake3Balloon(5),
289        };
290
291        match execute(req) {
292            Ok(_) => {
293                assert_eq!(output_content, V5_ENCRYPTED_CONTENT.to_vec());
294            }
295            Err(e) => {
296                println!("{e:?}");
297                unreachable!()
298            }
299        }
300    }
301
302    #[test]
303    fn should_save_header_separately() {
304        let mut input_content = b"Hello world";
305        let input_cur = RefCell::new(Cursor::new(&mut input_content));
306
307        let mut output_content = vec![];
308        let output_cur = RefCell::new(Cursor::new(&mut output_content));
309
310        let mut output_header = vec![];
311        let output_header_cur = RefCell::new(Cursor::new(&mut output_header));
312
313        let req = Request {
314            reader: &input_cur,
315            writer: &output_cur,
316            header_writer: Some(&output_header_cur),
317            raw_key: Protected::new(PASSWORD.to_vec()),
318            header_type: HeaderType {
319                version: HeaderVersion::V5,
320                algorithm: Algorithm::XChaCha20Poly1305,
321                mode: Mode::StreamMode,
322            },
323            hashing_algorithm: HashingAlgorithm::Blake3Balloon(5),
324        };
325
326        match execute(req) {
327            Ok(_) => {
328                assert_eq!(output_content, V5_ENCRYPTED_FULL_DETACHED_CONTENT.to_vec());
329                assert_eq!(output_header, V5_ENCRYPTED_DETACHED_HEADER.to_vec());
330            }
331            Err(e) => {
332                println!("{e:?}");
333                unreachable!()
334            }
335        }
336    }
337}