Skip to main content

clasp_crypto/
protocol.rs

1//! E2E encryption session — manages key exchange for one group/room/channel.
2
3use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5use std::time::{Duration, SystemTime, UNIX_EPOCH};
6
7use base64::{engine::general_purpose::STANDARD as B64, Engine};
8use zeroize::{Zeroize, Zeroizing};
9
10use crate::error::{CryptoError, Result};
11use crate::primitives;
12use crate::storage::KeyStore;
13use crate::types::{
14    E2EEnvelope, ECDHKeyPair, KeyData, KeyExchangeMessage, PublicKeyAnnouncement, TofuRecord,
15};
16
17/// Callback for TOFU key change events.
18/// Must return `true` to accept the new key, or `false` to reject.
19/// If absent, key changes are rejected by default.
20pub type OnKeyChange = Arc<dyn Fn(&str, &str, &str) -> bool + Send + Sync>;
21
22/// Minimum allowed rotation interval (60 seconds).
23const MIN_ROTATION_INTERVAL: Duration = Duration::from_secs(60);
24
25/// Default maximum future tolerance for key announcements (30 seconds).
26const DEFAULT_MAX_ANNOUNCEMENT_FUTURE: Duration = Duration::from_secs(30);
27
28/// Configuration for an E2E session.
29pub struct E2ESessionConfig {
30    pub identity_id: String,
31    pub base_path: String,
32    pub store: Arc<dyn KeyStore>,
33    pub on_key_change: Option<OnKeyChange>,
34    pub password_hash: Option<String>,
35    /// Automatic key rotation interval. If set, `maybe_rotate()` will
36    /// trigger rotation when this duration has elapsed since the last rotation.
37    /// Minimum enforced: 60 seconds.
38    pub rotation_interval: Option<Duration>,
39    /// Called after automatic key rotation completes.
40    pub on_rotation: Option<Arc<dyn Fn() + Send + Sync>>,
41    /// Maximum age of a peer's public key announcement before rejection.
42    /// Default: 5 minutes. Set to `None` to disable timestamp validation.
43    pub max_announcement_age: Option<Duration>,
44}
45
46/// E2E encryption session state machine.
47///
48/// Manages ECDH key exchange, group key distribution, TOFU verification,
49/// and encrypt/decrypt operations. This is the protocol layer that works
50/// with raw bytes and messages — the CryptoClient (behind `client` feature)
51/// wires it to a CLASP client.
52pub struct E2ESession {
53    config: E2ESessionConfig,
54    group_key: Option<Vec<u8>>,
55    ecdh_key_pair: Option<ECDHKeyPair>,
56    peer_public_keys: HashMap<String, Vec<u8>>,
57    started: bool,
58    destroyed: bool,
59    /// Timestamp of the last key rotation (Unix ms).
60    last_rotation: Option<u64>,
61    /// Number of key rotations performed in this session.
62    rotation_count: u64,
63    /// Set of seen nonces ("{from_id}:{iv}") for replay protection.
64    seen_nonces: HashSet<String>,
65}
66
67impl E2ESession {
68    pub fn new(config: E2ESessionConfig) -> Self {
69        Self {
70            config,
71            group_key: None,
72            ecdh_key_pair: None,
73            peer_public_keys: HashMap::new(),
74            started: false,
75            destroyed: false,
76            last_rotation: None,
77            rotation_count: 0,
78            seen_nonces: HashSet::new(),
79        }
80    }
81
82    /// Whether this session has an active group key.
83    pub fn encrypted(&self) -> bool {
84        self.group_key.is_some()
85    }
86
87    /// The base path for this session's E2E subpaths.
88    pub fn base_path(&self) -> &str {
89        &self.config.base_path
90    }
91
92    /// Start the session: attempt to load a persisted group key.
93    pub async fn start(&mut self) -> Result<()> {
94        if self.destroyed {
95            return Err(CryptoError::SessionDestroyed);
96        }
97        if self.started {
98            return Ok(());
99        }
100        self.started = true;
101
102        // Try loading persisted key (stored as JWK)
103        let session_id = self.session_id();
104        if let Some(data) = self.config.store.load_group_key(&session_id).await? {
105            match primitives::jwk_to_group_key(&data.key) {
106                Ok(key) => {
107                    self.group_key = Some(key);
108                    // Restore rotation timestamp from persisted key
109                    self.last_rotation = Some(data.stored_at);
110                }
111                Err(_) => {
112                    self.config.store.delete_group_key(&session_id).await?;
113                }
114            }
115        }
116
117        Ok(())
118    }
119
120    /// Enable encryption: generate a new group key.
121    /// Returns a PublicKeyAnnouncement to be published via CLASP.
122    pub async fn enable_encryption(&mut self) -> Result<PublicKeyAnnouncement> {
123        if self.destroyed {
124            return Err(CryptoError::SessionDestroyed);
125        }
126
127        let mut key = Zeroizing::new(primitives::generate_group_key());
128        self.group_key = Some(key.to_vec());
129
130        let creation_time = now_ms();
131        self.last_rotation = Some(creation_time);
132
133        // Persist as JWK for JS interop
134        let jwk = primitives::group_key_to_jwk(&key)?;
135        key.zeroize();
136        self.config
137            .store
138            .save_group_key(
139                &self.session_id(),
140                KeyData {
141                    key: jwk,
142                    stored_at: creation_time,
143                },
144            )
145            .await?;
146
147        self.make_public_key_announcement()
148    }
149
150    /// Create a public key announcement (for requestGroupKey).
151    pub fn request_group_key(&mut self) -> Result<Option<PublicKeyAnnouncement>> {
152        if self.destroyed {
153            return Err(CryptoError::SessionDestroyed);
154        }
155        if self.group_key.is_some() {
156            return Ok(None);
157        }
158        self.make_public_key_announcement().map(Some)
159    }
160
161    /// Encrypt a string value into an E2EEnvelope.
162    pub fn encrypt(&self, value: &str) -> Result<E2EEnvelope> {
163        if self.destroyed {
164            return Err(CryptoError::SessionDestroyed);
165        }
166        let key = self.group_key.as_ref().ok_or(CryptoError::NoGroupKey)?;
167        let plaintext = value.as_bytes();
168        let (ciphertext, iv) = primitives::encrypt(key, plaintext)?;
169        Ok(E2EEnvelope {
170            _e2e: 1,
171            ct: B64.encode(&ciphertext),
172            iv: B64.encode(&iv),
173            v: 1,
174        })
175    }
176
177    /// Decrypt an E2EEnvelope back to a string.
178    pub async fn decrypt(&mut self, envelope: &E2EEnvelope) -> Result<String> {
179        if self.destroyed {
180            return Err(CryptoError::SessionDestroyed);
181        }
182        if envelope._e2e != 1 {
183            return Err(CryptoError::DecryptionFailed("invalid E2E marker".into()));
184        }
185        if envelope.v != 1 {
186            return Err(CryptoError::DecryptionFailed(
187                "unsupported envelope version".into(),
188            ));
189        }
190        let key = Zeroizing::new(match &self.group_key {
191            Some(k) => k.clone(),
192            None => {
193                let session_id = self.session_id();
194                match self.config.store.load_group_key(&session_id).await? {
195                    Some(data) => {
196                        let k = primitives::jwk_to_group_key(&data.key)?;
197                        self.group_key = Some(k.clone());
198                        k
199                    }
200                    None => return Err(CryptoError::NoGroupKey),
201                }
202            }
203        });
204
205        let ciphertext = B64
206            .decode(&envelope.ct)
207            .map_err(|e| CryptoError::DecryptionFailed(format!("invalid base64 ct: {e}")))?;
208        let iv = B64
209            .decode(&envelope.iv)
210            .map_err(|e| CryptoError::DecryptionFailed(format!("invalid base64 iv: {e}")))?;
211        let plaintext = primitives::decrypt(&key, &ciphertext, &iv)?;
212        String::from_utf8(plaintext)
213            .map_err(|e| CryptoError::DecryptionFailed(format!("invalid UTF-8: {e}")))
214    }
215
216    /// Handle a peer's public key announcement.
217    /// Returns a KeyExchangeMessage if we have the group key and should distribute it.
218    ///
219    /// **Password-gated sessions**: If `password_hash` is set, the caller must
220    /// verify the peer's password proof *before* calling this method. This method
221    /// does not enforce password gating — it is the caller's responsibility.
222    pub async fn handle_peer_pubkey(
223        &mut self,
224        peer_id: &str,
225        announcement: &PublicKeyAnnouncement,
226    ) -> Result<Option<KeyExchangeMessage>> {
227        if self.destroyed {
228            return Err(CryptoError::SessionDestroyed);
229        }
230        if peer_id == self.config.identity_id {
231            return Ok(None);
232        }
233
234        // Timestamp validation: reject announcements that are too old or too far in the future
235        if let Some(max_age) = self.config.max_announcement_age {
236            let now = now_ms();
237            let ts = announcement.timestamp;
238            let max_age_ms = max_age.as_millis() as u64;
239            let max_future_ms = DEFAULT_MAX_ANNOUNCEMENT_FUTURE.as_millis() as u64;
240            if ts + max_age_ms < now {
241                return Err(CryptoError::Other(format!(
242                    "announcement from {peer_id} is too old ({} ms)",
243                    now.saturating_sub(ts)
244                )));
245            }
246            if ts > now + max_future_ms {
247                return Err(CryptoError::Other(format!(
248                    "announcement from {peer_id} is too far in the future ({} ms)",
249                    ts.saturating_sub(now)
250                )));
251            }
252        }
253
254        // Convert JWK to SEC1 for internal crypto operations
255        let peer_pub_bytes = primitives::jwk_to_public_key(&announcement.public_key)?;
256
257        // TOFU verification (uses JWK fingerprint for JS interop)
258        self.verify_peer_key(peer_id, &announcement.public_key)
259            .await?;
260
261        // Cache SEC1 bytes for crypto
262        self.peer_public_keys
263            .insert(peer_id.to_string(), peer_pub_bytes.clone());
264
265        // Only distribute if we have the group key
266        let group_key = Zeroizing::new(match &self.group_key {
267            Some(k) => k.clone(),
268            None => return Ok(None),
269        });
270
271        // Derive shared key and encrypt the group key as JWK JSON (JS interop)
272        self.ensure_ecdh_key_pair();
273        let kp = self.ecdh_key_pair();
274        let shared = Zeroizing::new(primitives::derive_shared_key(
275            &kp.private_key,
276            &peer_pub_bytes,
277            None,
278        )?);
279        let group_key_jwk = primitives::group_key_to_jwk(&group_key)?;
280        let mut group_key_json = Zeroizing::new(
281            serde_json::to_string(&group_key_jwk)
282                .map_err(|e| CryptoError::Serialization(e.to_string()))?,
283        );
284        let (ct, iv) = primitives::encrypt(&shared, group_key_json.as_bytes())?;
285        group_key_json.zeroize();
286
287        let sender_pub_jwk = primitives::public_key_to_jwk(&kp.public_key)?;
288
289        Ok(Some(KeyExchangeMessage {
290            from_id: self.config.identity_id.clone(),
291            encrypted_key: B64.encode(&ct),
292            iv: B64.encode(&iv),
293            sender_public_key: sender_pub_jwk,
294        }))
295    }
296
297    /// Handle a key exchange message sent to us.
298    /// Decrypts and stores the group key.
299    pub async fn handle_key_exchange(&mut self, msg: &KeyExchangeMessage) -> Result<()> {
300        if self.destroyed {
301            return Err(CryptoError::SessionDestroyed);
302        }
303        // Reject empty sender ID — prevents TOFU bypass
304        if msg.from_id.is_empty() {
305            return Err(CryptoError::InvalidKey(
306                "key exchange message missing sender ID".into(),
307            ));
308        }
309
310        // Replay protection: reject messages with previously seen nonces
311        let nonce_key = format!("{}:{}", msg.from_id, msg.iv);
312        if !self.seen_nonces.insert(nonce_key) {
313            return Err(CryptoError::Other(format!(
314                "replayed key exchange message from {}",
315                msg.from_id
316            )));
317        }
318        // Cap nonce set size to prevent unbounded growth
319        if self.seen_nonces.len() > 10_000 {
320            self.seen_nonces.clear();
321        }
322
323        let sender_pub = primitives::jwk_to_public_key(&msg.sender_public_key)?;
324
325        // TOFU verify sender (uses JWK fingerprint)
326        self.verify_peer_key(&msg.from_id, &msg.sender_public_key)
327            .await?;
328
329        // Cache sender's public key for future key rotations
330        self.peer_public_keys
331            .insert(msg.from_id.clone(), sender_pub.clone());
332
333        self.ensure_ecdh_key_pair();
334        let kp = self.ecdh_key_pair();
335        let shared = Zeroizing::new(primitives::derive_shared_key(
336            &kp.private_key,
337            &sender_pub,
338            None,
339        )?);
340
341        let ct = B64
342            .decode(&msg.encrypted_key)
343            .map_err(|e| CryptoError::DecryptionFailed(format!("invalid base64: {e}")))?;
344        let iv = B64
345            .decode(&msg.iv)
346            .map_err(|e| CryptoError::DecryptionFailed(format!("invalid base64: {e}")))?;
347
348        let decrypted = primitives::decrypt(&shared, &ct, &iv)?;
349        let mut key_json = Zeroizing::new(
350            String::from_utf8(decrypted)
351                .map_err(|e| CryptoError::DecryptionFailed(format!("invalid UTF-8: {e}")))?,
352        );
353        let key_jwk: serde_json::Value = serde_json::from_str(&key_json)
354            .map_err(|e| CryptoError::DecryptionFailed(format!("invalid JWK JSON: {e}")))?;
355        key_json.zeroize();
356        let mut group_key = Zeroizing::new(primitives::jwk_to_group_key(&key_jwk)?);
357
358        self.group_key = Some(group_key.to_vec());
359        group_key.zeroize();
360
361        // Persist as JWK
362        self.config
363            .store
364            .save_group_key(
365                &self.session_id(),
366                KeyData {
367                    key: key_jwk,
368                    stored_at: now_ms(),
369                },
370            )
371            .await?;
372
373        Ok(())
374    }
375
376    /// Rotate the group key. Returns KeyExchangeMessages for all cached peers.
377    pub async fn rotate_key(&mut self) -> Result<Vec<(String, KeyExchangeMessage)>> {
378        if self.destroyed {
379            return Err(CryptoError::SessionDestroyed);
380        }
381        if self.group_key.is_none() {
382            return Ok(vec![]);
383        }
384
385        // Zeroize old key before replacing
386        if let Some(ref mut old_key) = self.group_key {
387            old_key.zeroize();
388        }
389        let mut new_key = Zeroizing::new(primitives::generate_group_key());
390        self.group_key = Some(new_key.to_vec());
391
392        let rotation_time = now_ms();
393        self.last_rotation = Some(rotation_time);
394        self.rotation_count += 1;
395
396        let jwk = primitives::group_key_to_jwk(&new_key)?;
397        let mut group_key_json = Zeroizing::new(
398            serde_json::to_string(&jwk).map_err(|e| CryptoError::Serialization(e.to_string()))?,
399        );
400        new_key.zeroize();
401        self.config
402            .store
403            .save_group_key(
404                &self.session_id(),
405                KeyData {
406                    key: jwk,
407                    stored_at: rotation_time,
408                },
409            )
410            .await?;
411
412        // Distribute to all cached peers
413        self.ensure_ecdh_key_pair();
414        let kp = self.ecdh_key_pair();
415        let sender_pub_jwk = primitives::public_key_to_jwk(&kp.public_key)?;
416        let mut messages = Vec::new();
417
418        for (peer_id, peer_pub) in &self.peer_public_keys {
419            if *peer_id == self.config.identity_id {
420                continue;
421            }
422            if let Ok(shared) = primitives::derive_shared_key(&kp.private_key, peer_pub, None) {
423                let mut shared = Zeroizing::new(shared);
424                if let Ok((ct, iv)) = primitives::encrypt(&shared, group_key_json.as_bytes()) {
425                    messages.push((
426                        peer_id.clone(),
427                        KeyExchangeMessage {
428                            from_id: self.config.identity_id.clone(),
429                            encrypted_key: B64.encode(&ct),
430                            iv: B64.encode(&iv),
431                            sender_public_key: sender_pub_jwk.clone(),
432                        },
433                    ));
434                }
435                shared.zeroize();
436            }
437        }
438        group_key_json.zeroize();
439
440        Ok(messages)
441    }
442
443    /// Remove a peer's cached public key.
444    pub fn remove_peer(&mut self, peer_id: &str) {
445        self.peer_public_keys.remove(peer_id);
446    }
447
448    /// Check whether automatic rotation is due.
449    pub fn should_rotate(&self) -> bool {
450        let interval = match self.config.rotation_interval {
451            Some(d) => d.max(MIN_ROTATION_INTERVAL),
452            None => return false,
453        };
454        if self.group_key.is_none() || self.destroyed {
455            return false;
456        }
457        let last = self.last_rotation.unwrap_or(0);
458        if last == 0 {
459            return false;
460        }
461        let elapsed_ms = now_ms().saturating_sub(last);
462        elapsed_ms >= interval.as_millis() as u64
463    }
464
465    /// Rotate the key if the rotation interval has elapsed.
466    /// Returns any key exchange messages to distribute, plus a new
467    /// `PublicKeyAnnouncement` so new peers can request the fresh key.
468    pub async fn maybe_rotate(
469        &mut self,
470    ) -> Result<Option<(Vec<(String, KeyExchangeMessage)>, PublicKeyAnnouncement)>> {
471        if !self.should_rotate() {
472            return Ok(None);
473        }
474        let messages = self.rotate_key().await?;
475        let announcement = self.make_public_key_announcement()?;
476        if let Some(ref cb) = self.config.on_rotation {
477            cb();
478        }
479        Ok(Some((messages, announcement)))
480    }
481
482    /// Number of key rotations performed in this session.
483    pub fn rotation_count(&self) -> u64 {
484        self.rotation_count
485    }
486
487    /// Timestamp of the last key rotation (Unix ms), if any.
488    pub fn last_rotation(&self) -> Option<u64> {
489        self.last_rotation
490    }
491
492    /// Destroy the session, zeroing all key material.
493    pub fn destroy(&mut self) {
494        self.destroyed = true;
495        if let Some(ref mut key) = self.group_key {
496            key.zeroize();
497        }
498        self.group_key = None;
499        // ECDHKeyPair implements ZeroizeOnDrop, so dropping clears it
500        self.ecdh_key_pair = None;
501        self.peer_public_keys.clear();
502        self.seen_nonces.clear();
503    }
504
505    // --- Private ---
506
507    fn session_id(&self) -> String {
508        self.config.base_path.clone()
509    }
510
511    /// Ensure the ECDH key pair is initialized. Call this before accessing
512    /// `self.ecdh_key_pair` to avoid unnecessary cloning of private key material.
513    fn ensure_ecdh_key_pair(&mut self) {
514        if self.ecdh_key_pair.is_none() {
515            self.ecdh_key_pair = Some(primitives::generate_ecdh_key_pair());
516        }
517    }
518
519    /// Access the ECDH key pair (must call ensure_ecdh_key_pair first).
520    fn ecdh_key_pair(&self) -> &ECDHKeyPair {
521        self.ecdh_key_pair.as_ref().unwrap()
522    }
523
524    fn make_public_key_announcement(&mut self) -> Result<PublicKeyAnnouncement> {
525        self.ensure_ecdh_key_pair();
526        let kp = self.ecdh_key_pair();
527        let jwk = primitives::public_key_to_jwk(&kp.public_key)?;
528        Ok(PublicKeyAnnouncement {
529            public_key: jwk,
530            timestamp: now_ms(),
531        })
532    }
533
534    /// TOFU verification using JWK fingerprint (matches JS implementation).
535    /// Always stores the record on first use. On key change, calls the
536    /// `on_key_change` callback which must return `true` to accept.
537    /// If no callback is set, key changes are rejected.
538    async fn verify_peer_key(
539        &self,
540        peer_id: &str,
541        public_key_jwk: &serde_json::Value,
542    ) -> Result<()> {
543        let fp = primitives::fingerprint_jwk(public_key_jwk);
544        let record_id = format!("{}:{}", self.config.base_path, peer_id);
545
546        let stored = self.config.store.load_tofu_record(&record_id).await?;
547
548        match stored {
549            None => {
550                // First time — trust on first use
551                self.config
552                    .store
553                    .save_tofu_record(
554                        &record_id,
555                        TofuRecord {
556                            fingerprint: fp,
557                            first_seen: now_ms(),
558                        },
559                    )
560                    .await?;
561            }
562            Some(record) => {
563                if !primitives::constant_time_eq(record.fingerprint.as_bytes(), fp.as_bytes()) {
564                    // Key changed — check if caller accepts
565                    let accepted = self
566                        .config
567                        .on_key_change
568                        .as_ref()
569                        .map(|cb| cb(peer_id, &record.fingerprint, &fp))
570                        .unwrap_or(false);
571                    if !accepted {
572                        return Err(CryptoError::TofuViolation(peer_id.to_string()));
573                    }
574                    // Update the stored record to the new fingerprint,
575                    // preserving original first_seen
576                    self.config
577                        .store
578                        .save_tofu_record(
579                            &record_id,
580                            TofuRecord {
581                                fingerprint: fp,
582                                first_seen: record.first_seen,
583                            },
584                        )
585                        .await?;
586                }
587            }
588        }
589
590        Ok(())
591    }
592}
593
594impl Drop for E2ESession {
595    fn drop(&mut self) {
596        if !self.destroyed {
597            self.destroy();
598        }
599    }
600}
601
602fn now_ms() -> u64 {
603    SystemTime::now()
604        .duration_since(UNIX_EPOCH)
605        .unwrap_or_default()
606        .as_millis() as u64
607}
608
609#[cfg(test)]
610mod tests {
611    use super::*;
612    use crate::storage::MemoryKeyStore;
613
614    fn test_config(store: Arc<dyn KeyStore>) -> E2ESessionConfig {
615        E2ESessionConfig {
616            identity_id: "alice".to_string(),
617            base_path: "/test/room/1".to_string(),
618            store,
619            on_key_change: None,
620            password_hash: None,
621            rotation_interval: None,
622            on_rotation: None,
623            max_announcement_age: None,
624        }
625    }
626
627    #[tokio::test]
628    async fn session_starts_without_key() {
629        let store = Arc::new(MemoryKeyStore::new());
630        let mut session = E2ESession::new(test_config(store));
631        session.start().await.unwrap();
632        assert!(!session.encrypted());
633    }
634
635    #[tokio::test]
636    async fn enable_encryption_creates_key() {
637        let store = Arc::new(MemoryKeyStore::new());
638        let mut session = E2ESession::new(test_config(store));
639        session.start().await.unwrap();
640        session.enable_encryption().await.unwrap();
641        assert!(session.encrypted());
642    }
643
644    #[tokio::test]
645    async fn encrypt_decrypt_round_trip() {
646        let store = Arc::new(MemoryKeyStore::new());
647        let mut session = E2ESession::new(test_config(store));
648        session.start().await.unwrap();
649        session.enable_encryption().await.unwrap();
650
651        let envelope = session.encrypt("Hello, world!").unwrap();
652        assert_eq!(envelope._e2e, 1);
653        assert_eq!(envelope.v, 1);
654
655        let decrypted = session.decrypt(&envelope).await.unwrap();
656        assert_eq!(decrypted, "Hello, world!");
657    }
658
659    #[tokio::test]
660    async fn persists_and_loads_key() {
661        let store = Arc::new(MemoryKeyStore::new());
662
663        let mut session1 = E2ESession::new(test_config(store.clone()));
664        session1.start().await.unwrap();
665        session1.enable_encryption().await.unwrap();
666        let envelope = session1.encrypt("hello").unwrap();
667
668        let mut session2 = E2ESession::new(test_config(store));
669        session2.start().await.unwrap();
670        assert!(session2.encrypted());
671
672        let decrypted = session2.decrypt(&envelope).await.unwrap();
673        assert_eq!(decrypted, "hello");
674    }
675
676    #[tokio::test]
677    async fn key_exchange_between_peers() {
678        let store_a = Arc::new(MemoryKeyStore::new());
679        let store_b = Arc::new(MemoryKeyStore::new());
680
681        let mut alice = E2ESession::new(E2ESessionConfig {
682            identity_id: "alice".to_string(),
683            base_path: "/test/room/1".to_string(),
684            store: store_a,
685            on_key_change: None,
686            password_hash: None,
687            rotation_interval: None,
688            on_rotation: None,
689            max_announcement_age: None,
690        });
691        alice.start().await.unwrap();
692        let _alice_announcement = alice.enable_encryption().await.unwrap();
693
694        let mut bob = E2ESession::new(E2ESessionConfig {
695            identity_id: "bob".to_string(),
696            base_path: "/test/room/1".to_string(),
697            store: store_b,
698            on_key_change: None,
699            password_hash: None,
700            rotation_interval: None,
701            on_rotation: None,
702            max_announcement_age: None,
703        });
704        bob.start().await.unwrap();
705        let bob_announcement = bob.request_group_key().unwrap().unwrap();
706
707        let keyex = alice
708            .handle_peer_pubkey("bob", &bob_announcement)
709            .await
710            .unwrap();
711        assert!(keyex.is_some());
712
713        bob.handle_key_exchange(&keyex.unwrap()).await.unwrap();
714        assert!(bob.encrypted());
715
716        let envelope = alice.encrypt("secret message").unwrap();
717        let decrypted = bob.decrypt(&envelope).await.unwrap();
718        assert_eq!(decrypted, "secret message");
719    }
720
721    #[tokio::test]
722    async fn rotate_key_invalidates_old_messages() {
723        let store = Arc::new(MemoryKeyStore::new());
724        let mut session = E2ESession::new(test_config(store));
725        session.start().await.unwrap();
726        session.enable_encryption().await.unwrap();
727
728        let old_envelope = session.encrypt("before rotation").unwrap();
729        session.rotate_key().await.unwrap();
730
731        let result = session.decrypt(&old_envelope).await;
732        assert!(result.is_err());
733    }
734
735    #[tokio::test]
736    async fn tofu_detects_key_change_and_accepts() {
737        use std::sync::atomic::{AtomicBool, Ordering};
738
739        let store = Arc::new(MemoryKeyStore::new());
740        let changed = Arc::new(AtomicBool::new(false));
741        let changed_clone = changed.clone();
742
743        let mut session = E2ESession::new(E2ESessionConfig {
744            identity_id: "alice".to_string(),
745            base_path: "/test/room/1".to_string(),
746            store: store.clone(),
747            on_key_change: Some(Arc::new(move |_peer, _old, _new| {
748                changed_clone.store(true, Ordering::SeqCst);
749                true // accept the key change
750            })),
751            password_hash: None,
752            rotation_interval: None,
753            on_rotation: None,
754            max_announcement_age: None,
755        });
756        session.start().await.unwrap();
757        session.enable_encryption().await.unwrap();
758
759        // First key from Bob — TOFU, trusted
760        let bob_kp1 = primitives::generate_ecdh_key_pair();
761        let ann1 = PublicKeyAnnouncement {
762            public_key: primitives::public_key_to_jwk(&bob_kp1.public_key).unwrap(),
763            timestamp: now_ms(),
764        };
765        session.handle_peer_pubkey("bob", &ann1).await.unwrap();
766        assert!(!changed.load(Ordering::SeqCst));
767
768        // Different key from Bob — should trigger change, accepted by callback
769        let bob_kp2 = primitives::generate_ecdh_key_pair();
770        let ann2 = PublicKeyAnnouncement {
771            public_key: primitives::public_key_to_jwk(&bob_kp2.public_key).unwrap(),
772            timestamp: now_ms(),
773        };
774        session.handle_peer_pubkey("bob", &ann2).await.unwrap();
775        assert!(changed.load(Ordering::SeqCst));
776    }
777
778    #[tokio::test]
779    async fn tofu_rejects_key_change_without_callback() {
780        let store = Arc::new(MemoryKeyStore::new());
781
782        // No onKeyChange callback — key changes should be rejected
783        let mut session = E2ESession::new(test_config(store.clone()));
784        session.start().await.unwrap();
785        session.enable_encryption().await.unwrap();
786
787        // First key from Bob — TOFU, trusted
788        let bob_kp1 = primitives::generate_ecdh_key_pair();
789        let ann1 = PublicKeyAnnouncement {
790            public_key: primitives::public_key_to_jwk(&bob_kp1.public_key).unwrap(),
791            timestamp: now_ms(),
792        };
793        session.handle_peer_pubkey("bob", &ann1).await.unwrap();
794
795        // Different key from Bob — should be rejected (no callback)
796        let bob_kp2 = primitives::generate_ecdh_key_pair();
797        let ann2 = PublicKeyAnnouncement {
798            public_key: primitives::public_key_to_jwk(&bob_kp2.public_key).unwrap(),
799            timestamp: now_ms(),
800        };
801        let result = session.handle_peer_pubkey("bob", &ann2).await;
802        assert!(matches!(result, Err(CryptoError::TofuViolation(_))));
803    }
804
805    #[tokio::test]
806    async fn tofu_rejects_key_change_when_callback_returns_false() {
807        let store = Arc::new(MemoryKeyStore::new());
808
809        let mut session = E2ESession::new(E2ESessionConfig {
810            identity_id: "alice".to_string(),
811            base_path: "/test/room/1".to_string(),
812            store: store.clone(),
813            on_key_change: Some(Arc::new(|_peer, _old, _new| false)),
814            password_hash: None,
815            rotation_interval: None,
816            on_rotation: None,
817            max_announcement_age: None,
818        });
819        session.start().await.unwrap();
820        session.enable_encryption().await.unwrap();
821
822        let bob_kp1 = primitives::generate_ecdh_key_pair();
823        let ann1 = PublicKeyAnnouncement {
824            public_key: primitives::public_key_to_jwk(&bob_kp1.public_key).unwrap(),
825            timestamp: now_ms(),
826        };
827        session.handle_peer_pubkey("bob", &ann1).await.unwrap();
828
829        let bob_kp2 = primitives::generate_ecdh_key_pair();
830        let ann2 = PublicKeyAnnouncement {
831            public_key: primitives::public_key_to_jwk(&bob_kp2.public_key).unwrap(),
832            timestamp: now_ms(),
833        };
834        let result = session.handle_peer_pubkey("bob", &ann2).await;
835        assert!(matches!(result, Err(CryptoError::TofuViolation(_))));
836    }
837
838    #[tokio::test]
839    async fn tofu_stores_records_without_callback() {
840        let store = Arc::new(MemoryKeyStore::new());
841
842        // No onKeyChange callback
843        let mut session = E2ESession::new(test_config(store.clone()));
844        session.start().await.unwrap();
845        session.enable_encryption().await.unwrap();
846
847        let bob_kp = primitives::generate_ecdh_key_pair();
848        let ann = PublicKeyAnnouncement {
849            public_key: primitives::public_key_to_jwk(&bob_kp.public_key).unwrap(),
850            timestamp: now_ms(),
851        };
852        session.handle_peer_pubkey("bob", &ann).await.unwrap();
853
854        // TOFU record should be stored even without callback
855        let record = store.load_tofu_record("/test/room/1:bob").await.unwrap();
856        assert!(record.is_some());
857    }
858
859    #[tokio::test]
860    async fn empty_from_id_rejected() {
861        let store_a = Arc::new(MemoryKeyStore::new());
862        let store_b = Arc::new(MemoryKeyStore::new());
863
864        let mut alice = E2ESession::new(E2ESessionConfig {
865            identity_id: "alice".to_string(),
866            base_path: "/test/room/1".to_string(),
867            store: store_a,
868            on_key_change: None,
869            password_hash: None,
870            rotation_interval: None,
871            on_rotation: None,
872            max_announcement_age: None,
873        });
874        alice.start().await.unwrap();
875
876        let mut bob = E2ESession::new(E2ESessionConfig {
877            identity_id: "bob".to_string(),
878            base_path: "/test/room/1".to_string(),
879            store: store_b,
880            on_key_change: None,
881            password_hash: None,
882            rotation_interval: None,
883            on_rotation: None,
884            max_announcement_age: None,
885        });
886        bob.start().await.unwrap();
887        bob.enable_encryption().await.unwrap();
888        let bob_announcement = bob.request_group_key().unwrap();
889
890        // Craft a message with empty from_id
891        let msg = KeyExchangeMessage {
892            from_id: String::new(),
893            encrypted_key: "AAAA".to_string(),
894            iv: "BBBB".to_string(),
895            sender_public_key: serde_json::json!({}),
896        };
897        let result = alice.handle_key_exchange(&msg).await;
898        assert!(matches!(result, Err(CryptoError::InvalidKey(_))));
899
900        // Suppress unused variable warning
901        drop(bob_announcement);
902    }
903
904    #[tokio::test]
905    async fn encrypt_after_destroy_fails() {
906        let store = Arc::new(MemoryKeyStore::new());
907        let mut session = E2ESession::new(test_config(store));
908        session.start().await.unwrap();
909        session.enable_encryption().await.unwrap();
910        session.destroy();
911
912        let result = session.encrypt("test");
913        assert!(matches!(result, Err(CryptoError::SessionDestroyed)));
914    }
915
916    #[tokio::test]
917    async fn decrypt_after_destroy_fails() {
918        let store = Arc::new(MemoryKeyStore::new());
919        let mut session = E2ESession::new(test_config(store));
920        session.start().await.unwrap();
921        session.enable_encryption().await.unwrap();
922        let envelope = session.encrypt("test").unwrap();
923        session.destroy();
924
925        let result = session.decrypt(&envelope).await;
926        assert!(matches!(result, Err(CryptoError::SessionDestroyed)));
927    }
928
929    #[tokio::test]
930    async fn handle_peer_pubkey_after_destroy_fails() {
931        let store = Arc::new(MemoryKeyStore::new());
932        let mut session = E2ESession::new(test_config(store));
933        session.start().await.unwrap();
934        session.enable_encryption().await.unwrap();
935        session.destroy();
936
937        let bob_kp = primitives::generate_ecdh_key_pair();
938        let ann = PublicKeyAnnouncement {
939            public_key: primitives::public_key_to_jwk(&bob_kp.public_key).unwrap(),
940            timestamp: now_ms(),
941        };
942        let result = session.handle_peer_pubkey("bob", &ann).await;
943        assert!(matches!(result, Err(CryptoError::SessionDestroyed)));
944    }
945
946    #[tokio::test]
947    async fn handle_key_exchange_after_destroy_fails() {
948        let store = Arc::new(MemoryKeyStore::new());
949        let mut session = E2ESession::new(test_config(store));
950        session.start().await.unwrap();
951        session.destroy();
952
953        let msg = KeyExchangeMessage {
954            from_id: "bob".to_string(),
955            encrypted_key: "AAAA".to_string(),
956            iv: "BBBB".to_string(),
957            sender_public_key: serde_json::json!({}),
958        };
959        let result = session.handle_key_exchange(&msg).await;
960        assert!(matches!(result, Err(CryptoError::SessionDestroyed)));
961    }
962
963    #[tokio::test]
964    async fn decrypt_rejects_unknown_envelope_version() {
965        let store = Arc::new(MemoryKeyStore::new());
966        let mut session = E2ESession::new(test_config(store));
967        session.start().await.unwrap();
968        session.enable_encryption().await.unwrap();
969
970        let envelope = E2EEnvelope {
971            _e2e: 1,
972            ct: "AAAA".to_string(),
973            iv: "BBBB".to_string(),
974            v: 2,
975        };
976        let result = session.decrypt(&envelope).await;
977        assert!(matches!(result, Err(CryptoError::DecryptionFailed(_))));
978    }
979
980    #[tokio::test]
981    async fn should_rotate_returns_false_without_interval() {
982        let store = Arc::new(MemoryKeyStore::new());
983        let mut session = E2ESession::new(test_config(store));
984        session.start().await.unwrap();
985        session.enable_encryption().await.unwrap();
986        assert!(!session.should_rotate());
987    }
988
989    #[tokio::test]
990    async fn rotation_tracks_count_and_timestamp() {
991        let store = Arc::new(MemoryKeyStore::new());
992        let mut session = E2ESession::new(test_config(store));
993        session.start().await.unwrap();
994        session.enable_encryption().await.unwrap();
995
996        assert_eq!(session.rotation_count(), 0);
997        assert!(session.last_rotation().is_some()); // set by enable_encryption
998
999        session.rotate_key().await.unwrap();
1000        assert_eq!(session.rotation_count(), 1);
1001
1002        session.rotate_key().await.unwrap();
1003        assert_eq!(session.rotation_count(), 2);
1004    }
1005
1006    #[tokio::test]
1007    async fn maybe_rotate_triggers_when_due() {
1008        use std::sync::atomic::{AtomicU32, Ordering};
1009
1010        let store = Arc::new(MemoryKeyStore::new());
1011        let rotation_cb_count = Arc::new(AtomicU32::new(0));
1012        let cb_clone = rotation_cb_count.clone();
1013
1014        let mut session = E2ESession::new(E2ESessionConfig {
1015            identity_id: "alice".to_string(),
1016            base_path: "/test/room/1".to_string(),
1017            store,
1018            on_key_change: None,
1019            password_hash: None,
1020            rotation_interval: Some(Duration::from_secs(60)),
1021            on_rotation: Some(Arc::new(move || {
1022                cb_clone.fetch_add(1, Ordering::SeqCst);
1023            })),
1024            max_announcement_age: None,
1025        });
1026        session.start().await.unwrap();
1027        session.enable_encryption().await.unwrap();
1028
1029        // Not due yet (just created)
1030        let result = session.maybe_rotate().await.unwrap();
1031        assert!(result.is_none());
1032
1033        // Force last_rotation to the past
1034        session.last_rotation = Some(now_ms() - 120_000);
1035        let result = session.maybe_rotate().await.unwrap();
1036        assert!(result.is_some());
1037        assert_eq!(rotation_cb_count.load(Ordering::SeqCst), 1);
1038        assert_eq!(session.rotation_count(), 1);
1039    }
1040
1041    #[tokio::test]
1042    async fn timestamp_validation_rejects_old_announcement() {
1043        let store = Arc::new(MemoryKeyStore::new());
1044        let mut session = E2ESession::new(E2ESessionConfig {
1045            identity_id: "alice".to_string(),
1046            base_path: "/test/room/1".to_string(),
1047            store,
1048            on_key_change: None,
1049            password_hash: None,
1050            rotation_interval: None,
1051            on_rotation: None,
1052            max_announcement_age: Some(Duration::from_secs(300)),
1053        });
1054        session.start().await.unwrap();
1055        session.enable_encryption().await.unwrap();
1056
1057        let bob_kp = primitives::generate_ecdh_key_pair();
1058        let old_announcement = PublicKeyAnnouncement {
1059            public_key: primitives::public_key_to_jwk(&bob_kp.public_key).unwrap(),
1060            timestamp: now_ms() - 600_000, // 10 min ago
1061        };
1062        let result = session.handle_peer_pubkey("bob", &old_announcement).await;
1063        assert!(result.is_err());
1064        assert!(result.unwrap_err().to_string().contains("too old"));
1065    }
1066
1067    #[tokio::test]
1068    async fn timestamp_validation_rejects_future_announcement() {
1069        let store = Arc::new(MemoryKeyStore::new());
1070        let mut session = E2ESession::new(E2ESessionConfig {
1071            identity_id: "alice".to_string(),
1072            base_path: "/test/room/1".to_string(),
1073            store,
1074            on_key_change: None,
1075            password_hash: None,
1076            rotation_interval: None,
1077            on_rotation: None,
1078            max_announcement_age: Some(Duration::from_secs(300)),
1079        });
1080        session.start().await.unwrap();
1081        session.enable_encryption().await.unwrap();
1082
1083        let bob_kp = primitives::generate_ecdh_key_pair();
1084        let future_announcement = PublicKeyAnnouncement {
1085            public_key: primitives::public_key_to_jwk(&bob_kp.public_key).unwrap(),
1086            timestamp: now_ms() + 60_000, // 1 min in future
1087        };
1088        let result = session
1089            .handle_peer_pubkey("bob", &future_announcement)
1090            .await;
1091        assert!(result.is_err());
1092        assert!(result.unwrap_err().to_string().contains("future"));
1093    }
1094
1095    #[tokio::test]
1096    async fn replay_protection_rejects_duplicate_key_exchange() {
1097        let store_a = Arc::new(MemoryKeyStore::new());
1098        let store_b = Arc::new(MemoryKeyStore::new());
1099
1100        // Alice has the group key
1101        let mut alice = E2ESession::new(E2ESessionConfig {
1102            identity_id: "alice".to_string(),
1103            base_path: "/test/room/1".to_string(),
1104            store: store_a,
1105            on_key_change: None,
1106            password_hash: None,
1107            rotation_interval: None,
1108            on_rotation: None,
1109            max_announcement_age: None,
1110        });
1111        alice.start().await.unwrap();
1112        alice.enable_encryption().await.unwrap();
1113
1114        // Bob requests the key
1115        let mut bob = E2ESession::new(E2ESessionConfig {
1116            identity_id: "bob".to_string(),
1117            base_path: "/test/room/1".to_string(),
1118            store: store_b,
1119            on_key_change: None,
1120            password_hash: None,
1121            rotation_interval: None,
1122            on_rotation: None,
1123            max_announcement_age: None,
1124        });
1125        bob.start().await.unwrap();
1126        let bob_announcement = bob.request_group_key().unwrap().unwrap();
1127
1128        // Alice distributes key to Bob
1129        let keyex = alice
1130            .handle_peer_pubkey("bob", &bob_announcement)
1131            .await
1132            .unwrap()
1133            .unwrap();
1134
1135        // First receive should succeed
1136        bob.handle_key_exchange(&keyex).await.unwrap();
1137
1138        // Replay of the same message should fail
1139        let result = bob.handle_key_exchange(&keyex).await;
1140        assert!(result.is_err());
1141        assert!(result.unwrap_err().to_string().contains("replayed"));
1142    }
1143
1144    #[tokio::test]
1145    async fn persisted_key_restores_last_rotation() {
1146        let store = Arc::new(MemoryKeyStore::new());
1147
1148        let mut session1 = E2ESession::new(test_config(store.clone()));
1149        session1.start().await.unwrap();
1150        session1.enable_encryption().await.unwrap();
1151        let rotation_ts = session1.last_rotation().unwrap();
1152        assert!(rotation_ts > 0);
1153
1154        let mut session2 = E2ESession::new(test_config(store));
1155        session2.start().await.unwrap();
1156        assert_eq!(session2.last_rotation(), Some(rotation_ts));
1157    }
1158}