1use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::{SystemTime, UNIX_EPOCH};
19
20use chacha20poly1305::{
21 aead::{Aead, AeadCore, KeyInit, OsRng},
22 ChaCha20Poly1305, Key, Nonce,
23};
24use hkdf::Hkdf;
25use sha2::Sha256;
26use tokio::sync::RwLock;
27use x25519_dalek::{PublicKey, SharedSecret, StaticSecret};
28
29use super::error::SecurityError;
30use super::DeviceId;
31
32pub const NONCE_SIZE: usize = 12;
34
35pub const SYMMETRIC_KEY_SIZE: usize = 32;
37
38pub const X25519_PUBLIC_KEY_SIZE: usize = 32;
40
41const HKDF_INFO_PEER: &[u8] = b"peat-protocol-v1-peer";
43
44#[allow(dead_code)]
46const HKDF_INFO_GROUP: &[u8] = b"peat-protocol-v1-group";
47
48#[derive(Clone)]
50pub struct EncryptionKeypair {
51 secret: Arc<StaticSecret>,
53 public: PublicKey,
55}
56
57impl EncryptionKeypair {
58 pub fn generate() -> Self {
60 let secret = StaticSecret::random_from_rng(OsRng);
61 let public = PublicKey::from(&secret);
62 Self {
63 secret: Arc::new(secret),
64 public,
65 }
66 }
67
68 pub fn from_secret_bytes(bytes: &[u8; 32]) -> Self {
70 let secret = StaticSecret::from(*bytes);
71 let public = PublicKey::from(&secret);
72 Self {
73 secret: Arc::new(secret),
74 public,
75 }
76 }
77
78 pub fn public_key(&self) -> &PublicKey {
80 &self.public
81 }
82
83 pub fn public_key_bytes(&self) -> [u8; 32] {
85 self.public.to_bytes()
86 }
87
88 pub fn dh_exchange(&self, peer_public: &PublicKey) -> SharedSecret {
90 self.secret.diffie_hellman(peer_public)
91 }
92}
93
94impl std::fmt::Debug for EncryptionKeypair {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 f.debug_struct("EncryptionKeypair")
97 .field("public", &hex::encode(self.public.as_bytes()))
98 .field("secret", &"[REDACTED]")
99 .finish()
100 }
101}
102
103#[derive(Clone)]
105pub struct SymmetricKey {
106 key: Key,
107}
108
109impl SymmetricKey {
110 pub fn from_bytes(bytes: &[u8; SYMMETRIC_KEY_SIZE]) -> Self {
112 Self {
113 key: *Key::from_slice(bytes),
114 }
115 }
116
117 pub fn derive_from_shared_secret(shared_secret: &SharedSecret, info: &[u8]) -> Self {
119 let hk = Hkdf::<Sha256>::new(None, shared_secret.as_bytes());
120 let mut key_bytes = [0u8; SYMMETRIC_KEY_SIZE];
121 hk.expand(info, &mut key_bytes)
122 .expect("HKDF expand should never fail with correct output length");
123 Self::from_bytes(&key_bytes)
124 }
125
126 pub fn derive_for_peer(shared_secret: &SharedSecret) -> Self {
128 Self::derive_from_shared_secret(shared_secret, HKDF_INFO_PEER)
129 }
130
131 pub fn as_bytes(&self) -> &[u8; SYMMETRIC_KEY_SIZE] {
133 self.key[..].try_into().unwrap()
135 }
136
137 pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedData, SecurityError> {
139 let cipher = ChaCha20Poly1305::new(&self.key);
140 let nonce = ChaCha20Poly1305::generate_nonce(&mut OsRng);
141
142 let ciphertext = cipher
143 .encrypt(&nonce, plaintext)
144 .map_err(|e| SecurityError::EncryptionError(e.to_string()))?;
145
146 Ok(EncryptedData {
147 nonce: nonce.into(),
148 ciphertext,
149 })
150 }
151
152 pub fn decrypt(&self, encrypted: &EncryptedData) -> Result<Vec<u8>, SecurityError> {
154 let cipher = ChaCha20Poly1305::new(&self.key);
155 let nonce = Nonce::from_slice(&encrypted.nonce);
156
157 cipher
158 .decrypt(nonce, encrypted.ciphertext.as_ref())
159 .map_err(|e| SecurityError::DecryptionError(e.to_string()))
160 }
161}
162
163impl std::fmt::Debug for SymmetricKey {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 f.debug_struct("SymmetricKey")
166 .field("key", &"[REDACTED]")
167 .finish()
168 }
169}
170
171#[derive(Debug, Clone)]
173pub struct EncryptedData {
174 pub nonce: [u8; NONCE_SIZE],
176 pub ciphertext: Vec<u8>,
178}
179
180impl EncryptedData {
181 pub fn to_bytes(&self) -> Vec<u8> {
183 let mut bytes = Vec::with_capacity(NONCE_SIZE + self.ciphertext.len());
184 bytes.extend_from_slice(&self.nonce);
185 bytes.extend_from_slice(&self.ciphertext);
186 bytes
187 }
188
189 pub fn from_bytes(bytes: &[u8]) -> Result<Self, SecurityError> {
191 if bytes.len() < NONCE_SIZE {
192 return Err(SecurityError::DecryptionError(
193 "ciphertext too short for nonce".to_string(),
194 ));
195 }
196
197 let mut nonce = [0u8; NONCE_SIZE];
198 nonce.copy_from_slice(&bytes[..NONCE_SIZE]);
199 let ciphertext = bytes[NONCE_SIZE..].to_vec();
200
201 Ok(Self { nonce, ciphertext })
202 }
203}
204
205#[derive(Debug)]
207pub struct SecureChannel {
208 pub peer_id: DeviceId,
210 key: SymmetricKey,
212 pub established_at: SystemTime,
214}
215
216impl SecureChannel {
217 pub fn new(peer_id: DeviceId, key: SymmetricKey) -> Self {
219 Self {
220 peer_id,
221 key,
222 established_at: SystemTime::now(),
223 }
224 }
225
226 pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedData, SecurityError> {
228 self.key.encrypt(plaintext)
229 }
230
231 pub fn decrypt(&self, encrypted: &EncryptedData) -> Result<Vec<u8>, SecurityError> {
233 self.key.decrypt(encrypted)
234 }
235
236 pub fn age_secs(&self) -> u64 {
238 self.established_at
239 .elapsed()
240 .map(|d| d.as_secs())
241 .unwrap_or(0)
242 }
243}
244
245#[derive(Clone)]
247pub struct GroupKey {
248 pub cell_id: String,
250 key: SymmetricKey,
252 pub generation: u64,
254 pub created_at: SystemTime,
256}
257
258impl GroupKey {
259 pub fn generate(cell_id: String) -> Self {
261 let mut key_bytes = [0u8; SYMMETRIC_KEY_SIZE];
262 OsRng.fill_bytes(&mut key_bytes);
263 Self {
264 cell_id,
265 key: SymmetricKey::from_bytes(&key_bytes),
266 generation: 1,
267 created_at: SystemTime::now(),
268 }
269 }
270
271 pub fn from_bytes(
273 cell_id: String,
274 key_bytes: &[u8; SYMMETRIC_KEY_SIZE],
275 generation: u64,
276 ) -> Self {
277 Self {
278 cell_id,
279 key: SymmetricKey::from_bytes(key_bytes),
280 generation,
281 created_at: SystemTime::now(),
282 }
283 }
284
285 pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedCellMessage, SecurityError> {
287 let encrypted = self.key.encrypt(plaintext)?;
288 Ok(EncryptedCellMessage {
289 cell_id: self.cell_id.clone(),
290 generation: self.generation,
291 encrypted,
292 })
293 }
294
295 pub fn decrypt(&self, message: &EncryptedCellMessage) -> Result<Vec<u8>, SecurityError> {
297 if message.cell_id != self.cell_id {
298 return Err(SecurityError::DecryptionError(format!(
299 "cell ID mismatch: expected {}, got {}",
300 self.cell_id, message.cell_id
301 )));
302 }
303 if message.generation != self.generation {
304 return Err(SecurityError::DecryptionError(format!(
305 "key generation mismatch: expected {}, got {}",
306 self.generation, message.generation
307 )));
308 }
309 self.key.decrypt(&message.encrypted)
310 }
311
312 pub fn rotate(&self) -> Self {
314 let mut key_bytes = [0u8; SYMMETRIC_KEY_SIZE];
315 OsRng.fill_bytes(&mut key_bytes);
316 Self {
317 cell_id: self.cell_id.clone(),
318 key: SymmetricKey::from_bytes(&key_bytes),
319 generation: self.generation + 1,
320 created_at: SystemTime::now(),
321 }
322 }
323
324 pub fn key_bytes(&self) -> [u8; SYMMETRIC_KEY_SIZE] {
326 *self.key.as_bytes()
327 }
328}
329
330impl std::fmt::Debug for GroupKey {
331 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332 f.debug_struct("GroupKey")
333 .field("cell_id", &self.cell_id)
334 .field("generation", &self.generation)
335 .field("key", &"[REDACTED]")
336 .finish()
337 }
338}
339
340#[derive(Debug, Clone)]
342pub struct EncryptedCellMessage {
343 pub cell_id: String,
345 pub generation: u64,
347 pub encrypted: EncryptedData,
349}
350
351impl EncryptedCellMessage {
352 pub fn to_bytes(&self) -> Vec<u8> {
354 let cell_id_bytes = self.cell_id.as_bytes();
355 let encrypted_bytes = self.encrypted.to_bytes();
356
357 let mut bytes = Vec::new();
358 bytes.extend_from_slice(&(cell_id_bytes.len() as u32).to_le_bytes());
360 bytes.extend_from_slice(cell_id_bytes);
361 bytes.extend_from_slice(&self.generation.to_le_bytes());
362 bytes.extend_from_slice(&encrypted_bytes);
363 bytes
364 }
365
366 pub fn from_bytes(bytes: &[u8]) -> Result<Self, SecurityError> {
368 if bytes.len() < 12 {
369 return Err(SecurityError::DecryptionError(
370 "message too short".to_string(),
371 ));
372 }
373
374 let cell_id_len = u32::from_le_bytes(bytes[0..4].try_into().unwrap()) as usize;
375 if bytes.len() < 4 + cell_id_len + 8 {
376 return Err(SecurityError::DecryptionError(
377 "message truncated".to_string(),
378 ));
379 }
380
381 let cell_id = String::from_utf8(bytes[4..4 + cell_id_len].to_vec())
382 .map_err(|e| SecurityError::DecryptionError(e.to_string()))?;
383 let generation = u64::from_le_bytes(
384 bytes[4 + cell_id_len..4 + cell_id_len + 8]
385 .try_into()
386 .unwrap(),
387 );
388 let encrypted = EncryptedData::from_bytes(&bytes[4 + cell_id_len + 8..])?;
389
390 Ok(Self {
391 cell_id,
392 generation,
393 encrypted,
394 })
395 }
396}
397
398#[derive(Debug, Clone)]
400pub struct EncryptedDocument {
401 pub encrypted: EncryptedData,
403 pub encrypted_by: DeviceId,
405 pub encrypted_at: u64,
407}
408
409impl EncryptedDocument {
410 pub fn new(encrypted: EncryptedData, device_id: DeviceId) -> Self {
412 let encrypted_at = SystemTime::now()
413 .duration_since(UNIX_EPOCH)
414 .map(|d| d.as_secs())
415 .unwrap_or(0);
416 Self {
417 encrypted,
418 encrypted_by: device_id,
419 encrypted_at,
420 }
421 }
422}
423
424pub struct EncryptionManager {
431 keypair: EncryptionKeypair,
433 device_id: DeviceId,
435 peer_channels: Arc<RwLock<HashMap<DeviceId, SecureChannel>>>,
437 cell_keys: Arc<RwLock<HashMap<String, GroupKey>>>,
439 device_key: SymmetricKey,
441}
442
443impl EncryptionManager {
444 pub fn new(keypair: EncryptionKeypair, device_id: DeviceId) -> Self {
446 let hk = Hkdf::<Sha256>::new(None, keypair.public_key_bytes().as_ref());
448 let mut device_key_bytes = [0u8; SYMMETRIC_KEY_SIZE];
449 hk.expand(b"peat-protocol-v1-device", &mut device_key_bytes)
450 .expect("HKDF expand should never fail");
451
452 Self {
453 keypair,
454 device_id,
455 peer_channels: Arc::new(RwLock::new(HashMap::new())),
456 cell_keys: Arc::new(RwLock::new(HashMap::new())),
457 device_key: SymmetricKey::from_bytes(&device_key_bytes),
458 }
459 }
460
461 pub fn public_key(&self) -> &PublicKey {
463 self.keypair.public_key()
464 }
465
466 pub fn public_key_bytes(&self) -> [u8; 32] {
468 self.keypair.public_key_bytes()
469 }
470
471 pub async fn establish_channel(
473 &self,
474 peer_id: DeviceId,
475 peer_public_key: &[u8; X25519_PUBLIC_KEY_SIZE],
476 ) -> Result<(), SecurityError> {
477 let peer_public = PublicKey::from(*peer_public_key);
478 let shared_secret = self.keypair.dh_exchange(&peer_public);
479 let symmetric_key = SymmetricKey::derive_for_peer(&shared_secret);
480
481 let channel = SecureChannel::new(peer_id, symmetric_key);
482 self.peer_channels.write().await.insert(peer_id, channel);
483
484 Ok(())
485 }
486
487 pub async fn get_channel(&self, peer_id: &DeviceId) -> Option<SecureChannel> {
489 let channels = self.peer_channels.read().await;
490 channels.get(peer_id).map(|c| SecureChannel {
491 peer_id: c.peer_id,
492 key: c.key.clone(),
493 established_at: c.established_at,
494 })
495 }
496
497 pub async fn has_channel(&self, peer_id: &DeviceId) -> bool {
499 self.peer_channels.read().await.contains_key(peer_id)
500 }
501
502 pub async fn remove_channel(&self, peer_id: &DeviceId) {
504 self.peer_channels.write().await.remove(peer_id);
505 }
506
507 pub async fn encrypt_for_peer(
509 &self,
510 peer_id: &DeviceId,
511 plaintext: &[u8],
512 ) -> Result<EncryptedData, SecurityError> {
513 let channels = self.peer_channels.read().await;
514 let channel = channels.get(peer_id).ok_or_else(|| {
515 SecurityError::EncryptionError(format!("no channel for peer: {}", peer_id))
516 })?;
517 channel.encrypt(plaintext)
518 }
519
520 pub async fn decrypt_from_peer(
522 &self,
523 peer_id: &DeviceId,
524 encrypted: &EncryptedData,
525 ) -> Result<Vec<u8>, SecurityError> {
526 let channels = self.peer_channels.read().await;
527 let channel = channels.get(peer_id).ok_or_else(|| {
528 SecurityError::DecryptionError(format!("no channel for peer: {}", peer_id))
529 })?;
530 channel.decrypt(encrypted)
531 }
532
533 pub async fn get_or_create_cell_key(&self, cell_id: &str) -> GroupKey {
535 let mut keys = self.cell_keys.write().await;
536 if let Some(key) = keys.get(cell_id) {
537 key.clone()
538 } else {
539 let key = GroupKey::generate(cell_id.to_string());
540 keys.insert(cell_id.to_string(), key.clone());
541 key
542 }
543 }
544
545 pub async fn set_cell_key(&self, key: GroupKey) {
547 self.cell_keys
548 .write()
549 .await
550 .insert(key.cell_id.clone(), key);
551 }
552
553 pub async fn get_cell_key(&self, cell_id: &str) -> Option<GroupKey> {
555 self.cell_keys.read().await.get(cell_id).cloned()
556 }
557
558 pub async fn rotate_cell_key(&self, cell_id: &str) -> Result<GroupKey, SecurityError> {
560 let mut keys = self.cell_keys.write().await;
561 let old_key = keys.get(cell_id).ok_or_else(|| {
562 SecurityError::EncryptionError(format!("no key for cell: {}", cell_id))
563 })?;
564 let new_key = old_key.rotate();
565 keys.insert(cell_id.to_string(), new_key.clone());
566 Ok(new_key)
567 }
568
569 pub async fn remove_cell_key(&self, cell_id: &str) {
571 self.cell_keys.write().await.remove(cell_id);
572 }
573
574 pub async fn encrypt_for_cell(
576 &self,
577 cell_id: &str,
578 plaintext: &[u8],
579 ) -> Result<EncryptedCellMessage, SecurityError> {
580 let keys = self.cell_keys.read().await;
581 let key = keys.get(cell_id).ok_or_else(|| {
582 SecurityError::EncryptionError(format!("no key for cell: {}", cell_id))
583 })?;
584 key.encrypt(plaintext)
585 }
586
587 pub async fn decrypt_cell_message(
589 &self,
590 message: &EncryptedCellMessage,
591 ) -> Result<Vec<u8>, SecurityError> {
592 let keys = self.cell_keys.read().await;
593 let key = keys.get(&message.cell_id).ok_or_else(|| {
594 SecurityError::DecryptionError(format!("no key for cell: {}", message.cell_id))
595 })?;
596 key.decrypt(message)
597 }
598
599 pub fn encrypt_document(&self, plaintext: &[u8]) -> Result<EncryptedDocument, SecurityError> {
601 let encrypted = self.device_key.encrypt(plaintext)?;
602 Ok(EncryptedDocument::new(encrypted, self.device_id))
603 }
604
605 pub fn decrypt_document(&self, document: &EncryptedDocument) -> Result<Vec<u8>, SecurityError> {
607 if document.encrypted_by != self.device_id {
608 return Err(SecurityError::DecryptionError(
609 "document encrypted by different device".to_string(),
610 ));
611 }
612 self.device_key.decrypt(&document.encrypted)
613 }
614
615 pub async fn peer_channel_count(&self) -> usize {
617 self.peer_channels.read().await.len()
618 }
619
620 pub async fn cell_key_count(&self) -> usize {
622 self.cell_keys.read().await.len()
623 }
624}
625
626impl std::fmt::Debug for EncryptionManager {
627 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
628 f.debug_struct("EncryptionManager")
629 .field("device_id", &self.device_id)
630 .field("public_key", &hex::encode(self.keypair.public_key_bytes()))
631 .finish()
632 }
633}
634
635use rand_core::RngCore;
636
637#[cfg(test)]
638mod tests {
639 use super::*;
640
641 #[test]
642 fn test_keypair_generation() {
643 let kp1 = EncryptionKeypair::generate();
644 let kp2 = EncryptionKeypair::generate();
645
646 assert_ne!(kp1.public_key_bytes(), kp2.public_key_bytes());
648 }
649
650 #[test]
651 fn test_keypair_from_bytes() {
652 let kp1 = EncryptionKeypair::generate();
653 let secret_bytes = [42u8; 32]; let kp2 = EncryptionKeypair::from_secret_bytes(&secret_bytes);
655 let kp3 = EncryptionKeypair::from_secret_bytes(&secret_bytes);
656
657 assert_eq!(kp2.public_key_bytes(), kp3.public_key_bytes());
659 assert_ne!(kp1.public_key_bytes(), kp2.public_key_bytes());
661 }
662
663 #[test]
664 fn test_dh_key_exchange() {
665 let alice = EncryptionKeypair::generate();
666 let bob = EncryptionKeypair::generate();
667
668 let alice_shared = alice.dh_exchange(bob.public_key());
670 let bob_shared = bob.dh_exchange(alice.public_key());
671
672 assert_eq!(alice_shared.as_bytes(), bob_shared.as_bytes());
673 }
674
675 #[test]
676 fn test_symmetric_key_encrypt_decrypt() {
677 let key = SymmetricKey::from_bytes(&[42u8; 32]);
678 let plaintext = b"Hello, World!";
679
680 let encrypted = key.encrypt(plaintext).unwrap();
681 let decrypted = key.decrypt(&encrypted).unwrap();
682
683 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
684 }
685
686 #[test]
687 fn test_symmetric_key_different_nonces() {
688 let key = SymmetricKey::from_bytes(&[42u8; 32]);
689 let plaintext = b"Hello, World!";
690
691 let encrypted1 = key.encrypt(plaintext).unwrap();
692 let encrypted2 = key.encrypt(plaintext).unwrap();
693
694 assert_ne!(encrypted1.nonce, encrypted2.nonce);
696 assert_ne!(encrypted1.ciphertext, encrypted2.ciphertext);
697
698 assert_eq!(key.decrypt(&encrypted1).unwrap(), plaintext);
700 assert_eq!(key.decrypt(&encrypted2).unwrap(), plaintext);
701 }
702
703 #[test]
704 fn test_wrong_key_decryption_fails() {
705 let key1 = SymmetricKey::from_bytes(&[42u8; 32]);
706 let key2 = SymmetricKey::from_bytes(&[43u8; 32]);
707 let plaintext = b"Hello, World!";
708
709 let encrypted = key1.encrypt(plaintext).unwrap();
710 let result = key2.decrypt(&encrypted);
711
712 assert!(result.is_err());
713 }
714
715 #[test]
716 fn test_encrypted_data_serialization() {
717 let key = SymmetricKey::from_bytes(&[42u8; 32]);
718 let plaintext = b"Hello, World!";
719
720 let encrypted = key.encrypt(plaintext).unwrap();
721 let bytes = encrypted.to_bytes();
722 let restored = EncryptedData::from_bytes(&bytes).unwrap();
723
724 assert_eq!(encrypted.nonce, restored.nonce);
725 assert_eq!(encrypted.ciphertext, restored.ciphertext);
726
727 let decrypted = key.decrypt(&restored).unwrap();
728 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
729 }
730
731 #[test]
732 fn test_secure_channel() {
733 let alice = EncryptionKeypair::generate();
734 let bob = EncryptionKeypair::generate();
735
736 let alice_shared = alice.dh_exchange(bob.public_key());
738 let bob_shared = bob.dh_exchange(alice.public_key());
739
740 let alice_key = SymmetricKey::derive_for_peer(&alice_shared);
741 let bob_key = SymmetricKey::derive_for_peer(&bob_shared);
742
743 let alice_id = DeviceId::from_bytes([1u8; 16]);
744 let bob_id = DeviceId::from_bytes([2u8; 16]);
745
746 let alice_channel = SecureChannel::new(bob_id, alice_key);
747 let bob_channel = SecureChannel::new(alice_id, bob_key);
748
749 let message = b"Secret message from Alice";
751 let encrypted = alice_channel.encrypt(message).unwrap();
752 let decrypted = bob_channel.decrypt(&encrypted).unwrap();
753
754 assert_eq!(message.as_slice(), decrypted.as_slice());
755
756 let reply = b"Reply from Bob";
758 let encrypted_reply = bob_channel.encrypt(reply).unwrap();
759 let decrypted_reply = alice_channel.decrypt(&encrypted_reply).unwrap();
760
761 assert_eq!(reply.as_slice(), decrypted_reply.as_slice());
762 }
763
764 #[test]
765 fn test_group_key() {
766 let key = GroupKey::generate("cell-1".to_string());
767 let plaintext = b"Broadcast message";
768
769 let encrypted = key.encrypt(plaintext).unwrap();
770 let decrypted = key.decrypt(&encrypted).unwrap();
771
772 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
773 assert_eq!(encrypted.cell_id, "cell-1");
774 assert_eq!(encrypted.generation, 1);
775 }
776
777 #[test]
778 fn test_group_key_rotation() {
779 let key1 = GroupKey::generate("cell-1".to_string());
780 let key2 = key1.rotate();
781
782 assert_eq!(key1.cell_id, key2.cell_id);
783 assert_eq!(key1.generation + 1, key2.generation);
784 assert_ne!(key1.key_bytes(), key2.key_bytes());
785
786 let message = b"New message";
788 let encrypted = key2.encrypt(message).unwrap();
789 assert!(key1.decrypt(&encrypted).is_err());
790
791 let decrypted = key2.decrypt(&encrypted).unwrap();
793 assert_eq!(message.as_slice(), decrypted.as_slice());
794 }
795
796 #[test]
797 fn test_encrypted_cell_message_serialization() {
798 let key = GroupKey::generate("cell-1".to_string());
799 let plaintext = b"Cell broadcast";
800
801 let encrypted = key.encrypt(plaintext).unwrap();
802 let bytes = encrypted.to_bytes();
803 let restored = EncryptedCellMessage::from_bytes(&bytes).unwrap();
804
805 assert_eq!(encrypted.cell_id, restored.cell_id);
806 assert_eq!(encrypted.generation, restored.generation);
807
808 let decrypted = key.decrypt(&restored).unwrap();
809 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
810 }
811
812 #[tokio::test]
813 async fn test_encryption_manager_peer_channels() {
814 let alice_kp = EncryptionKeypair::generate();
815 let bob_kp = EncryptionKeypair::generate();
816
817 let alice_id = DeviceId::from_bytes([1u8; 16]);
818 let bob_id = DeviceId::from_bytes([2u8; 16]);
819
820 let alice_mgr = EncryptionManager::new(alice_kp.clone(), alice_id);
821 let bob_mgr = EncryptionManager::new(bob_kp.clone(), bob_id);
822
823 alice_mgr
825 .establish_channel(bob_id, &bob_mgr.public_key_bytes())
826 .await
827 .unwrap();
828 bob_mgr
829 .establish_channel(alice_id, &alice_mgr.public_key_bytes())
830 .await
831 .unwrap();
832
833 assert!(alice_mgr.has_channel(&bob_id).await);
835 assert!(bob_mgr.has_channel(&alice_id).await);
836
837 let message = b"Hello Bob!";
839 let encrypted = alice_mgr.encrypt_for_peer(&bob_id, message).await.unwrap();
840 let decrypted = bob_mgr
841 .decrypt_from_peer(&alice_id, &encrypted)
842 .await
843 .unwrap();
844
845 assert_eq!(message.as_slice(), decrypted.as_slice());
846 }
847
848 #[tokio::test]
849 async fn test_encryption_manager_cell_keys() {
850 let kp = EncryptionKeypair::generate();
851 let device_id = DeviceId::from_bytes([1u8; 16]);
852 let mgr = EncryptionManager::new(kp, device_id);
853
854 let key = mgr.get_or_create_cell_key("cell-1").await;
856 assert_eq!(key.cell_id, "cell-1");
857 assert_eq!(key.generation, 1);
858
859 let key2 = mgr.get_or_create_cell_key("cell-1").await;
861 assert_eq!(key.generation, key2.generation);
862
863 let message = b"Cell message";
865 let encrypted = mgr.encrypt_for_cell("cell-1", message).await.unwrap();
866 let decrypted = mgr.decrypt_cell_message(&encrypted).await.unwrap();
867
868 assert_eq!(message.as_slice(), decrypted.as_slice());
869
870 let new_key = mgr.rotate_cell_key("cell-1").await.unwrap();
872 assert_eq!(new_key.generation, 2);
873 }
874
875 #[test]
876 fn test_document_encryption() {
877 let kp = EncryptionKeypair::generate();
878 let device_id = DeviceId::from_bytes([1u8; 16]);
879 let mgr = EncryptionManager::new(kp, device_id);
880
881 let document = b"Sensitive document content";
882 let encrypted = mgr.encrypt_document(document).unwrap();
883 let decrypted = mgr.decrypt_document(&encrypted).unwrap();
884
885 assert_eq!(document.as_slice(), decrypted.as_slice());
886 }
887
888 #[test]
889 fn test_document_wrong_device_fails() {
890 let kp1 = EncryptionKeypair::generate();
891 let kp2 = EncryptionKeypair::generate();
892 let device_id1 = DeviceId::from_bytes([1u8; 16]);
893 let device_id2 = DeviceId::from_bytes([2u8; 16]);
894
895 let mgr1 = EncryptionManager::new(kp1, device_id1);
896 let mgr2 = EncryptionManager::new(kp2, device_id2);
897
898 let document = b"Sensitive document";
899 let encrypted = mgr1.encrypt_document(document).unwrap();
900
901 let result = mgr2.decrypt_document(&encrypted);
903 assert!(result.is_err());
904 }
905
906 #[test]
907 fn test_secure_channel_age_secs() {
908 let key = SymmetricKey::from_bytes(&[42u8; 32]);
909 let peer_id = DeviceId::from_bytes([1u8; 16]);
910 let channel = SecureChannel::new(peer_id, key);
911
912 let age = channel.age_secs();
914 assert!(age < 2, "Channel age should be near zero, got {}", age);
915 }
916
917 #[test]
918 fn test_encrypted_data_from_bytes_too_short() {
919 let short_data = vec![0u8; 5];
921 let result = EncryptedData::from_bytes(&short_data);
922 assert!(result.is_err());
923 let err_msg = format!("{}", result.unwrap_err());
924 assert!(err_msg.contains("too short"));
925 }
926
927 #[test]
928 fn test_encrypted_data_from_bytes_exact_nonce_size() {
929 let data = vec![0u8; NONCE_SIZE];
931 let result = EncryptedData::from_bytes(&data);
932 assert!(result.is_ok());
933 let ed = result.unwrap();
934 assert_eq!(ed.nonce, [0u8; NONCE_SIZE]);
935 assert!(ed.ciphertext.is_empty());
936 }
937
938 #[test]
939 fn test_encrypted_cell_message_from_bytes_too_short() {
940 let short_data = vec![0u8; 8];
942 let result = EncryptedCellMessage::from_bytes(&short_data);
943 assert!(result.is_err());
944 let err_msg = format!("{}", result.unwrap_err());
945 assert!(err_msg.contains("too short"));
946 }
947
948 #[test]
949 fn test_encrypted_cell_message_from_bytes_truncated() {
950 let mut data = Vec::new();
952 data.extend_from_slice(&100u32.to_le_bytes()); data.extend_from_slice(&[0u8; 8]); let result = EncryptedCellMessage::from_bytes(&data);
955 assert!(result.is_err());
956 let err_msg = format!("{}", result.unwrap_err());
957 assert!(err_msg.contains("truncated"));
958 }
959
960 #[test]
961 fn test_encrypted_cell_message_from_bytes_invalid_utf8() {
962 let mut data = Vec::new();
964 let bad_utf8 = [0xFF, 0xFE]; data.extend_from_slice(&(bad_utf8.len() as u32).to_le_bytes());
966 data.extend_from_slice(&bad_utf8);
967 data.extend_from_slice(&1u64.to_le_bytes()); data.extend_from_slice(&[0u8; NONCE_SIZE]);
970 let result = EncryptedCellMessage::from_bytes(&data);
971 assert!(result.is_err());
972 }
973
974 #[tokio::test]
975 async fn test_encryption_manager_encrypt_for_peer_no_channel() {
976 let kp = EncryptionKeypair::generate();
977 let device_id = DeviceId::from_bytes([1u8; 16]);
978 let mgr = EncryptionManager::new(kp, device_id);
979
980 let nonexistent_peer = DeviceId::from_bytes([99u8; 16]);
981 let result = mgr.encrypt_for_peer(&nonexistent_peer, b"hello").await;
982 assert!(result.is_err());
983 let err_msg = format!("{}", result.unwrap_err());
984 assert!(err_msg.contains("no channel"));
985 }
986
987 #[tokio::test]
988 async fn test_encryption_manager_decrypt_from_peer_no_channel() {
989 let kp = EncryptionKeypair::generate();
990 let device_id = DeviceId::from_bytes([1u8; 16]);
991 let mgr = EncryptionManager::new(kp, device_id);
992
993 let nonexistent_peer = DeviceId::from_bytes([99u8; 16]);
994 let fake_encrypted = EncryptedData {
995 nonce: [0u8; NONCE_SIZE],
996 ciphertext: vec![1, 2, 3],
997 };
998 let result = mgr
999 .decrypt_from_peer(&nonexistent_peer, &fake_encrypted)
1000 .await;
1001 assert!(result.is_err());
1002 let err_msg = format!("{}", result.unwrap_err());
1003 assert!(err_msg.contains("no channel"));
1004 }
1005
1006 #[tokio::test]
1007 async fn test_encryption_manager_encrypt_for_cell_no_key() {
1008 let kp = EncryptionKeypair::generate();
1009 let device_id = DeviceId::from_bytes([1u8; 16]);
1010 let mgr = EncryptionManager::new(kp, device_id);
1011
1012 let result = mgr.encrypt_for_cell("nonexistent-cell", b"data").await;
1013 assert!(result.is_err());
1014 let err_msg = format!("{}", result.unwrap_err());
1015 assert!(err_msg.contains("no key for cell"));
1016 }
1017
1018 #[tokio::test]
1019 async fn test_encryption_manager_decrypt_cell_message_no_key() {
1020 let kp = EncryptionKeypair::generate();
1021 let device_id = DeviceId::from_bytes([1u8; 16]);
1022 let mgr = EncryptionManager::new(kp, device_id);
1023
1024 let fake_message = EncryptedCellMessage {
1025 cell_id: "nonexistent-cell".to_string(),
1026 generation: 1,
1027 encrypted: EncryptedData {
1028 nonce: [0u8; NONCE_SIZE],
1029 ciphertext: vec![1, 2, 3],
1030 },
1031 };
1032 let result = mgr.decrypt_cell_message(&fake_message).await;
1033 assert!(result.is_err());
1034 let err_msg = format!("{}", result.unwrap_err());
1035 assert!(err_msg.contains("no key for cell"));
1036 }
1037
1038 #[tokio::test]
1039 async fn test_encryption_manager_rotate_cell_key_no_key() {
1040 let kp = EncryptionKeypair::generate();
1041 let device_id = DeviceId::from_bytes([1u8; 16]);
1042 let mgr = EncryptionManager::new(kp, device_id);
1043
1044 let result = mgr.rotate_cell_key("nonexistent").await;
1045 assert!(result.is_err());
1046 let err_msg = format!("{}", result.unwrap_err());
1047 assert!(err_msg.contains("no key for cell"));
1048 }
1049
1050 #[tokio::test]
1051 async fn test_encryption_manager_rotate_cell_key_success() {
1052 let kp = EncryptionKeypair::generate();
1053 let device_id = DeviceId::from_bytes([1u8; 16]);
1054 let mgr = EncryptionManager::new(kp, device_id);
1055
1056 let key1 = mgr.get_or_create_cell_key("cell-1").await;
1058 assert_eq!(key1.generation, 1);
1059
1060 let key2 = mgr.rotate_cell_key("cell-1").await.unwrap();
1062 assert_eq!(key2.generation, 2);
1063 assert_eq!(key2.cell_id, "cell-1");
1064
1065 let stored = mgr.get_cell_key("cell-1").await.unwrap();
1067 assert_eq!(stored.generation, 2);
1068
1069 let msg = b"test message";
1071 let encrypted = mgr.encrypt_for_cell("cell-1", msg).await.unwrap();
1072 assert_eq!(encrypted.generation, 2);
1073 let decrypted = mgr.decrypt_cell_message(&encrypted).await.unwrap();
1074 assert_eq!(decrypted, msg);
1075 }
1076
1077 #[tokio::test]
1078 async fn test_encryption_manager_remove_channel() {
1079 let alice_kp = EncryptionKeypair::generate();
1080 let bob_kp = EncryptionKeypair::generate();
1081
1082 let alice_id = DeviceId::from_bytes([1u8; 16]);
1083 let bob_id = DeviceId::from_bytes([2u8; 16]);
1084
1085 let alice_mgr = EncryptionManager::new(alice_kp, alice_id);
1086
1087 alice_mgr
1089 .establish_channel(bob_id, &bob_kp.public_key_bytes())
1090 .await
1091 .unwrap();
1092 assert!(alice_mgr.has_channel(&bob_id).await);
1093 assert_eq!(alice_mgr.peer_channel_count().await, 1);
1094
1095 alice_mgr.remove_channel(&bob_id).await;
1097 assert!(!alice_mgr.has_channel(&bob_id).await);
1098 assert_eq!(alice_mgr.peer_channel_count().await, 0);
1099 }
1100
1101 #[tokio::test]
1102 async fn test_encryption_manager_remove_cell_key() {
1103 let kp = EncryptionKeypair::generate();
1104 let device_id = DeviceId::from_bytes([1u8; 16]);
1105 let mgr = EncryptionManager::new(kp, device_id);
1106
1107 mgr.get_or_create_cell_key("cell-1").await;
1108 assert_eq!(mgr.cell_key_count().await, 1);
1109
1110 mgr.remove_cell_key("cell-1").await;
1111 assert_eq!(mgr.cell_key_count().await, 0);
1112 assert!(mgr.get_cell_key("cell-1").await.is_none());
1113 }
1114
1115 #[tokio::test]
1116 async fn test_encryption_manager_set_cell_key() {
1117 let kp = EncryptionKeypair::generate();
1118 let device_id = DeviceId::from_bytes([1u8; 16]);
1119 let mgr = EncryptionManager::new(kp, device_id);
1120
1121 let key = GroupKey::from_bytes("cell-99".to_string(), &[7u8; SYMMETRIC_KEY_SIZE], 5);
1122 mgr.set_cell_key(key).await;
1123
1124 let stored = mgr.get_cell_key("cell-99").await;
1125 assert!(stored.is_some());
1126 let stored = stored.unwrap();
1127 assert_eq!(stored.cell_id, "cell-99");
1128 assert_eq!(stored.generation, 5);
1129 }
1130
1131 #[tokio::test]
1132 async fn test_encryption_manager_get_channel() {
1133 let alice_kp = EncryptionKeypair::generate();
1134 let bob_kp = EncryptionKeypair::generate();
1135 let alice_id = DeviceId::from_bytes([1u8; 16]);
1136 let bob_id = DeviceId::from_bytes([2u8; 16]);
1137
1138 let mgr = EncryptionManager::new(alice_kp, alice_id);
1139 mgr.establish_channel(bob_id, &bob_kp.public_key_bytes())
1140 .await
1141 .unwrap();
1142
1143 let channel = mgr.get_channel(&bob_id).await;
1145 assert!(channel.is_some());
1146 let channel = channel.unwrap();
1147 assert_eq!(channel.peer_id, bob_id);
1148
1149 let missing = DeviceId::from_bytes([99u8; 16]);
1151 assert!(mgr.get_channel(&missing).await.is_none());
1152 }
1153
1154 #[test]
1155 fn test_group_key_decrypt_cell_id_mismatch() {
1156 let key = GroupKey::generate("cell-1".to_string());
1157 let encrypted = key.encrypt(b"test").unwrap();
1158
1159 let wrong_msg = EncryptedCellMessage {
1161 cell_id: "cell-WRONG".to_string(),
1162 generation: encrypted.generation,
1163 encrypted: encrypted.encrypted.clone(),
1164 };
1165 let result = key.decrypt(&wrong_msg);
1166 assert!(result.is_err());
1167 let err_msg = format!("{}", result.unwrap_err());
1168 assert!(err_msg.contains("cell ID mismatch"));
1169 }
1170
1171 #[test]
1172 fn test_group_key_decrypt_generation_mismatch() {
1173 let key = GroupKey::generate("cell-1".to_string());
1174 let encrypted = key.encrypt(b"test").unwrap();
1175
1176 let wrong_msg = EncryptedCellMessage {
1178 cell_id: "cell-1".to_string(),
1179 generation: 999,
1180 encrypted: encrypted.encrypted.clone(),
1181 };
1182 let result = key.decrypt(&wrong_msg);
1183 assert!(result.is_err());
1184 let err_msg = format!("{}", result.unwrap_err());
1185 assert!(err_msg.contains("key generation mismatch"));
1186 }
1187
1188 #[test]
1189 fn test_group_key_from_bytes() {
1190 let key_bytes = [42u8; SYMMETRIC_KEY_SIZE];
1191 let key = GroupKey::from_bytes("cell-x".to_string(), &key_bytes, 10);
1192 assert_eq!(key.cell_id, "cell-x");
1193 assert_eq!(key.generation, 10);
1194 assert_eq!(key.key_bytes(), key_bytes);
1195
1196 let plaintext = b"from bytes key test";
1198 let encrypted = key.encrypt(plaintext).unwrap();
1199 let decrypted = key.decrypt(&encrypted).unwrap();
1200 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
1201 }
1202
1203 #[test]
1204 fn test_group_key_rotation_preserves_cell_id() {
1205 let key1 = GroupKey::generate("my-cell".to_string());
1206 let key2 = key1.rotate();
1207 let key3 = key2.rotate();
1208
1209 assert_eq!(key1.cell_id, "my-cell");
1210 assert_eq!(key2.cell_id, "my-cell");
1211 assert_eq!(key3.cell_id, "my-cell");
1212 assert_eq!(key1.generation, 1);
1213 assert_eq!(key2.generation, 2);
1214 assert_eq!(key3.generation, 3);
1215
1216 assert_ne!(key1.key_bytes(), key2.key_bytes());
1218 assert_ne!(key2.key_bytes(), key3.key_bytes());
1219 }
1220
1221 #[test]
1222 fn test_keypair_debug_redacts_secret() {
1223 let kp = EncryptionKeypair::generate();
1224 let debug_str = format!("{:?}", kp);
1225 assert!(debug_str.contains("REDACTED"));
1226 assert!(debug_str.contains("[REDACTED]"));
1229 }
1230
1231 #[test]
1232 fn test_symmetric_key_debug_redacts() {
1233 let key = SymmetricKey::from_bytes(&[42u8; 32]);
1234 let debug_str = format!("{:?}", key);
1235 assert!(debug_str.contains("REDACTED"));
1236 }
1237
1238 #[test]
1239 fn test_group_key_debug_redacts() {
1240 let key = GroupKey::generate("cell-1".to_string());
1241 let debug_str = format!("{:?}", key);
1242 assert!(debug_str.contains("REDACTED"));
1243 assert!(debug_str.contains("cell-1"));
1244 }
1245
1246 #[test]
1247 fn test_encryption_manager_debug() {
1248 let kp = EncryptionKeypair::generate();
1249 let device_id = DeviceId::from_bytes([1u8; 16]);
1250 let mgr = EncryptionManager::new(kp, device_id);
1251 let debug_str = format!("{:?}", mgr);
1252 assert!(debug_str.contains("EncryptionManager"));
1253 assert!(debug_str.contains("device_id"));
1254 assert!(debug_str.contains("public_key"));
1255 }
1256
1257 #[test]
1258 fn test_encrypted_document_new() {
1259 let key = SymmetricKey::from_bytes(&[42u8; 32]);
1260 let encrypted = key.encrypt(b"doc data").unwrap();
1261 let device_id = DeviceId::from_bytes([1u8; 16]);
1262 let doc = EncryptedDocument::new(encrypted, device_id);
1263
1264 assert_eq!(doc.encrypted_by, device_id);
1265 assert!(doc.encrypted_at > 0);
1266 }
1267
1268 #[test]
1269 fn test_symmetric_key_derive_from_shared_secret() {
1270 let alice = EncryptionKeypair::generate();
1271 let bob = EncryptionKeypair::generate();
1272
1273 let shared = alice.dh_exchange(bob.public_key());
1274 let key = SymmetricKey::derive_for_peer(&shared);
1275
1276 let plaintext = b"derived key test";
1278 let encrypted = key.encrypt(plaintext).unwrap();
1279 let decrypted = key.decrypt(&encrypted).unwrap();
1280 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
1281 }
1282
1283 #[test]
1284 fn test_symmetric_key_as_bytes_roundtrip() {
1285 let original_bytes = [99u8; SYMMETRIC_KEY_SIZE];
1286 let key = SymmetricKey::from_bytes(&original_bytes);
1287 let extracted = key.as_bytes();
1288 assert_eq!(&original_bytes, extracted);
1289 }
1290}