use crate::{decrypt, decrypt_stream, encrypt, encrypt_stream, generate_keypair};
use anyhow::{Context, Result};
use bincode::decode_from_slice;
use chacha20poly1305::aead::OsRng;
use chacha20poly1305::aead::rand_core::RngCore;
use std::error::Error;
use std::fs::File;
use std::io::{BufReader, Cursor, Read, Write};
use tempfile::tempfile;
#[test]
fn test_round_trip() -> Result<(), Box<dyn Error>> {
let keypair = generate_keypair()?;
let mut sink = Vec::new();
encrypt_stream(
keypair.public_key,
&mut Cursor::new(b"secret message"),
&mut sink,
)?;
let mut ciphertext_bytes = Vec::new();
decrypt_stream(
&keypair.private_key,
&mut Cursor::new(sink),
&mut ciphertext_bytes,
)?;
assert_eq!(String::from_utf8_lossy(&ciphertext_bytes), "secret message");
Ok(())
}
#[test]
fn test_big_round_trip() -> Result<(), Box<dyn Error>> {
let mut file = tempfile()?;
const TOTAL_SIZE: usize = 10 * 1024 * 1024; const BUFFER_SIZE: usize = 1024 * 1024; let mut buffer = vec![0u8; BUFFER_SIZE];
let mut remaining_bytes = TOTAL_SIZE;
while remaining_bytes > 0 {
let bytes_to_write = BUFFER_SIZE.min(remaining_bytes);
OsRng.fill_bytes(&mut buffer[0..bytes_to_write]);
file.write_all(&buffer[0..bytes_to_write])?;
remaining_bytes -= bytes_to_write;
}
file.flush()?;
use std::io::Seek;
file.seek(std::io::SeekFrom::Start(0))?;
let keypair = generate_keypair()?;
let mut encrypted_file = tempfile()?;
encrypt_stream(keypair.public_key, &mut file, &mut encrypted_file)?;
encrypted_file.seek(std::io::SeekFrom::Start(0))?;
let mut decrypted_file = tempfile()?;
decrypt_stream(
&keypair.private_key,
&mut encrypted_file,
&mut decrypted_file,
)?;
file.seek(std::io::SeekFrom::Start(0))?;
decrypted_file.seek(std::io::SeekFrom::Start(0))?;
let mut buffer1 = vec![0u8; BUFFER_SIZE];
let mut buffer2 = vec![0u8; BUFFER_SIZE];
let mut reader1 = BufReader::new(file);
let mut reader2 = BufReader::new(decrypted_file);
let mut comp = true;
loop {
let bytes1 = reader1.read(&mut buffer1)?;
let bytes2 = reader2.read(&mut buffer2)?;
if bytes1 != bytes2 {
comp = false;
break;
}
if bytes1 == 0 {
break;
}
if buffer1[..bytes1] != buffer2[..bytes2] {
comp = false;
break;
}
}
assert!(comp);
Ok(())
}
#[test]
fn legacy_functions() -> Result<(), Box<dyn Error>> {
let keypair = generate_keypair()?;
let ciphertext = encrypt(keypair.public_key, b"secret message")?;
let plaintext = decrypt(&ciphertext, &keypair.private_key)?;
assert_eq!(plaintext, "secret message");
Ok(())
}
#[test]
fn test_tampered_ciphertext() {
let server_kp = generate_keypair().unwrap();
let mut encrypted = Vec::new();
encrypt_stream(
server_kp.public_key,
&mut Cursor::new("test".as_bytes()),
&mut encrypted,
)
.unwrap();
let mut corrupted = encrypted.clone();
if corrupted.len() > 10 {
corrupted[10] ^= 0x01; }
let mut output = Vec::new();
let result = decrypt_stream(
&server_kp.private_key,
&mut Cursor::new(corrupted),
&mut output,
);
assert!(result.is_err());
}
#[test]
fn test_tampered_nonce() {
let server_kp = generate_keypair().unwrap();
let mut encrypted = Vec::new();
encrypt_stream(
server_kp.public_key,
&mut Cursor::new("test".as_bytes()),
&mut encrypted,
)
.unwrap();
let mut corrupted = encrypted.clone();
if corrupted.len() > 50 {
let nonce_start = corrupted.len().saturating_sub(50); corrupted[nonce_start] ^= 0x01;
}
let mut output = Vec::new();
let result = decrypt_stream(
&server_kp.private_key,
&mut Cursor::new(corrupted),
&mut output,
);
assert!(result.is_err());
}
#[test]
fn test_empty_message() -> Result<()> {
let server_kp = generate_keypair()?;
let msg = "";
let mut encrypted = Vec::new();
encrypt_stream(
server_kp.public_key,
&mut Cursor::new(msg.as_bytes()),
&mut encrypted,
)?;
let mut decrypted = Vec::new();
decrypt_stream(
&server_kp.private_key,
&mut Cursor::new(encrypted),
&mut decrypted,
)?;
assert_eq!(String::from_utf8_lossy(&decrypted), msg);
Ok(())
}
#[test]
fn test_large_message() -> Result<()> {
let server_kp = generate_keypair()?;
let msg = "A".repeat(10_000);
let mut encrypted = Vec::new();
encrypt_stream(
server_kp.public_key,
&mut Cursor::new(msg.as_bytes()),
&mut encrypted,
)?;
let mut decrypted = Vec::new();
decrypt_stream(
&server_kp.private_key,
&mut Cursor::new(encrypted),
&mut decrypted,
)?;
assert_eq!(String::from_utf8_lossy(&decrypted), msg);
Ok(())
}
#[test]
fn test_wrong_key_decryption() {
let server_kp = generate_keypair().unwrap();
let attacker_kp = generate_keypair().unwrap();
let msg = "Confidential message.";
let mut encrypted = Vec::new();
encrypt_stream(
server_kp.public_key,
&mut Cursor::new(msg.as_bytes()),
&mut encrypted,
)
.unwrap();
let mut output = Vec::new();
let result = decrypt_stream(
&attacker_kp.private_key,
&mut Cursor::new(encrypted),
&mut output,
);
assert!(result.is_err());
}
#[test]
fn test_known_vector() -> Result<()> {
let path = "tests.bin";
let config = bincode::config::standard()
.with_big_endian()
.with_variable_int_encoding();
let metadata = std::fs::metadata(path).context(format!(
"File '{}' not found. Run `cargo run --bin main` first",
path
))?;
let min_size = 100;
if metadata.len() < min_size {
anyhow::bail!(
"Corrupted file: size {} < minimum expected {}",
metadata.len(),
min_size
);
}
let bytes = std::fs::read(path)?;
let (_empty, _): ((), usize) = decode_from_slice(&bytes, config)?;
Ok(())
}
#[test]
fn file_round_trip() -> Result<(), Box<dyn Error>> {
let keypair = generate_keypair()?;
let mut encrypted_data = Vec::new();
encrypt_stream(
keypair.public_key,
&mut Cursor::new(b"file content example"),
&mut encrypted_data,
)?;
{
let mut file = File::create("encrypted.bin")?;
file.write_all(&encrypted_data)?;
}
let mut file = File::open("encrypted.bin")?;
let mut decrypted = Vec::new();
decrypt_stream(&keypair.private_key, &mut file, &mut decrypted)?;
assert_eq!(String::from_utf8_lossy(&decrypted), "file content example");
Ok(())
}