1use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5use std::time::{Duration, SystemTime, UNIX_EPOCH};
6
7use base64::{engine::general_purpose::STANDARD as B64, Engine};
8use zeroize::{Zeroize, Zeroizing};
9
10use crate::error::{CryptoError, Result};
11use crate::primitives;
12use crate::storage::KeyStore;
13use crate::types::{
14 E2EEnvelope, ECDHKeyPair, KeyData, KeyExchangeMessage, PublicKeyAnnouncement, TofuRecord,
15};
16
17pub type OnKeyChange = Arc<dyn Fn(&str, &str, &str) -> bool + Send + Sync>;
21
22const MIN_ROTATION_INTERVAL: Duration = Duration::from_secs(60);
24
25const DEFAULT_MAX_ANNOUNCEMENT_FUTURE: Duration = Duration::from_secs(30);
27
28pub struct E2ESessionConfig {
30 pub identity_id: String,
31 pub base_path: String,
32 pub store: Arc<dyn KeyStore>,
33 pub on_key_change: Option<OnKeyChange>,
34 pub password_hash: Option<String>,
35 pub rotation_interval: Option<Duration>,
39 pub on_rotation: Option<Arc<dyn Fn() + Send + Sync>>,
41 pub max_announcement_age: Option<Duration>,
44}
45
46pub struct E2ESession {
53 config: E2ESessionConfig,
54 group_key: Option<Vec<u8>>,
55 ecdh_key_pair: Option<ECDHKeyPair>,
56 peer_public_keys: HashMap<String, Vec<u8>>,
57 started: bool,
58 destroyed: bool,
59 last_rotation: Option<u64>,
61 rotation_count: u64,
63 seen_nonces: HashSet<String>,
65}
66
67impl E2ESession {
68 pub fn new(config: E2ESessionConfig) -> Self {
69 Self {
70 config,
71 group_key: None,
72 ecdh_key_pair: None,
73 peer_public_keys: HashMap::new(),
74 started: false,
75 destroyed: false,
76 last_rotation: None,
77 rotation_count: 0,
78 seen_nonces: HashSet::new(),
79 }
80 }
81
82 pub fn encrypted(&self) -> bool {
84 self.group_key.is_some()
85 }
86
87 pub fn base_path(&self) -> &str {
89 &self.config.base_path
90 }
91
92 pub async fn start(&mut self) -> Result<()> {
94 if self.destroyed {
95 return Err(CryptoError::SessionDestroyed);
96 }
97 if self.started {
98 return Ok(());
99 }
100 self.started = true;
101
102 let session_id = self.session_id();
104 if let Some(data) = self.config.store.load_group_key(&session_id).await? {
105 match primitives::jwk_to_group_key(&data.key) {
106 Ok(key) => {
107 self.group_key = Some(key);
108 self.last_rotation = Some(data.stored_at);
110 }
111 Err(_) => {
112 self.config.store.delete_group_key(&session_id).await?;
113 }
114 }
115 }
116
117 Ok(())
118 }
119
120 pub async fn enable_encryption(&mut self) -> Result<PublicKeyAnnouncement> {
123 if self.destroyed {
124 return Err(CryptoError::SessionDestroyed);
125 }
126
127 let mut key = Zeroizing::new(primitives::generate_group_key());
128 self.group_key = Some(key.to_vec());
129
130 let creation_time = now_ms();
131 self.last_rotation = Some(creation_time);
132
133 let jwk = primitives::group_key_to_jwk(&key)?;
135 key.zeroize();
136 self.config
137 .store
138 .save_group_key(
139 &self.session_id(),
140 KeyData {
141 key: jwk,
142 stored_at: creation_time,
143 },
144 )
145 .await?;
146
147 self.make_public_key_announcement()
148 }
149
150 pub fn request_group_key(&mut self) -> Result<Option<PublicKeyAnnouncement>> {
152 if self.destroyed {
153 return Err(CryptoError::SessionDestroyed);
154 }
155 if self.group_key.is_some() {
156 return Ok(None);
157 }
158 self.make_public_key_announcement().map(Some)
159 }
160
161 pub fn encrypt(&self, value: &str) -> Result<E2EEnvelope> {
163 if self.destroyed {
164 return Err(CryptoError::SessionDestroyed);
165 }
166 let key = self.group_key.as_ref().ok_or(CryptoError::NoGroupKey)?;
167 let plaintext = value.as_bytes();
168 let (ciphertext, iv) = primitives::encrypt(key, plaintext)?;
169 Ok(E2EEnvelope {
170 _e2e: 1,
171 ct: B64.encode(&ciphertext),
172 iv: B64.encode(&iv),
173 v: 1,
174 })
175 }
176
177 pub async fn decrypt(&mut self, envelope: &E2EEnvelope) -> Result<String> {
179 if self.destroyed {
180 return Err(CryptoError::SessionDestroyed);
181 }
182 if envelope._e2e != 1 {
183 return Err(CryptoError::DecryptionFailed("invalid E2E marker".into()));
184 }
185 if envelope.v != 1 {
186 return Err(CryptoError::DecryptionFailed(
187 "unsupported envelope version".into(),
188 ));
189 }
190 let key = Zeroizing::new(match &self.group_key {
191 Some(k) => k.clone(),
192 None => {
193 let session_id = self.session_id();
194 match self.config.store.load_group_key(&session_id).await? {
195 Some(data) => {
196 let k = primitives::jwk_to_group_key(&data.key)?;
197 self.group_key = Some(k.clone());
198 k
199 }
200 None => return Err(CryptoError::NoGroupKey),
201 }
202 }
203 });
204
205 let ciphertext = B64
206 .decode(&envelope.ct)
207 .map_err(|e| CryptoError::DecryptionFailed(format!("invalid base64 ct: {e}")))?;
208 let iv = B64
209 .decode(&envelope.iv)
210 .map_err(|e| CryptoError::DecryptionFailed(format!("invalid base64 iv: {e}")))?;
211 let plaintext = primitives::decrypt(&key, &ciphertext, &iv)?;
212 String::from_utf8(plaintext)
213 .map_err(|e| CryptoError::DecryptionFailed(format!("invalid UTF-8: {e}")))
214 }
215
216 pub async fn handle_peer_pubkey(
223 &mut self,
224 peer_id: &str,
225 announcement: &PublicKeyAnnouncement,
226 ) -> Result<Option<KeyExchangeMessage>> {
227 if self.destroyed {
228 return Err(CryptoError::SessionDestroyed);
229 }
230 if peer_id == self.config.identity_id {
231 return Ok(None);
232 }
233
234 if let Some(max_age) = self.config.max_announcement_age {
236 let now = now_ms();
237 let ts = announcement.timestamp;
238 let max_age_ms = max_age.as_millis() as u64;
239 let max_future_ms = DEFAULT_MAX_ANNOUNCEMENT_FUTURE.as_millis() as u64;
240 if ts + max_age_ms < now {
241 return Err(CryptoError::Other(format!(
242 "announcement from {peer_id} is too old ({} ms)",
243 now.saturating_sub(ts)
244 )));
245 }
246 if ts > now + max_future_ms {
247 return Err(CryptoError::Other(format!(
248 "announcement from {peer_id} is too far in the future ({} ms)",
249 ts.saturating_sub(now)
250 )));
251 }
252 }
253
254 let peer_pub_bytes = primitives::jwk_to_public_key(&announcement.public_key)?;
256
257 self.verify_peer_key(peer_id, &announcement.public_key)
259 .await?;
260
261 self.peer_public_keys
263 .insert(peer_id.to_string(), peer_pub_bytes.clone());
264
265 let group_key = Zeroizing::new(match &self.group_key {
267 Some(k) => k.clone(),
268 None => return Ok(None),
269 });
270
271 self.ensure_ecdh_key_pair();
273 let kp = self.ecdh_key_pair();
274 let shared = Zeroizing::new(primitives::derive_shared_key(
275 &kp.private_key,
276 &peer_pub_bytes,
277 None,
278 )?);
279 let group_key_jwk = primitives::group_key_to_jwk(&group_key)?;
280 let mut group_key_json = Zeroizing::new(
281 serde_json::to_string(&group_key_jwk)
282 .map_err(|e| CryptoError::Serialization(e.to_string()))?,
283 );
284 let (ct, iv) = primitives::encrypt(&shared, group_key_json.as_bytes())?;
285 group_key_json.zeroize();
286
287 let sender_pub_jwk = primitives::public_key_to_jwk(&kp.public_key)?;
288
289 Ok(Some(KeyExchangeMessage {
290 from_id: self.config.identity_id.clone(),
291 encrypted_key: B64.encode(&ct),
292 iv: B64.encode(&iv),
293 sender_public_key: sender_pub_jwk,
294 }))
295 }
296
297 pub async fn handle_key_exchange(&mut self, msg: &KeyExchangeMessage) -> Result<()> {
300 if self.destroyed {
301 return Err(CryptoError::SessionDestroyed);
302 }
303 if msg.from_id.is_empty() {
305 return Err(CryptoError::InvalidKey(
306 "key exchange message missing sender ID".into(),
307 ));
308 }
309
310 let nonce_key = format!("{}:{}", msg.from_id, msg.iv);
312 if !self.seen_nonces.insert(nonce_key) {
313 return Err(CryptoError::Other(format!(
314 "replayed key exchange message from {}",
315 msg.from_id
316 )));
317 }
318 if self.seen_nonces.len() > 10_000 {
320 self.seen_nonces.clear();
321 }
322
323 let sender_pub = primitives::jwk_to_public_key(&msg.sender_public_key)?;
324
325 self.verify_peer_key(&msg.from_id, &msg.sender_public_key)
327 .await?;
328
329 self.peer_public_keys
331 .insert(msg.from_id.clone(), sender_pub.clone());
332
333 self.ensure_ecdh_key_pair();
334 let kp = self.ecdh_key_pair();
335 let shared = Zeroizing::new(primitives::derive_shared_key(
336 &kp.private_key,
337 &sender_pub,
338 None,
339 )?);
340
341 let ct = B64
342 .decode(&msg.encrypted_key)
343 .map_err(|e| CryptoError::DecryptionFailed(format!("invalid base64: {e}")))?;
344 let iv = B64
345 .decode(&msg.iv)
346 .map_err(|e| CryptoError::DecryptionFailed(format!("invalid base64: {e}")))?;
347
348 let decrypted = primitives::decrypt(&shared, &ct, &iv)?;
349 let mut key_json = Zeroizing::new(
350 String::from_utf8(decrypted)
351 .map_err(|e| CryptoError::DecryptionFailed(format!("invalid UTF-8: {e}")))?,
352 );
353 let key_jwk: serde_json::Value = serde_json::from_str(&key_json)
354 .map_err(|e| CryptoError::DecryptionFailed(format!("invalid JWK JSON: {e}")))?;
355 key_json.zeroize();
356 let mut group_key = Zeroizing::new(primitives::jwk_to_group_key(&key_jwk)?);
357
358 self.group_key = Some(group_key.to_vec());
359 group_key.zeroize();
360
361 self.config
363 .store
364 .save_group_key(
365 &self.session_id(),
366 KeyData {
367 key: key_jwk,
368 stored_at: now_ms(),
369 },
370 )
371 .await?;
372
373 Ok(())
374 }
375
376 pub async fn rotate_key(&mut self) -> Result<Vec<(String, KeyExchangeMessage)>> {
378 if self.destroyed {
379 return Err(CryptoError::SessionDestroyed);
380 }
381 if self.group_key.is_none() {
382 return Ok(vec![]);
383 }
384
385 if let Some(ref mut old_key) = self.group_key {
387 old_key.zeroize();
388 }
389 let mut new_key = Zeroizing::new(primitives::generate_group_key());
390 self.group_key = Some(new_key.to_vec());
391
392 let rotation_time = now_ms();
393 self.last_rotation = Some(rotation_time);
394 self.rotation_count += 1;
395
396 let jwk = primitives::group_key_to_jwk(&new_key)?;
397 let mut group_key_json = Zeroizing::new(
398 serde_json::to_string(&jwk).map_err(|e| CryptoError::Serialization(e.to_string()))?,
399 );
400 new_key.zeroize();
401 self.config
402 .store
403 .save_group_key(
404 &self.session_id(),
405 KeyData {
406 key: jwk,
407 stored_at: rotation_time,
408 },
409 )
410 .await?;
411
412 self.ensure_ecdh_key_pair();
414 let kp = self.ecdh_key_pair();
415 let sender_pub_jwk = primitives::public_key_to_jwk(&kp.public_key)?;
416 let mut messages = Vec::new();
417
418 for (peer_id, peer_pub) in &self.peer_public_keys {
419 if *peer_id == self.config.identity_id {
420 continue;
421 }
422 if let Ok(shared) = primitives::derive_shared_key(&kp.private_key, peer_pub, None) {
423 let mut shared = Zeroizing::new(shared);
424 if let Ok((ct, iv)) = primitives::encrypt(&shared, group_key_json.as_bytes()) {
425 messages.push((
426 peer_id.clone(),
427 KeyExchangeMessage {
428 from_id: self.config.identity_id.clone(),
429 encrypted_key: B64.encode(&ct),
430 iv: B64.encode(&iv),
431 sender_public_key: sender_pub_jwk.clone(),
432 },
433 ));
434 }
435 shared.zeroize();
436 }
437 }
438 group_key_json.zeroize();
439
440 Ok(messages)
441 }
442
443 pub fn remove_peer(&mut self, peer_id: &str) {
445 self.peer_public_keys.remove(peer_id);
446 }
447
448 pub fn should_rotate(&self) -> bool {
450 let interval = match self.config.rotation_interval {
451 Some(d) => d.max(MIN_ROTATION_INTERVAL),
452 None => return false,
453 };
454 if self.group_key.is_none() || self.destroyed {
455 return false;
456 }
457 let last = self.last_rotation.unwrap_or(0);
458 if last == 0 {
459 return false;
460 }
461 let elapsed_ms = now_ms().saturating_sub(last);
462 elapsed_ms >= interval.as_millis() as u64
463 }
464
465 pub async fn maybe_rotate(
469 &mut self,
470 ) -> Result<Option<(Vec<(String, KeyExchangeMessage)>, PublicKeyAnnouncement)>> {
471 if !self.should_rotate() {
472 return Ok(None);
473 }
474 let messages = self.rotate_key().await?;
475 let announcement = self.make_public_key_announcement()?;
476 if let Some(ref cb) = self.config.on_rotation {
477 cb();
478 }
479 Ok(Some((messages, announcement)))
480 }
481
482 pub fn rotation_count(&self) -> u64 {
484 self.rotation_count
485 }
486
487 pub fn last_rotation(&self) -> Option<u64> {
489 self.last_rotation
490 }
491
492 pub fn destroy(&mut self) {
494 self.destroyed = true;
495 if let Some(ref mut key) = self.group_key {
496 key.zeroize();
497 }
498 self.group_key = None;
499 self.ecdh_key_pair = None;
501 self.peer_public_keys.clear();
502 self.seen_nonces.clear();
503 }
504
505 fn session_id(&self) -> String {
508 self.config.base_path.clone()
509 }
510
511 fn ensure_ecdh_key_pair(&mut self) {
514 if self.ecdh_key_pair.is_none() {
515 self.ecdh_key_pair = Some(primitives::generate_ecdh_key_pair());
516 }
517 }
518
519 fn ecdh_key_pair(&self) -> &ECDHKeyPair {
521 self.ecdh_key_pair.as_ref().unwrap()
522 }
523
524 fn make_public_key_announcement(&mut self) -> Result<PublicKeyAnnouncement> {
525 self.ensure_ecdh_key_pair();
526 let kp = self.ecdh_key_pair();
527 let jwk = primitives::public_key_to_jwk(&kp.public_key)?;
528 Ok(PublicKeyAnnouncement {
529 public_key: jwk,
530 timestamp: now_ms(),
531 })
532 }
533
534 async fn verify_peer_key(
539 &self,
540 peer_id: &str,
541 public_key_jwk: &serde_json::Value,
542 ) -> Result<()> {
543 let fp = primitives::fingerprint_jwk(public_key_jwk);
544 let record_id = format!("{}:{}", self.config.base_path, peer_id);
545
546 let stored = self.config.store.load_tofu_record(&record_id).await?;
547
548 match stored {
549 None => {
550 self.config
552 .store
553 .save_tofu_record(
554 &record_id,
555 TofuRecord {
556 fingerprint: fp,
557 first_seen: now_ms(),
558 },
559 )
560 .await?;
561 }
562 Some(record) => {
563 if !primitives::constant_time_eq(record.fingerprint.as_bytes(), fp.as_bytes()) {
564 let accepted = self
566 .config
567 .on_key_change
568 .as_ref()
569 .map(|cb| cb(peer_id, &record.fingerprint, &fp))
570 .unwrap_or(false);
571 if !accepted {
572 return Err(CryptoError::TofuViolation(peer_id.to_string()));
573 }
574 self.config
577 .store
578 .save_tofu_record(
579 &record_id,
580 TofuRecord {
581 fingerprint: fp,
582 first_seen: record.first_seen,
583 },
584 )
585 .await?;
586 }
587 }
588 }
589
590 Ok(())
591 }
592}
593
594impl Drop for E2ESession {
595 fn drop(&mut self) {
596 if !self.destroyed {
597 self.destroy();
598 }
599 }
600}
601
602fn now_ms() -> u64 {
603 SystemTime::now()
604 .duration_since(UNIX_EPOCH)
605 .unwrap_or_default()
606 .as_millis() as u64
607}
608
609#[cfg(test)]
610mod tests {
611 use super::*;
612 use crate::storage::MemoryKeyStore;
613
614 fn test_config(store: Arc<dyn KeyStore>) -> E2ESessionConfig {
615 E2ESessionConfig {
616 identity_id: "alice".to_string(),
617 base_path: "/test/room/1".to_string(),
618 store,
619 on_key_change: None,
620 password_hash: None,
621 rotation_interval: None,
622 on_rotation: None,
623 max_announcement_age: None,
624 }
625 }
626
627 #[tokio::test]
628 async fn session_starts_without_key() {
629 let store = Arc::new(MemoryKeyStore::new());
630 let mut session = E2ESession::new(test_config(store));
631 session.start().await.unwrap();
632 assert!(!session.encrypted());
633 }
634
635 #[tokio::test]
636 async fn enable_encryption_creates_key() {
637 let store = Arc::new(MemoryKeyStore::new());
638 let mut session = E2ESession::new(test_config(store));
639 session.start().await.unwrap();
640 session.enable_encryption().await.unwrap();
641 assert!(session.encrypted());
642 }
643
644 #[tokio::test]
645 async fn encrypt_decrypt_round_trip() {
646 let store = Arc::new(MemoryKeyStore::new());
647 let mut session = E2ESession::new(test_config(store));
648 session.start().await.unwrap();
649 session.enable_encryption().await.unwrap();
650
651 let envelope = session.encrypt("Hello, world!").unwrap();
652 assert_eq!(envelope._e2e, 1);
653 assert_eq!(envelope.v, 1);
654
655 let decrypted = session.decrypt(&envelope).await.unwrap();
656 assert_eq!(decrypted, "Hello, world!");
657 }
658
659 #[tokio::test]
660 async fn persists_and_loads_key() {
661 let store = Arc::new(MemoryKeyStore::new());
662
663 let mut session1 = E2ESession::new(test_config(store.clone()));
664 session1.start().await.unwrap();
665 session1.enable_encryption().await.unwrap();
666 let envelope = session1.encrypt("hello").unwrap();
667
668 let mut session2 = E2ESession::new(test_config(store));
669 session2.start().await.unwrap();
670 assert!(session2.encrypted());
671
672 let decrypted = session2.decrypt(&envelope).await.unwrap();
673 assert_eq!(decrypted, "hello");
674 }
675
676 #[tokio::test]
677 async fn key_exchange_between_peers() {
678 let store_a = Arc::new(MemoryKeyStore::new());
679 let store_b = Arc::new(MemoryKeyStore::new());
680
681 let mut alice = E2ESession::new(E2ESessionConfig {
682 identity_id: "alice".to_string(),
683 base_path: "/test/room/1".to_string(),
684 store: store_a,
685 on_key_change: None,
686 password_hash: None,
687 rotation_interval: None,
688 on_rotation: None,
689 max_announcement_age: None,
690 });
691 alice.start().await.unwrap();
692 let _alice_announcement = alice.enable_encryption().await.unwrap();
693
694 let mut bob = E2ESession::new(E2ESessionConfig {
695 identity_id: "bob".to_string(),
696 base_path: "/test/room/1".to_string(),
697 store: store_b,
698 on_key_change: None,
699 password_hash: None,
700 rotation_interval: None,
701 on_rotation: None,
702 max_announcement_age: None,
703 });
704 bob.start().await.unwrap();
705 let bob_announcement = bob.request_group_key().unwrap().unwrap();
706
707 let keyex = alice
708 .handle_peer_pubkey("bob", &bob_announcement)
709 .await
710 .unwrap();
711 assert!(keyex.is_some());
712
713 bob.handle_key_exchange(&keyex.unwrap()).await.unwrap();
714 assert!(bob.encrypted());
715
716 let envelope = alice.encrypt("secret message").unwrap();
717 let decrypted = bob.decrypt(&envelope).await.unwrap();
718 assert_eq!(decrypted, "secret message");
719 }
720
721 #[tokio::test]
722 async fn rotate_key_invalidates_old_messages() {
723 let store = Arc::new(MemoryKeyStore::new());
724 let mut session = E2ESession::new(test_config(store));
725 session.start().await.unwrap();
726 session.enable_encryption().await.unwrap();
727
728 let old_envelope = session.encrypt("before rotation").unwrap();
729 session.rotate_key().await.unwrap();
730
731 let result = session.decrypt(&old_envelope).await;
732 assert!(result.is_err());
733 }
734
735 #[tokio::test]
736 async fn tofu_detects_key_change_and_accepts() {
737 use std::sync::atomic::{AtomicBool, Ordering};
738
739 let store = Arc::new(MemoryKeyStore::new());
740 let changed = Arc::new(AtomicBool::new(false));
741 let changed_clone = changed.clone();
742
743 let mut session = E2ESession::new(E2ESessionConfig {
744 identity_id: "alice".to_string(),
745 base_path: "/test/room/1".to_string(),
746 store: store.clone(),
747 on_key_change: Some(Arc::new(move |_peer, _old, _new| {
748 changed_clone.store(true, Ordering::SeqCst);
749 true })),
751 password_hash: None,
752 rotation_interval: None,
753 on_rotation: None,
754 max_announcement_age: None,
755 });
756 session.start().await.unwrap();
757 session.enable_encryption().await.unwrap();
758
759 let bob_kp1 = primitives::generate_ecdh_key_pair();
761 let ann1 = PublicKeyAnnouncement {
762 public_key: primitives::public_key_to_jwk(&bob_kp1.public_key).unwrap(),
763 timestamp: now_ms(),
764 };
765 session.handle_peer_pubkey("bob", &ann1).await.unwrap();
766 assert!(!changed.load(Ordering::SeqCst));
767
768 let bob_kp2 = primitives::generate_ecdh_key_pair();
770 let ann2 = PublicKeyAnnouncement {
771 public_key: primitives::public_key_to_jwk(&bob_kp2.public_key).unwrap(),
772 timestamp: now_ms(),
773 };
774 session.handle_peer_pubkey("bob", &ann2).await.unwrap();
775 assert!(changed.load(Ordering::SeqCst));
776 }
777
778 #[tokio::test]
779 async fn tofu_rejects_key_change_without_callback() {
780 let store = Arc::new(MemoryKeyStore::new());
781
782 let mut session = E2ESession::new(test_config(store.clone()));
784 session.start().await.unwrap();
785 session.enable_encryption().await.unwrap();
786
787 let bob_kp1 = primitives::generate_ecdh_key_pair();
789 let ann1 = PublicKeyAnnouncement {
790 public_key: primitives::public_key_to_jwk(&bob_kp1.public_key).unwrap(),
791 timestamp: now_ms(),
792 };
793 session.handle_peer_pubkey("bob", &ann1).await.unwrap();
794
795 let bob_kp2 = primitives::generate_ecdh_key_pair();
797 let ann2 = PublicKeyAnnouncement {
798 public_key: primitives::public_key_to_jwk(&bob_kp2.public_key).unwrap(),
799 timestamp: now_ms(),
800 };
801 let result = session.handle_peer_pubkey("bob", &ann2).await;
802 assert!(matches!(result, Err(CryptoError::TofuViolation(_))));
803 }
804
805 #[tokio::test]
806 async fn tofu_rejects_key_change_when_callback_returns_false() {
807 let store = Arc::new(MemoryKeyStore::new());
808
809 let mut session = E2ESession::new(E2ESessionConfig {
810 identity_id: "alice".to_string(),
811 base_path: "/test/room/1".to_string(),
812 store: store.clone(),
813 on_key_change: Some(Arc::new(|_peer, _old, _new| false)),
814 password_hash: None,
815 rotation_interval: None,
816 on_rotation: None,
817 max_announcement_age: None,
818 });
819 session.start().await.unwrap();
820 session.enable_encryption().await.unwrap();
821
822 let bob_kp1 = primitives::generate_ecdh_key_pair();
823 let ann1 = PublicKeyAnnouncement {
824 public_key: primitives::public_key_to_jwk(&bob_kp1.public_key).unwrap(),
825 timestamp: now_ms(),
826 };
827 session.handle_peer_pubkey("bob", &ann1).await.unwrap();
828
829 let bob_kp2 = primitives::generate_ecdh_key_pair();
830 let ann2 = PublicKeyAnnouncement {
831 public_key: primitives::public_key_to_jwk(&bob_kp2.public_key).unwrap(),
832 timestamp: now_ms(),
833 };
834 let result = session.handle_peer_pubkey("bob", &ann2).await;
835 assert!(matches!(result, Err(CryptoError::TofuViolation(_))));
836 }
837
838 #[tokio::test]
839 async fn tofu_stores_records_without_callback() {
840 let store = Arc::new(MemoryKeyStore::new());
841
842 let mut session = E2ESession::new(test_config(store.clone()));
844 session.start().await.unwrap();
845 session.enable_encryption().await.unwrap();
846
847 let bob_kp = primitives::generate_ecdh_key_pair();
848 let ann = PublicKeyAnnouncement {
849 public_key: primitives::public_key_to_jwk(&bob_kp.public_key).unwrap(),
850 timestamp: now_ms(),
851 };
852 session.handle_peer_pubkey("bob", &ann).await.unwrap();
853
854 let record = store.load_tofu_record("/test/room/1:bob").await.unwrap();
856 assert!(record.is_some());
857 }
858
859 #[tokio::test]
860 async fn empty_from_id_rejected() {
861 let store_a = Arc::new(MemoryKeyStore::new());
862 let store_b = Arc::new(MemoryKeyStore::new());
863
864 let mut alice = E2ESession::new(E2ESessionConfig {
865 identity_id: "alice".to_string(),
866 base_path: "/test/room/1".to_string(),
867 store: store_a,
868 on_key_change: None,
869 password_hash: None,
870 rotation_interval: None,
871 on_rotation: None,
872 max_announcement_age: None,
873 });
874 alice.start().await.unwrap();
875
876 let mut bob = E2ESession::new(E2ESessionConfig {
877 identity_id: "bob".to_string(),
878 base_path: "/test/room/1".to_string(),
879 store: store_b,
880 on_key_change: None,
881 password_hash: None,
882 rotation_interval: None,
883 on_rotation: None,
884 max_announcement_age: None,
885 });
886 bob.start().await.unwrap();
887 bob.enable_encryption().await.unwrap();
888 let bob_announcement = bob.request_group_key().unwrap();
889
890 let msg = KeyExchangeMessage {
892 from_id: String::new(),
893 encrypted_key: "AAAA".to_string(),
894 iv: "BBBB".to_string(),
895 sender_public_key: serde_json::json!({}),
896 };
897 let result = alice.handle_key_exchange(&msg).await;
898 assert!(matches!(result, Err(CryptoError::InvalidKey(_))));
899
900 drop(bob_announcement);
902 }
903
904 #[tokio::test]
905 async fn encrypt_after_destroy_fails() {
906 let store = Arc::new(MemoryKeyStore::new());
907 let mut session = E2ESession::new(test_config(store));
908 session.start().await.unwrap();
909 session.enable_encryption().await.unwrap();
910 session.destroy();
911
912 let result = session.encrypt("test");
913 assert!(matches!(result, Err(CryptoError::SessionDestroyed)));
914 }
915
916 #[tokio::test]
917 async fn decrypt_after_destroy_fails() {
918 let store = Arc::new(MemoryKeyStore::new());
919 let mut session = E2ESession::new(test_config(store));
920 session.start().await.unwrap();
921 session.enable_encryption().await.unwrap();
922 let envelope = session.encrypt("test").unwrap();
923 session.destroy();
924
925 let result = session.decrypt(&envelope).await;
926 assert!(matches!(result, Err(CryptoError::SessionDestroyed)));
927 }
928
929 #[tokio::test]
930 async fn handle_peer_pubkey_after_destroy_fails() {
931 let store = Arc::new(MemoryKeyStore::new());
932 let mut session = E2ESession::new(test_config(store));
933 session.start().await.unwrap();
934 session.enable_encryption().await.unwrap();
935 session.destroy();
936
937 let bob_kp = primitives::generate_ecdh_key_pair();
938 let ann = PublicKeyAnnouncement {
939 public_key: primitives::public_key_to_jwk(&bob_kp.public_key).unwrap(),
940 timestamp: now_ms(),
941 };
942 let result = session.handle_peer_pubkey("bob", &ann).await;
943 assert!(matches!(result, Err(CryptoError::SessionDestroyed)));
944 }
945
946 #[tokio::test]
947 async fn handle_key_exchange_after_destroy_fails() {
948 let store = Arc::new(MemoryKeyStore::new());
949 let mut session = E2ESession::new(test_config(store));
950 session.start().await.unwrap();
951 session.destroy();
952
953 let msg = KeyExchangeMessage {
954 from_id: "bob".to_string(),
955 encrypted_key: "AAAA".to_string(),
956 iv: "BBBB".to_string(),
957 sender_public_key: serde_json::json!({}),
958 };
959 let result = session.handle_key_exchange(&msg).await;
960 assert!(matches!(result, Err(CryptoError::SessionDestroyed)));
961 }
962
963 #[tokio::test]
964 async fn decrypt_rejects_unknown_envelope_version() {
965 let store = Arc::new(MemoryKeyStore::new());
966 let mut session = E2ESession::new(test_config(store));
967 session.start().await.unwrap();
968 session.enable_encryption().await.unwrap();
969
970 let envelope = E2EEnvelope {
971 _e2e: 1,
972 ct: "AAAA".to_string(),
973 iv: "BBBB".to_string(),
974 v: 2,
975 };
976 let result = session.decrypt(&envelope).await;
977 assert!(matches!(result, Err(CryptoError::DecryptionFailed(_))));
978 }
979
980 #[tokio::test]
981 async fn should_rotate_returns_false_without_interval() {
982 let store = Arc::new(MemoryKeyStore::new());
983 let mut session = E2ESession::new(test_config(store));
984 session.start().await.unwrap();
985 session.enable_encryption().await.unwrap();
986 assert!(!session.should_rotate());
987 }
988
989 #[tokio::test]
990 async fn rotation_tracks_count_and_timestamp() {
991 let store = Arc::new(MemoryKeyStore::new());
992 let mut session = E2ESession::new(test_config(store));
993 session.start().await.unwrap();
994 session.enable_encryption().await.unwrap();
995
996 assert_eq!(session.rotation_count(), 0);
997 assert!(session.last_rotation().is_some()); session.rotate_key().await.unwrap();
1000 assert_eq!(session.rotation_count(), 1);
1001
1002 session.rotate_key().await.unwrap();
1003 assert_eq!(session.rotation_count(), 2);
1004 }
1005
1006 #[tokio::test]
1007 async fn maybe_rotate_triggers_when_due() {
1008 use std::sync::atomic::{AtomicU32, Ordering};
1009
1010 let store = Arc::new(MemoryKeyStore::new());
1011 let rotation_cb_count = Arc::new(AtomicU32::new(0));
1012 let cb_clone = rotation_cb_count.clone();
1013
1014 let mut session = E2ESession::new(E2ESessionConfig {
1015 identity_id: "alice".to_string(),
1016 base_path: "/test/room/1".to_string(),
1017 store,
1018 on_key_change: None,
1019 password_hash: None,
1020 rotation_interval: Some(Duration::from_secs(60)),
1021 on_rotation: Some(Arc::new(move || {
1022 cb_clone.fetch_add(1, Ordering::SeqCst);
1023 })),
1024 max_announcement_age: None,
1025 });
1026 session.start().await.unwrap();
1027 session.enable_encryption().await.unwrap();
1028
1029 let result = session.maybe_rotate().await.unwrap();
1031 assert!(result.is_none());
1032
1033 session.last_rotation = Some(now_ms() - 120_000);
1035 let result = session.maybe_rotate().await.unwrap();
1036 assert!(result.is_some());
1037 assert_eq!(rotation_cb_count.load(Ordering::SeqCst), 1);
1038 assert_eq!(session.rotation_count(), 1);
1039 }
1040
1041 #[tokio::test]
1042 async fn timestamp_validation_rejects_old_announcement() {
1043 let store = Arc::new(MemoryKeyStore::new());
1044 let mut session = E2ESession::new(E2ESessionConfig {
1045 identity_id: "alice".to_string(),
1046 base_path: "/test/room/1".to_string(),
1047 store,
1048 on_key_change: None,
1049 password_hash: None,
1050 rotation_interval: None,
1051 on_rotation: None,
1052 max_announcement_age: Some(Duration::from_secs(300)),
1053 });
1054 session.start().await.unwrap();
1055 session.enable_encryption().await.unwrap();
1056
1057 let bob_kp = primitives::generate_ecdh_key_pair();
1058 let old_announcement = PublicKeyAnnouncement {
1059 public_key: primitives::public_key_to_jwk(&bob_kp.public_key).unwrap(),
1060 timestamp: now_ms() - 600_000, };
1062 let result = session.handle_peer_pubkey("bob", &old_announcement).await;
1063 assert!(result.is_err());
1064 assert!(result.unwrap_err().to_string().contains("too old"));
1065 }
1066
1067 #[tokio::test]
1068 async fn timestamp_validation_rejects_future_announcement() {
1069 let store = Arc::new(MemoryKeyStore::new());
1070 let mut session = E2ESession::new(E2ESessionConfig {
1071 identity_id: "alice".to_string(),
1072 base_path: "/test/room/1".to_string(),
1073 store,
1074 on_key_change: None,
1075 password_hash: None,
1076 rotation_interval: None,
1077 on_rotation: None,
1078 max_announcement_age: Some(Duration::from_secs(300)),
1079 });
1080 session.start().await.unwrap();
1081 session.enable_encryption().await.unwrap();
1082
1083 let bob_kp = primitives::generate_ecdh_key_pair();
1084 let future_announcement = PublicKeyAnnouncement {
1085 public_key: primitives::public_key_to_jwk(&bob_kp.public_key).unwrap(),
1086 timestamp: now_ms() + 60_000, };
1088 let result = session
1089 .handle_peer_pubkey("bob", &future_announcement)
1090 .await;
1091 assert!(result.is_err());
1092 assert!(result.unwrap_err().to_string().contains("future"));
1093 }
1094
1095 #[tokio::test]
1096 async fn replay_protection_rejects_duplicate_key_exchange() {
1097 let store_a = Arc::new(MemoryKeyStore::new());
1098 let store_b = Arc::new(MemoryKeyStore::new());
1099
1100 let mut alice = E2ESession::new(E2ESessionConfig {
1102 identity_id: "alice".to_string(),
1103 base_path: "/test/room/1".to_string(),
1104 store: store_a,
1105 on_key_change: None,
1106 password_hash: None,
1107 rotation_interval: None,
1108 on_rotation: None,
1109 max_announcement_age: None,
1110 });
1111 alice.start().await.unwrap();
1112 alice.enable_encryption().await.unwrap();
1113
1114 let mut bob = E2ESession::new(E2ESessionConfig {
1116 identity_id: "bob".to_string(),
1117 base_path: "/test/room/1".to_string(),
1118 store: store_b,
1119 on_key_change: None,
1120 password_hash: None,
1121 rotation_interval: None,
1122 on_rotation: None,
1123 max_announcement_age: None,
1124 });
1125 bob.start().await.unwrap();
1126 let bob_announcement = bob.request_group_key().unwrap().unwrap();
1127
1128 let keyex = alice
1130 .handle_peer_pubkey("bob", &bob_announcement)
1131 .await
1132 .unwrap()
1133 .unwrap();
1134
1135 bob.handle_key_exchange(&keyex).await.unwrap();
1137
1138 let result = bob.handle_key_exchange(&keyex).await;
1140 assert!(result.is_err());
1141 assert!(result.unwrap_err().to_string().contains("replayed"));
1142 }
1143
1144 #[tokio::test]
1145 async fn persisted_key_restores_last_rotation() {
1146 let store = Arc::new(MemoryKeyStore::new());
1147
1148 let mut session1 = E2ESession::new(test_config(store.clone()));
1149 session1.start().await.unwrap();
1150 session1.enable_encryption().await.unwrap();
1151 let rotation_ts = session1.last_rotation().unwrap();
1152 assert!(rotation_ts > 0);
1153
1154 let mut session2 = E2ESession::new(test_config(store));
1155 session2.start().await.unwrap();
1156 assert_eq!(session2.last_rotation(), Some(rotation_ts));
1157 }
1158}