#[cfg(feature = "encryption")]
use std::num::NonZeroU32;
#[cfg(feature = "encryption")]
use std::time::Instant;
use serde::Serialize;
use tracing::{debug, warn};
#[cfg(feature = "encryption")]
use tracing::info;
#[derive(Debug, thiserror::Error)]
pub enum EncryptionError {
#[error("key derivation failed: {0}")]
KeyDerivation(String),
#[error("encryption failed: {0}")]
EncryptionFailed(String),
#[error("invalid passphrase")]
InvalidPassphrase,
}
#[derive(Debug, Clone, Serialize)]
pub struct EncryptedContent {
pub salt: String,
pub iv: String,
pub ciphertext: String,
pub iterations: u32,
}
impl EncryptedContent {
pub fn to_json(&self) -> String {
serde_json::to_string(self).unwrap_or_else(|_| "{}".to_string())
}
}
#[derive(Debug, Clone)]
pub struct EncryptionParams {
pub iterations: u32,
pub salt_len: usize,
pub iv_len: usize,
}
impl Default for EncryptionParams {
fn default() -> Self {
Self {
iterations: 600_000,
salt_len: 16,
iv_len: 12,
}
}
}
#[cfg(feature = "encryption")]
pub fn encrypt_content(
plaintext: &str,
password: &str,
params: &EncryptionParams,
) -> Result<EncryptedContent, EncryptionError> {
use aes_gcm::{
Aes256Gcm, Nonce,
aead::{Aead, KeyInit},
};
use ring::pbkdf2;
if password.is_empty() {
warn!(
component = "encryption",
operation = "validate_password",
"Rejected empty password"
);
return Err(EncryptionError::InvalidPassphrase);
}
if params.iterations == 0 {
return Err(EncryptionError::KeyDerivation(
"iterations must be greater than zero".to_string(),
));
}
if params.salt_len == 0 {
return Err(EncryptionError::KeyDerivation(
"salt length must be greater than zero".to_string(),
));
}
if params.iv_len != 12 {
return Err(EncryptionError::KeyDerivation(
"iv length must be 12 bytes for AES-GCM".to_string(),
));
}
let started = Instant::now();
info!(
component = "encryption",
operation = "encrypt_payload",
plaintext_bytes = plaintext.len(),
iterations = params.iterations,
salt_len = params.salt_len,
iv_len = params.iv_len,
"Starting encryption"
);
let mut salt = vec![0u8; params.salt_len];
let mut iv = vec![0u8; params.iv_len];
fill_encryption_random("salt", &mut salt);
fill_encryption_random("iv", &mut iv);
let derive_started = Instant::now();
let mut key = zeroize::Zeroizing::new([0u8; 32]); let iterations = NonZeroU32::new(params.iterations).ok_or_else(|| {
EncryptionError::KeyDerivation("iterations must be greater than zero".to_string())
})?;
pbkdf2::derive(
pbkdf2::PBKDF2_HMAC_SHA256,
iterations,
&salt,
password.as_bytes(),
&mut *key,
);
debug!(
component = "encryption",
operation = "derive_key",
duration_ms = derive_started.elapsed().as_millis(),
"Derived key via PBKDF2"
);
let cipher = Aes256Gcm::new_from_slice(key.as_ref())
.map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
let nonce = Nonce::from_slice(&iv);
let ciphertext = cipher
.encrypt(nonce, plaintext.as_bytes())
.map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
let encrypted = EncryptedContent {
salt: base64_encode(&salt),
iv: base64_encode(&iv),
ciphertext: base64_encode(&ciphertext),
iterations: params.iterations,
};
info!(
component = "encryption",
operation = "encrypt_complete",
ciphertext_bytes = encrypted.ciphertext.len(),
duration_ms = started.elapsed().as_millis(),
"Encryption complete"
);
Ok(encrypted)
}
#[cfg(feature = "encryption")]
fn fill_encryption_random(label: &str, output: &mut [u8]) {
if let Some(bytes) = deterministic_test_bytes(label, output.len()) {
output.copy_from_slice(&bytes);
return;
}
use aes_gcm::aead::{OsRng, rand_core::RngCore};
OsRng.fill_bytes(output);
}
#[cfg(feature = "encryption")]
fn deterministic_test_bytes(entropy_label: &str, len: usize) -> Option<Vec<u8>> {
#[cfg(debug_assertions)]
{
let golden_label = dotenvy::var("CASS_HTML_EXPORT_GOLDEN_BYTES_LABEL").ok()?;
if golden_label.is_empty() {
return None;
}
let mut out = Vec::with_capacity(len);
let mut counter = 0u64;
while out.len() < len {
let mut hasher = blake3::Hasher::new();
hasher.update(b"cass-html-export-deterministic-encryption-v1");
hasher.update(golden_label.as_bytes());
hasher.update(entropy_label.as_bytes());
hasher.update(&counter.to_le_bytes());
out.extend_from_slice(hasher.finalize().as_bytes());
counter += 1;
}
out.truncate(len);
Some(out)
}
#[cfg(not(debug_assertions))]
{
let _ = (entropy_label, len);
None
}
}
#[cfg(not(feature = "encryption"))]
pub fn encrypt_content(
_plaintext: &str,
_password: &str,
_params: &EncryptionParams,
) -> Result<EncryptedContent, EncryptionError> {
warn!(
component = "encryption",
operation = "encrypt_payload",
"Encryption feature not enabled"
);
Err(EncryptionError::EncryptionFailed(
"encryption feature not enabled - compile with --features encryption".to_string(),
))
}
#[cfg(feature = "encryption")]
fn base64_encode(data: &[u8]) -> String {
use base64::Engine;
base64::prelude::BASE64_STANDARD.encode(data)
}
pub fn render_encrypted_placeholder(encrypted: &EncryptedContent) -> String {
debug!(
component = "encryption",
operation = "render_placeholder",
ciphertext_bytes = encrypted.ciphertext.len(),
"Rendering encrypted placeholder"
);
let json = encrypted.to_json();
let escaped_json = html_escape_for_content(&json);
format!(
r###" <!-- Encrypted content - requires password to decrypt -->
<div id="encrypted-content" hidden>{}</div>
<div class="encrypted-notice">
<p>This conversation is encrypted. Enter the password above to view.</p>
</div>"###,
escaped_json
)
}
fn html_escape_for_content(s: &str) -> String {
let mut result = String::with_capacity(s.len());
for c in s.chars() {
match c {
'&' => result.push_str("&"),
'<' => result.push_str("<"),
'>' => result.push_str(">"),
_ => result.push(c),
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encryption_error_display_strings() {
assert_eq!(
EncryptionError::KeyDerivation("bad params".to_string()).to_string(),
"key derivation failed: bad params"
);
assert_eq!(
EncryptionError::EncryptionFailed("cipher failed".to_string()).to_string(),
"encryption failed: cipher failed"
);
assert_eq!(
EncryptionError::InvalidPassphrase.to_string(),
"invalid passphrase"
);
}
#[test]
#[cfg(feature = "encryption")]
fn test_base64_encode() {
assert_eq!(base64_encode(b"" as &[u8]), "");
assert_eq!(base64_encode(b"f" as &[u8]), "Zg==");
assert_eq!(base64_encode(b"fo" as &[u8]), "Zm8=");
assert_eq!(base64_encode(b"foo" as &[u8]), "Zm9v");
assert_eq!(base64_encode(b"foob" as &[u8]), "Zm9vYg==");
assert_eq!(base64_encode(b"fooba" as &[u8]), "Zm9vYmE=");
assert_eq!(base64_encode(b"foobar" as &[u8]), "Zm9vYmFy");
}
#[test]
fn test_encrypted_content_to_json() {
let content = EncryptedContent {
salt: "abc123".to_string(),
iv: "xyz789".to_string(),
ciphertext: "encrypted_data".to_string(),
iterations: 123_456,
};
let json = content.to_json();
assert!(json.contains("\"salt\":\"abc123\""));
assert!(json.contains("\"iv\":\"xyz789\""));
assert!(json.contains("\"ciphertext\":\"encrypted_data\""));
assert!(json.contains("\"iterations\":123456"));
}
#[test]
fn test_encryption_params_default() {
let params = EncryptionParams::default();
assert_eq!(params.iterations, 600_000);
assert_eq!(params.salt_len, 16);
assert_eq!(params.iv_len, 12);
}
#[test]
#[cfg(feature = "encryption")]
fn test_encrypt_content_roundtrip() {
use aes_gcm::{
Aes256Gcm, Nonce,
aead::{Aead, KeyInit},
};
use base64::Engine; use base64::prelude::BASE64_STANDARD;
use ring::pbkdf2;
let params = EncryptionParams {
iterations: 1_000,
salt_len: 16,
iv_len: 12,
};
let plaintext = "Hello 🌍";
let test_phrase = ["unit", "test", "phrase"].join(" ");
let encrypted = encrypt_content(plaintext, &test_phrase, ¶ms).expect("encrypt");
assert_eq!(encrypted.iterations, params.iterations);
let salt = BASE64_STANDARD
.decode(encrypted.salt.as_bytes())
.expect("salt b64");
let iv = BASE64_STANDARD
.decode(encrypted.iv.as_bytes())
.expect("iv b64");
let ciphertext = BASE64_STANDARD
.decode(encrypted.ciphertext.as_bytes())
.expect("ciphertext b64");
let mut key = [0u8; 32];
pbkdf2::derive(
pbkdf2::PBKDF2_HMAC_SHA256,
NonZeroU32::new(params.iterations).expect("test iterations should be non-zero"),
&salt,
test_phrase.as_bytes(),
&mut key,
);
let cipher = Aes256Gcm::new_from_slice(&key).expect("cipher");
let nonce = Nonce::from_slice(&iv);
let decrypted = cipher.decrypt(nonce, ciphertext.as_ref()).expect("decrypt");
assert_eq!(plaintext, String::from_utf8(decrypted).expect("utf8"));
}
#[test]
#[cfg(feature = "encryption")]
fn test_encrypt_content_produces_authenticated_ciphertext() {
let params = EncryptionParams {
iterations: 1_000,
salt_len: 16,
iv_len: 12,
};
let result = encrypt_content(
"sensitive data",
"authenticated encryption fixture",
¶ms,
)
.expect("feature-enabled encrypt_content should produce ciphertext");
assert!(!result.salt.is_empty(), "salt must be generated");
assert!(!result.iv.is_empty(), "iv must be generated");
assert_ne!(
result.ciphertext, "sensitive data",
"ciphertext must differ from plaintext"
);
assert!(
result.ciphertext.len() > "sensitive data".len(),
"ciphertext should include authenticated-encryption overhead"
);
assert_eq!(result.iterations, params.iterations);
}
#[test]
#[cfg(feature = "encryption")]
fn test_encrypt_rejects_empty_password() {
let params = EncryptionParams {
iterations: 1_000,
salt_len: 16,
iv_len: 12,
};
let result = encrypt_content("hello", "", ¶ms);
assert!(matches!(result, Err(EncryptionError::InvalidPassphrase)));
}
#[test]
#[cfg(feature = "encryption")]
fn test_encrypt_rejects_invalid_params() {
let mut params = EncryptionParams {
iterations: 1_000,
salt_len: 16,
iv_len: 12,
};
params.iterations = 0;
let result = encrypt_content("hello", "pw", ¶ms);
assert!(matches!(result, Err(EncryptionError::KeyDerivation(_))));
params.iterations = 1_000;
params.salt_len = 0;
let result = encrypt_content("hello", "pw", ¶ms);
assert!(matches!(result, Err(EncryptionError::KeyDerivation(_))));
params.salt_len = 16;
params.iv_len = 8;
let result = encrypt_content("hello", "pw", ¶ms);
assert!(matches!(result, Err(EncryptionError::KeyDerivation(_))));
}
#[test]
#[cfg(not(feature = "encryption"))]
fn test_encrypt_without_feature_returns_error() {
let phrase = ["disabled", "feature", "phrase"].join(" ");
let result = encrypt_content("test", &phrase, &EncryptionParams::default());
assert!(result.is_err());
}
}