use anyhow::{Context, Result};
use chacha20poly1305::{
aead::{Aead, KeyInit},
ChaCha20Poly1305, Nonce,
};
use rand::RngCore;
use sha2::{Digest, Sha256};
use zeroize::Zeroizing;
const KEY_SIZE: usize = 32;
const NONCE_SIZE: usize = 12;
pub struct IpcCipher {
cipher: ChaCha20Poly1305,
}
impl IpcCipher {
pub fn from_session_token(token: &str) -> Self {
let key = derive_key_from_token(token);
let cipher = ChaCha20Poly1305::new_from_slice(key.as_slice())
.expect("Key is always 32 bytes from SHA-256");
Self { cipher }
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
let mut nonce_bytes = [0u8; NONCE_SIZE];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = self
.cipher
.encrypt(nonce, plaintext)
.map_err(|e| anyhow::anyhow!("Encryption failed: {}", e))?;
let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&ciphertext);
Ok(result)
}
pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
if encrypted.len() < NONCE_SIZE {
anyhow::bail!("Encrypted message too short (missing nonce)");
}
let (nonce_bytes, ciphertext) = encrypted.split_at(NONCE_SIZE);
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext = self
.cipher
.decrypt(nonce, ciphertext)
.map_err(|_| anyhow::anyhow!("Decryption failed (authentication failed)"))?;
Ok(plaintext)
}
pub fn encrypt_string(&self, plaintext: &str) -> Result<String> {
let encrypted = self.encrypt(plaintext.as_bytes())?;
Ok(base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
&encrypted,
))
}
pub fn decrypt_string(&self, encrypted_b64: &str) -> Result<String> {
let encrypted = base64::Engine::decode(
&base64::engine::general_purpose::STANDARD,
encrypted_b64,
)
.context("Invalid base64 encoding")?;
let plaintext = self.decrypt(&encrypted)?;
String::from_utf8(plaintext).context("Decrypted data is not valid UTF-8")
}
}
fn derive_key_from_token(token: &str) -> Zeroizing<[u8; KEY_SIZE]> {
let mut hasher = Sha256::new();
hasher.update(b"brainwires-ipc-v1:");
hasher.update(token.as_bytes());
let result = hasher.finalize();
let mut key = Zeroizing::new([0u8; KEY_SIZE]);
key.copy_from_slice(&result);
key
}
pub fn generate_random_key() -> Zeroizing<[u8; KEY_SIZE]> {
let mut key = Zeroizing::new([0u8; KEY_SIZE]);
rand::thread_rng().fill_bytes(&mut *key);
key
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encrypt_decrypt_roundtrip() {
let token = "test-session-token-12345";
let cipher = IpcCipher::from_session_token(token);
let plaintext = b"Hello, this is a secret message!";
let encrypted = cipher.encrypt(plaintext).unwrap();
assert_ne!(encrypted.as_slice(), plaintext);
let decrypted = cipher.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_decrypt_string() {
let token = "string-test-token";
let cipher = IpcCipher::from_session_token(token);
let message = "This is a JSON message: {\"key\": \"value\"}";
let encrypted = cipher.encrypt_string(message).unwrap();
assert!(encrypted.chars().all(|c| c.is_ascii_alphanumeric() || c == '+' || c == '/' || c == '='));
let decrypted = cipher.decrypt_string(&encrypted).unwrap();
assert_eq!(decrypted, message);
}
#[test]
fn test_different_tokens_different_ciphertext() {
let cipher1 = IpcCipher::from_session_token("token1");
let cipher2 = IpcCipher::from_session_token("token2");
let plaintext = b"Same message";
let encrypted1 = cipher1.encrypt(plaintext).unwrap();
let encrypted2 = cipher2.encrypt(plaintext).unwrap();
assert_ne!(encrypted1, encrypted2);
assert!(cipher2.decrypt(&encrypted1).is_err());
assert!(cipher1.decrypt(&encrypted2).is_err());
}
#[test]
fn test_tamper_detection() {
let token = "tamper-test";
let cipher = IpcCipher::from_session_token(token);
let plaintext = b"Original message";
let mut encrypted = cipher.encrypt(plaintext).unwrap();
if let Some(byte) = encrypted.last_mut() {
*byte ^= 0xFF;
}
assert!(cipher.decrypt(&encrypted).is_err());
}
#[test]
fn test_empty_message() {
let token = "empty-test";
let cipher = IpcCipher::from_session_token(token);
let plaintext = b"";
let encrypted = cipher.encrypt(plaintext).unwrap();
let decrypted = cipher.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_large_message() {
let token = "large-test";
let cipher = IpcCipher::from_session_token(token);
let plaintext: Vec<u8> = (0..1_000_000).map(|i| (i % 256) as u8).collect();
let encrypted = cipher.encrypt(&plaintext).unwrap();
let decrypted = cipher.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_key_derivation_deterministic() {
let token = "deterministic-test";
let cipher1 = IpcCipher::from_session_token(token);
let cipher2 = IpcCipher::from_session_token(token);
let plaintext = b"Test message";
let encrypted = cipher1.encrypt(plaintext).unwrap();
let decrypted = cipher2.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, plaintext);
}
}