1use std::io::{Read, Write};
7
8use zeroize::Zeroizing;
9
10use crate::errors::{DecryptError, FileFormatError};
11use crate::{AsymFileFormat, FileFormat, PassFileFormat};
12use crate::{CHUNK_SIZE, SCRYPT_N, SCRYPT_P, SCRYPT_R, TAG_SIZE};
13use crate::{PrivateKey, PublicKey, chapoly_decrypt_noise, hkdf_sha256, noise_decrypt, scrypt};
14
15pub fn key_decrypt<T: Read, U: Write>(
17 ciphertext: &mut T,
18 plaintext: &mut U,
19 recipient: &PrivateKey,
20 recipient_public: &PublicKey,
21 file_format: AsymFileFormat,
22) -> Result<PublicKey, DecryptError> {
23 if file_format != AsymFileFormat::V1 {
24 return Err(DecryptError::Other(
25 "File format not supported. This may be your plaintext.".into(),
26 ));
27 }
28
29 let mut prologue = [0u8; 4];
30 ciphertext.read_exact(&mut prologue).map_err(read_err)?;
31 let file_format = valid_file_format(&prologue)?;
32 if file_format == FileFormat::PassV1 {
33 return Err(DecryptError::Other(
34 "This is a password encrypted file. Try password decrypt instread.".into(),
35 ));
36 }
37
38 let mut handshake_message = [0u8; 128];
39 ciphertext
40 .read_exact(&mut handshake_message)
41 .map_err(read_err)?;
42
43 let noise_message = noise_decrypt(recipient, recipient_public, &prologue, &handshake_message)
44 .map_err(|e| DecryptError::Other(e.to_string()))?;
45
46 let file_encryption_key = hkdf_sha256(
47 &[],
48 noise_message.payload_key.as_bytes(),
49 &noise_message.handshake_hash,
50 32,
51 );
52 let file_encryption_key = Zeroizing::new(file_encryption_key);
53
54 decrypt_chunks(ciphertext, plaintext, &file_encryption_key, &[], CHUNK_SIZE)?;
55
56 let public_key = noise_message.public_key.clone();
57
58 Ok(public_key)
59}
60
61pub fn pass_decrypt<T: Read, U: Write>(
63 ciphertext: &mut T,
64 plaintext: &mut U,
65 password: &[u8],
66 file_format: PassFileFormat,
67) -> Result<(), DecryptError> {
68 if file_format != PassFileFormat::V1 {
69 return Err(DecryptError::Other(
70 "File format not supported. This may be your plaintext.".into(),
71 ));
72 }
73
74 let mut pass_magic_num = [0u8; 4];
75 ciphertext
76 .read_exact(&mut pass_magic_num)
77 .map_err(read_err)?;
78 let file_format = valid_file_format(&pass_magic_num)?;
79 if file_format == FileFormat::AsymV1 {
80 return Err(DecryptError::Other(
81 "This is a key encrypted file. Try decrypt instread.".into(),
82 ));
83 }
84
85 let mut salt = [0u8; 32];
86 ciphertext.read_exact(&mut salt).map_err(read_err)?;
87
88 let key = scrypt(password, &salt, SCRYPT_N, SCRYPT_R, SCRYPT_P, 32);
89 let key = Zeroizing::new(key);
90 let aad = &pass_magic_num[..];
91
92 decrypt_chunks(ciphertext, plaintext, &key, aad, CHUNK_SIZE)?;
93
94 Ok(())
95}
96
97fn decrypt_chunks<T: Read, U: Write>(
105 ciphertext: &mut T,
106 plaintext: &mut U,
107 key: &[u8],
108 aad: &[u8],
109 chunk_size: u32,
110) -> Result<(), DecryptError> {
111 let mut chunk_number: u64 = 0;
112 let mut done = false;
113 let cs: usize = chunk_size.try_into().unwrap();
114 let mut buffer = vec![0; cs + TAG_SIZE];
115 let mut auth_data = vec![0u8; aad.len() + 8];
116
117 loop {
118 let mut chunk_header = [0u8; 16];
119 ciphertext.read_exact(&mut chunk_header).map_err(read_err)?;
120 let last_chunk_indicator_bytes: [u8; 4] = chunk_header[8..12].try_into().unwrap();
121 let ciphertext_length_bytes: [u8; 4] = chunk_header[12..].try_into().unwrap();
122 let last_chunk_indicator = u32::from_be_bytes(last_chunk_indicator_bytes);
123 let ciphertext_length = u32::from_be_bytes(ciphertext_length_bytes);
124 if ciphertext_length > chunk_size {
125 return Err(DecryptError::ChunkLen);
126 }
127
128 let ct_len: usize = ciphertext_length.try_into().unwrap();
129 ciphertext
130 .read_exact(&mut buffer[..ct_len + TAG_SIZE])
131 .map_err(read_err)?;
132
133 let aad_len = aad.len();
134 auth_data[..aad_len].copy_from_slice(aad);
135 auth_data[aad_len..aad_len + 4].copy_from_slice(&last_chunk_indicator_bytes);
136 auth_data[aad_len + 4..].copy_from_slice(&ciphertext_length_bytes);
137
138 let ct = &buffer[..ct_len + TAG_SIZE];
139 let pt_chunk = chapoly_decrypt_noise(key, chunk_number, auth_data.as_slice(), ct)?;
140
141 if last_chunk_indicator == 1 {
146 done = true;
147 let check = ciphertext.read(&mut [0u8; 1]).map_err(read_err)?;
153 if check != 0 {
154 return Err(DecryptError::UnexpectedData);
157 }
158 }
159
160 plaintext
161 .write_all(pt_chunk.as_slice())
162 .map_err(write_err)?;
163 plaintext.flush().map_err(write_err)?;
164
165 if done {
166 break;
167 }
168
169 chunk_number += 1;
173 }
174
175 Ok(())
176}
177
178pub fn valid_file_format(header: &[u8]) -> Result<FileFormat, FileFormatError> {
180 let asym_v1 = [0x65, 0x67, 0x6b, 0x10];
181 let pass_v1 = [0x65, 0x67, 0x6b, 0x20];
182
183 if header == asym_v1 {
184 return Ok(FileFormat::AsymV1);
185 } else if header == pass_v1 {
186 return Ok(FileFormat::PassV1);
187 }
188
189 Err(FileFormatError)
190}
191
192fn read_err(err: std::io::Error) -> DecryptError {
193 use std::io::ErrorKind;
194
195 match err.kind() {
196 ErrorKind::UnexpectedEof => {
197 DecryptError::IORead(std::io::Error::other("Did not read enough data."))
198 }
199 _ => DecryptError::IORead(err),
200 }
201}
202
203fn write_err(err: std::io::Error) -> DecryptError {
204 DecryptError::IOWrite(err)
205}
206
207#[cfg(test)]
208mod tests {
209 use super::CHUNK_SIZE;
210 use super::{PrivateKey, PublicKey};
211 use super::{key_decrypt, pass_decrypt};
212 use crate::encrypt::{key_encrypt, pass_encrypt};
213 use crate::sha256;
214 use crate::{AsymFileFormat, PassFileFormat, PayloadKey};
215 use ct_codecs::{Decoder, Hex};
216 use std::io::Read;
217
218 #[allow(dead_code)]
219 struct KeyData {
220 alice_private: PrivateKey,
221 alice_public: PublicKey,
222 bob_private: PrivateKey,
223 bob_public: PublicKey,
224 }
225
226 #[test]
227 fn test_decrypt_small() {
228 let expected_plaintext = b"Hello, world!";
229 let key_data = get_key_data();
230 let expected_sender = key_data.alice_public;
231 let recipient = key_data.bob_private;
232 let recipient_public = key_data.bob_public;
233 let ciphertext = encrypt_small_util();
234 let mut plaintext = Vec::new();
235 let sender_public = key_decrypt(
236 &mut ciphertext.as_slice(),
237 &mut plaintext,
238 &recipient,
239 &recipient_public,
240 AsymFileFormat::V1,
241 )
242 .unwrap();
243
244 assert_eq!(&expected_plaintext[..], plaintext.as_slice());
245 assert_eq!(expected_sender.as_bytes(), sender_public.as_bytes());
246 }
247
248 fn encrypt_small_util() -> Vec<u8> {
249 let ephemeral_private = Hex::decode_to_vec(
250 "fdbc28d8f4c2a97013e460836cece7a4bdf59df0cb4b3a185146d13615884f38",
251 None,
252 )
253 .unwrap();
254 let payload_key = Hex::decode_to_vec(
255 "a9f9ddef54d0432ec067b75aef26c3db5419ade3b016339743ca1812d89188b2",
256 None,
257 )
258 .unwrap();
259 let key_data = get_key_data();
260
261 let sender = PrivateKey::try_from(key_data.alice_private.as_bytes()).unwrap();
262 let sender_public = sender.to_public().unwrap();
263 let recipient = PublicKey::try_from(key_data.bob_public.as_bytes()).unwrap();
264 let ephemeral = PrivateKey::try_from(ephemeral_private.as_slice()).unwrap();
265 let ephemeral_public = ephemeral.to_public().unwrap();
266 let payload_key = PayloadKey::new(payload_key.as_slice());
267
268 let plaintext_data = b"Hello, world!";
269 let mut plaintext = Vec::new();
270 plaintext.extend_from_slice(plaintext_data);
271 let mut ciphertext = Vec::new();
272
273 key_encrypt(
274 &mut plaintext.as_slice(),
275 &mut ciphertext,
276 &sender,
277 &sender_public,
278 &recipient,
279 Some(&ephemeral),
280 Some(&ephemeral_public),
281 Some(&payload_key),
282 AsymFileFormat::V1,
283 )
284 .unwrap();
285
286 ciphertext
287 }
288
289 #[test]
290 fn test_decrypt_one_chunk() {
291 let expected_hash = Hex::decode_to_vec(
292 "916b144867c340614f515c7b0e5415c74832d899c05264ded2a277a6e81d81ff",
293 None,
294 )
295 .unwrap();
296 let key_data = get_key_data();
297 let expected_sender = key_data.alice_public;
298 let recipient = key_data.bob_private;
299 let recipient_public = key_data.bob_public;
300 let ciphertext = encrypt_one_chunk();
301 let mut plaintext = Vec::new();
302 let sender_public = key_decrypt(
303 &mut ciphertext.as_slice(),
304 &mut plaintext,
305 &recipient,
306 &recipient_public,
307 AsymFileFormat::V1,
308 )
309 .unwrap();
310 let got_hash = sha256(plaintext.as_slice());
311
312 assert_eq!(expected_hash.as_slice(), &got_hash[..]);
313 assert_eq!(expected_sender.as_bytes(), sender_public.as_bytes());
314 }
315
316 fn encrypt_one_chunk() -> Vec<u8> {
317 let ephemeral_private = Hex::decode_to_vec(
318 "fdf2b46d965e4bb85d856971d657fdd6dc1fe8993f27587980e4f07f6409927f",
319 None,
320 )
321 .unwrap();
322 let ephemeral_private = PrivateKey::try_from(ephemeral_private.as_slice()).unwrap();
323 let ephemeral_public = ephemeral_private.to_public().unwrap();
324 let payload_key = Hex::decode_to_vec(
325 "a300f423e416610a5dd87442f4edc21325f2b3211c4c69f0e0c541cf6cf4eca6",
326 None,
327 )
328 .unwrap();
329 let payload_key = PayloadKey::new(payload_key.as_slice());
330 let key_data = get_key_data();
331
332 let chunk_size: usize = CHUNK_SIZE.try_into().unwrap();
333 let mut plaintext = vec![0; chunk_size];
334 std::io::repeat(0x01).read_exact(&mut plaintext).unwrap();
335 let mut ciphertext = Vec::new();
336
337 key_encrypt(
338 &mut plaintext.as_slice(),
339 &mut ciphertext,
340 &key_data.alice_private,
341 &key_data.alice_public,
342 &key_data.bob_public,
343 Some(&ephemeral_private),
344 Some(&ephemeral_public),
345 Some(&payload_key),
346 AsymFileFormat::V1,
347 )
348 .unwrap();
349
350 ciphertext
351 }
352
353 #[test]
354 fn test_decrypt_two_chunks() {
355 let expected_hash = Hex::decode_to_vec(
356 "6cb0ccb39028c57dd7db638d27c88fd1acc1794c8582fefe0949c091a2035ac7",
357 None,
358 )
359 .unwrap();
360 let key_data = get_key_data();
361 let expected_sender = key_data.alice_public;
362 let recipient = key_data.bob_private;
363 let recipient_public = key_data.bob_public;
364 let ciphertext = encrypt_two_chunks();
365 let mut plaintext = Vec::new();
366 let sender_public = key_decrypt(
367 &mut ciphertext.as_slice(),
368 &mut plaintext,
369 &recipient,
370 &recipient_public,
371 AsymFileFormat::V1,
372 )
373 .unwrap();
374 let got_hash = sha256(plaintext.as_slice());
375
376 assert_eq!(expected_hash.as_slice(), &got_hash[..]);
377 assert_eq!(expected_sender.as_bytes(), sender_public.as_bytes());
378 }
379
380 fn encrypt_two_chunks() -> Vec<u8> {
381 let ephemeral_private = Hex::decode_to_vec(
383 "90ecf9d1dca6ed1e6997585228513a73d4db36bd7dd7c758acb55a6d333bb2fb",
384 None,
385 )
386 .unwrap();
387 let ephemeral_private = PrivateKey::try_from(ephemeral_private.as_slice()).unwrap();
388 let ephemeral_public = ephemeral_private.to_public().unwrap();
389 let payload_key = Hex::decode_to_vec(
390 "d3387376438daeb6f7543e815cbde249810e341c1ccab192025b909b9ea4ebe7",
391 None,
392 )
393 .unwrap();
394 let payload_key = PayloadKey::new(payload_key.as_slice());
395 let key_data = get_key_data();
396
397 let chunk_size: usize = CHUNK_SIZE.try_into().unwrap();
398 let mut plaintext = vec![0; chunk_size + 1];
399 std::io::repeat(0x02).read_exact(&mut plaintext).unwrap();
400 let mut ciphertext = Vec::new();
401
402 key_encrypt(
403 &mut plaintext.as_slice(),
404 &mut ciphertext,
405 &key_data.alice_private,
406 &key_data.alice_public,
407 &key_data.bob_public,
408 Some(&ephemeral_private),
409 Some(&ephemeral_public),
410 Some(&payload_key),
411 AsymFileFormat::V1,
412 )
413 .unwrap();
414
415 ciphertext
416 }
417
418 #[test]
419 fn test_pass_decrypt() {
420 let expected_pt = b"Be sure to drink your Ovaltine";
421 let pass = b"hackme";
422
423 let ciphertext = pass_encrypt_util();
424 let mut plaintext = Vec::new();
425 pass_decrypt(
426 &mut ciphertext.as_slice(),
427 &mut plaintext,
428 pass,
429 PassFileFormat::V1,
430 )
431 .unwrap();
432
433 assert_eq!(&expected_pt[..], plaintext.as_slice());
434 }
435
436 fn pass_encrypt_util() -> Vec<u8> {
437 let salt = Hex::decode_to_vec(
438 "b3e94eb6bba5bc462aab92fd86eb9d9f939320a60ae46e690907918ef2ee3aec",
439 None,
440 )
441 .unwrap();
442 let salt: [u8; 32] = salt.try_into().unwrap();
443 let pass = b"hackme";
444 let plaintext = b"Be sure to drink your Ovaltine";
445 let mut pt = Vec::new();
446 pt.extend_from_slice(plaintext);
447 let mut ciphertext = Vec::new();
448
449 pass_encrypt(
450 &mut pt.as_slice(),
451 &mut ciphertext,
452 pass,
453 salt,
454 PassFileFormat::V1,
455 )
456 .unwrap();
457
458 ciphertext
459 }
460
461 fn get_key_data() -> KeyData {
462 let alice_private = Hex::decode_to_vec(
463 "46acb4ad2a6ffb9d70245798634ad0d5caf7a9738e5f3b60905dee7a7b973bd5",
464 None,
465 )
466 .unwrap();
467 let alice_private = PrivateKey::try_from(alice_private.as_slice()).unwrap();
468 let alice_public = Hex::decode_to_vec(
469 "3cf3637b4dfdc4596544a936b3983fca09324505f39568d4b8537bc01a92cf6d",
470 None,
471 )
472 .unwrap();
473 let alice_public = PublicKey::try_from(alice_public.as_slice()).unwrap();
474
475 let bob_private = Hex::decode_to_vec(
476 "461299525a53333e8597a2b065703ec751356f8462d2704e630c108037567bd4",
477 None,
478 )
479 .unwrap();
480 let bob_private = PrivateKey::try_from(bob_private.as_slice()).unwrap();
481 let bob_public = Hex::decode_to_vec(
482 "98459724b39e6b9e90b60d214df2887093e224b163714e07e527a4d37edc2d03",
483 None,
484 )
485 .unwrap();
486 let bob_public = PublicKey::try_from(bob_public.as_slice()).unwrap();
487
488 KeyData {
489 alice_private,
490 alice_public,
491 bob_private,
492 bob_public,
493 }
494 }
495}