1use std::io::{Read, Write};
7
8use zeroize::Zeroizing;
9
10use crate::errors::{DecryptError, FileFormatError};
11use crate::{chapoly_decrypt_noise, hkdf_sha256, noise_decrypt, scrypt, PrivateKey, PublicKey};
12use crate::{AsymFileFormat, FileFormat, PassFileFormat};
13use crate::{CHUNK_SIZE, SCRYPT_N, SCRYPT_P, SCRYPT_R, TAG_SIZE};
14
15pub fn key_decrypt<T: Read, U: Write>(
17 ciphertext: &mut T,
18 plaintext: &mut U,
19 recipient: &PrivateKey,
20 recipient_public: &PublicKey,
21 file_format: AsymFileFormat,
22) -> Result<PublicKey, DecryptError> {
23 if file_format != AsymFileFormat::V1 {
24 return Err(DecryptError::Other(
25 "File format not supported. This may be your plaintext.".into(),
26 ));
27 }
28
29 let mut prologue = [0u8; 4];
30 ciphertext.read_exact(&mut prologue).map_err(read_err)?;
31 let file_format = valid_file_format(&prologue)?;
32 if file_format == FileFormat::PassV1 {
33 return Err(DecryptError::Other(
34 "This is a password encrypted file. Try password decrypt instread.".into(),
35 ));
36 }
37
38 let mut handshake_message = [0u8; 128];
39 ciphertext
40 .read_exact(&mut handshake_message)
41 .map_err(read_err)?;
42
43 let noise_message = noise_decrypt(recipient, recipient_public, &prologue, &handshake_message)
44 .map_err(|e| DecryptError::Other(e.to_string()))?;
45
46 let file_encryption_key = hkdf_sha256(
47 &[],
48 &noise_message.payload_key.as_bytes(),
49 &noise_message.handshake_hash,
50 32,
51 );
52 let file_encryption_key = Zeroizing::new(file_encryption_key);
53
54 decrypt_chunks(ciphertext, plaintext, &file_encryption_key, &[], CHUNK_SIZE)?;
55
56 let public_key = noise_message.public_key.clone();
57
58 Ok(public_key)
59}
60
61pub fn pass_decrypt<T: Read, U: Write>(
63 ciphertext: &mut T,
64 plaintext: &mut U,
65 password: &[u8],
66 file_format: PassFileFormat,
67) -> Result<(), DecryptError> {
68 if file_format != PassFileFormat::V1 {
69 return Err(DecryptError::Other(
70 "File format not supported. This may be your plaintext.".into(),
71 ));
72 }
73
74 let mut pass_magic_num = [0u8; 4];
75 ciphertext
76 .read_exact(&mut pass_magic_num)
77 .map_err(read_err)?;
78 let file_format = valid_file_format(&pass_magic_num)?;
79 if file_format == FileFormat::AsymV1 {
80 return Err(DecryptError::Other(
81 "This is a key encrypted file. Try decrypt instread.".into(),
82 ));
83 }
84
85 let mut salt = [0u8; 32];
86 ciphertext.read_exact(&mut salt).map_err(read_err)?;
87
88 let key = scrypt(password, &salt, SCRYPT_N, SCRYPT_R, SCRYPT_P, 32);
89 let key = Zeroizing::new(key);
90 let aad = &pass_magic_num[..];
91
92 decrypt_chunks(ciphertext, plaintext, &key, aad, CHUNK_SIZE)?;
93
94 Ok(())
95}
96
97fn decrypt_chunks<T: Read, U: Write>(
105 ciphertext: &mut T,
106 plaintext: &mut U,
107 key: &[u8],
108 aad: &[u8],
109 chunk_size: u32,
110) -> Result<(), DecryptError> {
111 let mut chunk_number: u64 = 0;
112 let mut done = false;
113 let cs: usize = chunk_size.try_into().unwrap();
114 let mut buffer = vec![0; cs + TAG_SIZE];
115 let mut auth_data = vec![0u8; aad.len() + 8];
116
117 loop {
118 let mut chunk_header = [0u8; 16];
119 ciphertext.read_exact(&mut chunk_header).map_err(read_err)?;
120 let last_chunk_indicator_bytes: [u8; 4] = chunk_header[8..12].try_into().unwrap();
121 let ciphertext_length_bytes: [u8; 4] = chunk_header[12..].try_into().unwrap();
122 let last_chunk_indicator = u32::from_be_bytes(last_chunk_indicator_bytes);
123 let ciphertext_length = u32::from_be_bytes(ciphertext_length_bytes);
124 if ciphertext_length > chunk_size {
125 return Err(DecryptError::ChunkLen);
126 }
127
128 let ct_len: usize = ciphertext_length.try_into().unwrap();
129 ciphertext
130 .read_exact(&mut buffer[..ct_len + TAG_SIZE])
131 .map_err(read_err)?;
132
133 let aad_len = aad.len();
134 auth_data[..aad_len].copy_from_slice(aad);
135 auth_data[aad_len..aad_len + 4].copy_from_slice(&last_chunk_indicator_bytes);
136 auth_data[aad_len + 4..].copy_from_slice(&ciphertext_length_bytes);
137
138 let ct = &buffer[..ct_len + TAG_SIZE];
139 let pt_chunk = chapoly_decrypt_noise(&key, chunk_number, auth_data.as_slice(), ct)?;
140
141 if last_chunk_indicator == 1 {
146 done = true;
147 let check = ciphertext.read(&mut [0u8; 1]).map_err(read_err)?;
153 if check != 0 {
154 return Err(DecryptError::UnexpectedData);
157 }
158 }
159
160 plaintext
161 .write_all(pt_chunk.as_slice())
162 .map_err(write_err)?;
163 plaintext.flush().map_err(write_err)?;
164
165 if done {
166 break;
167 }
168
169 chunk_number += 1;
173 }
174
175 Ok(())
176}
177
178pub fn valid_file_format(header: &[u8]) -> Result<FileFormat, FileFormatError> {
180 let asym_v1 = [0x65, 0x67, 0x6b, 0x10];
181 let pass_v1 = [0x65, 0x67, 0x6b, 0x20];
182
183 if header == asym_v1 {
184 return Ok(FileFormat::AsymV1);
185 } else if header == pass_v1 {
186 return Ok(FileFormat::PassV1);
187 }
188
189 Err(FileFormatError)
190}
191
192fn read_err(err: std::io::Error) -> DecryptError {
193 use std::io::ErrorKind;
194
195 match err.kind() {
196 ErrorKind::UnexpectedEof => DecryptError::IORead(std::io::Error::new(
197 ErrorKind::Other,
198 "Did not read enough data.",
199 )),
200 _ => DecryptError::IORead(err),
201 }
202}
203
204fn write_err(err: std::io::Error) -> DecryptError {
205 DecryptError::IOWrite(err)
206}
207
208#[cfg(test)]
209mod tests {
210 use super::CHUNK_SIZE;
211 use super::{key_decrypt, pass_decrypt};
212 use super::{PrivateKey, PublicKey};
213 use crate::encrypt::{key_encrypt, pass_encrypt};
214 use crate::sha256;
215 use crate::{AsymFileFormat, PassFileFormat, PayloadKey};
216 use std::io::Read;
217
218 #[allow(dead_code)]
219 struct KeyData {
220 alice_private: PrivateKey,
221 alice_public: PublicKey,
222 bob_private: PrivateKey,
223 bob_public: PublicKey,
224 }
225
226 #[test]
227 fn test_decrypt_small() {
228 let expected_plaintext = b"Hello, world!";
229 let key_data = get_key_data();
230 let expected_sender = key_data.alice_public;
231 let recipient = key_data.bob_private;
232 let recipient_public = key_data.bob_public;
233 let ciphertext = encrypt_small_util();
234 let mut plaintext = Vec::new();
235 let sender_public = key_decrypt(
236 &mut ciphertext.as_slice(),
237 &mut plaintext,
238 &recipient,
239 &recipient_public,
240 AsymFileFormat::V1,
241 )
242 .unwrap();
243
244 assert_eq!(&expected_plaintext[..], plaintext.as_slice());
245 assert_eq!(expected_sender.as_bytes(), sender_public.as_bytes());
246 }
247
248 fn encrypt_small_util() -> Vec<u8> {
249 let ephemeral_private =
250 hex::decode("fdbc28d8f4c2a97013e460836cece7a4bdf59df0cb4b3a185146d13615884f38")
251 .unwrap();
252 let payload_key =
253 hex::decode("a9f9ddef54d0432ec067b75aef26c3db5419ade3b016339743ca1812d89188b2")
254 .unwrap();
255 let key_data = get_key_data();
256
257 let sender = PrivateKey::try_from(key_data.alice_private.as_bytes()).unwrap();
258 let sender_public = sender.to_public().unwrap();
259 let recipient = PublicKey::try_from(key_data.bob_public.as_bytes()).unwrap();
260 let ephemeral = PrivateKey::try_from(ephemeral_private.as_slice()).unwrap();
261 let ephemeral_public = ephemeral.to_public().unwrap();
262 let payload_key = PayloadKey::new(payload_key.as_slice());
263
264 let plaintext_data = b"Hello, world!";
265 let mut plaintext = Vec::new();
266 plaintext.extend_from_slice(plaintext_data);
267 let mut ciphertext = Vec::new();
268
269 key_encrypt(
270 &mut plaintext.as_slice(),
271 &mut ciphertext,
272 &sender,
273 &sender_public,
274 &recipient,
275 Some(&ephemeral),
276 Some(&ephemeral_public),
277 Some(&payload_key),
278 AsymFileFormat::V1,
279 )
280 .unwrap();
281
282 ciphertext
283 }
284
285 #[test]
286 fn test_decrypt_one_chunk() {
287 let expected_hash =
288 hex::decode("916b144867c340614f515c7b0e5415c74832d899c05264ded2a277a6e81d81ff")
289 .unwrap();
290 let key_data = get_key_data();
291 let expected_sender = key_data.alice_public;
292 let recipient = key_data.bob_private;
293 let recipient_public = key_data.bob_public;
294 let ciphertext = encrypt_one_chunk();
295 let mut plaintext = Vec::new();
296 let sender_public = key_decrypt(
297 &mut ciphertext.as_slice(),
298 &mut plaintext,
299 &recipient,
300 &recipient_public,
301 AsymFileFormat::V1,
302 )
303 .unwrap();
304 let got_hash = sha256(plaintext.as_slice());
305
306 assert_eq!(expected_hash.as_slice(), &got_hash[..]);
307 assert_eq!(expected_sender.as_bytes(), sender_public.as_bytes());
308 }
309
310 fn encrypt_one_chunk() -> Vec<u8> {
311 let ephemeral_private =
312 hex::decode("fdf2b46d965e4bb85d856971d657fdd6dc1fe8993f27587980e4f07f6409927f")
313 .unwrap();
314 let ephemeral_private = PrivateKey::try_from(ephemeral_private.as_slice()).unwrap();
315 let ephemeral_public = ephemeral_private.to_public().unwrap();
316 let payload_key =
317 hex::decode("a300f423e416610a5dd87442f4edc21325f2b3211c4c69f0e0c541cf6cf4eca6")
318 .unwrap();
319 let payload_key = PayloadKey::new(payload_key.as_slice());
320 let key_data = get_key_data();
321
322 let chunk_size: usize = CHUNK_SIZE.try_into().unwrap();
323 let mut plaintext = vec![0; chunk_size];
324 std::io::repeat(0x01).read_exact(&mut plaintext).unwrap();
325 let mut ciphertext = Vec::new();
326
327 key_encrypt(
328 &mut plaintext.as_slice(),
329 &mut ciphertext,
330 &key_data.alice_private,
331 &key_data.alice_public,
332 &key_data.bob_public,
333 Some(&ephemeral_private),
334 Some(&ephemeral_public),
335 Some(&payload_key),
336 AsymFileFormat::V1,
337 )
338 .unwrap();
339
340 ciphertext
341 }
342
343 #[test]
344 fn test_decrypt_two_chunks() {
345 let expected_hash =
346 hex::decode("6cb0ccb39028c57dd7db638d27c88fd1acc1794c8582fefe0949c091a2035ac7")
347 .unwrap();
348 let key_data = get_key_data();
349 let expected_sender = key_data.alice_public;
350 let recipient = key_data.bob_private;
351 let recipient_public = key_data.bob_public;
352 let ciphertext = encrypt_two_chunks();
353 let mut plaintext = Vec::new();
354 let sender_public = key_decrypt(
355 &mut ciphertext.as_slice(),
356 &mut plaintext,
357 &recipient,
358 &recipient_public,
359 AsymFileFormat::V1,
360 )
361 .unwrap();
362 let got_hash = sha256(plaintext.as_slice());
363
364 assert_eq!(expected_hash.as_slice(), &got_hash[..]);
365 assert_eq!(expected_sender.as_bytes(), sender_public.as_bytes());
366 }
367
368 fn encrypt_two_chunks() -> Vec<u8> {
369 let ephemeral_private =
371 hex::decode("90ecf9d1dca6ed1e6997585228513a73d4db36bd7dd7c758acb55a6d333bb2fb")
372 .unwrap();
373 let ephemeral_private = PrivateKey::try_from(ephemeral_private.as_slice()).unwrap();
374 let ephemeral_public = ephemeral_private.to_public().unwrap();
375 let payload_key =
376 hex::decode("d3387376438daeb6f7543e815cbde249810e341c1ccab192025b909b9ea4ebe7")
377 .unwrap();
378 let payload_key = PayloadKey::new(payload_key.as_slice());
379 let key_data = get_key_data();
380
381 let chunk_size: usize = CHUNK_SIZE.try_into().unwrap();
382 let mut plaintext = vec![0; chunk_size + 1];
383 std::io::repeat(0x02).read_exact(&mut plaintext).unwrap();
384 let mut ciphertext = Vec::new();
385
386 key_encrypt(
387 &mut plaintext.as_slice(),
388 &mut ciphertext,
389 &key_data.alice_private,
390 &key_data.alice_public,
391 &key_data.bob_public,
392 Some(&ephemeral_private),
393 Some(&ephemeral_public),
394 Some(&payload_key),
395 AsymFileFormat::V1,
396 )
397 .unwrap();
398
399 ciphertext
400 }
401
402 #[test]
403 fn test_pass_decrypt() {
404 let expected_pt = b"Be sure to drink your Ovaltine";
405 let pass = b"hackme";
406
407 let ciphertext = pass_encrypt_util();
408 let mut plaintext = Vec::new();
409 pass_decrypt(
410 &mut ciphertext.as_slice(),
411 &mut plaintext,
412 pass,
413 PassFileFormat::V1,
414 )
415 .unwrap();
416
417 assert_eq!(&expected_pt[..], plaintext.as_slice());
418 }
419
420 fn pass_encrypt_util() -> Vec<u8> {
421 let salt = hex::decode("b3e94eb6bba5bc462aab92fd86eb9d9f939320a60ae46e690907918ef2ee3aec")
422 .unwrap();
423 let salt: [u8; 32] = salt.try_into().unwrap();
424 let pass = b"hackme";
425 let plaintext = b"Be sure to drink your Ovaltine";
426 let mut pt = Vec::new();
427 pt.extend_from_slice(plaintext);
428 let mut ciphertext = Vec::new();
429
430 pass_encrypt(
431 &mut pt.as_slice(),
432 &mut ciphertext,
433 pass,
434 salt,
435 PassFileFormat::V1,
436 )
437 .unwrap();
438
439 ciphertext
440 }
441
442 fn get_key_data() -> KeyData {
443 let alice_private =
444 hex::decode("46acb4ad2a6ffb9d70245798634ad0d5caf7a9738e5f3b60905dee7a7b973bd5")
445 .unwrap();
446 let alice_private = PrivateKey::try_from(alice_private.as_slice()).unwrap();
447 let alice_public =
448 hex::decode("3cf3637b4dfdc4596544a936b3983fca09324505f39568d4b8537bc01a92cf6d")
449 .unwrap();
450 let alice_public = PublicKey::try_from(alice_public.as_slice()).unwrap();
451
452 let bob_private =
453 hex::decode("461299525a53333e8597a2b065703ec751356f8462d2704e630c108037567bd4")
454 .unwrap();
455 let bob_private = PrivateKey::try_from(bob_private.as_slice()).unwrap();
456 let bob_public =
457 hex::decode("98459724b39e6b9e90b60d214df2887093e224b163714e07e527a4d37edc2d03")
458 .unwrap();
459 let bob_public = PublicKey::try_from(bob_public.as_slice()).unwrap();
460
461 KeyData {
462 alice_private,
463 alice_public,
464 bob_private,
465 bob_public,
466 }
467 }
468}