1use crate::{Error, Result};
2use vodozemac::olm::Account as OlmAccount;
3use vodozemac::megolm::{GroupSession, InboundGroupSession, MegolmMessage};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use zeroize::Zeroizing;
7use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
8
9const PBKDF2_ITERATIONS: u32 = 600_000; const SALT_LENGTH: usize = 32; const MAX_ONE_TIME_KEYS: usize = 100; const MAX_GROUP_SESSIONS: usize = 1000; const MIN_PASSPHRASE_LENGTH: usize = 12; fn derive_key_from_passphrase(passphrase: &str, salt: &[u8]) -> Result<[u8; 32]> {
18 use pbkdf2::pbkdf2_hmac;
19 use sha2::Sha256;
20
21 if passphrase.len() < MIN_PASSPHRASE_LENGTH {
22 return Err(Error::Crypto(format!(
23 "Passphrase must be at least {} characters",
24 MIN_PASSPHRASE_LENGTH
25 )));
26 }
27
28 let passphrase_bytes = Zeroizing::new(passphrase.as_bytes().to_vec());
29 let mut key = Zeroizing::new([0u8; 32]);
30
31 pbkdf2_hmac::<Sha256>(
32 &passphrase_bytes,
33 salt,
34 PBKDF2_ITERATIONS,
35 &mut *key,
36 );
37
38 Ok(*key)
39}
40
41fn generate_salt() -> [u8; SALT_LENGTH] {
43 use rand::RngCore;
44 let mut salt = [0u8; SALT_LENGTH];
45 rand::thread_rng().fill_bytes(&mut salt);
46 salt
47}
48
49fn validate_room_id(room_id: &str) -> Result<()> {
51 if !room_id.starts_with('!') || !room_id.contains(':') {
52 return Err(Error::Crypto(format!("Invalid room ID format: {}", room_id)));
53 }
54 if room_id.len() > 255 {
55 return Err(Error::Crypto("Room ID too long".to_string()));
56 }
57 Ok(())
58}
59
60pub struct CryptoManager {
62 account: OlmAccount,
63 group_sessions: HashMap<String, GroupSession>,
64 inbound_group_sessions: HashMap<String, InboundGroupSession>,
65}
66
67impl CryptoManager {
68 pub fn new() -> Self {
70 Self {
71 account: OlmAccount::new(),
72 group_sessions: HashMap::new(),
73 inbound_group_sessions: HashMap::new(),
74 }
75 }
76
77 pub fn identity_keys(&self) -> IdentityKeys {
79 let keys = self.account.identity_keys();
80
81 IdentityKeys {
82 curve25519: keys.curve25519.to_base64(),
83 ed25519: keys.ed25519.to_base64(),
84 }
85 }
86
87 pub fn generate_one_time_keys(&mut self, count: usize) -> Result<()> {
89 if count > MAX_ONE_TIME_KEYS {
90 return Err(Error::Crypto(format!(
91 "Cannot generate more than {} one-time keys at once",
92 MAX_ONE_TIME_KEYS
93 )));
94 }
95 self.account.generate_one_time_keys(count);
96 Ok(())
97 }
98
99 pub fn one_time_keys(&self) -> HashMap<String, String> {
101 self.account
102 .one_time_keys()
103 .iter()
104 .enumerate()
105 .map(|(idx, (_, key))| {
106 (format!("key_{}", idx), key.to_base64())
107 })
108 .collect()
109 }
110
111 pub fn mark_keys_as_published(&mut self) {
113 self.account.mark_keys_as_published();
114 }
115
116 pub fn create_group_session(&mut self, room_id: &str) -> Result<String> {
118 validate_room_id(room_id)?;
120
121 if self.group_sessions.len() >= MAX_GROUP_SESSIONS {
123 return Err(Error::Crypto(format!(
124 "Maximum group sessions limit ({}) reached. Consider cleaning up old sessions.",
125 MAX_GROUP_SESSIONS
126 )));
127 }
128
129 if self.group_sessions.contains_key(room_id) {
131 tracing::warn!("Overwriting existing group session for room: {}", room_id);
132 }
133
134 let session = GroupSession::new(Default::default());
135 let session_id = session.session_id();
136 let session_id_str = session_id.to_owned();
137 self.group_sessions.insert(room_id.to_string(), session);
138 Ok(session_id_str)
139 }
140
141 pub fn encrypt_room_message(&mut self, room_id: &str, plaintext: &str) -> Result<EncryptedMessage> {
143 let session = self.group_sessions
144 .get_mut(room_id)
145 .ok_or_else(|| Error::Crypto("No group session for room".to_string()))?;
146
147 let ciphertext = session.encrypt(plaintext.as_bytes());
148
149 Ok(EncryptedMessage {
150 algorithm: "m.megolm.v1.aes-sha2".to_string(),
151 sender_key: self.account.identity_keys().curve25519.to_base64(),
152 ciphertext: ciphertext.to_base64(),
153 session_id: session.session_id().to_owned(),
154 device_id: String::new(), })
156 }
157
158 pub fn add_inbound_group_session(&mut self, session_key_base64: &str) -> Result<String> {
160 use vodozemac::megolm::{SessionKey, SessionConfig};
161
162 let session_key = SessionKey::from_base64(session_key_base64)
164 .map_err(|e| Error::Crypto(format!("Invalid session key: {}", e)))?;
165
166 let session = InboundGroupSession::new(&session_key, SessionConfig::default());
167
168 let session_id = session.session_id().to_owned();
169 self.inbound_group_sessions.insert(session_id.clone(), session);
170 Ok(session_id)
171 }
172
173 pub fn decrypt_room_message(&mut self, session_id: &str, ciphertext: &str) -> Result<String> {
175 let session = self.inbound_group_sessions
176 .get_mut(session_id)
177 .ok_or_else(|| Error::Crypto("Unknown session".to_string()))?;
178
179 let message = MegolmMessage::from_base64(ciphertext)
180 .map_err(|e| Error::Crypto(format!("Invalid ciphertext: {}", e)))?;
181
182 let decrypted = session.decrypt(&message)
183 .map_err(|e| Error::Crypto(format!("Decryption failed: {}", e)))?;
184
185 String::from_utf8(decrypted.plaintext)
186 .map_err(|e| Error::Crypto(format!("Invalid UTF-8: {}", e)))
187 }
188
189 pub fn export_account(&self, passphrase: &str) -> Result<String> {
192 let salt = generate_salt();
194
195 let key = Zeroizing::new(derive_key_from_passphrase(passphrase, &salt)?);
197
198 let pickle_str = self.account.pickle().encrypt(&*key);
200
201 let salt_b64 = BASE64.encode(&salt);
204 Ok(format!("{}:{}", salt_b64, pickle_str))
205 }
206
207 pub fn import_account(pickle_with_salt: &str, passphrase: &str) -> Result<Self> {
209 use vodozemac::olm::AccountPickle;
210
211 let parts: Vec<&str> = pickle_with_salt.splitn(2, ':').collect();
213 if parts.len() != 2 {
214 return Err(Error::Crypto("Invalid pickle format (missing salt)".to_string()));
215 }
216
217 let salt_b64 = parts[0];
218 let pickle_str = parts[1];
219
220 let salt = BASE64.decode(salt_b64)
222 .map_err(|e| Error::Crypto(format!("Invalid salt encoding: {}", e)))?;
223
224 if salt.len() != SALT_LENGTH {
225 return Err(Error::Crypto(format!(
226 "Invalid salt length: expected {}, got {}",
227 SALT_LENGTH,
228 salt.len()
229 )));
230 }
231
232 let key = Zeroizing::new(derive_key_from_passphrase(passphrase, &salt)?);
234
235 let pickle = AccountPickle::from_encrypted(pickle_str, &*key)
237 .map_err(|e| Error::Crypto(format!("Failed to decrypt pickle: {}", e)))?;
238
239 let account = OlmAccount::from_pickle(pickle);
240
241 Ok(Self {
242 account,
243 group_sessions: HashMap::new(),
244 inbound_group_sessions: HashMap::new(),
245 })
246 }
247
248 pub fn export_group_session(&self, room_id: &str) -> Result<String> {
250 let session = self.group_sessions
251 .get(room_id)
252 .ok_or_else(|| Error::Crypto("No group session for room".to_string()))?;
253
254 Ok(session.session_key().to_base64())
256 }
257
258}
259
260impl Default for CryptoManager {
261 fn default() -> Self {
262 Self::new()
263 }
264}
265
266#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct IdentityKeys {
269 pub curve25519: String,
270 pub ed25519: String,
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize)]
275pub struct EncryptedMessage {
276 pub algorithm: String,
277 pub sender_key: String,
278 pub ciphertext: String,
279 pub session_id: String,
280 pub device_id: String,
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[test]
288 fn test_crypto_manager_creation() {
289 let manager = CryptoManager::new();
290 let keys = manager.identity_keys();
291 assert!(!keys.curve25519.is_empty());
292 assert!(!keys.ed25519.is_empty());
293 }
294
295 #[test]
296 fn test_one_time_keys() {
297 let mut manager = CryptoManager::new();
298 manager.generate_one_time_keys(5).unwrap();
299 let keys = manager.one_time_keys();
300 assert!(keys.len() > 0);
301 assert!(keys.len() <= 5);
302 }
303
304 #[test]
305 fn test_group_session() {
306 let mut manager = CryptoManager::new();
307 let result = manager.create_group_session("!test:example.com");
308 assert!(result.is_ok());
309 let session_id = result.unwrap();
310 assert!(!session_id.is_empty());
311 }
312
313 #[test]
314 fn test_encrypt_decrypt() {
315 let mut sender = CryptoManager::new();
316 let mut receiver = CryptoManager::new();
317
318 let room_id = "!test:example.com";
320 let session_id = sender.create_group_session(room_id).unwrap();
321
322 let session_key = sender.export_group_session(room_id).unwrap();
324 let imported_session_id = receiver.add_inbound_group_session(&session_key).unwrap();
325 assert_eq!(session_id, imported_session_id);
326
327 let plaintext = "Hello, encrypted world!";
329 let encrypted = sender.encrypt_room_message(room_id, plaintext).unwrap();
330
331 let decrypted = receiver.decrypt_room_message(&session_id, &encrypted.ciphertext).unwrap();
333 assert_eq!(plaintext, decrypted);
334 }
335
336 #[test]
337 fn test_account_export_import() {
338 let manager1 = CryptoManager::new();
339 let keys1 = manager1.identity_keys();
340
341 let passphrase = "test_password_12345"; let pickle = manager1.export_account(passphrase).unwrap();
343
344 let manager2 = CryptoManager::import_account(&pickle, passphrase).unwrap();
345 let keys2 = manager2.identity_keys();
346
347 assert_eq!(keys1.curve25519, keys2.curve25519);
348 assert_eq!(keys1.ed25519, keys2.ed25519);
349 }
350
351 #[test]
352 fn test_multiple_messages() {
353 let mut sender = CryptoManager::new();
354 let mut receiver = CryptoManager::new();
355
356 let room_id = "!test:example.com";
357 let session_id = sender.create_group_session(room_id).unwrap();
358 let session_key = sender.export_group_session(room_id).unwrap();
359 receiver.add_inbound_group_session(&session_key).unwrap();
360
361 let messages = vec![
363 "First message",
364 "Second message",
365 "Third message",
366 ];
367
368 for msg in &messages {
369 let encrypted = sender.encrypt_room_message(room_id, msg).unwrap();
370 let decrypted = receiver.decrypt_room_message(&session_id, &encrypted.ciphertext).unwrap();
371 assert_eq!(*msg, decrypted);
372 }
373 }
374
375 #[test]
376 fn test_encrypt_without_group_session() {
377 let mut manager = CryptoManager::new();
378 let result = manager.encrypt_room_message("!test:example.com", "test");
379 assert!(result.is_err());
380 assert!(result.unwrap_err().to_string().contains("No group session for room"));
381 }
382
383 #[test]
384 fn test_decrypt_with_unknown_session() {
385 let mut manager = CryptoManager::new();
386 let result = manager.decrypt_room_message("unknown_session_id", "invalid_ciphertext");
387 assert!(result.is_err());
388 assert!(result.unwrap_err().to_string().contains("Unknown session"));
389 }
390
391 #[test]
392 fn test_decrypt_invalid_ciphertext() {
393 let mut sender = CryptoManager::new();
394 let mut receiver = CryptoManager::new();
395
396 let room_id = "!test:example.com";
397 let session_id = sender.create_group_session(room_id).unwrap();
398 let session_key = sender.export_group_session(room_id).unwrap();
399 receiver.add_inbound_group_session(&session_key).unwrap();
400
401 let result = receiver.decrypt_room_message(&session_id, "not_valid_base64!!!");
403 assert!(result.is_err());
404 assert!(result.unwrap_err().to_string().contains("Invalid ciphertext"));
405 }
406
407 #[test]
408 fn test_import_invalid_session_key() {
409 let mut manager = CryptoManager::new();
410 let result = manager.add_inbound_group_session("invalid_base64_key!!!");
411 assert!(result.is_err());
412 assert!(result.unwrap_err().to_string().contains("Invalid session key"));
413 }
414
415 #[test]
416 fn test_export_nonexistent_group_session() {
417 let manager = CryptoManager::new();
418 let result = manager.export_group_session("!nonexistent:example.com");
419 assert!(result.is_err());
420 assert!(result.unwrap_err().to_string().contains("No group session for room"));
421 }
422
423 #[test]
424 fn test_multiple_rooms() {
425 let mut sender = CryptoManager::new();
426
427 let room1 = "!room1:example.com";
429 let room2 = "!room2:example.com";
430
431 let session1_id = sender.create_group_session(room1).unwrap();
432 let session2_id = sender.create_group_session(room2).unwrap();
433
434 assert_ne!(session1_id, session2_id);
436
437 let encrypted1 = sender.encrypt_room_message(room1, "message for room 1").unwrap();
439 let encrypted2 = sender.encrypt_room_message(room2, "message for room 2").unwrap();
440
441 assert_eq!(encrypted1.session_id, session1_id);
442 assert_eq!(encrypted2.session_id, session2_id);
443 }
444
445 #[test]
446 fn test_account_import_wrong_passphrase() {
447 let manager = CryptoManager::new();
448 let correct_passphrase = "correct_password_123";
449 let wrong_passphrase = "wrong_password_456";
450
451 let pickle = manager.export_account(correct_passphrase).unwrap();
452
453 let result = CryptoManager::import_account(&pickle, wrong_passphrase);
454 assert!(result.is_err());
455 if let Err(e) = result {
456 assert!(e.to_string().contains("Failed to decrypt pickle"));
457 }
458 }
459
460 #[test]
461 fn test_account_import_invalid_pickle() {
462 let result = CryptoManager::import_account("invalid_pickle_data", "any_password_123");
463 assert!(result.is_err());
464 }
465
466 #[test]
467 fn test_unicode_messages() {
468 let mut sender = CryptoManager::new();
469 let mut receiver = CryptoManager::new();
470
471 let room_id = "!test:example.com";
472 let session_id = sender.create_group_session(room_id).unwrap();
473 let session_key = sender.export_group_session(room_id).unwrap();
474 receiver.add_inbound_group_session(&session_key).unwrap();
475
476 let unicode_messages = vec![
478 "Hello 世界",
479 "Здравствуй мир",
480 "مرحبا بالعالم",
481 "🔐🌍🚀",
482 "Emoji test: 😀😎🎉",
483 ];
484
485 for msg in &unicode_messages {
486 let encrypted = sender.encrypt_room_message(room_id, msg).unwrap();
487 let decrypted = receiver.decrypt_room_message(&session_id, &encrypted.ciphertext).unwrap();
488 assert_eq!(*msg, decrypted);
489 }
490 }
491
492 #[test]
493 fn test_large_message() {
494 let mut sender = CryptoManager::new();
495 let mut receiver = CryptoManager::new();
496
497 let room_id = "!test:example.com";
498 let session_id = sender.create_group_session(room_id).unwrap();
499 let session_key = sender.export_group_session(room_id).unwrap();
500 receiver.add_inbound_group_session(&session_key).unwrap();
501
502 let large_message = "A".repeat(10_000);
504 let encrypted = sender.encrypt_room_message(room_id, &large_message).unwrap();
505 let decrypted = receiver.decrypt_room_message(&session_id, &encrypted.ciphertext).unwrap();
506 assert_eq!(large_message, decrypted);
507 }
508
509 #[test]
510 fn test_empty_message() {
511 let mut sender = CryptoManager::new();
512 let mut receiver = CryptoManager::new();
513
514 let room_id = "!test:example.com";
515 let session_id = sender.create_group_session(room_id).unwrap();
516 let session_key = sender.export_group_session(room_id).unwrap();
517 receiver.add_inbound_group_session(&session_key).unwrap();
518
519 let empty_message = "";
520 let encrypted = sender.encrypt_room_message(room_id, empty_message).unwrap();
521 let decrypted = receiver.decrypt_room_message(&session_id, &encrypted.ciphertext).unwrap();
522 assert_eq!(empty_message, decrypted);
523 }
524
525 #[test]
526 fn test_multi_device_scenario() {
527 let mut sender = CryptoManager::new();
529 let mut device1 = CryptoManager::new();
530 let mut device2 = CryptoManager::new();
531 let mut device3 = CryptoManager::new();
532
533 let room_id = "!test:example.com";
534 let session_id = sender.create_group_session(room_id).unwrap();
535 let session_key = sender.export_group_session(room_id).unwrap();
536
537 device1.add_inbound_group_session(&session_key).unwrap();
539 device2.add_inbound_group_session(&session_key).unwrap();
540 device3.add_inbound_group_session(&session_key).unwrap();
541
542 let message = "Broadcast to all devices";
544 let encrypted = sender.encrypt_room_message(room_id, message).unwrap();
545
546 let decrypted1 = device1.decrypt_room_message(&session_id, &encrypted.ciphertext).unwrap();
548 let decrypted2 = device2.decrypt_room_message(&session_id, &encrypted.ciphertext).unwrap();
549 let decrypted3 = device3.decrypt_room_message(&session_id, &encrypted.ciphertext).unwrap();
550
551 assert_eq!(message, decrypted1);
552 assert_eq!(message, decrypted2);
553 assert_eq!(message, decrypted3);
554 }
555
556 #[test]
557 fn test_one_time_keys_generation() {
558 let mut manager = CryptoManager::new();
559
560 manager.generate_one_time_keys(10).unwrap();
562 let keys = manager.one_time_keys();
563
564 assert!(keys.len() > 0);
565 assert!(keys.len() <= 10);
566
567 for (key_id, key_value) in keys {
569 assert!(!key_id.is_empty());
570 assert!(!key_value.is_empty());
571 assert!(key_value.chars().all(|c| c.is_alphanumeric() || c == '+' || c == '/' || c == '='));
573 }
574 }
575
576 #[test]
577 fn test_mark_keys_as_published() {
578 let mut manager = CryptoManager::new();
579
580 manager.generate_one_time_keys(5).unwrap();
581 let keys_before = manager.one_time_keys();
582 assert!(keys_before.len() > 0);
583
584 manager.mark_keys_as_published();
585 let keys_after = manager.one_time_keys();
586
587 assert!(keys_after.len() <= keys_before.len());
589 }
590
591 #[test]
592 fn test_encrypted_message_structure() {
593 let mut sender = CryptoManager::new();
594 let room_id = "!test:example.com";
595
596 sender.create_group_session(room_id).unwrap();
597 let encrypted = sender.encrypt_room_message(room_id, "test").unwrap();
598
599 assert_eq!(encrypted.algorithm, "m.megolm.v1.aes-sha2");
601 assert!(!encrypted.sender_key.is_empty());
602 assert!(!encrypted.ciphertext.is_empty());
603 assert!(!encrypted.session_id.is_empty());
604
605 assert!(encrypted.sender_key.chars().all(|c| c.is_alphanumeric() || c == '+' || c == '/' || c == '='));
607 assert!(encrypted.ciphertext.chars().all(|c| c.is_alphanumeric() || c == '+' || c == '/' || c == '='));
608 }
609
610 #[test]
612 fn test_rate_limit_one_time_keys() {
613 let mut manager = CryptoManager::new();
614 let result = manager.generate_one_time_keys(MAX_ONE_TIME_KEYS + 1);
616 assert!(result.is_err());
617 assert!(result.unwrap_err().to_string().contains("Cannot generate more than"));
618 }
619
620 #[test]
621 fn test_passphrase_too_short() {
622 let manager = CryptoManager::new();
623 let weak_passphrase = "short"; let result = manager.export_account(weak_passphrase);
625 assert!(result.is_err());
626 assert!(result.unwrap_err().to_string().contains("at least 12 characters"));
627 }
628
629 #[test]
630 fn test_invalid_room_id_format() {
631 let mut manager = CryptoManager::new();
632
633 let result1 = manager.create_group_session("test:example.com");
635 assert!(result1.is_err());
636
637 let result2 = manager.create_group_session("!testexample.com");
639 assert!(result2.is_err());
640
641 let long_room_id = format!("!{}:example.com", "a".repeat(300));
643 let result3 = manager.create_group_session(&long_room_id);
644 assert!(result3.is_err());
645 }
646
647 #[test]
648 fn test_pickle_format_with_salt() {
649 let manager = CryptoManager::new();
650 let passphrase = "secure_password_123";
651 let pickle = manager.export_account(passphrase).unwrap();
652
653 assert!(pickle.contains(':'));
655 let parts: Vec<&str> = pickle.splitn(2, ':').collect();
656 assert_eq!(parts.len(), 2);
657
658 let salt_bytes = base64::decode(parts[0]).unwrap();
660 assert_eq!(salt_bytes.len(), SALT_LENGTH);
661 }
662}