use crate::crdt::TaskListDelta;
use crate::mls::{MlsCipher, MlsError, MlsGroup, MlsKeySchedule, Result as MlsResult};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedTaskListDelta {
group_id: Vec<u8>,
epoch: u64,
ciphertext: Vec<u8>,
aad: Vec<u8>,
}
impl EncryptedTaskListDelta {
pub fn encrypt(delta: &TaskListDelta, group: &MlsGroup, cipher: &MlsCipher) -> MlsResult<Self> {
let context = group.context();
let group_id = context.group_id().to_vec();
let epoch = context.epoch();
let plaintext = bincode::serialize(delta)
.map_err(|e| MlsError::EncryptionError(format!("delta serialization failed: {}", e)))?;
let mut aad = Vec::new();
aad.extend_from_slice(b"EncryptedDelta");
aad.extend_from_slice(&group_id);
aad.extend_from_slice(&epoch.to_le_bytes());
let ciphertext = cipher.encrypt(&plaintext, &aad, 0)?;
Ok(Self {
group_id,
epoch,
ciphertext,
aad,
})
}
pub fn decrypt(&self, cipher: &MlsCipher) -> MlsResult<TaskListDelta> {
let plaintext = cipher.decrypt(&self.ciphertext, &self.aad, 0)?;
bincode::deserialize(&plaintext)
.map_err(|e| MlsError::DecryptionError(format!("delta deserialization failed: {}", e)))
}
pub fn encrypt_with_group(delta: &TaskListDelta, group: &MlsGroup) -> MlsResult<Self> {
let key_schedule = MlsKeySchedule::from_group(group)?;
let cipher = MlsCipher::new(
key_schedule.encryption_key().to_vec(),
key_schedule.base_nonce().to_vec(),
);
Self::encrypt(delta, group, &cipher)
}
pub fn decrypt_with_group(&self, group: &MlsGroup) -> MlsResult<TaskListDelta> {
let context = group.context();
if context.epoch() != self.epoch {
return Err(MlsError::EpochMismatch {
current: context.epoch(),
received: self.epoch,
});
}
if context.group_id() != self.group_id {
return Err(MlsError::MlsOperation(format!(
"group ID mismatch: expected {:?}, got {:?}",
context.group_id(),
self.group_id
)));
}
let key_schedule = MlsKeySchedule::from_group(group)?;
let cipher = MlsCipher::new(
key_schedule.encryption_key().to_vec(),
key_schedule.base_nonce().to_vec(),
);
self.decrypt(&cipher)
}
#[must_use]
pub fn group_id(&self) -> &[u8] {
&self.group_id
}
#[must_use]
pub fn epoch(&self) -> u64 {
self.epoch
}
#[must_use]
pub fn ciphertext(&self) -> &[u8] {
&self.ciphertext
}
#[must_use]
pub fn aad(&self) -> &[u8] {
&self.aad
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crdt::{TaskId, TaskItem, TaskMetadata};
use crate::identity::Identity;
use crate::mls::MlsGroup;
use saorsa_gossip_types::PeerId;
async fn create_test_group() -> (MlsGroup, Vec<u8>) {
let identity = Identity::generate().expect("identity generation failed");
let agent_id = identity.agent_id();
let group_id = b"test-encryption-group".to_vec();
let group = MlsGroup::new(group_id.clone(), agent_id)
.await
.expect("group creation failed");
(group, group_id)
}
fn create_test_delta() -> TaskListDelta {
let mut delta = TaskListDelta::new(1);
let identity = Identity::generate().expect("identity generation failed");
let agent_id = identity.agent_id();
let timestamp = 1000;
let task_id = TaskId::new("Test task", &agent_id, timestamp);
let metadata = TaskMetadata {
title: "Test task".to_string(),
description: "Test description".to_string(),
priority: 128,
created_by: agent_id,
owner: None,
created_at: timestamp,
tags: vec![],
};
let peer_id = PeerId::new(*agent_id.as_bytes());
let task = TaskItem::new(task_id, metadata, peer_id);
let tag = (peer_id, 1);
delta.added_tasks.insert(task_id, (task, tag));
delta
}
#[tokio::test]
async fn test_encrypt_decrypt_roundtrip() {
let (group, _group_id) = create_test_group().await;
let delta = create_test_delta();
let encrypted =
EncryptedTaskListDelta::encrypt_with_group(&delta, &group).expect("encryption failed");
assert_eq!(encrypted.group_id(), group.context().group_id());
assert_eq!(encrypted.epoch(), group.current_epoch());
assert!(!encrypted.ciphertext().is_empty());
let decrypted = encrypted
.decrypt_with_group(&group)
.expect("decryption failed");
assert_eq!(decrypted.version, delta.version);
assert_eq!(decrypted.added_tasks.len(), delta.added_tasks.len());
}
#[tokio::test]
async fn test_encrypted_delta_includes_group_metadata() {
let (group, group_id) = create_test_group().await;
let delta = create_test_delta();
let encrypted =
EncryptedTaskListDelta::encrypt_with_group(&delta, &group).expect("encryption failed");
assert_eq!(encrypted.group_id(), &group_id);
assert_eq!(encrypted.epoch(), 0);
}
#[tokio::test]
async fn test_decryption_fails_with_wrong_epoch() {
let (mut group, _) = create_test_group().await;
let delta = create_test_delta();
let encrypted =
EncryptedTaskListDelta::encrypt_with_group(&delta, &group).expect("encryption failed");
let commit = group.commit().expect("commit failed");
group.apply_commit(&commit).expect("apply failed");
let result = encrypted.decrypt_with_group(&group);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
MlsError::EpochMismatch { .. }
));
}
#[tokio::test]
async fn test_decryption_fails_with_wrong_group() {
let identity1 = Identity::generate().expect("identity generation failed");
let agent_id1 = identity1.agent_id();
let group_id1 = b"test-group-1".to_vec();
let group1 = MlsGroup::new(group_id1, agent_id1)
.await
.expect("group creation failed");
let identity2 = Identity::generate().expect("identity generation failed");
let agent_id2 = identity2.agent_id();
let group_id2 = b"test-group-2".to_vec(); let group2 = MlsGroup::new(group_id2, agent_id2)
.await
.expect("group creation failed");
let delta = create_test_delta();
let encrypted =
EncryptedTaskListDelta::encrypt_with_group(&delta, &group1).expect("encryption failed");
let result = encrypted.decrypt_with_group(&group2);
assert!(result.is_err());
match result.unwrap_err() {
MlsError::MlsOperation(msg) => assert!(msg.contains("group ID mismatch")),
_ => panic!("Expected MlsOperation error for group ID mismatch"),
}
}
#[tokio::test]
async fn test_authentication_prevents_tampering() {
let (group, _) = create_test_group().await;
let delta = create_test_delta();
let mut encrypted =
EncryptedTaskListDelta::encrypt_with_group(&delta, &group).expect("encryption failed");
encrypted.ciphertext[0] ^= 1;
let result = encrypted.decrypt_with_group(&group);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), MlsError::DecryptionError(_)));
}
#[tokio::test]
async fn test_different_epochs_produce_different_ciphertexts() {
let (mut group, _) = create_test_group().await;
let delta = create_test_delta();
let encrypted1 =
EncryptedTaskListDelta::encrypt_with_group(&delta, &group).expect("encryption failed");
let commit = group.commit().expect("commit failed");
group.apply_commit(&commit).expect("apply failed");
let encrypted2 =
EncryptedTaskListDelta::encrypt_with_group(&delta, &group).expect("encryption failed");
assert_ne!(encrypted1.ciphertext(), encrypted2.ciphertext());
assert_ne!(encrypted1.epoch(), encrypted2.epoch());
}
#[tokio::test]
async fn test_empty_delta_encryption() {
let (group, _) = create_test_group().await;
let delta = TaskListDelta::new(1);
let encrypted =
EncryptedTaskListDelta::encrypt_with_group(&delta, &group).expect("encryption failed");
let decrypted = encrypted
.decrypt_with_group(&group)
.expect("decryption failed");
assert_eq!(decrypted.version, delta.version);
assert!(decrypted.added_tasks.is_empty());
}
#[tokio::test]
async fn test_large_delta_encryption() {
let (group, _) = create_test_group().await;
let mut delta = TaskListDelta::new(1);
let identity = Identity::generate().expect("identity generation failed");
let agent_id = identity.agent_id();
let peer_id = PeerId::new(*agent_id.as_bytes());
for i in 0..100 {
let task_id = TaskId::new(&format!("Task {}", i), &agent_id, 1000 + i);
let metadata = TaskMetadata {
title: format!("Task {}", i),
description: format!("Description {}", i),
priority: 128,
created_by: agent_id,
owner: None,
created_at: 1000 + i,
tags: vec![],
};
let task = TaskItem::new(task_id, metadata, peer_id);
let tag = (peer_id, i);
delta.added_tasks.insert(task_id, (task, tag));
}
let encrypted =
EncryptedTaskListDelta::encrypt_with_group(&delta, &group).expect("encryption failed");
let decrypted = encrypted
.decrypt_with_group(&group)
.expect("decryption failed");
assert_eq!(decrypted.added_tasks.len(), 100);
}
#[tokio::test]
async fn test_encrypted_delta_serialization() {
let (group, _) = create_test_group().await;
let delta = create_test_delta();
let encrypted =
EncryptedTaskListDelta::encrypt_with_group(&delta, &group).expect("encryption failed");
let serialized = bincode::serialize(&encrypted).expect("serialization failed");
let deserialized: EncryptedTaskListDelta =
bincode::deserialize(&serialized).expect("deserialization failed");
assert_eq!(deserialized.group_id(), encrypted.group_id());
assert_eq!(deserialized.epoch(), encrypted.epoch());
assert_eq!(deserialized.ciphertext(), encrypted.ciphertext());
}
#[tokio::test]
async fn test_aad_includes_group_and_epoch() {
let (group, _) = create_test_group().await;
let delta = create_test_delta();
let encrypted =
EncryptedTaskListDelta::encrypt_with_group(&delta, &group).expect("encryption failed");
let aad = encrypted.aad();
assert!(aad.starts_with(b"EncryptedDelta"));
assert!(aad.len() > b"EncryptedDelta".len());
}
}