use std::collections::HashMap;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use chacha20poly1305::{
aead::{Aead, AeadCore, KeyInit, OsRng},
ChaCha20Poly1305, Key, Nonce,
};
use hkdf::Hkdf;
use sha2::Sha256;
use tokio::sync::RwLock;
use x25519_dalek::{PublicKey, SharedSecret, StaticSecret};
use super::error::SecurityError;
use super::DeviceId;
pub const NONCE_SIZE: usize = 12;
pub const SYMMETRIC_KEY_SIZE: usize = 32;
pub const X25519_PUBLIC_KEY_SIZE: usize = 32;
const HKDF_INFO_PEER: &[u8] = b"peat-protocol-v1-peer";
#[allow(dead_code)]
const HKDF_INFO_GROUP: &[u8] = b"peat-protocol-v1-group";
#[derive(Clone)]
pub struct EncryptionKeypair {
secret: Arc<StaticSecret>,
public: PublicKey,
}
impl EncryptionKeypair {
pub fn generate() -> Self {
let secret = StaticSecret::random_from_rng(OsRng);
let public = PublicKey::from(&secret);
Self {
secret: Arc::new(secret),
public,
}
}
pub fn from_secret_bytes(bytes: &[u8; 32]) -> Self {
let secret = StaticSecret::from(*bytes);
let public = PublicKey::from(&secret);
Self {
secret: Arc::new(secret),
public,
}
}
pub fn public_key(&self) -> &PublicKey {
&self.public
}
pub fn public_key_bytes(&self) -> [u8; 32] {
self.public.to_bytes()
}
pub fn dh_exchange(&self, peer_public: &PublicKey) -> SharedSecret {
self.secret.diffie_hellman(peer_public)
}
}
impl std::fmt::Debug for EncryptionKeypair {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EncryptionKeypair")
.field("public", &hex::encode(self.public.as_bytes()))
.field("secret", &"[REDACTED]")
.finish()
}
}
#[derive(Clone)]
pub struct SymmetricKey {
key: Key,
}
impl SymmetricKey {
pub fn from_bytes(bytes: &[u8; SYMMETRIC_KEY_SIZE]) -> Self {
Self {
key: *Key::from_slice(bytes),
}
}
pub fn derive_from_shared_secret(shared_secret: &SharedSecret, info: &[u8]) -> Self {
let hk = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
let mut key_bytes = [0u8; SYMMETRIC_KEY_SIZE];
hk.expand(info, &mut key_bytes)
.expect("HKDF expand should never fail with correct output length");
Self::from_bytes(&key_bytes)
}
pub fn derive_for_peer(shared_secret: &SharedSecret) -> Self {
Self::derive_from_shared_secret(shared_secret, HKDF_INFO_PEER)
}
pub fn as_bytes(&self) -> &[u8; SYMMETRIC_KEY_SIZE] {
self.key[..].try_into().unwrap()
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedData, SecurityError> {
let cipher = ChaCha20Poly1305::new(&self.key);
let nonce = ChaCha20Poly1305::generate_nonce(&mut OsRng);
let ciphertext = cipher
.encrypt(&nonce, plaintext)
.map_err(|e| SecurityError::EncryptionError(e.to_string()))?;
Ok(EncryptedData {
nonce: nonce.into(),
ciphertext,
})
}
pub fn decrypt(&self, encrypted: &EncryptedData) -> Result<Vec<u8>, SecurityError> {
let cipher = ChaCha20Poly1305::new(&self.key);
let nonce = Nonce::from_slice(&encrypted.nonce);
cipher
.decrypt(nonce, encrypted.ciphertext.as_ref())
.map_err(|e| SecurityError::DecryptionError(e.to_string()))
}
}
impl std::fmt::Debug for SymmetricKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SymmetricKey")
.field("key", &"[REDACTED]")
.finish()
}
}
#[derive(Debug, Clone)]
pub struct EncryptedData {
pub nonce: [u8; NONCE_SIZE],
pub ciphertext: Vec<u8>,
}
impl EncryptedData {
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(NONCE_SIZE + self.ciphertext.len());
bytes.extend_from_slice(&self.nonce);
bytes.extend_from_slice(&self.ciphertext);
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, SecurityError> {
if bytes.len() < NONCE_SIZE {
return Err(SecurityError::DecryptionError(
"ciphertext too short for nonce".to_string(),
));
}
let mut nonce = [0u8; NONCE_SIZE];
nonce.copy_from_slice(&bytes[..NONCE_SIZE]);
let ciphertext = bytes[NONCE_SIZE..].to_vec();
Ok(Self { nonce, ciphertext })
}
}
#[derive(Debug)]
pub struct SecureChannel {
pub peer_id: DeviceId,
key: SymmetricKey,
pub established_at: SystemTime,
}
impl SecureChannel {
pub fn new(peer_id: DeviceId, key: SymmetricKey) -> Self {
Self {
peer_id,
key,
established_at: SystemTime::now(),
}
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedData, SecurityError> {
self.key.encrypt(plaintext)
}
pub fn decrypt(&self, encrypted: &EncryptedData) -> Result<Vec<u8>, SecurityError> {
self.key.decrypt(encrypted)
}
pub fn age_secs(&self) -> u64 {
self.established_at
.elapsed()
.map(|d| d.as_secs())
.unwrap_or(0)
}
}
#[derive(Clone)]
pub struct GroupKey {
pub cell_id: String,
key: SymmetricKey,
pub generation: u64,
pub created_at: SystemTime,
}
impl GroupKey {
pub fn generate(cell_id: String) -> Self {
let mut key_bytes = [0u8; SYMMETRIC_KEY_SIZE];
OsRng.fill_bytes(&mut key_bytes);
Self {
cell_id,
key: SymmetricKey::from_bytes(&key_bytes),
generation: 1,
created_at: SystemTime::now(),
}
}
pub fn from_bytes(
cell_id: String,
key_bytes: &[u8; SYMMETRIC_KEY_SIZE],
generation: u64,
) -> Self {
Self {
cell_id,
key: SymmetricKey::from_bytes(key_bytes),
generation,
created_at: SystemTime::now(),
}
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedCellMessage, SecurityError> {
let encrypted = self.key.encrypt(plaintext)?;
Ok(EncryptedCellMessage {
cell_id: self.cell_id.clone(),
generation: self.generation,
encrypted,
})
}
pub fn decrypt(&self, message: &EncryptedCellMessage) -> Result<Vec<u8>, SecurityError> {
if message.cell_id != self.cell_id {
return Err(SecurityError::DecryptionError(format!(
"cell ID mismatch: expected {}, got {}",
self.cell_id, message.cell_id
)));
}
if message.generation != self.generation {
return Err(SecurityError::DecryptionError(format!(
"key generation mismatch: expected {}, got {}",
self.generation, message.generation
)));
}
self.key.decrypt(&message.encrypted)
}
pub fn rotate(&self) -> Self {
let mut key_bytes = [0u8; SYMMETRIC_KEY_SIZE];
OsRng.fill_bytes(&mut key_bytes);
Self {
cell_id: self.cell_id.clone(),
key: SymmetricKey::from_bytes(&key_bytes),
generation: self.generation + 1,
created_at: SystemTime::now(),
}
}
pub fn key_bytes(&self) -> [u8; SYMMETRIC_KEY_SIZE] {
*self.key.as_bytes()
}
}
impl std::fmt::Debug for GroupKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GroupKey")
.field("cell_id", &self.cell_id)
.field("generation", &self.generation)
.field("key", &"[REDACTED]")
.finish()
}
}
#[derive(Debug, Clone)]
pub struct EncryptedCellMessage {
pub cell_id: String,
pub generation: u64,
pub encrypted: EncryptedData,
}
impl EncryptedCellMessage {
pub fn to_bytes(&self) -> Vec<u8> {
let cell_id_bytes = self.cell_id.as_bytes();
let encrypted_bytes = self.encrypted.to_bytes();
let mut bytes = Vec::new();
bytes.extend_from_slice(&(cell_id_bytes.len() as u32).to_le_bytes());
bytes.extend_from_slice(cell_id_bytes);
bytes.extend_from_slice(&self.generation.to_le_bytes());
bytes.extend_from_slice(&encrypted_bytes);
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, SecurityError> {
if bytes.len() < 12 {
return Err(SecurityError::DecryptionError(
"message too short".to_string(),
));
}
let cell_id_len = u32::from_le_bytes(bytes[0..4].try_into().unwrap()) as usize;
if bytes.len() < 4 + cell_id_len + 8 {
return Err(SecurityError::DecryptionError(
"message truncated".to_string(),
));
}
let cell_id = String::from_utf8(bytes[4..4 + cell_id_len].to_vec())
.map_err(|e| SecurityError::DecryptionError(e.to_string()))?;
let generation = u64::from_le_bytes(
bytes[4 + cell_id_len..4 + cell_id_len + 8]
.try_into()
.unwrap(),
);
let encrypted = EncryptedData::from_bytes(&bytes[4 + cell_id_len + 8..])?;
Ok(Self {
cell_id,
generation,
encrypted,
})
}
}
#[derive(Debug, Clone)]
pub struct EncryptedDocument {
pub encrypted: EncryptedData,
pub encrypted_by: DeviceId,
pub encrypted_at: u64,
}
impl EncryptedDocument {
pub fn new(encrypted: EncryptedData, device_id: DeviceId) -> Self {
let encrypted_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
Self {
encrypted,
encrypted_by: device_id,
encrypted_at,
}
}
}
pub struct EncryptionManager {
keypair: EncryptionKeypair,
device_id: DeviceId,
peer_channels: Arc<RwLock<HashMap<DeviceId, SecureChannel>>>,
cell_keys: Arc<RwLock<HashMap<String, GroupKey>>>,
device_key: SymmetricKey,
}
impl EncryptionManager {
pub fn new(keypair: EncryptionKeypair, device_id: DeviceId) -> Self {
let hk = Hkdf::<Sha256>::new(None, keypair.public_key_bytes().as_ref());
let mut device_key_bytes = [0u8; SYMMETRIC_KEY_SIZE];
hk.expand(b"peat-protocol-v1-device", &mut device_key_bytes)
.expect("HKDF expand should never fail");
Self {
keypair,
device_id,
peer_channels: Arc::new(RwLock::new(HashMap::new())),
cell_keys: Arc::new(RwLock::new(HashMap::new())),
device_key: SymmetricKey::from_bytes(&device_key_bytes),
}
}
pub fn public_key(&self) -> &PublicKey {
self.keypair.public_key()
}
pub fn public_key_bytes(&self) -> [u8; 32] {
self.keypair.public_key_bytes()
}
pub async fn establish_channel(
&self,
peer_id: DeviceId,
peer_public_key: &[u8; X25519_PUBLIC_KEY_SIZE],
) -> Result<(), SecurityError> {
let peer_public = PublicKey::from(*peer_public_key);
let shared_secret = self.keypair.dh_exchange(&peer_public);
let symmetric_key = SymmetricKey::derive_for_peer(&shared_secret);
let channel = SecureChannel::new(peer_id, symmetric_key);
self.peer_channels.write().await.insert(peer_id, channel);
Ok(())
}
pub async fn get_channel(&self, peer_id: &DeviceId) -> Option<SecureChannel> {
let channels = self.peer_channels.read().await;
channels.get(peer_id).map(|c| SecureChannel {
peer_id: c.peer_id,
key: c.key.clone(),
established_at: c.established_at,
})
}
pub async fn has_channel(&self, peer_id: &DeviceId) -> bool {
self.peer_channels.read().await.contains_key(peer_id)
}
pub async fn remove_channel(&self, peer_id: &DeviceId) {
self.peer_channels.write().await.remove(peer_id);
}
pub async fn encrypt_for_peer(
&self,
peer_id: &DeviceId,
plaintext: &[u8],
) -> Result<EncryptedData, SecurityError> {
let channels = self.peer_channels.read().await;
let channel = channels.get(peer_id).ok_or_else(|| {
SecurityError::EncryptionError(format!("no channel for peer: {}", peer_id))
})?;
channel.encrypt(plaintext)
}
pub async fn decrypt_from_peer(
&self,
peer_id: &DeviceId,
encrypted: &EncryptedData,
) -> Result<Vec<u8>, SecurityError> {
let channels = self.peer_channels.read().await;
let channel = channels.get(peer_id).ok_or_else(|| {
SecurityError::DecryptionError(format!("no channel for peer: {}", peer_id))
})?;
channel.decrypt(encrypted)
}
pub async fn get_or_create_cell_key(&self, cell_id: &str) -> GroupKey {
let mut keys = self.cell_keys.write().await;
if let Some(key) = keys.get(cell_id) {
key.clone()
} else {
let key = GroupKey::generate(cell_id.to_string());
keys.insert(cell_id.to_string(), key.clone());
key
}
}
pub async fn set_cell_key(&self, key: GroupKey) {
self.cell_keys
.write()
.await
.insert(key.cell_id.clone(), key);
}
pub async fn get_cell_key(&self, cell_id: &str) -> Option<GroupKey> {
self.cell_keys.read().await.get(cell_id).cloned()
}
pub async fn rotate_cell_key(&self, cell_id: &str) -> Result<GroupKey, SecurityError> {
let mut keys = self.cell_keys.write().await;
let old_key = keys.get(cell_id).ok_or_else(|| {
SecurityError::EncryptionError(format!("no key for cell: {}", cell_id))
})?;
let new_key = old_key.rotate();
keys.insert(cell_id.to_string(), new_key.clone());
Ok(new_key)
}
pub async fn remove_cell_key(&self, cell_id: &str) {
self.cell_keys.write().await.remove(cell_id);
}
pub async fn encrypt_for_cell(
&self,
cell_id: &str,
plaintext: &[u8],
) -> Result<EncryptedCellMessage, SecurityError> {
let keys = self.cell_keys.read().await;
let key = keys.get(cell_id).ok_or_else(|| {
SecurityError::EncryptionError(format!("no key for cell: {}", cell_id))
})?;
key.encrypt(plaintext)
}
pub async fn decrypt_cell_message(
&self,
message: &EncryptedCellMessage,
) -> Result<Vec<u8>, SecurityError> {
let keys = self.cell_keys.read().await;
let key = keys.get(&message.cell_id).ok_or_else(|| {
SecurityError::DecryptionError(format!("no key for cell: {}", message.cell_id))
})?;
key.decrypt(message)
}
pub fn encrypt_document(&self, plaintext: &[u8]) -> Result<EncryptedDocument, SecurityError> {
let encrypted = self.device_key.encrypt(plaintext)?;
Ok(EncryptedDocument::new(encrypted, self.device_id))
}
pub fn decrypt_document(&self, document: &EncryptedDocument) -> Result<Vec<u8>, SecurityError> {
if document.encrypted_by != self.device_id {
return Err(SecurityError::DecryptionError(
"document encrypted by different device".to_string(),
));
}
self.device_key.decrypt(&document.encrypted)
}
pub async fn peer_channel_count(&self) -> usize {
self.peer_channels.read().await.len()
}
pub async fn cell_key_count(&self) -> usize {
self.cell_keys.read().await.len()
}
}
impl std::fmt::Debug for EncryptionManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EncryptionManager")
.field("device_id", &self.device_id)
.field("public_key", &hex::encode(self.keypair.public_key_bytes()))
.finish()
}
}
use rand_core::RngCore;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keypair_generation() {
let kp1 = EncryptionKeypair::generate();
let kp2 = EncryptionKeypair::generate();
assert_ne!(kp1.public_key_bytes(), kp2.public_key_bytes());
}
#[test]
fn test_keypair_from_bytes() {
let kp1 = EncryptionKeypair::generate();
let secret_bytes = [42u8; 32]; let kp2 = EncryptionKeypair::from_secret_bytes(&secret_bytes);
let kp3 = EncryptionKeypair::from_secret_bytes(&secret_bytes);
assert_eq!(kp2.public_key_bytes(), kp3.public_key_bytes());
assert_ne!(kp1.public_key_bytes(), kp2.public_key_bytes());
}
#[test]
fn test_dh_key_exchange() {
let alice = EncryptionKeypair::generate();
let bob = EncryptionKeypair::generate();
let alice_shared = alice.dh_exchange(bob.public_key());
let bob_shared = bob.dh_exchange(alice.public_key());
assert_eq!(alice_shared.as_bytes(), bob_shared.as_bytes());
}
#[test]
fn test_symmetric_key_encrypt_decrypt() {
let key = SymmetricKey::from_bytes(&[42u8; 32]);
let plaintext = b"Hello, World!";
let encrypted = key.encrypt(plaintext).unwrap();
let decrypted = key.decrypt(&encrypted).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
}
#[test]
fn test_symmetric_key_different_nonces() {
let key = SymmetricKey::from_bytes(&[42u8; 32]);
let plaintext = b"Hello, World!";
let encrypted1 = key.encrypt(plaintext).unwrap();
let encrypted2 = key.encrypt(plaintext).unwrap();
assert_ne!(encrypted1.nonce, encrypted2.nonce);
assert_ne!(encrypted1.ciphertext, encrypted2.ciphertext);
assert_eq!(key.decrypt(&encrypted1).unwrap(), plaintext);
assert_eq!(key.decrypt(&encrypted2).unwrap(), plaintext);
}
#[test]
fn test_wrong_key_decryption_fails() {
let key1 = SymmetricKey::from_bytes(&[42u8; 32]);
let key2 = SymmetricKey::from_bytes(&[43u8; 32]);
let plaintext = b"Hello, World!";
let encrypted = key1.encrypt(plaintext).unwrap();
let result = key2.decrypt(&encrypted);
assert!(result.is_err());
}
#[test]
fn test_encrypted_data_serialization() {
let key = SymmetricKey::from_bytes(&[42u8; 32]);
let plaintext = b"Hello, World!";
let encrypted = key.encrypt(plaintext).unwrap();
let bytes = encrypted.to_bytes();
let restored = EncryptedData::from_bytes(&bytes).unwrap();
assert_eq!(encrypted.nonce, restored.nonce);
assert_eq!(encrypted.ciphertext, restored.ciphertext);
let decrypted = key.decrypt(&restored).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
}
#[test]
fn test_secure_channel() {
let alice = EncryptionKeypair::generate();
let bob = EncryptionKeypair::generate();
let alice_shared = alice.dh_exchange(bob.public_key());
let bob_shared = bob.dh_exchange(alice.public_key());
let alice_key = SymmetricKey::derive_for_peer(&alice_shared);
let bob_key = SymmetricKey::derive_for_peer(&bob_shared);
let alice_id = DeviceId::from_bytes([1u8; 16]);
let bob_id = DeviceId::from_bytes([2u8; 16]);
let alice_channel = SecureChannel::new(bob_id, alice_key);
let bob_channel = SecureChannel::new(alice_id, bob_key);
let message = b"Secret message from Alice";
let encrypted = alice_channel.encrypt(message).unwrap();
let decrypted = bob_channel.decrypt(&encrypted).unwrap();
assert_eq!(message.as_slice(), decrypted.as_slice());
let reply = b"Reply from Bob";
let encrypted_reply = bob_channel.encrypt(reply).unwrap();
let decrypted_reply = alice_channel.decrypt(&encrypted_reply).unwrap();
assert_eq!(reply.as_slice(), decrypted_reply.as_slice());
}
#[test]
fn test_group_key() {
let key = GroupKey::generate("cell-1".to_string());
let plaintext = b"Broadcast message";
let encrypted = key.encrypt(plaintext).unwrap();
let decrypted = key.decrypt(&encrypted).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
assert_eq!(encrypted.cell_id, "cell-1");
assert_eq!(encrypted.generation, 1);
}
#[test]
fn test_group_key_rotation() {
let key1 = GroupKey::generate("cell-1".to_string());
let key2 = key1.rotate();
assert_eq!(key1.cell_id, key2.cell_id);
assert_eq!(key1.generation + 1, key2.generation);
assert_ne!(key1.key_bytes(), key2.key_bytes());
let message = b"New message";
let encrypted = key2.encrypt(message).unwrap();
assert!(key1.decrypt(&encrypted).is_err());
let decrypted = key2.decrypt(&encrypted).unwrap();
assert_eq!(message.as_slice(), decrypted.as_slice());
}
#[test]
fn test_encrypted_cell_message_serialization() {
let key = GroupKey::generate("cell-1".to_string());
let plaintext = b"Cell broadcast";
let encrypted = key.encrypt(plaintext).unwrap();
let bytes = encrypted.to_bytes();
let restored = EncryptedCellMessage::from_bytes(&bytes).unwrap();
assert_eq!(encrypted.cell_id, restored.cell_id);
assert_eq!(encrypted.generation, restored.generation);
let decrypted = key.decrypt(&restored).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
}
#[tokio::test]
async fn test_encryption_manager_peer_channels() {
let alice_kp = EncryptionKeypair::generate();
let bob_kp = EncryptionKeypair::generate();
let alice_id = DeviceId::from_bytes([1u8; 16]);
let bob_id = DeviceId::from_bytes([2u8; 16]);
let alice_mgr = EncryptionManager::new(alice_kp.clone(), alice_id);
let bob_mgr = EncryptionManager::new(bob_kp.clone(), bob_id);
alice_mgr
.establish_channel(bob_id, &bob_mgr.public_key_bytes())
.await
.unwrap();
bob_mgr
.establish_channel(alice_id, &alice_mgr.public_key_bytes())
.await
.unwrap();
assert!(alice_mgr.has_channel(&bob_id).await);
assert!(bob_mgr.has_channel(&alice_id).await);
let message = b"Hello Bob!";
let encrypted = alice_mgr.encrypt_for_peer(&bob_id, message).await.unwrap();
let decrypted = bob_mgr
.decrypt_from_peer(&alice_id, &encrypted)
.await
.unwrap();
assert_eq!(message.as_slice(), decrypted.as_slice());
}
#[tokio::test]
async fn test_encryption_manager_cell_keys() {
let kp = EncryptionKeypair::generate();
let device_id = DeviceId::from_bytes([1u8; 16]);
let mgr = EncryptionManager::new(kp, device_id);
let key = mgr.get_or_create_cell_key("cell-1").await;
assert_eq!(key.cell_id, "cell-1");
assert_eq!(key.generation, 1);
let key2 = mgr.get_or_create_cell_key("cell-1").await;
assert_eq!(key.generation, key2.generation);
let message = b"Cell message";
let encrypted = mgr.encrypt_for_cell("cell-1", message).await.unwrap();
let decrypted = mgr.decrypt_cell_message(&encrypted).await.unwrap();
assert_eq!(message.as_slice(), decrypted.as_slice());
let new_key = mgr.rotate_cell_key("cell-1").await.unwrap();
assert_eq!(new_key.generation, 2);
}
#[test]
fn test_document_encryption() {
let kp = EncryptionKeypair::generate();
let device_id = DeviceId::from_bytes([1u8; 16]);
let mgr = EncryptionManager::new(kp, device_id);
let document = b"Sensitive document content";
let encrypted = mgr.encrypt_document(document).unwrap();
let decrypted = mgr.decrypt_document(&encrypted).unwrap();
assert_eq!(document.as_slice(), decrypted.as_slice());
}
#[test]
fn test_document_wrong_device_fails() {
let kp1 = EncryptionKeypair::generate();
let kp2 = EncryptionKeypair::generate();
let device_id1 = DeviceId::from_bytes([1u8; 16]);
let device_id2 = DeviceId::from_bytes([2u8; 16]);
let mgr1 = EncryptionManager::new(kp1, device_id1);
let mgr2 = EncryptionManager::new(kp2, device_id2);
let document = b"Sensitive document";
let encrypted = mgr1.encrypt_document(document).unwrap();
let result = mgr2.decrypt_document(&encrypted);
assert!(result.is_err());
}
#[test]
fn test_secure_channel_age_secs() {
let key = SymmetricKey::from_bytes(&[42u8; 32]);
let peer_id = DeviceId::from_bytes([1u8; 16]);
let channel = SecureChannel::new(peer_id, key);
let age = channel.age_secs();
assert!(age < 2, "Channel age should be near zero, got {}", age);
}
#[test]
fn test_encrypted_data_from_bytes_too_short() {
let short_data = vec![0u8; 5];
let result = EncryptedData::from_bytes(&short_data);
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("too short"));
}
#[test]
fn test_encrypted_data_from_bytes_exact_nonce_size() {
let data = vec![0u8; NONCE_SIZE];
let result = EncryptedData::from_bytes(&data);
assert!(result.is_ok());
let ed = result.unwrap();
assert_eq!(ed.nonce, [0u8; NONCE_SIZE]);
assert!(ed.ciphertext.is_empty());
}
#[test]
fn test_encrypted_cell_message_from_bytes_too_short() {
let short_data = vec![0u8; 8];
let result = EncryptedCellMessage::from_bytes(&short_data);
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("too short"));
}
#[test]
fn test_encrypted_cell_message_from_bytes_truncated() {
let mut data = Vec::new();
data.extend_from_slice(&100u32.to_le_bytes()); data.extend_from_slice(&[0u8; 8]); let result = EncryptedCellMessage::from_bytes(&data);
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("truncated"));
}
#[test]
fn test_encrypted_cell_message_from_bytes_invalid_utf8() {
let mut data = Vec::new();
let bad_utf8 = [0xFF, 0xFE]; data.extend_from_slice(&(bad_utf8.len() as u32).to_le_bytes());
data.extend_from_slice(&bad_utf8);
data.extend_from_slice(&1u64.to_le_bytes()); data.extend_from_slice(&[0u8; NONCE_SIZE]);
let result = EncryptedCellMessage::from_bytes(&data);
assert!(result.is_err());
}
#[tokio::test]
async fn test_encryption_manager_encrypt_for_peer_no_channel() {
let kp = EncryptionKeypair::generate();
let device_id = DeviceId::from_bytes([1u8; 16]);
let mgr = EncryptionManager::new(kp, device_id);
let nonexistent_peer = DeviceId::from_bytes([99u8; 16]);
let result = mgr.encrypt_for_peer(&nonexistent_peer, b"hello").await;
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("no channel"));
}
#[tokio::test]
async fn test_encryption_manager_decrypt_from_peer_no_channel() {
let kp = EncryptionKeypair::generate();
let device_id = DeviceId::from_bytes([1u8; 16]);
let mgr = EncryptionManager::new(kp, device_id);
let nonexistent_peer = DeviceId::from_bytes([99u8; 16]);
let fake_encrypted = EncryptedData {
nonce: [0u8; NONCE_SIZE],
ciphertext: vec![1, 2, 3],
};
let result = mgr
.decrypt_from_peer(&nonexistent_peer, &fake_encrypted)
.await;
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("no channel"));
}
#[tokio::test]
async fn test_encryption_manager_encrypt_for_cell_no_key() {
let kp = EncryptionKeypair::generate();
let device_id = DeviceId::from_bytes([1u8; 16]);
let mgr = EncryptionManager::new(kp, device_id);
let result = mgr.encrypt_for_cell("nonexistent-cell", b"data").await;
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("no key for cell"));
}
#[tokio::test]
async fn test_encryption_manager_decrypt_cell_message_no_key() {
let kp = EncryptionKeypair::generate();
let device_id = DeviceId::from_bytes([1u8; 16]);
let mgr = EncryptionManager::new(kp, device_id);
let fake_message = EncryptedCellMessage {
cell_id: "nonexistent-cell".to_string(),
generation: 1,
encrypted: EncryptedData {
nonce: [0u8; NONCE_SIZE],
ciphertext: vec![1, 2, 3],
},
};
let result = mgr.decrypt_cell_message(&fake_message).await;
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("no key for cell"));
}
#[tokio::test]
async fn test_encryption_manager_rotate_cell_key_no_key() {
let kp = EncryptionKeypair::generate();
let device_id = DeviceId::from_bytes([1u8; 16]);
let mgr = EncryptionManager::new(kp, device_id);
let result = mgr.rotate_cell_key("nonexistent").await;
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("no key for cell"));
}
#[tokio::test]
async fn test_encryption_manager_rotate_cell_key_success() {
let kp = EncryptionKeypair::generate();
let device_id = DeviceId::from_bytes([1u8; 16]);
let mgr = EncryptionManager::new(kp, device_id);
let key1 = mgr.get_or_create_cell_key("cell-1").await;
assert_eq!(key1.generation, 1);
let key2 = mgr.rotate_cell_key("cell-1").await.unwrap();
assert_eq!(key2.generation, 2);
assert_eq!(key2.cell_id, "cell-1");
let stored = mgr.get_cell_key("cell-1").await.unwrap();
assert_eq!(stored.generation, 2);
let msg = b"test message";
let encrypted = mgr.encrypt_for_cell("cell-1", msg).await.unwrap();
assert_eq!(encrypted.generation, 2);
let decrypted = mgr.decrypt_cell_message(&encrypted).await.unwrap();
assert_eq!(decrypted, msg);
}
#[tokio::test]
async fn test_encryption_manager_remove_channel() {
let alice_kp = EncryptionKeypair::generate();
let bob_kp = EncryptionKeypair::generate();
let alice_id = DeviceId::from_bytes([1u8; 16]);
let bob_id = DeviceId::from_bytes([2u8; 16]);
let alice_mgr = EncryptionManager::new(alice_kp, alice_id);
alice_mgr
.establish_channel(bob_id, &bob_kp.public_key_bytes())
.await
.unwrap();
assert!(alice_mgr.has_channel(&bob_id).await);
assert_eq!(alice_mgr.peer_channel_count().await, 1);
alice_mgr.remove_channel(&bob_id).await;
assert!(!alice_mgr.has_channel(&bob_id).await);
assert_eq!(alice_mgr.peer_channel_count().await, 0);
}
#[tokio::test]
async fn test_encryption_manager_remove_cell_key() {
let kp = EncryptionKeypair::generate();
let device_id = DeviceId::from_bytes([1u8; 16]);
let mgr = EncryptionManager::new(kp, device_id);
mgr.get_or_create_cell_key("cell-1").await;
assert_eq!(mgr.cell_key_count().await, 1);
mgr.remove_cell_key("cell-1").await;
assert_eq!(mgr.cell_key_count().await, 0);
assert!(mgr.get_cell_key("cell-1").await.is_none());
}
#[tokio::test]
async fn test_encryption_manager_set_cell_key() {
let kp = EncryptionKeypair::generate();
let device_id = DeviceId::from_bytes([1u8; 16]);
let mgr = EncryptionManager::new(kp, device_id);
let key = GroupKey::from_bytes("cell-99".to_string(), &[7u8; SYMMETRIC_KEY_SIZE], 5);
mgr.set_cell_key(key).await;
let stored = mgr.get_cell_key("cell-99").await;
assert!(stored.is_some());
let stored = stored.unwrap();
assert_eq!(stored.cell_id, "cell-99");
assert_eq!(stored.generation, 5);
}
#[tokio::test]
async fn test_encryption_manager_get_channel() {
let alice_kp = EncryptionKeypair::generate();
let bob_kp = EncryptionKeypair::generate();
let alice_id = DeviceId::from_bytes([1u8; 16]);
let bob_id = DeviceId::from_bytes([2u8; 16]);
let mgr = EncryptionManager::new(alice_kp, alice_id);
mgr.establish_channel(bob_id, &bob_kp.public_key_bytes())
.await
.unwrap();
let channel = mgr.get_channel(&bob_id).await;
assert!(channel.is_some());
let channel = channel.unwrap();
assert_eq!(channel.peer_id, bob_id);
let missing = DeviceId::from_bytes([99u8; 16]);
assert!(mgr.get_channel(&missing).await.is_none());
}
#[test]
fn test_group_key_decrypt_cell_id_mismatch() {
let key = GroupKey::generate("cell-1".to_string());
let encrypted = key.encrypt(b"test").unwrap();
let wrong_msg = EncryptedCellMessage {
cell_id: "cell-WRONG".to_string(),
generation: encrypted.generation,
encrypted: encrypted.encrypted.clone(),
};
let result = key.decrypt(&wrong_msg);
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("cell ID mismatch"));
}
#[test]
fn test_group_key_decrypt_generation_mismatch() {
let key = GroupKey::generate("cell-1".to_string());
let encrypted = key.encrypt(b"test").unwrap();
let wrong_msg = EncryptedCellMessage {
cell_id: "cell-1".to_string(),
generation: 999,
encrypted: encrypted.encrypted.clone(),
};
let result = key.decrypt(&wrong_msg);
assert!(result.is_err());
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("key generation mismatch"));
}
#[test]
fn test_group_key_from_bytes() {
let key_bytes = [42u8; SYMMETRIC_KEY_SIZE];
let key = GroupKey::from_bytes("cell-x".to_string(), &key_bytes, 10);
assert_eq!(key.cell_id, "cell-x");
assert_eq!(key.generation, 10);
assert_eq!(key.key_bytes(), key_bytes);
let plaintext = b"from bytes key test";
let encrypted = key.encrypt(plaintext).unwrap();
let decrypted = key.decrypt(&encrypted).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
}
#[test]
fn test_group_key_rotation_preserves_cell_id() {
let key1 = GroupKey::generate("my-cell".to_string());
let key2 = key1.rotate();
let key3 = key2.rotate();
assert_eq!(key1.cell_id, "my-cell");
assert_eq!(key2.cell_id, "my-cell");
assert_eq!(key3.cell_id, "my-cell");
assert_eq!(key1.generation, 1);
assert_eq!(key2.generation, 2);
assert_eq!(key3.generation, 3);
assert_ne!(key1.key_bytes(), key2.key_bytes());
assert_ne!(key2.key_bytes(), key3.key_bytes());
}
#[test]
fn test_keypair_debug_redacts_secret() {
let kp = EncryptionKeypair::generate();
let debug_str = format!("{:?}", kp);
assert!(debug_str.contains("REDACTED"));
assert!(debug_str.contains("[REDACTED]"));
}
#[test]
fn test_symmetric_key_debug_redacts() {
let key = SymmetricKey::from_bytes(&[42u8; 32]);
let debug_str = format!("{:?}", key);
assert!(debug_str.contains("REDACTED"));
}
#[test]
fn test_group_key_debug_redacts() {
let key = GroupKey::generate("cell-1".to_string());
let debug_str = format!("{:?}", key);
assert!(debug_str.contains("REDACTED"));
assert!(debug_str.contains("cell-1"));
}
#[test]
fn test_encryption_manager_debug() {
let kp = EncryptionKeypair::generate();
let device_id = DeviceId::from_bytes([1u8; 16]);
let mgr = EncryptionManager::new(kp, device_id);
let debug_str = format!("{:?}", mgr);
assert!(debug_str.contains("EncryptionManager"));
assert!(debug_str.contains("device_id"));
assert!(debug_str.contains("public_key"));
}
#[test]
fn test_encrypted_document_new() {
let key = SymmetricKey::from_bytes(&[42u8; 32]);
let encrypted = key.encrypt(b"doc data").unwrap();
let device_id = DeviceId::from_bytes([1u8; 16]);
let doc = EncryptedDocument::new(encrypted, device_id);
assert_eq!(doc.encrypted_by, device_id);
assert!(doc.encrypted_at > 0);
}
#[test]
fn test_symmetric_key_derive_from_shared_secret() {
let alice = EncryptionKeypair::generate();
let bob = EncryptionKeypair::generate();
let shared = alice.dh_exchange(bob.public_key());
let key = SymmetricKey::derive_for_peer(&shared);
let plaintext = b"derived key test";
let encrypted = key.encrypt(plaintext).unwrap();
let decrypted = key.decrypt(&encrypted).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
}
#[test]
fn test_symmetric_key_as_bytes_roundtrip() {
let original_bytes = [99u8; SYMMETRIC_KEY_SIZE];
let key = SymmetricKey::from_bytes(&original_bytes);
let extracted = key.as_bytes();
assert_eq!(&original_bytes, extracted);
}
}