1use std::{
16 collections::{BTreeMap, BTreeSet},
17 sync::Arc,
18 time::Duration,
19};
20
21use matrix_sdk_common::{failures_cache::FailuresCache, locks::RwLock as StdRwLock};
22use ruma::{
23 api::client::keys::claim_keys::v3::{
24 Request as KeysClaimRequest, Response as KeysClaimResponse,
25 },
26 assign,
27 events::dummy::ToDeviceDummyEventContent,
28 DeviceId, OneTimeKeyAlgorithm, OwnedDeviceId, OwnedOneTimeKeyId, OwnedServerName,
29 OwnedTransactionId, OwnedUserId, SecondsSinceUnixEpoch, ServerName, TransactionId, UserId,
30};
31use tracing::{debug, error, info, instrument, warn};
32use vodozemac::Curve25519PublicKey;
33
34use crate::{
35 error::OlmResult,
36 gossiping::GossipMachine,
37 store::{types::Changes, Result as StoreResult, Store},
38 types::{
39 events::EventType,
40 requests::{OutgoingRequest, ToDeviceRequest},
41 EventEncryptionAlgorithm,
42 },
43 DeviceData,
44};
45
46#[derive(Debug, Clone)]
47pub(crate) struct SessionManager {
48 store: Store,
49
50 current_key_claim_request: Arc<StdRwLock<Option<(OwnedTransactionId, KeysClaimRequest)>>>,
59
60 users_for_key_claim: Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>>>>,
65 wedged_devices: Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>>>>,
66 key_request_machine: GossipMachine,
67 outgoing_to_device_requests: Arc<StdRwLock<BTreeMap<OwnedTransactionId, OutgoingRequest>>>,
68
69 failures: FailuresCache<OwnedServerName>,
74
75 failed_devices: Arc<StdRwLock<BTreeMap<OwnedUserId, FailuresCache<OwnedDeviceId>>>>,
76}
77
78impl SessionManager {
79 const KEY_CLAIM_TIMEOUT: Duration = Duration::from_secs(10);
80 const UNWEDGING_INTERVAL: Duration = Duration::from_secs(60 * 60);
81
82 pub fn new(
83 users_for_key_claim: Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>>>>,
84 key_request_machine: GossipMachine,
85 store: Store,
86 ) -> Self {
87 Self {
88 store,
89 current_key_claim_request: Default::default(),
90 key_request_machine,
91 users_for_key_claim,
92 wedged_devices: Default::default(),
93 outgoing_to_device_requests: Default::default(),
94 failures: Default::default(),
95 failed_devices: Default::default(),
96 }
97 }
98
99 pub fn mark_outgoing_request_as_sent(&self, id: &TransactionId) {
101 self.outgoing_to_device_requests.write().remove(id);
102 }
103
104 pub async fn mark_device_as_wedged(
105 &self,
106 sender: &UserId,
107 curve_key: Curve25519PublicKey,
108 ) -> OlmResult<()> {
109 if let Some(device) = self.store.get_device_from_curve_key(sender, curve_key).await? {
110 if let Some(session) = device.get_most_recent_session().await? {
111 info!(sender_key = ?curve_key, "Marking session to be unwedged");
112
113 let creation_time = Duration::from_secs(session.creation_time.get().into());
114 let now = Duration::from_secs(SecondsSinceUnixEpoch::now().get().into());
115
116 let should_unwedge = now
117 .checked_sub(creation_time)
118 .map(|elapsed| elapsed > Self::UNWEDGING_INTERVAL)
119 .unwrap_or(true);
120
121 if should_unwedge {
122 self.users_for_key_claim
123 .write()
124 .entry(device.user_id().to_owned())
125 .or_default()
126 .insert(device.device_id().into());
127 self.wedged_devices
128 .write()
129 .entry(device.user_id().to_owned())
130 .or_default()
131 .insert(device.device_id().into());
132 }
133 }
134 }
135
136 Ok(())
137 }
138
139 #[allow(dead_code)]
140 pub fn is_device_wedged(&self, device: &DeviceData) -> bool {
141 self.wedged_devices
142 .read()
143 .get(device.user_id())
144 .is_some_and(|d| d.contains(device.device_id()))
145 }
146
147 async fn check_if_unwedged(&self, user_id: &UserId, device_id: &DeviceId) -> OlmResult<()> {
151 if self.wedged_devices.write().get_mut(user_id).is_some_and(|d| d.remove(device_id)) {
152 if let Some(device) = self.store.get_device(user_id, device_id).await? {
153 let (_, content) =
154 device.encrypt("m.dummy", ToDeviceDummyEventContent::new()).await?;
155
156 let event_type = content.event_type().to_owned();
157
158 let request = ToDeviceRequest::new(
159 device.user_id(),
160 device.device_id().to_owned(),
161 &event_type,
162 content.cast(),
163 );
164
165 let request = OutgoingRequest {
166 request_id: request.txn_id.clone(),
167 request: Arc::new(request.into()),
168 };
169
170 self.outgoing_to_device_requests
171 .write()
172 .insert(request.request_id.clone(), request);
173 }
174 }
175
176 Ok(())
177 }
178
179 pub async fn get_missing_sessions(
207 &self,
208 users: impl Iterator<Item = &UserId>,
209 ) -> StoreResult<Option<(OwnedTransactionId, KeysClaimRequest)>> {
210 let mut missing_session_devices_by_user: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
211 let mut timed_out_devices_by_user: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
212
213 let unfailed_users = users.filter(|u| !self.failures.contains(u.server_name()));
214
215 let devices_by_user = Box::pin(
217 self.key_request_machine
218 .identity_manager()
219 .get_user_devices_for_encryption(unfailed_users),
220 )
221 .await?;
222
223 #[derive(Debug, Default)]
224 struct UserFailedDeviceInfo {
225 non_olm_devices: BTreeMap<OwnedDeviceId, Vec<EventEncryptionAlgorithm>>,
226 bad_key_devices: BTreeSet<OwnedDeviceId>,
227 }
228
229 let mut failed_devices_by_user: BTreeMap<_, UserFailedDeviceInfo> = BTreeMap::new();
230
231 for (user_id, user_devices) in devices_by_user {
232 for (device_id, device) in user_devices {
233 if !device.supports_olm() {
234 failed_devices_by_user
235 .entry(user_id.clone())
236 .or_default()
237 .non_olm_devices
238 .insert(device_id, Vec::from(device.algorithms()));
239 } else if let Some(sender_key) = device.curve25519_key() {
240 let sessions = self.store.get_sessions(&sender_key.to_base64()).await?;
241
242 let is_missing = if let Some(sessions) = sessions {
243 sessions.lock().await.is_empty()
244 } else {
245 true
246 };
247
248 let is_timed_out = self.is_user_timed_out(&user_id, &device_id);
249
250 if is_missing && is_timed_out {
251 timed_out_devices_by_user
252 .entry(user_id.to_owned())
253 .or_default()
254 .insert(device_id);
255 } else if is_missing && !is_timed_out {
256 missing_session_devices_by_user
257 .entry(user_id.to_owned())
258 .or_default()
259 .insert(device_id, OneTimeKeyAlgorithm::SignedCurve25519);
260 }
261 } else {
262 failed_devices_by_user
263 .entry(user_id.clone())
264 .or_default()
265 .bad_key_devices
266 .insert(device_id);
267 }
268 }
269 }
270
271 for (user, device_ids) in self.users_for_key_claim.read().iter() {
274 missing_session_devices_by_user.entry(user.to_owned()).or_default().extend(
275 device_ids
276 .iter()
277 .map(|device_id| (device_id.clone(), OneTimeKeyAlgorithm::SignedCurve25519)),
278 );
279 }
280
281 if tracing::level_enabled!(tracing::Level::DEBUG) {
282 let missing_session_devices_by_user = missing_session_devices_by_user
284 .iter()
285 .map(|(user_id, devices)| (user_id, devices.keys().collect::<BTreeSet<_>>()))
286 .collect::<BTreeMap<_, _>>();
287 debug!(
288 ?missing_session_devices_by_user,
289 ?timed_out_devices_by_user,
290 "Collected user/device pairs that are missing an Olm session"
291 );
292 }
293
294 if !failed_devices_by_user.is_empty() {
295 warn!(
296 ?failed_devices_by_user,
297 "Can't establish an Olm session with some devices due to missing Olm support or bad keys",
298 );
299 }
300
301 let result = if missing_session_devices_by_user.is_empty() {
302 None
303 } else {
304 Some((
305 TransactionId::new(),
306 assign!(KeysClaimRequest::new(missing_session_devices_by_user), {
307 timeout: Some(Self::KEY_CLAIM_TIMEOUT),
308 }),
309 ))
310 };
311
312 *(self.current_key_claim_request.write()) = result.clone();
315 Ok(result)
316 }
317
318 fn is_user_timed_out(&self, user_id: &UserId, device_id: &DeviceId) -> bool {
319 self.failed_devices.read().get(user_id).is_some_and(|d| d.contains(device_id))
320 }
321
322 fn handle_otk_exhaustion_failure(
339 &self,
340 request_id: &TransactionId,
341 failed_servers: &BTreeSet<OwnedServerName>,
342 one_time_keys: &BTreeMap<
343 &OwnedUserId,
344 BTreeMap<&OwnedDeviceId, BTreeSet<&OwnedOneTimeKeyId>>,
345 >,
346 ) {
347 let request = {
349 let mut guard = self.current_key_claim_request.write();
350 let expected_request_id = guard.as_ref().map(|e| e.0.as_ref());
351
352 if Some(request_id) == expected_request_id {
353 guard.take().map(|(_, request)| request)
356 } else {
357 warn!(
358 ?request_id,
359 ?expected_request_id,
360 "Received a `/keys/claim` response for the wrong request"
361 );
362 None
363 }
364 };
365
366 if let Some(request) = request {
369 let devices_in_response: BTreeSet<_> = one_time_keys
370 .iter()
371 .flat_map(|(user_id, device_key_map)| {
372 device_key_map
373 .keys()
374 .map(|device_id| (*user_id, *device_id))
375 .collect::<BTreeSet<_>>()
376 })
377 .collect();
378
379 let devices_in_request: BTreeSet<(_, _)> = request
380 .one_time_keys
381 .iter()
382 .flat_map(|(user_id, device_key_map)| {
383 device_key_map
384 .keys()
385 .map(|device_id| (user_id, device_id))
386 .collect::<BTreeSet<_>>()
387 })
388 .collect();
389
390 let missing_devices: BTreeSet<_> = devices_in_request
391 .difference(&devices_in_response)
392 .filter(|(user_id, _)| {
393 !failed_servers.contains(user_id.server_name())
396 })
397 .collect();
398
399 if !missing_devices.is_empty() {
400 let mut missing_devices_by_user: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
401
402 for &(user_id, device_id) in missing_devices {
403 missing_devices_by_user.entry(user_id).or_default().insert(device_id.clone());
404 }
405
406 warn!(
407 ?missing_devices_by_user,
408 "Tried to create new Olm sessions, but the signed one-time key was missing for some devices",
409 );
410
411 let mut failed_devices_lock = self.failed_devices.write();
412
413 for (user_id, device_set) in missing_devices_by_user {
414 failed_devices_lock.entry(user_id.clone()).or_default().extend(device_set);
415 }
416 }
417 }
418 }
419
420 #[instrument(skip(self, response))]
430 pub async fn receive_keys_claim_response(
431 &self,
432 request_id: &TransactionId,
433 response: &KeysClaimResponse,
434 ) -> OlmResult<()> {
435 let one_time_keys: BTreeMap<_, BTreeMap<_, BTreeSet<_>>> = response
437 .one_time_keys
438 .iter()
439 .map(|(user_id, device_map)| {
440 (
441 user_id,
442 device_map
443 .iter()
444 .map(|(device_id, key_map)| {
445 (device_id, key_map.keys().collect::<BTreeSet<_>>())
446 })
447 .collect::<BTreeMap<_, _>>(),
448 )
449 })
450 .collect();
451
452 debug!(?request_id, ?one_time_keys, failures = ?response.failures, "Received a `/keys/claim` response");
453
454 let failed_servers: BTreeSet<_> = response
456 .failures
457 .keys()
458 .filter_map(|s| ServerName::parse(s).ok())
459 .filter(|s| s != self.store.static_account().user_id.server_name())
460 .collect();
461 let successful_servers = response.one_time_keys.keys().map(|u| u.server_name());
462
463 self.handle_otk_exhaustion_failure(request_id, &failed_servers, &one_time_keys);
466 self.failures.extend(failed_servers);
468 self.failures.remove(successful_servers);
470
471 self.create_sessions(response).await
473 }
474
475 pub(crate) async fn create_sessions(&self, response: &KeysClaimResponse) -> OlmResult<()> {
482 struct SessionInfo {
483 session_id: String,
484 algorithm: EventEncryptionAlgorithm,
485 fallback_key_used: bool,
486 }
487
488 #[cfg(not(tarpaulin_include))]
489 impl std::fmt::Debug for SessionInfo {
490 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
491 write!(
492 f,
493 "session_id: {}, algorithm: {}, fallback_key_used: {}",
494 self.session_id, self.algorithm, self.fallback_key_used
495 )
496 }
497 }
498
499 let mut changes = Changes::default();
500 let mut new_sessions: BTreeMap<&UserId, BTreeMap<&DeviceId, SessionInfo>> = BTreeMap::new();
501 let mut store_transaction = self.store.transaction().await;
502
503 for (user_id, user_devices) in &response.one_time_keys {
504 for (device_id, key_map) in user_devices {
505 let device = match self.store.get_device_data(user_id, device_id).await {
506 Ok(Some(d)) => d,
507 Ok(None) => {
508 warn!(
509 ?user_id,
510 ?device_id,
511 "Tried to create an Olm session but the device is unknown",
512 );
513 continue;
514 }
515 Err(e) => {
516 warn!(
517 ?user_id, ?device_id, error = ?e,
518 "Tried to create an Olm session, but we can't \
519 fetch the device from the store",
520 );
521 continue;
522 }
523 };
524
525 let account = store_transaction.account().await?;
526 let device_keys = self.store.get_own_device().await?.as_device_keys().clone();
527 let session = match account.create_outbound_session(&device, key_map, device_keys) {
528 Ok(s) => s,
529 Err(e) => {
530 warn!(
531 ?user_id, ?device_id, error = ?e,
532 "Error creating Olm session"
533 );
534
535 self.failed_devices
536 .write()
537 .entry(user_id.to_owned())
538 .or_default()
539 .insert(device_id.to_owned());
540
541 continue;
542 }
543 };
544
545 self.key_request_machine.retry_keyshare(user_id, device_id);
546
547 if let Err(e) = self.check_if_unwedged(user_id, device_id).await {
548 error!(?user_id, ?device_id, "Error while treating an unwedged device: {e:?}");
549 }
550
551 let session_info = SessionInfo {
552 session_id: session.session_id().to_owned(),
553 algorithm: session.algorithm().await,
554 fallback_key_used: session.created_using_fallback_key,
555 };
556
557 changes.sessions.push(session);
558 new_sessions.entry(user_id).or_default().insert(device_id, session_info);
559 }
560 }
561
562 store_transaction.commit().await?;
563 self.store.save_changes(changes).await?;
564 info!(sessions = ?new_sessions, "Established new Olm sessions");
565
566 for (user, device_map) in new_sessions {
567 if let Some(user_cache) = self.failed_devices.read().get(user) {
568 user_cache.remove(device_map.into_keys());
569 }
570 }
571
572 let store_cache = self.store.cache().await?;
573 match self.key_request_machine.collect_incoming_key_requests(&store_cache).await {
574 Ok(sessions) => {
575 let changes = Changes { sessions, ..Default::default() };
576 self.store.save_changes(changes).await?
577 }
578 Err(e) => {
581 warn!(error = ?e, "Error while trying to collect the incoming secret requests")
582 }
583 }
584
585 Ok(())
586 }
587}
588
589#[cfg(test)]
590mod tests {
591 use std::{collections::BTreeMap, iter, ops::Deref, sync::Arc, time::Duration};
592
593 use matrix_sdk_common::{executor::spawn, locks::RwLock as StdRwLock};
594 use matrix_sdk_test::{async_test, ruma_response_from_json};
595 use ruma::{
596 api::client::keys::claim_keys::v3::Response as KeyClaimResponse, device_id,
597 owned_server_name, user_id, DeviceId, OwnedUserId, UserId,
598 };
599 use serde_json::json;
600 use tokio::sync::Mutex;
601 use tracing::info;
602
603 use super::SessionManager;
604 use crate::{
605 gossiping::GossipMachine,
606 identities::{DeviceData, IdentityManager},
607 olm::{Account, PrivateCrossSigningIdentity},
608 session_manager::GroupSessionCache,
609 store::{
610 types::{Changes, DeviceChanges, PendingChanges},
611 CryptoStoreWrapper, MemoryStore, Store,
612 },
613 verification::VerificationMachine,
614 };
615
616 fn user_id() -> &'static UserId {
617 user_id!("@example:localhost")
618 }
619
620 fn device_id() -> &'static DeviceId {
621 device_id!("DEVICEID")
622 }
623
624 fn bob_account() -> Account {
625 Account::with_device_id(user_id!("@bob:localhost"), device_id!("BOBDEVICE"))
626 }
627
628 fn keys_claim_with_failure() -> KeyClaimResponse {
629 let response = json!({
630 "one_time_keys": {},
631 "failures": {
632 "example.org": {
633 "errcode": "M_RESOURCE_LIMIT_EXCEEDED",
634 "error": "Not yet ready to retry",
635 }
636 }
637 });
638 ruma_response_from_json(&response)
639 }
640
641 fn keys_claim_without_failure() -> KeyClaimResponse {
642 let response = json!({
643 "one_time_keys": {
644 "@alice:example.org": {},
645 },
646 "failures": {},
647 });
648 ruma_response_from_json(&response)
649 }
650
651 async fn session_manager_test_helper() -> (SessionManager, IdentityManager) {
652 let user_id = user_id();
653 let device_id = device_id();
654
655 let account = Account::with_device_id(user_id, device_id);
656 let store = Arc::new(CryptoStoreWrapper::new(user_id, device_id, MemoryStore::new()));
657 let identity = Arc::new(Mutex::new(PrivateCrossSigningIdentity::empty(user_id)));
658 let verification = VerificationMachine::new(
659 account.static_data().clone(),
660 identity.clone(),
661 store.clone(),
662 );
663
664 let store = Store::new(account.static_data().clone(), identity, store, verification);
665 let device = DeviceData::from_account(&account);
666 store.save_pending_changes(PendingChanges { account: Some(account) }).await.unwrap();
667 store
668 .save_changes(Changes {
669 devices: DeviceChanges { new: vec![device], ..Default::default() },
670 ..Default::default()
671 })
672 .await
673 .unwrap();
674
675 let session_cache = GroupSessionCache::new(store.clone());
676 let identity_manager = IdentityManager::new(store.clone());
677
678 let users_for_key_claim = Arc::new(StdRwLock::new(BTreeMap::new()));
679 let key_request = GossipMachine::new(
680 store.clone(),
681 identity_manager.clone(),
682 session_cache,
683 users_for_key_claim.clone(),
684 );
685
686 (SessionManager::new(users_for_key_claim, key_request, store), identity_manager)
687 }
688
689 #[async_test]
690 async fn test_session_creation() {
691 let (manager, _identity_manager) = session_manager_test_helper().await;
692 let mut bob = bob_account();
693
694 let bob_device = DeviceData::from_account(&bob);
695
696 manager.store.save_device_data(&[bob_device]).await.unwrap();
697
698 let (txn_id, request) =
699 manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().unwrap();
700
701 assert!(request.one_time_keys.contains_key(bob.user_id()));
702
703 bob.generate_one_time_keys(1);
704 let one_time = bob.signed_one_time_keys();
705 assert!(!one_time.is_empty());
706 bob.mark_keys_as_published();
707
708 let mut one_time_keys = BTreeMap::new();
709 one_time_keys
710 .entry(bob.user_id().to_owned())
711 .or_insert_with(BTreeMap::new)
712 .insert(bob.device_id().to_owned(), one_time);
713
714 let response = KeyClaimResponse::new(one_time_keys);
715
716 manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
717
718 assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
719 }
720
721 #[async_test]
722 async fn test_session_creation_waits_for_keys_query() {
723 let (manager, identity_manager) = session_manager_test_helper().await;
724
725 let (key_query_txn_id, key_query_request) =
728 identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap();
729 info!("Initial key query: {:?}", key_query_request);
730
731 let bob = bob_account();
733 let bob_device = DeviceData::from_account(&bob);
734 {
735 let cache = manager.store.cache().await.unwrap();
736 identity_manager
737 .key_query_manager
738 .synced(&cache)
739 .await
740 .unwrap()
741 .update_tracked_users(iter::once(bob.user_id()))
742 .await
743 .unwrap();
744 }
745
746 let missing_sessions_task = {
749 let manager = manager.clone();
750 let bob_user_id = bob.user_id().to_owned();
751
752 #[allow(unknown_lints, clippy::redundant_async_block)] spawn(
754 async move { manager.get_missing_sessions(iter::once(bob_user_id.deref())).await },
755 )
756 };
757
758 let response_json =
760 json!({ "device_keys": { manager.store.static_account().user_id.to_owned(): {}}});
761 let response = ruma_response_from_json(&response_json);
762 identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
763
764 let (key_query_txn_id, key_query_request) =
765 identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap();
766 info!("Second key query: {:?}", key_query_request);
767
768 let response_json = json!({ "device_keys": { bob.user_id(): {
770 bob_device.device_id(): bob_device.as_device_keys()
771 }}});
772 let response = ruma_response_from_json(&response_json);
773 identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
774
775 let (_, keys_claim_request) = missing_sessions_task.await.unwrap().unwrap().unwrap();
778 info!("Key claim request: {:?}", keys_claim_request.one_time_keys);
779 let bob_key_claims = keys_claim_request.one_time_keys.get(bob.user_id()).unwrap();
780 assert!(bob_key_claims.contains_key(bob_device.device_id()));
781 }
782
783 #[async_test]
784 async fn test_session_creation_does_not_wait_for_keys_query_on_failed_server() {
785 let (manager, identity_manager) = session_manager_test_helper().await;
786
787 let other_user_id = OwnedUserId::try_from("@bob:example.com").unwrap();
789 {
790 let cache = manager.store.cache().await.unwrap();
791 identity_manager
792 .key_query_manager
793 .synced(&cache)
794 .await
795 .unwrap()
796 .update_tracked_users(iter::once(other_user_id.as_ref()))
797 .await
798 .unwrap();
799 }
800
801 let (key_query_txn_id, _key_query_request) =
803 identity_manager.users_for_key_query().await.unwrap().pop_first().unwrap();
804 let response = ruma_response_from_json(
805 &json!({ "device_keys": {}, "failures": { other_user_id.server_name(): "unreachable" }}),
806 );
807 identity_manager.receive_keys_query_response(&key_query_txn_id, &response).await.unwrap();
808
809 let result = tokio::time::timeout(
812 Duration::from_millis(10),
813 manager.get_missing_sessions(iter::once(other_user_id.as_ref())),
814 )
815 .await
816 .expect("get_missing_sessions blocked rather than completing quickly")
817 .expect("get_missing_sessions returned an error");
818
819 assert!(result.is_none(), "get_missing_sessions returned Some(...)");
820 }
821
822 #[async_test]
825 #[cfg(target_os = "linux")]
826 async fn test_session_unwedging() {
827 use ruma::{time::SystemTime, SecondsSinceUnixEpoch};
828
829 let (manager, _identity_manager) = session_manager_test_helper().await;
830 let mut bob = bob_account();
831
832 let (_, mut session) = manager
833 .store
834 .with_transaction(|mut tr| async {
835 let manager_account = tr.account().await.unwrap();
836 let res = bob.create_session_for_test_helper(manager_account).await;
837 Ok((tr, res))
838 })
839 .await
840 .unwrap();
841
842 let bob_device = DeviceData::from_account(&bob);
843 let time = SystemTime::now() - Duration::from_secs(3601);
844 session.creation_time = SecondsSinceUnixEpoch::from_system_time(time).unwrap();
845
846 let devices = std::slice::from_ref(&bob_device);
847 manager.store.save_device_data(devices).await.unwrap();
848 manager.store.save_sessions(&[session]).await.unwrap();
849
850 assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
851
852 let curve_key = bob_device.curve25519_key().unwrap();
853
854 assert!(!manager.users_for_key_claim.read().contains_key(bob.user_id()));
855 assert!(!manager.is_device_wedged(&bob_device));
856 manager.mark_device_as_wedged(bob_device.user_id(), curve_key).await.unwrap();
857 assert!(manager.is_device_wedged(&bob_device));
858 assert!(manager.users_for_key_claim.read().contains_key(bob.user_id()));
859
860 let (txn_id, request) =
861 manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().unwrap();
862
863 assert!(request.one_time_keys.contains_key(bob.user_id()));
864
865 bob.generate_one_time_keys(1);
866 let one_time = bob.signed_one_time_keys();
867 assert!(!one_time.is_empty());
868 bob.mark_keys_as_published();
869
870 let mut one_time_keys = BTreeMap::new();
871 one_time_keys
872 .entry(bob.user_id().to_owned())
873 .or_insert_with(BTreeMap::new)
874 .insert(bob.device_id().to_owned(), one_time);
875
876 let response = KeyClaimResponse::new(one_time_keys);
877
878 assert!(manager.outgoing_to_device_requests.read().is_empty());
879
880 manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
881
882 assert!(!manager.is_device_wedged(&bob_device));
883 assert!(manager.get_missing_sessions(iter::once(bob.user_id())).await.unwrap().is_none());
884 assert!(!manager.outgoing_to_device_requests.read().is_empty())
885 }
886
887 #[async_test]
888 async fn test_failure_handling() {
889 let alice = user_id!("@alice:example.org");
890 let alice_account = Account::with_device_id(alice, "DEVICEID".into());
891 let alice_device = DeviceData::from_account(&alice_account);
892
893 let (manager, _identity_manager) = session_manager_test_helper().await;
894
895 manager.store.save_device_data(&[alice_device]).await.unwrap();
896
897 let (txn_id, users_for_key_claim) =
898 manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
899 assert!(users_for_key_claim.one_time_keys.contains_key(alice));
900
901 manager.receive_keys_claim_response(&txn_id, &keys_claim_with_failure()).await.unwrap();
902 assert!(manager.get_missing_sessions(iter::once(alice)).await.unwrap().is_none());
903
904 manager.failures.expire(&owned_server_name!("example.org"));
906
907 let (txn_id, users_for_key_claim) =
908 manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
909 assert!(users_for_key_claim.one_time_keys.contains_key(alice));
910
911 manager.receive_keys_claim_response(&txn_id, &keys_claim_without_failure()).await.unwrap();
912 }
913
914 #[async_test]
915 async fn test_failed_devices_handling() {
916 test_invalid_claim_response(json!({
918 "one_time_keys": {},
919 "failures": {},
920 }))
921 .await;
922
923 test_invalid_claim_response(json!({
925 "one_time_keys": {
926 "@alice:example.org": {}
927 },
928 "failures": {},
929 }))
930 .await;
931
932 test_invalid_claim_response(json!({
934 "one_time_keys": {
935 "@alice:example.org": {
936 "DEVICEID": {}
937 }
938 },
939 "failures": {},
940 }))
941 .await;
942
943 test_invalid_claim_response(json!({
945 "one_time_keys": {
946 "@alice:example.org": {
947 "DEVICEID": {
948 "signed_curve25519:AAAAAA": {
949 "fallback": true,
950 "key": "1sra5GVo1ONz478aQybxSEeHTSo2xq0Z+Q3Yzqvp3A4",
951 "signatures": {
952 "@example:morpheus.localhost": {
953 "ed25519:YAFLBLXAUK": "Zwk90fJhZWOYGNOgtOswZ6RSOGeTjTi/h2dMpyB0CR6EVtvTra0WJtp32ntifrxtwD710y2F3pe5Oyrm7jngCQ"
954 }
955 }
956 }
957 }
958 }
959 },
960 "failures": {},
961 })).await;
962 }
963
964 async fn test_invalid_claim_response(response_json: serde_json::Value) {
970 let response = ruma_response_from_json(&response_json);
971
972 let alice = user_id!("@alice:example.org");
973 let mut alice_account = Account::with_device_id(alice, "DEVICEID".into());
974 let alice_device = DeviceData::from_account(&alice_account);
975
976 let (manager, _identity_manager) = session_manager_test_helper().await;
977 manager.store.save_device_data(&[alice_device]).await.unwrap();
978
979 let (txn_id, users_for_key_claim) =
982 manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
983 assert!(users_for_key_claim.one_time_keys.contains_key(alice));
984
985 manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
988 assert!(manager.get_missing_sessions(iter::once(alice)).await.unwrap().is_none());
990
991 alice_account.generate_one_time_keys(1);
992 let one_time = alice_account.signed_one_time_keys();
993 assert!(!one_time.is_empty());
994
995 let mut one_time_keys = BTreeMap::new();
996 one_time_keys
997 .entry(alice.to_owned())
998 .or_insert_with(BTreeMap::new)
999 .insert(alice_account.device_id().to_owned(), one_time);
1000
1001 manager
1003 .failed_devices
1004 .write()
1005 .get(alice)
1006 .unwrap()
1007 .expire(&alice_account.device_id().to_owned());
1008 let (txn_id, users_for_key_claim) =
1009 manager.get_missing_sessions(iter::once(alice)).await.unwrap().unwrap();
1010 assert!(users_for_key_claim.one_time_keys.contains_key(alice));
1011
1012 let response = KeyClaimResponse::new(one_time_keys);
1013 manager.receive_keys_claim_response(&txn_id, &response).await.unwrap();
1014
1015 assert!(manager
1017 .failed_devices
1018 .read()
1019 .get(alice)
1020 .unwrap()
1021 .failure_count(alice_account.device_id())
1022 .is_none());
1023 }
1024}