1use std::{
16 collections::{BTreeMap, BTreeSet},
17 sync::Arc,
18 time::Duration,
19};
20
21use matrix_sdk_common::{failures_cache::FailuresCache, locks::RwLock as StdRwLock};
22use ruma::{
23 DeviceId, OneTimeKeyAlgorithm, OwnedDeviceId, OwnedOneTimeKeyId, OwnedServerName,
24 OwnedTransactionId, OwnedUserId, SecondsSinceUnixEpoch, ServerName, TransactionId, UserId,
25 api::client::keys::claim_keys::v3::{
26 Request as KeysClaimRequest, Response as KeysClaimResponse,
27 },
28 assign,
29 events::dummy::ToDeviceDummyEventContent,
30};
31use tracing::{debug, error, info, instrument, warn};
32use vodozemac::Curve25519PublicKey;
33
34use crate::{
35 DeviceData,
36 error::OlmResult,
37 gossiping::GossipMachine,
38 store::{Result as StoreResult, Store, types::Changes},
39 types::{
40 EventEncryptionAlgorithm,
41 events::EventType,
42 requests::{OutgoingRequest, ToDeviceRequest},
43 },
44};
45
46#[derive(Debug, Clone)]
47pub(crate) struct SessionManager {
48 store: Store,
49
50 current_key_claim_request: Arc<StdRwLock<Option<(OwnedTransactionId, KeysClaimRequest)>>>,
59
60 users_for_key_claim: Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>>>>,
65 wedged_devices: Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>>>>,
66 key_request_machine: GossipMachine,
67 outgoing_to_device_requests: Arc<StdRwLock<BTreeMap<OwnedTransactionId, OutgoingRequest>>>,
68
69 failures: FailuresCache<OwnedServerName>,
74
75 failed_devices: Arc<StdRwLock<BTreeMap<OwnedUserId, FailuresCache<OwnedDeviceId>>>>,
76}
77
78impl SessionManager {
79 const KEY_CLAIM_TIMEOUT: Duration = Duration::from_secs(10);
80 const UNWEDGING_INTERVAL: Duration = Duration::from_secs(60 * 60);
81
82 pub fn new(
83 users_for_key_claim: Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>>>>,
84 key_request_machine: GossipMachine,
85 store: Store,
86 ) -> Self {
87 Self {
88 store,
89 current_key_claim_request: Default::default(),
90 key_request_machine,
91 users_for_key_claim,
92 wedged_devices: Default::default(),
93 outgoing_to_device_requests: Default::default(),
94 failures: Default::default(),
95 failed_devices: Default::default(),
96 }
97 }
98
99 pub fn mark_outgoing_request_as_sent(&self, id: &TransactionId) {
101 self.outgoing_to_device_requests.write().remove(id);
102 }
103
104 pub async fn mark_device_as_wedged(
105 &self,
106 sender: &UserId,
107 curve_key: Curve25519PublicKey,
108 ) -> OlmResult<()> {
109 if let Some(device) = self.store.get_device_from_curve_key(sender, curve_key).await?
110 && let Some(session) = device.get_most_recent_session().await?
111 {
112 info!(sender_key = ?curve_key, "Marking session to be unwedged");
113
114 let creation_time = Duration::from_secs(session.creation_time.get().into());
115 let now = Duration::from_secs(SecondsSinceUnixEpoch::now().get().into());
116
117 let should_unwedge = now
118 .checked_sub(creation_time)
119 .map(|elapsed| elapsed > Self::UNWEDGING_INTERVAL)
120 .unwrap_or(true);
121
122 if should_unwedge {
123 self.users_for_key_claim
124 .write()
125 .entry(device.user_id().to_owned())
126 .or_default()
127 .insert(device.device_id().into());
128 self.wedged_devices
129 .write()
130 .entry(device.user_id().to_owned())
131 .or_default()
132 .insert(device.device_id().into());
133 }
134 }
135
136 Ok(())
137 }
138
139 #[allow(dead_code)]
140 pub fn is_device_wedged(&self, device: &DeviceData) -> bool {
141 self.wedged_devices
142 .read()
143 .get(device.user_id())
144 .is_some_and(|d| d.contains(device.device_id()))
145 }
146
147 async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> {
151 if self.wedged_devices.write().get_mut(user_id).is_some_and(|d| d.remove(device_id))
152 && let Some(device) = self.store.get_device(user_id, device_id).await?
153 {
154 let (_, content, _) =
155 device.encrypt("m.dummy", ToDeviceDummyEventContent::new()).await?;
156
157 let event_type = content.event_type().to_owned();
158
159 let request = ToDeviceRequest::new(
160 device.user_id(),
161 device.device_id().to_owned(),
162 &event_type,
163 content.cast(),
164 );
165
166 let request = OutgoingRequest {
167 request_id: request.txn_id.clone(),
168 request: Arc::new(request.into()),
169 };
170
171 self.outgoing_to_device_requests.write().insert(request.request_id.clone(), request);
172 }
173
174 Ok(())
175 }
176
177 pub async fn get_missing_sessions(
205 &self,
206 users: impl Iterator<Item = &UserId>,
207 ) -> StoreResult<Option<(OwnedTransactionId, KeysClaimRequest)>> {
208 let mut missing_session_devices_by_user: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
209 let mut timed_out_devices_by_user: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
210
211 let unfailed_users = users.filter(|u| !self.failures.contains(u.server_name()));
212
213 let devices_by_user = Box::pin(
215 self.key_request_machine
216 .identity_manager()
217 .get_user_devices_for_encryption(unfailed_users),
218 )
219 .await?;
220
221 #[derive(Debug, Default)]
222 struct UserFailedDeviceInfo {
223 non_olm_devices: BTreeMap<OwnedDeviceId, Vec<EventEncryptionAlgorithm>>,
224 bad_key_devices: BTreeSet<OwnedDeviceId>,
225 }
226
227 let mut failed_devices_by_user: BTreeMap<_, UserFailedDeviceInfo> = BTreeMap::new();
228
229 for (user_id, user_devices) in devices_by_user {
230 for (device_id, device) in user_devices {
231 if !device.supports_olm() {
232 failed_devices_by_user
233 .entry(user_id.clone())
234 .or_default()
235 .non_olm_devices
236 .insert(device_id, Vec::from(device.algorithms()));
237 } else if let Some(sender_key) = device.curve25519_key() {
238 let sessions = self.store.get_sessions(&sender_key.to_base64()).await?;
239
240 let is_missing = if let Some(sessions) = sessions {
241 sessions.lock().await.is_empty()
242 } else {
243 true
244 };
245
246 let is_timed_out = self.is_user_timed_out(&user_id, &device_id);
247
248 if is_missing && is_timed_out {
249 timed_out_devices_by_user
250 .entry(user_id.to_owned())
251 .or_default()
252 .insert(device_id);
253 } else if is_missing && !is_timed_out {
254 missing_session_devices_by_user
255 .entry(user_id.to_owned())
256 .or_default()
257 .insert(device_id, OneTimeKeyAlgorithm::SignedCurve25519);
258 }
259 } else {
260 failed_devices_by_user
261 .entry(user_id.clone())
262 .or_default()
263 .bad_key_devices
264 .insert(device_id);
265 }
266 }
267 }
268
269 for (user, device_ids) in self.users_for_key_claim.read().iter() {
272 missing_session_devices_by_user.entry(user.to_owned()).or_default().extend(
273 device_ids
274 .iter()
275 .map(|device_id| (device_id.clone(), OneTimeKeyAlgorithm::SignedCurve25519)),
276 );
277 }
278
279 if tracing::level_enabled!(tracing::Level::DEBUG) {
280 let missing_session_devices_by_user = missing_session_devices_by_user
282 .iter()
283 .map(|(user_id, devices)| (user_id, devices.keys().collect::<BTreeSet<_>>()))
284 .collect::<BTreeMap<_, _>>();
285 debug!(
286 ?missing_session_devices_by_user,
287 ?timed_out_devices_by_user,
288 "Collected user/device pairs that are missing an Olm session"
289 );
290 }
291
292 if !failed_devices_by_user.is_empty() {
293 warn!(
294 ?failed_devices_by_user,
295 "Can't establish an Olm session with some devices due to missing Olm support or bad keys",
296 );
297 }
298
299 let result = if missing_session_devices_by_user.is_empty() {
300 None
301 } else {
302 Some((
303 TransactionId::new(),
304 assign!(KeysClaimRequest::new(missing_session_devices_by_user), {
305 timeout: Some(Self::KEY_CLAIM_TIMEOUT),
306 }),
307 ))
308 };
309
310 *(self.current_key_claim_request.write()) = result.clone();
313 Ok(result)
314 }
315
316 fn is_user_timed_out(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
317 self.failed_devices.read().get(user_id).is_some_and(|d| d.contains(device_id))
318 }
319
320 fn handle_otk_exhaustion_failure(
337 &self,
338 request_id: &TransactionId,
339 failed_servers: &BTreeSet<OwnedServerName>,
340 one_time_keys: &BTreeMap<
341 &OwnedUserId,
342 BTreeMap<&OwnedDeviceId, BTreeSet<&OwnedOneTimeKeyId>>,
343 >,
344 ) {
345 let request = {
347 let mut guard = self.current_key_claim_request.write();
348 let expected_request_id = guard.as_ref().map(|e| e.0.as_ref());
349
350 if Some(request_id) == expected_request_id {
351 guard.take().map(|(_, request)| request)
354 } else {
355 warn!(
356 ?request_id,
357 ?expected_request_id,
358 "Received a `/keys/claim` response for the wrong request"
359 );
360 None
361 }
362 };
363
364 if let Some(request) = request {
367 let devices_in_response: BTreeSet<_> = one_time_keys
368 .iter()
369 .flat_map(|(user_id, device_key_map)| {
370 device_key_map
371 .keys()
372 .map(|device_id| (*user_id, *device_id))
373 .collect::<BTreeSet<_>>()
374 })
375 .collect();
376
377 let devices_in_request: BTreeSet<(_, _)> = request
378 .one_time_keys
379 .iter()
380 .flat_map(|(user_id, device_key_map)| {
381 device_key_map
382 .keys()
383 .map(|device_id| (user_id, device_id))
384 .collect::<BTreeSet<_>>()
385 })
386 .collect();
387
388 let missing_devices: BTreeSet<_> = devices_in_request
389 .difference(&devices_in_response)
390 .filter(|(user_id, _)| {
391 !failed_servers.contains(user_id.server_name())
394 })
395 .collect();
396
397 if !missing_devices.is_empty() {
398 let mut missing_devices_by_user: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
399
400 for &(user_id, device_id) in missing_devices {
401 missing_devices_by_user.entry(user_id).or_default().insert(device_id.clone());
402 }
403
404 warn!(
405 ?missing_devices_by_user,
406 "Tried to create new Olm sessions, but the signed one-time key was missing for some devices",
407 );
408
409 let mut failed_devices_lock = self.failed_devices.write();
410
411 for (user_id, device_set) in missing_devices_by_user {
412 failed_devices_lock.entry(user_id.clone()).or_default().extend(device_set);
413 }
414 }
415 }
416 }
417
418 #[instrument(skip(self, response))]
428 pub async fn receive_keys_claim_response(
429 &self,
430 request_id: &TransactionId,
431 response: &KeysClaimResponse,
432 ) -> OlmResult<()> {
433 let one_time_keys: BTreeMap<_, BTreeMap<_, BTreeSet<_>>> = response
435 .one_time_keys
436 .iter()
437 .map(|(user_id, device_map)| {
438 (
439 user_id,
440 device_map
441 .iter()
442 .map(|(device_id, key_map)| {
443 (device_id, key_map.keys().collect::<BTreeSet<_>>())
444 })
445 .collect::<BTreeMap<_, _>>(),
446 )
447 })
448 .collect();
449
450 debug!(?request_id, ?one_time_keys, failures = ?response.failures, "Received a `/keys/claim` response");
451
452 let failed_servers: BTreeSet<_> = response
454 .failures
455 .keys()
456 .filter_map(|s| ServerName::parse(s).ok())
457 .filter(|s| s != self.store.static_account().user_id.server_name())
458 .collect();
459 let successful_servers = response.one_time_keys.keys().map(|u| u.server_name());
460
461 self.handle_otk_exhaustion_failure(request_id, &failed_servers, &one_time_keys);
464 self.failures.extend(failed_servers);
466 self.failures.remove(successful_servers);
468
469 self.create_sessions(response).await
471 }
472
473 pub(crate) async fn create_sessions(&self, response: &KeysClaimResponse) -> OlmResult<()> {
480 struct SessionInfo {
481 session_id: String,
482 algorithm: EventEncryptionAlgorithm,
483 fallback_key_used: bool,
484 }
485
486 #[cfg(not(tarpaulin_include))]
487 impl std::fmt::Debug for SessionInfo {
488 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
489 write!(
490 f,
491 "session_id: {}, algorithm: {}, fallback_key_used: {}",
492 self.session_id, self.algorithm, self.fallback_key_used
493 )
494 }
495 }
496
497 let mut changes = Changes::default();
498 let mut new_sessions: BTreeMap<&UserId, BTreeMap<&DeviceId, SessionInfo>> = BTreeMap::new();
499 let mut store_transaction = self.store.transaction().await;
500
501 for (user_id, user_devices) in &response.one_time_keys {
502 for (device_id, key_map) in user_devices {
503 let device = match self.store.get_device_data(user_id, device_id).await {
504 Ok(Some(d)) => d,
505 Ok(None) => {
506 warn!(
507 ?user_id,
508 ?device_id,
509 "Tried to create an Olm session but the device is unknown",
510 );
511 continue;
512 }
513 Err(e) => {
514 warn!(
515 ?user_id, ?device_id, error = ?e,
516 "Tried to create an Olm session, but we can't \
517 fetch the device from the store",
518 );
519 continue;
520 }
521 };
522
523 let account = store_transaction.account().await?;
524 let device_keys = self.store.get_own_device().await?.as_device_keys().clone();
525 let session = match account.create_outbound_session(&device, key_map, device_keys) {
526 Ok(s) => s,
527 Err(e) => {
528 warn!(
529 ?user_id, ?device_id, error = ?e,
530 "Error creating Olm session"
531 );
532
533 self.failed_devices
534 .write()
535 .entry(user_id.to_owned())
536 .or_default()
537 .insert(device_id.to_owned());
538
539 continue;
540 }
541 };
542
543 self.key_request_machine.retry_keyshare(user_id, device_id);
544
545 if let Err(e) = self.check_if_unwedged(user_id, device_id).await {
546 error!(?user_id, ?device_id, "Error while treating an unwedged device: {e:?}");
547 }
548
549 let session_info = SessionInfo {
550 session_id: session.session_id().to_owned(),
551 algorithm: session.algorithm().await,
552 fallback_key_used: session.created_using_fallback_key,
553 };
554
555 changes.sessions.push(session);
556 new_sessions.entry(user_id).or_default().insert(device_id, session_info);
557 }
558 }
559
560 store_transaction.commit().await?;
561 self.store.save_changes(changes).await?;
562 info!(sessions = ?new_sessions, "Established new Olm sessions");
563
564 for (user, device_map) in new_sessions {
565 if let Some(user_cache) = self.failed_devices.read().get(user) {
566 user_cache.remove(device_map.into_keys());
567 }
568 }
569
570 let store_cache = self.store.cache().await?;
571 match self.key_request_machine.collect_incoming_key_requests(&store_cache).await {
572 Ok(sessions) => {
573 let changes = Changes { sessions, ..Default::default() };
574 self.store.save_changes(changes).await?
575 }
576 Err(e) => {
579 warn!(error = ?e, "Error while trying to collect the incoming secret requests")
580 }
581 }
582
583 Ok(())
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use std::{collections::BTreeMap, iter, ops::Deref, sync::Arc, time::Duration};
590
591 use matrix_sdk_common::{executor::spawn, locks::RwLock as StdRwLock};
592 use matrix_sdk_test::{async_test, ruma_response_from_json};
593 use ruma::{
594 DeviceId, OwnedUserId, UserId,
595 api::client::keys::claim_keys::v3::Response as KeyClaimResponse, device_id,
596 owned_server_name, user_id,
597 };
598 use serde_json::json;
599 use tokio::sync::Mutex;
600 use tracing::info;
601
602 use super::SessionManager;
603 use crate::{
604 gossiping::GossipMachine,
605 identities::{DeviceData, IdentityManager},
606 olm::{Account, PrivateCrossSigningIdentity},
607 session_manager::GroupSessionCache,
608 store::{
609 CryptoStoreWrapper, MemoryStore, Store,
610 types::{Changes, DeviceChanges, PendingChanges},
611 },
612 verification::VerificationMachine,
613 };
614
615 fn user_id() -> &'static UserId {
616 user_id!("@example:localhost")
617 }
618
619 fn device_id() -> &'static DeviceId {
620 device_id!("DEVICEID")
621 }
622
623 fn bob_account() -> Account {
624 Account::with_device_id(user_id!("@bob:localhost"), device_id!("BOBDEVICE"))
625 }
626
627 fn keys_claim_with_failure() -> KeyClaimResponse {
628 let response = json!({
629 "one_time_keys": {},
630 "failures": {
631 "example.org": {
632 "errcode": "M_RESOURCE_LIMIT_EXCEEDED",
633 "error": "Not yet ready to retry",
634 }
635 }
636 });
637 ruma_response_from_json(&response)
638 }
639
640 fn keys_claim_without_failure() -> KeyClaimResponse {
641 let response = json!({
642 "one_time_keys": {
643 "@alice:example.org": {},
644 },
645 "failures": {},
646 });
647 ruma_response_from_json(&response)
648 }
649
650 async fn session_manager_test_helper() -> (SessionManager, IdentityManager) {
651 let user_id = user_id();
652 let device_id = device_id();
653
654 let account = Account::with_device_id(user_id, device_id);
655 let store = Arc::new(CryptoStoreWrapper::new(user_id, device_id, MemoryStore::new()));
656 let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id)));
657 let verification = VerificationMachine::new(
658 account.static_data().clone(),
659 identity.clone(),
660 store.clone(),
661 );
662
663 let store = Store::new(account.static_data().clone(), identity, store, verification);
664 let device = DeviceData::from_account(&account);
665 store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
666 store
667 .save_changes(Changes {
668 devices: DeviceChanges { new: vec![device], ..Default::default() },
669 ..Default::default()
670 })
671 .await
672 .unwrap();
673
674 let session_cache = GroupSessionCache::new(store.clone());
675 let identity_manager = IdentityManager::new(store.clone());
676
677 let users_for_key_claim = Arc::new(StdRwLock::new(BTreeMap::new()));
678 let key_request = GossipMachine::new(
679 store.clone(),
680 identity_manager.clone(),
681 session_cache,
682 users_for_key_claim.clone(),
683 );
684
685 (SessionManager::new(users_for_key_claim, key_request, store), identity_manager)
686 }
687
688 #[async_test]
689 async fn test_session_creation() {
690 let (manager, _identity_manager) = session_manager_test_helper().await;
691 let mut bob = bob_account();
692
693 let bob_device = DeviceData::from_account(&bob);
694
695 manager.store.save_device_data(&[bob_device]).await.unwrap();
696
697 let (txn_id, request) =
698 manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().unwrap();
699
700 assert!(request.one_time_keys.contains_key(bob.user_id()));
701
702 bob.generate_one_time_keys(1);
703 let one_time = bob.signed_one_time_keys();
704 assert!(!one_time.is_empty());
705 bob.mark_keys_as_published();
706
707 let mut one_time_keys = BTreeMap::new();
708 one_time_keys
709 .entry(bob.user_id().to_owned())
710 .or_insert_with(BTreeMap::new)
711 .insert(bob.device_id().to_owned(), one_time);
712
713 let response = KeyClaimResponse::new(one_time_keys);
714
715 manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
716
717 assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
718 }
719
720 #[async_test]
721 async fn test_session_creation_waits_for_keys_query() {
722 let (manager, identity_manager) = session_manager_test_helper().await;
723
724 let (key_query_txn_id, key_query_request) =
727 identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap();
728 info!("Initial key query: {:?}", key_query_request);
729
730 let bob = bob_account();
732 let bob_device = DeviceData::from_account(&bob);
733 {
734 let cache = manager.store.cache().await.unwrap();
735 identity_manager
736 .key_query_manager
737 .synced(&cache)
738 .await
739 .unwrap()
740 .update_tracked_users(iter::once(bob.user_id()))
741 .await
742 .unwrap();
743 }
744
745 let missing_sessions_task = {
748 let manager = manager.clone();
749 let bob_user_id = bob.user_id().to_owned();
750
751 #[allow(unknown_lints, clippy::redundant_async_block)] spawn(
753 async move { manager.get_missing_sessions(iter::once(bob_user_id.deref())).await },
754 )
755 };
756
757 let response_json =
759 json!({ "device_keys": { manager.store.static_account().user_id.to_owned(): {}}});
760 let response = ruma_response_from_json(&response_json);
761 identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
762
763 let (key_query_txn_id, key_query_request) =
764 identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap();
765 info!("Second key query: {:?}", key_query_request);
766
767 let response_json = json!({ "device_keys": { bob.user_id(): {
769 bob_device.device_id(): bob_device.as_device_keys()
770 }}});
771 let response = ruma_response_from_json(&response_json);
772 identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
773
774 let (_, keys_claim_request) = missing_sessions_task.await.unwrap().unwrap().unwrap();
777 info!("Key claim request: {:?}", keys_claim_request.one_time_keys);
778 let bob_key_claims = keys_claim_request.one_time_keys.get(bob.user_id()).unwrap();
779 assert!(bob_key_claims.contains_key(bob_device.device_id()));
780 }
781
782 #[async_test]
783 async fn test_session_creation_does_not_wait_for_keys_query_on_failed_server() {
784 let (manager, identity_manager) = session_manager_test_helper().await;
785
786 let other_user_id = OwnedUserId::try_from("@bob:example.com").unwrap();
788 {
789 let cache = manager.store.cache().await.unwrap();
790 identity_manager
791 .key_query_manager
792 .synced(&cache)
793 .await
794 .unwrap()
795 .update_tracked_users(iter::once(other_user_id.as_ref()))
796 .await
797 .unwrap();
798 }
799
800 let (key_query_txn_id, _key_query_request) =
802 identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap();
803 let response = ruma_response_from_json(
804 &json!({ "device_keys": {}, "failures": { other_user_id.server_name(): "unreachable" }}),
805 );
806 identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
807
808 let result = tokio::time::timeout(
811 Duration::from_millis(10),
812 manager.get_missing_sessions(iter::once(other_user_id.as_ref())),
813 )
814 .await
815 .expect("get_missing_sessions blocked rather than completing quickly")
816 .expect("get_missing_sessions returned an error");
817
818 assert!(result.is_none(), "get_missing_sessions returned Some(...)");
819 }
820
821 #[async_test]
824 #[cfg(target_os = "linux")]
825 async fn test_session_unwedging() {
826 use ruma::{SecondsSinceUnixEpoch, time::SystemTime};
827
828 let (manager, _identity_manager) = session_manager_test_helper().await;
829 let mut bob = bob_account();
830
831 let (_, mut session) = manager
832 .store
833 .with_transaction(async |tr| {
834 let manager_account = tr.account().await.unwrap();
835 let res = bob.create_session_for_test_helper(manager_account).await;
836 Ok(res)
837 })
838 .await
839 .unwrap();
840
841 let bob_device = DeviceData::from_account(&bob);
842 let time = SystemTime::now() - Duration::from_secs(3601);
843 session.creation_time = SecondsSinceUnixEpoch::from_system_time(time).unwrap();
844
845 let devices = std::slice::from_ref(&bob_device);
846 manager.store.save_device_data(devices).await.unwrap();
847 manager.store.save_sessions(&[session]).await.unwrap();
848
849 assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
850
851 let curve_key = bob_device.curve25519_key().unwrap();
852
853 assert!(!manager.users_for_key_claim.read().contains_key(bob.user_id()));
854 assert!(!manager.is_device_wedged(&bob_device));
855 manager.mark_device_as_wedged(bob_device.user_id(), curve_key).await.unwrap();
856 assert!(manager.is_device_wedged(&bob_device));
857 assert!(manager.users_for_key_claim.read().contains_key(bob.user_id()));
858
859 let (txn_id, request) =
860 manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().unwrap();
861
862 assert!(request.one_time_keys.contains_key(bob.user_id()));
863
864 bob.generate_one_time_keys(1);
865 let one_time = bob.signed_one_time_keys();
866 assert!(!one_time.is_empty());
867 bob.mark_keys_as_published();
868
869 let mut one_time_keys = BTreeMap::new();
870 one_time_keys
871 .entry(bob.user_id().to_owned())
872 .or_insert_with(BTreeMap::new)
873 .insert(bob.device_id().to_owned(), one_time);
874
875 let response = KeyClaimResponse::new(one_time_keys);
876
877 assert!(manager.outgoing_to_device_requests.read().is_empty());
878
879 manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
880
881 assert!(!manager.is_device_wedged(&bob_device));
882 assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
883 assert!(!manager.outgoing_to_device_requests.read().is_empty())
884 }
885
886 #[async_test]
887 async fn test_failure_handling() {
888 let alice = user_id!("@alice:example.org");
889 let alice_account = Account::with_device_id(alice, "DEVICEID".into());
890 let alice_device = DeviceData::from_account(&alice_account);
891
892 let (manager, _identity_manager) = session_manager_test_helper().await;
893
894 manager.store.save_device_data(&[alice_device]).await.unwrap();
895
896 let (txn_id, users_for_key_claim) =
897 manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
898 assert!(users_for_key_claim.one_time_keys.contains_key(alice));
899
900 manager.receive_keys_claim_response(&txn_id, &keys_claim_with_failure()).await.unwrap();
901 assert!(manager.get_missing_sessions(iter::once(alice)).await.unwrap().is_none());
902
903 manager.failures.expire(&owned_server_name!("example.org"));
905
906 let (txn_id, users_for_key_claim) =
907 manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
908 assert!(users_for_key_claim.one_time_keys.contains_key(alice));
909
910 manager.receive_keys_claim_response(&txn_id, &keys_claim_without_failure()).await.unwrap();
911 }
912
913 #[async_test]
914 async fn test_failed_devices_handling() {
915 test_invalid_claim_response(json!({
917 "one_time_keys": {},
918 "failures": {},
919 }))
920 .await;
921
922 test_invalid_claim_response(json!({
924 "one_time_keys": {
925 "@alice:example.org": {}
926 },
927 "failures": {},
928 }))
929 .await;
930
931 test_invalid_claim_response(json!({
933 "one_time_keys": {
934 "@alice:example.org": {
935 "DEVICEID": {}
936 }
937 },
938 "failures": {},
939 }))
940 .await;
941
942 test_invalid_claim_response(json!({
944 "one_time_keys": {
945 "@alice:example.org": {
946 "DEVICEID": {
947 "signed_curve25519:AAAAAA": {
948 "fallback": true,
949 "key": "1sra5GVo1ONz478aQybxSEeHTSo2xq0Z+Q3Yzqvp3A4",
950 "signatures": {
951 "@example:morpheus.localhost": {
952 "ed25519:YAFLBLXAUK": "Zwk90fJhZWOYGNOgtOswZ6RSOGeTjTi/h2dMpyB0CR6EVtvTra0WJtp32ntifrxtwD710y2F3pe5Oyrm7jngCQ"
953 }
954 }
955 }
956 }
957 }
958 },
959 "failures": {},
960 })).await;
961 }
962
963 async fn test_invalid_claim_response(response_json: serde_json::Value) {
969 let response = ruma_response_from_json(&response_json);
970
971 let alice = user_id!("@alice:example.org");
972 let mut alice_account = Account::with_device_id(alice, "DEVICEID".into());
973 let alice_device = DeviceData::from_account(&alice_account);
974
975 let (manager, _identity_manager) = session_manager_test_helper().await;
976 manager.store.save_device_data(&[alice_device]).await.unwrap();
977
978 let (txn_id, users_for_key_claim) =
981 manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
982 assert!(users_for_key_claim.one_time_keys.contains_key(alice));
983
984 manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
987 assert!(manager.get_missing_sessions(iter::once(alice)).await.unwrap().is_none());
989
990 alice_account.generate_one_time_keys(1);
991 let one_time = alice_account.signed_one_time_keys();
992 assert!(!one_time.is_empty());
993
994 let mut one_time_keys = BTreeMap::new();
995 one_time_keys
996 .entry(alice.to_owned())
997 .or_insert_with(BTreeMap::new)
998 .insert(alice_account.device_id().to_owned(), one_time);
999
1000 manager
1002 .failed_devices
1003 .write()
1004 .get(alice)
1005 .unwrap()
1006 .expire(&alice_account.device_id().to_owned());
1007 let (txn_id, users_for_key_claim) =
1008 manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
1009 assert!(users_for_key_claim.one_time_keys.contains_key(alice));
1010
1011 let response = KeyClaimResponse::new(one_time_keys);
1012 manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
1013
1014 assert!(
1016 manager
1017 .failed_devices
1018 .read()
1019 .get(alice)
1020 .unwrap()
1021 .failure_count(alice_account.device_id())
1022 .is_none()
1023 );
1024 }
1025}