1use std::io::{Read, Write};
7
8use zeroize::Zeroizing;
9
10use crate::errors::{DecryptError, FileFormatError};
11use crate::{chapoly_decrypt_noise, hkdf_sha256, noise_decrypt, scrypt, PrivateKey, PublicKey};
12use crate::{AsymFileFormat, FileFormat, PassFileFormat};
13use crate::{CHUNK_SIZE, SCRYPT_N, SCRYPT_P, SCRYPT_R, TAG_SIZE};
14
15pub fn key_decrypt<T: Read, U: Write>(
17    ciphertext: &mut T,
18    plaintext: &mut U,
19    recipient: &PrivateKey,
20    recipient_public: &PublicKey,
21    file_format: AsymFileFormat,
22) -> Result<PublicKey, DecryptError> {
23    if file_format != AsymFileFormat::V1 {
24        return Err(DecryptError::Other(
25            "File format not supported. This may be your plaintext.".into(),
26        ));
27    }
28
29    let mut prologue = [0u8; 4];
30    ciphertext.read_exact(&mut prologue).map_err(read_err)?;
31    let file_format = valid_file_format(&prologue)?;
32    if file_format == FileFormat::PassV1 {
33        return Err(DecryptError::Other(
34            "This is a password encrypted file. Try password decrypt instread.".into(),
35        ));
36    }
37
38    let mut handshake_message = [0u8; 128];
39    ciphertext
40        .read_exact(&mut handshake_message)
41        .map_err(read_err)?;
42
43    let noise_message = noise_decrypt(recipient, recipient_public, &prologue, &handshake_message)
44        .map_err(|e| DecryptError::Other(e.to_string()))?;
45
46    let file_encryption_key = hkdf_sha256(
47        &[],
48        &noise_message.payload_key.as_bytes(),
49        &noise_message.handshake_hash,
50        32,
51    );
52    let file_encryption_key = Zeroizing::new(file_encryption_key);
53
54    decrypt_chunks(ciphertext, plaintext, &file_encryption_key, &[], CHUNK_SIZE)?;
55
56    let public_key = noise_message.public_key.clone();
57
58    Ok(public_key)
59}
60
61pub fn pass_decrypt<T: Read, U: Write>(
63    ciphertext: &mut T,
64    plaintext: &mut U,
65    password: &[u8],
66    file_format: PassFileFormat,
67) -> Result<(), DecryptError> {
68    if file_format != PassFileFormat::V1 {
69        return Err(DecryptError::Other(
70            "File format not supported. This may be your plaintext.".into(),
71        ));
72    }
73
74    let mut pass_magic_num = [0u8; 4];
75    ciphertext
76        .read_exact(&mut pass_magic_num)
77        .map_err(read_err)?;
78    let file_format = valid_file_format(&pass_magic_num)?;
79    if file_format == FileFormat::AsymV1 {
80        return Err(DecryptError::Other(
81            "This is a key encrypted file. Try decrypt instread.".into(),
82        ));
83    }
84
85    let mut salt = [0u8; 32];
86    ciphertext.read_exact(&mut salt).map_err(read_err)?;
87
88    let key = scrypt(password, &salt, SCRYPT_N, SCRYPT_R, SCRYPT_P, 32);
89    let key = Zeroizing::new(key);
90    let aad = &pass_magic_num[..];
91
92    decrypt_chunks(ciphertext, plaintext, &key, aad, CHUNK_SIZE)?;
93
94    Ok(())
95}
96
97fn decrypt_chunks<T: Read, U: Write>(
105    ciphertext: &mut T,
106    plaintext: &mut U,
107    key: &[u8],
108    aad: &[u8],
109    chunk_size: u32,
110) -> Result<(), DecryptError> {
111    let mut chunk_number: u64 = 0;
112    let mut done = false;
113    let cs: usize = chunk_size.try_into().unwrap();
114    let mut buffer = vec![0; cs + TAG_SIZE];
115    let mut auth_data = vec![0u8; aad.len() + 8];
116
117    loop {
118        let mut chunk_header = [0u8; 16];
119        ciphertext.read_exact(&mut chunk_header).map_err(read_err)?;
120        let last_chunk_indicator_bytes: [u8; 4] = chunk_header[8..12].try_into().unwrap();
121        let ciphertext_length_bytes: [u8; 4] = chunk_header[12..].try_into().unwrap();
122        let last_chunk_indicator = u32::from_be_bytes(last_chunk_indicator_bytes);
123        let ciphertext_length = u32::from_be_bytes(ciphertext_length_bytes);
124        if ciphertext_length > chunk_size {
125            return Err(DecryptError::ChunkLen);
126        }
127
128        let ct_len: usize = ciphertext_length.try_into().unwrap();
129        ciphertext
130            .read_exact(&mut buffer[..ct_len + TAG_SIZE])
131            .map_err(read_err)?;
132
133        let aad_len = aad.len();
134        auth_data[..aad_len].copy_from_slice(aad);
135        auth_data[aad_len..aad_len + 4].copy_from_slice(&last_chunk_indicator_bytes);
136        auth_data[aad_len + 4..].copy_from_slice(&ciphertext_length_bytes);
137
138        let ct = &buffer[..ct_len + TAG_SIZE];
139        let pt_chunk = chapoly_decrypt_noise(&key, chunk_number, auth_data.as_slice(), ct)?;
140
141        if last_chunk_indicator == 1 {
146            done = true;
147            let check = ciphertext.read(&mut [0u8; 1]).map_err(read_err)?;
153            if check != 0 {
154                return Err(DecryptError::UnexpectedData);
157            }
158        }
159
160        plaintext
161            .write_all(pt_chunk.as_slice())
162            .map_err(write_err)?;
163        plaintext.flush().map_err(write_err)?;
164
165        if done {
166            break;
167        }
168
169        chunk_number += 1;
173    }
174
175    Ok(())
176}
177
178pub fn valid_file_format(header: &[u8]) -> Result<FileFormat, FileFormatError> {
180    let asym_v1 = [0x65, 0x67, 0x6b, 0x10];
181    let pass_v1 = [0x65, 0x67, 0x6b, 0x20];
182
183    if header == asym_v1 {
184        return Ok(FileFormat::AsymV1);
185    } else if header == pass_v1 {
186        return Ok(FileFormat::PassV1);
187    }
188
189    Err(FileFormatError)
190}
191
192fn read_err(err: std::io::Error) -> DecryptError {
193    use std::io::ErrorKind;
194
195    match err.kind() {
196        ErrorKind::UnexpectedEof => DecryptError::IORead(std::io::Error::new(
197            ErrorKind::Other,
198            "Did not read enough data.",
199        )),
200        _ => DecryptError::IORead(err),
201    }
202}
203
204fn write_err(err: std::io::Error) -> DecryptError {
205    DecryptError::IOWrite(err)
206}
207
208#[cfg(test)]
209mod tests {
210    use super::CHUNK_SIZE;
211    use super::{key_decrypt, pass_decrypt};
212    use super::{PrivateKey, PublicKey};
213    use crate::encrypt::{key_encrypt, pass_encrypt};
214    use crate::sha256;
215    use crate::{AsymFileFormat, PassFileFormat, PayloadKey};
216    use std::io::Read;
217
218    #[allow(dead_code)]
219    struct KeyData {
220        alice_private: PrivateKey,
221        alice_public: PublicKey,
222        bob_private: PrivateKey,
223        bob_public: PublicKey,
224    }
225
226    #[test]
227    fn test_decrypt_small() {
228        let expected_plaintext = b"Hello, world!";
229        let key_data = get_key_data();
230        let expected_sender = key_data.alice_public;
231        let recipient = key_data.bob_private;
232        let recipient_public = key_data.bob_public;
233        let ciphertext = encrypt_small_util();
234        let mut plaintext = Vec::new();
235        let sender_public = key_decrypt(
236            &mut ciphertext.as_slice(),
237            &mut plaintext,
238            &recipient,
239            &recipient_public,
240            AsymFileFormat::V1,
241        )
242        .unwrap();
243
244        assert_eq!(&expected_plaintext[..], plaintext.as_slice());
245        assert_eq!(expected_sender.as_bytes(), sender_public.as_bytes());
246    }
247
248    fn encrypt_small_util() -> Vec<u8> {
249        let ephemeral_private =
250            hex::decode("fdbc28d8f4c2a97013e460836cece7a4bdf59df0cb4b3a185146d13615884f38")
251                .unwrap();
252        let payload_key =
253            hex::decode("a9f9ddef54d0432ec067b75aef26c3db5419ade3b016339743ca1812d89188b2")
254                .unwrap();
255        let key_data = get_key_data();
256
257        let sender = PrivateKey::try_from(key_data.alice_private.as_bytes()).unwrap();
258        let sender_public = sender.to_public().unwrap();
259        let recipient = PublicKey::try_from(key_data.bob_public.as_bytes()).unwrap();
260        let ephemeral = PrivateKey::try_from(ephemeral_private.as_slice()).unwrap();
261        let ephemeral_public = ephemeral.to_public().unwrap();
262        let payload_key = PayloadKey::new(payload_key.as_slice());
263
264        let plaintext_data = b"Hello, world!";
265        let mut plaintext = Vec::new();
266        plaintext.extend_from_slice(plaintext_data);
267        let mut ciphertext = Vec::new();
268
269        key_encrypt(
270            &mut plaintext.as_slice(),
271            &mut ciphertext,
272            &sender,
273            &sender_public,
274            &recipient,
275            Some(&ephemeral),
276            Some(&ephemeral_public),
277            Some(&payload_key),
278            AsymFileFormat::V1,
279        )
280        .unwrap();
281
282        ciphertext
283    }
284
285    #[test]
286    fn test_decrypt_one_chunk() {
287        let expected_hash =
288            hex::decode("916b144867c340614f515c7b0e5415c74832d899c05264ded2a277a6e81d81ff")
289                .unwrap();
290        let key_data = get_key_data();
291        let expected_sender = key_data.alice_public;
292        let recipient = key_data.bob_private;
293        let recipient_public = key_data.bob_public;
294        let ciphertext = encrypt_one_chunk();
295        let mut plaintext = Vec::new();
296        let sender_public = key_decrypt(
297            &mut ciphertext.as_slice(),
298            &mut plaintext,
299            &recipient,
300            &recipient_public,
301            AsymFileFormat::V1,
302        )
303        .unwrap();
304        let got_hash = sha256(plaintext.as_slice());
305
306        assert_eq!(expected_hash.as_slice(), &got_hash[..]);
307        assert_eq!(expected_sender.as_bytes(), sender_public.as_bytes());
308    }
309
310    fn encrypt_one_chunk() -> Vec<u8> {
311        let ephemeral_private =
312            hex::decode("fdf2b46d965e4bb85d856971d657fdd6dc1fe8993f27587980e4f07f6409927f")
313                .unwrap();
314        let ephemeral_private = PrivateKey::try_from(ephemeral_private.as_slice()).unwrap();
315        let ephemeral_public = ephemeral_private.to_public().unwrap();
316        let payload_key =
317            hex::decode("a300f423e416610a5dd87442f4edc21325f2b3211c4c69f0e0c541cf6cf4eca6")
318                .unwrap();
319        let payload_key = PayloadKey::new(payload_key.as_slice());
320        let key_data = get_key_data();
321
322        let chunk_size: usize = CHUNK_SIZE.try_into().unwrap();
323        let mut plaintext = vec![0; chunk_size];
324        std::io::repeat(0x01).read_exact(&mut plaintext).unwrap();
325        let mut ciphertext = Vec::new();
326
327        key_encrypt(
328            &mut plaintext.as_slice(),
329            &mut ciphertext,
330            &key_data.alice_private,
331            &key_data.alice_public,
332            &key_data.bob_public,
333            Some(&ephemeral_private),
334            Some(&ephemeral_public),
335            Some(&payload_key),
336            AsymFileFormat::V1,
337        )
338        .unwrap();
339
340        ciphertext
341    }
342
343    #[test]
344    fn test_decrypt_two_chunks() {
345        let expected_hash =
346            hex::decode("6cb0ccb39028c57dd7db638d27c88fd1acc1794c8582fefe0949c091a2035ac7")
347                .unwrap();
348        let key_data = get_key_data();
349        let expected_sender = key_data.alice_public;
350        let recipient = key_data.bob_private;
351        let recipient_public = key_data.bob_public;
352        let ciphertext = encrypt_two_chunks();
353        let mut plaintext = Vec::new();
354        let sender_public = key_decrypt(
355            &mut ciphertext.as_slice(),
356            &mut plaintext,
357            &recipient,
358            &recipient_public,
359            AsymFileFormat::V1,
360        )
361        .unwrap();
362        let got_hash = sha256(plaintext.as_slice());
363
364        assert_eq!(expected_hash.as_slice(), &got_hash[..]);
365        assert_eq!(expected_sender.as_bytes(), sender_public.as_bytes());
366    }
367
368    fn encrypt_two_chunks() -> Vec<u8> {
369        let ephemeral_private =
371            hex::decode("90ecf9d1dca6ed1e6997585228513a73d4db36bd7dd7c758acb55a6d333bb2fb")
372                .unwrap();
373        let ephemeral_private = PrivateKey::try_from(ephemeral_private.as_slice()).unwrap();
374        let ephemeral_public = ephemeral_private.to_public().unwrap();
375        let payload_key =
376            hex::decode("d3387376438daeb6f7543e815cbde249810e341c1ccab192025b909b9ea4ebe7")
377                .unwrap();
378        let payload_key = PayloadKey::new(payload_key.as_slice());
379        let key_data = get_key_data();
380
381        let chunk_size: usize = CHUNK_SIZE.try_into().unwrap();
382        let mut plaintext = vec![0; chunk_size + 1];
383        std::io::repeat(0x02).read_exact(&mut plaintext).unwrap();
384        let mut ciphertext = Vec::new();
385
386        key_encrypt(
387            &mut plaintext.as_slice(),
388            &mut ciphertext,
389            &key_data.alice_private,
390            &key_data.alice_public,
391            &key_data.bob_public,
392            Some(&ephemeral_private),
393            Some(&ephemeral_public),
394            Some(&payload_key),
395            AsymFileFormat::V1,
396        )
397        .unwrap();
398
399        ciphertext
400    }
401
402    #[test]
403    fn test_pass_decrypt() {
404        let expected_pt = b"Be sure to drink your Ovaltine";
405        let pass = b"hackme";
406
407        let ciphertext = pass_encrypt_util();
408        let mut plaintext = Vec::new();
409        pass_decrypt(
410            &mut ciphertext.as_slice(),
411            &mut plaintext,
412            pass,
413            PassFileFormat::V1,
414        )
415        .unwrap();
416
417        assert_eq!(&expected_pt[..], plaintext.as_slice());
418    }
419
420    fn pass_encrypt_util() -> Vec<u8> {
421        let salt = hex::decode("b3e94eb6bba5bc462aab92fd86eb9d9f939320a60ae46e690907918ef2ee3aec")
422            .unwrap();
423        let salt: [u8; 32] = salt.try_into().unwrap();
424        let pass = b"hackme";
425        let plaintext = b"Be sure to drink your Ovaltine";
426        let mut pt = Vec::new();
427        pt.extend_from_slice(plaintext);
428        let mut ciphertext = Vec::new();
429
430        pass_encrypt(
431            &mut pt.as_slice(),
432            &mut ciphertext,
433            pass,
434            salt,
435            PassFileFormat::V1,
436        )
437        .unwrap();
438
439        ciphertext
440    }
441
442    fn get_key_data() -> KeyData {
443        let alice_private =
444            hex::decode("46acb4ad2a6ffb9d70245798634ad0d5caf7a9738e5f3b60905dee7a7b973bd5")
445                .unwrap();
446        let alice_private = PrivateKey::try_from(alice_private.as_slice()).unwrap();
447        let alice_public =
448            hex::decode("3cf3637b4dfdc4596544a936b3983fca09324505f39568d4b8537bc01a92cf6d")
449                .unwrap();
450        let alice_public = PublicKey::try_from(alice_public.as_slice()).unwrap();
451
452        let bob_private =
453            hex::decode("461299525a53333e8597a2b065703ec751356f8462d2704e630c108037567bd4")
454                .unwrap();
455        let bob_private = PrivateKey::try_from(bob_private.as_slice()).unwrap();
456        let bob_public =
457            hex::decode("98459724b39e6b9e90b60d214df2887093e224b163714e07e527a4d37edc2d03")
458                .unwrap();
459        let bob_public = PublicKey::try_from(bob_public.as_slice()).unwrap();
460
461        KeyData {
462            alice_private,
463            alice_public,
464            bob_private,
465            bob_public,
466        }
467    }
468}