use anyhow::{Context, Result};
use argon2::{Argon2, PasswordHasher};
use chacha20poly1305::{
aead::{Aead, KeyInit},
ChaCha20Poly1305, Nonce,
};
use rand::RngCore;
pub const SALT_LEN: usize = 16;
pub const SESSION_NONCE_LEN: usize = 8;
#[derive(Clone)]
pub struct EncryptionContext {
cipher: ChaCha20Poly1305,
session_nonce: [u8; SESSION_NONCE_LEN],
}
impl EncryptionContext {
pub fn new(password: &str, salt: &[u8]) -> Result<Self> {
let key = derive_key(password, salt)?;
let cipher = ChaCha20Poly1305::new_from_slice(&key)
.map_err(|e| anyhow::anyhow!("Failed to create cipher: {}", e))?;
let mut session_nonce = [0u8; SESSION_NONCE_LEN];
rand::thread_rng().fill_bytes(&mut session_nonce);
Ok(Self {
cipher,
session_nonce,
})
}
pub fn with_session_nonce(
password: &str,
salt: &[u8],
session_nonce: [u8; SESSION_NONCE_LEN],
) -> Result<Self> {
let key = derive_key(password, salt)?;
let cipher = ChaCha20Poly1305::new_from_slice(&key)
.map_err(|e| anyhow::anyhow!("Failed to create cipher: {}", e))?;
Ok(Self {
cipher,
session_nonce,
})
}
pub fn session_nonce(&self) -> &[u8; SESSION_NONCE_LEN] {
&self.session_nonce
}
pub fn encrypt_chunk(&self, chunk_index: u64, data: &[u8]) -> Result<Vec<u8>> {
let nonce = self.derive_chunk_nonce(chunk_index);
self.cipher
.encrypt(&nonce, data)
.map_err(|e| anyhow::anyhow!("Encryption failed: {}", e))
}
pub fn decrypt_chunk(&self, chunk_index: u64, ciphertext: &[u8]) -> Result<Vec<u8>> {
let nonce = self.derive_chunk_nonce(chunk_index);
self.cipher
.decrypt(&nonce, ciphertext)
.map_err(|_| anyhow::anyhow!("Decryption failed (wrong password or corrupted data)"))
}
fn derive_chunk_nonce(&self, chunk_index: u64) -> Nonce {
let mut nonce_bytes = [0u8; 12];
nonce_bytes[..SESSION_NONCE_LEN].copy_from_slice(&self.session_nonce);
nonce_bytes[SESSION_NONCE_LEN..].copy_from_slice(&(chunk_index as u32).to_le_bytes());
Nonce::from(nonce_bytes)
}
}
fn derive_key(password: &str, salt: &[u8]) -> Result<[u8; 32]> {
use argon2::password_hash::SaltString;
let salt_b64 = base64::Engine::encode(&base64::engine::general_purpose::STANDARD_NO_PAD, salt);
let salt_str = if salt_b64.len() < 4 {
format!("{:0<4}", salt_b64)
} else if salt_b64.len() > 64 {
salt_b64[..64].to_string()
} else {
salt_b64
};
let salt_string =
SaltString::from_b64(&salt_str).map_err(|e| anyhow::anyhow!("Invalid salt: {}", e))?;
let argon2 = Argon2::default();
let hash = argon2
.hash_password(password.as_bytes(), &salt_string)
.map_err(|e| anyhow::anyhow!("Key derivation failed: {}", e))?;
let hash_output = hash.hash.context("No hash output from Argon2")?;
let hash_bytes = hash_output.as_bytes();
let mut key = [0u8; 32];
key.copy_from_slice(&hash_bytes[..32]);
Ok(key)
}
pub fn generate_salt() -> [u8; SALT_LEN] {
let mut salt = [0u8; SALT_LEN];
rand::thread_rng().fill_bytes(&mut salt);
salt
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encrypt_decrypt_roundtrip() {
let password = "test_password_123";
let salt = generate_salt();
let ctx = EncryptionContext::new(password, &salt).unwrap();
let plaintext = b"Hello, World! This is test data for encryption.";
let ciphertext = ctx.encrypt_chunk(0, plaintext).unwrap();
assert!(ciphertext.len() > plaintext.len());
let ctx2 =
EncryptionContext::with_session_nonce(password, &salt, *ctx.session_nonce()).unwrap();
let decrypted = ctx2.decrypt_chunk(0, &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_different_chunks_different_ciphertext() {
let password = "test_password";
let salt = generate_salt();
let ctx = EncryptionContext::new(password, &salt).unwrap();
let data = b"Same data";
let cipher1 = ctx.encrypt_chunk(0, data).unwrap();
let cipher2 = ctx.encrypt_chunk(1, data).unwrap();
assert_ne!(cipher1, cipher2);
}
#[test]
fn test_wrong_password_fails() {
let salt = generate_salt();
let ctx1 = EncryptionContext::new("correct_password", &salt).unwrap();
let plaintext = b"Secret data";
let ciphertext = ctx1.encrypt_chunk(0, plaintext).unwrap();
let ctx2 =
EncryptionContext::with_session_nonce("wrong_password", &salt, *ctx1.session_nonce())
.unwrap();
let result = ctx2.decrypt_chunk(0, &ciphertext);
assert!(result.is_err());
}
#[test]
fn test_wrong_chunk_index_fails() {
let password = "test_password";
let salt = generate_salt();
let ctx = EncryptionContext::new(password, &salt).unwrap();
let plaintext = b"Test data";
let ciphertext = ctx.encrypt_chunk(0, plaintext).unwrap();
let ctx2 =
EncryptionContext::with_session_nonce(password, &salt, *ctx.session_nonce()).unwrap();
let result = ctx2.decrypt_chunk(1, &ciphertext);
assert!(result.is_err());
}
#[test]
fn test_generate_salt_uniqueness() {
let salt1 = generate_salt();
let salt2 = generate_salt();
assert_ne!(salt1, salt2);
}
#[test]
fn test_large_chunk_encryption() {
let password = "test_password";
let salt = generate_salt();
let ctx = EncryptionContext::new(password, &salt).unwrap();
let large_data: Vec<u8> = (0..30_000).map(|i| (i % 256) as u8).collect();
let ciphertext = ctx.encrypt_chunk(0, &large_data).unwrap();
let ctx2 =
EncryptionContext::with_session_nonce(password, &salt, *ctx.session_nonce()).unwrap();
let decrypted = ctx2.decrypt_chunk(0, &ciphertext).unwrap();
assert_eq!(decrypted, large_data);
}
#[test]
fn test_empty_data_encryption() {
let password = "test_password";
let salt = generate_salt();
let ctx = EncryptionContext::new(password, &salt).unwrap();
let empty_data = b"";
let ciphertext = ctx.encrypt_chunk(0, empty_data).unwrap();
assert!(!ciphertext.is_empty());
let ctx2 =
EncryptionContext::with_session_nonce(password, &salt, *ctx.session_nonce()).unwrap();
let decrypted = ctx2.decrypt_chunk(0, &ciphertext).unwrap();
assert_eq!(decrypted, empty_data);
}
}