1use std::{collections::BTreeMap, sync::Arc, time::Duration};
16
17use futures_core::Stream;
18use futures_util::{pin_mut, StreamExt};
19#[cfg(feature = "experimental-encrypted-state-events")]
20use matrix_sdk_base::crypto::types::events::room::encrypted::{
21 EncryptedEvent, RoomEventEncryptionScheme,
22};
23use matrix_sdk_base::{
24 crypto::store::types::RoomKeyBundleInfo, InviteAcceptanceDetails, RoomState,
25};
26use matrix_sdk_common::failures_cache::FailuresCache;
27#[cfg(not(feature = "experimental-encrypted-state-events"))]
28use ruma::events::room::encrypted::{EncryptedEventScheme, OriginalSyncRoomEncryptedEvent};
29#[cfg(feature = "experimental-encrypted-state-events")]
30use ruma::serde::JsonCastable;
31use ruma::{serde::Raw, OwnedEventId, OwnedRoomId};
32use tokio::sync::{mpsc, Mutex};
33use tracing::{debug, info, instrument, trace, warn};
34
35use crate::{
36 client::WeakClient,
37 encryption::backups::UploadState,
38 executor::{spawn, JoinHandle},
39 room::shared_room_history,
40 Client, Room,
41};
42
43type DownloadCache = FailuresCache<RoomKeyInfo>;
45
46#[derive(Default)]
47pub(crate) struct ClientTasks {
48 pub(crate) upload_room_keys: Option<BackupUploadingTask>,
49 pub(crate) download_room_keys: Option<BackupDownloadTask>,
50 pub(crate) update_recovery_state_after_backup: Option<JoinHandle<()>>,
51 pub(crate) receive_historic_room_key_bundles: Option<BundleReceiverTask>,
52 pub(crate) setup_e2ee: Option<JoinHandle<()>>,
53}
54
55pub(crate) struct BackupUploadingTask {
56 sender: mpsc::UnboundedSender<()>,
57 #[allow(dead_code)]
58 join_handle: JoinHandle<()>,
59}
60
61impl Drop for BackupUploadingTask {
62 fn drop(&mut self) {
63 #[cfg(not(target_family = "wasm"))]
64 self.join_handle.abort();
65 }
66}
67
68impl BackupUploadingTask {
69 pub(crate) fn new(client: WeakClient) -> Self {
70 let (sender, receiver) = mpsc::unbounded_channel();
71
72 let join_handle = spawn(async move {
73 Self::listen(client, receiver).await;
74 });
75
76 Self { sender, join_handle }
77 }
78
79 pub(crate) fn trigger_upload(&self) {
80 let _ = self.sender.send(());
81 }
82
83 pub(crate) async fn listen(client: WeakClient, mut receiver: mpsc::UnboundedReceiver<()>) {
84 while receiver.recv().await.is_some() {
85 if let Some(client) = client.get() {
86 let upload_progress = &client.inner.e2ee.backup_state.upload_progress;
87
88 if let Err(e) = client.encryption().backups().backup_room_keys().await {
89 upload_progress.set(UploadState::Error);
90 warn!("Error backing up room keys {e:?}");
91 }
95
96 upload_progress.set(UploadState::Idle);
97 } else {
98 trace!("Client got dropped, shutting down the task");
99 break;
100 }
101 }
102 }
103}
104
105#[derive(Debug)]
108struct RoomKeyDownloadRequest {
109 room_id: OwnedRoomId,
111
112 event_id: OwnedEventId,
114
115 #[cfg(not(feature = "experimental-encrypted-state-events"))]
117 event: Raw<OriginalSyncRoomEncryptedEvent>,
118
119 #[cfg(feature = "experimental-encrypted-state-events")]
121 event: Raw<EncryptedEvent>,
122
123 megolm_session_id: String,
125}
126
127impl RoomKeyDownloadRequest {
128 pub fn to_room_key_info(&self) -> RoomKeyInfo {
129 (self.room_id.clone(), self.megolm_session_id.clone())
130 }
131}
132
133pub type RoomKeyInfo = (OwnedRoomId, String);
134
135pub(crate) struct BackupDownloadTask {
136 sender: mpsc::UnboundedSender<RoomKeyDownloadRequest>,
137 #[allow(dead_code)]
138 join_handle: JoinHandle<()>,
139}
140
141impl Drop for BackupDownloadTask {
142 fn drop(&mut self) {
143 #[cfg(not(target_family = "wasm"))]
144 self.join_handle.abort();
145 }
146}
147
148impl BackupDownloadTask {
149 #[cfg(not(test))]
150 const DOWNLOAD_DELAY_MILLIS: u64 = 100;
151
152 pub(crate) fn new(client: WeakClient) -> Self {
153 let (sender, receiver) = mpsc::unbounded_channel();
154
155 let join_handle = spawn(async move {
156 Self::listen(client, receiver).await;
157 });
158
159 Self { sender, join_handle }
160 }
161
162 #[cfg(not(feature = "experimental-encrypted-state-events"))]
168 pub(crate) fn trigger_download_for_utd_event(
169 &self,
170 room_id: OwnedRoomId,
171 event: Raw<OriginalSyncRoomEncryptedEvent>,
172 ) {
173 if let Ok(deserialized_event) = event.deserialize() {
174 if let EncryptedEventScheme::MegolmV1AesSha2(c) = deserialized_event.content.scheme {
175 let _ = self.sender.send(RoomKeyDownloadRequest {
176 room_id,
177 event_id: deserialized_event.event_id,
178 event,
179 megolm_session_id: c.session_id,
180 });
181 }
182 }
183 }
184
185 #[cfg(feature = "experimental-encrypted-state-events")]
191 pub(crate) fn trigger_download_for_utd_event<T: JsonCastable<EncryptedEvent>>(
192 &self,
193 room_id: OwnedRoomId,
194 event: Raw<T>,
195 ) {
196 if let Ok(deserialized_event) = event.deserialize_as::<EncryptedEvent>() {
197 if let RoomEventEncryptionScheme::MegolmV1AesSha2(c) = deserialized_event.content.scheme
198 {
199 let _ = self.sender.send(RoomKeyDownloadRequest {
200 room_id,
201 event_id: deserialized_event.event_id,
202 event: event.cast(),
203 megolm_session_id: c.session_id,
204 });
205 }
206 }
207 }
208
209 async fn listen(
218 client: WeakClient,
219 mut receiver: mpsc::UnboundedReceiver<RoomKeyDownloadRequest>,
220 ) {
221 let state = Arc::new(Mutex::new(BackupDownloadTaskListenerState::new(client)));
222
223 while let Some(room_key_download_request) = receiver.recv().await {
224 let mut state_guard = state.lock().await;
225
226 if state_guard.client.strong_count() == 0 {
227 trace!("Client got dropped, shutting down the task");
228 break;
229 }
230
231 let event_id = &room_key_download_request.event_id;
234 if !state_guard.active_tasks.contains_key(event_id) {
235 let event_id = event_id.to_owned();
236 let task =
237 spawn(Self::handle_download_request(state.clone(), room_key_download_request));
238 state_guard.active_tasks.insert(event_id, task);
239 }
240 }
241 }
242
243 async fn handle_download_request(
248 state: Arc<Mutex<BackupDownloadTaskListenerState>>,
249 download_request: RoomKeyDownloadRequest,
250 ) {
251 #[cfg(not(test))]
253 crate::sleep::sleep(Duration::from_millis(Self::DOWNLOAD_DELAY_MILLIS)).await;
254
255 let client = {
258 let mut state = state.lock().await;
259
260 let Some(client) = state.client.get() else {
261 return;
264 };
265
266 if !state.should_download(&client, &download_request).await {
268 state.active_tasks.remove(&download_request.event_id);
271 return;
272 }
273
274 state.downloaded_room_keys.insert(download_request.to_room_key_info());
277
278 client
279 };
280
281 let result = client
283 .encryption()
284 .backups()
285 .download_room_key(&download_request.room_id, &download_request.megolm_session_id)
286 .await;
287
288 {
290 let mut state = state.lock().await;
291 let room_key_info = download_request.to_room_key_info();
292
293 match result {
294 Ok(true) => {
295 state.failures_cache.remove(std::iter::once(&room_key_info))
298 }
299 Ok(false) => {
300 state.downloaded_room_keys.remove(std::iter::once(&room_key_info));
303 }
304 Err(_) => {
305 state.downloaded_room_keys.remove(std::iter::once(&room_key_info));
309 state.failures_cache.insert(room_key_info);
310 }
311 }
312
313 state.active_tasks.remove(&download_request.event_id);
314 }
315 }
316}
317
318struct BackupDownloadTaskListenerState {
320 client: WeakClient,
323
324 failures_cache: FailuresCache<RoomKeyInfo>,
326
327 active_tasks: BTreeMap<OwnedEventId, JoinHandle<()>>,
329
330 downloaded_room_keys: DownloadCache,
337}
338
339impl BackupDownloadTaskListenerState {
340 pub fn new(client: WeakClient) -> Self {
347 Self {
348 client,
349 failures_cache: FailuresCache::with_settings(Duration::from_secs(60 * 60 * 24), 60),
350 active_tasks: Default::default(),
351 downloaded_room_keys: DownloadCache::with_settings(
352 Duration::from_secs(60 * 60 * 24),
353 60,
354 ),
355 }
356 }
357
358 pub async fn should_download(
368 &self,
369 client: &Client,
370 download_request: &RoomKeyDownloadRequest,
371 ) -> bool {
372 let machine_guard = client.olm_machine().await;
374 let Some(machine) = machine_guard.as_ref() else {
375 return false;
376 };
377
378 if !client.encryption().backups().are_enabled().await {
380 debug!(
381 ?download_request,
382 "Not performing backup download because backups are not enabled"
383 );
384
385 return false;
386 }
387
388 if machine
393 .is_room_key_available(
394 #[cfg(not(feature = "experimental-encrypted-state-events"))]
395 download_request.event.cast_ref(),
396 #[cfg(feature = "experimental-encrypted-state-events")]
397 &download_request.event,
398 &download_request.room_id,
399 )
400 .await
401 .unwrap_or(false)
402 {
403 debug!(
404 ?download_request,
405 "Not performing backup download because key became available while we were sleeping"
406 );
407 return false;
408 }
409
410 let room_key_info = download_request.to_room_key_info();
413 if self.downloaded_room_keys.contains(&room_key_info) {
414 debug!(
415 ?download_request,
416 "Not performing backup download because this room key has already been downloaded recently"
417 );
418 return false;
419 }
420
421 if self.failures_cache.contains(&room_key_info) {
423 debug!(
424 ?download_request,
425 "Not performing backup download because this room key failed to download recently"
426 );
427 return false;
428 }
429
430 debug!(?download_request, "Performing backup download");
431 true
432 }
433}
434
435pub(crate) struct BundleReceiverTask {
436 _handle: JoinHandle<()>,
437}
438
439impl BundleReceiverTask {
440 pub async fn new(client: &Client) -> Self {
441 let stream = client.encryption().historic_room_key_stream().await.expect("E2EE tasks should only be initialized once we have logged in and have access to an OlmMachine");
442 let weak_client = WeakClient::from_client(client);
443 let handle = spawn(Self::listen_task(weak_client, stream));
444
445 Self { _handle: handle }
446 }
447
448 async fn listen_task(client: WeakClient, stream: impl Stream<Item = RoomKeyBundleInfo>) {
449 pin_mut!(stream);
450
451 while let Some(bundle_info) = stream.next().await {
456 let Some(client) = client.get() else {
457 break;
460 };
461
462 let Some(room) = client.get_room(&bundle_info.room_id) else {
463 warn!(room_id = %bundle_info.room_id, "Received a historic room key bundle for an unknown room");
464 continue;
465 };
466
467 Self::handle_bundle(&room, &bundle_info).await;
468 }
469 }
470
471 #[instrument(skip(room), fields(room_id = %room.room_id()))]
472 async fn handle_bundle(room: &Room, bundle_info: &RoomKeyBundleInfo) {
473 if Self::should_accept_bundle(room, bundle_info) {
474 info!("Accepting a late key bundle.");
475
476 if let Err(e) =
477 shared_room_history::maybe_accept_key_bundle(room, &bundle_info.sender).await
478 {
479 warn!("Couldn't accept a late room key bundle {e:?}");
480 }
481 } else {
482 info!("Refusing to accept a historic room key bundle.");
483 }
484 }
485
486 fn should_accept_bundle(room: &Room, bundle_info: &RoomKeyBundleInfo) -> bool {
487 const DAY: Duration = Duration::from_secs(24 * 60 * 60);
490
491 let Some(InviteAcceptanceDetails { invite_accepted_at, inviter }) =
494 room.invite_acceptance_details()
495 else {
496 return false;
497 };
498
499 let state = room.state();
500 let elapsed_since_join = invite_accepted_at.to_system_time().and_then(|t| t.elapsed().ok());
501 let bundle_sender = &bundle_info.sender;
502
503 match (state, elapsed_since_join) {
504 (RoomState::Joined, Some(elapsed_since_join)) => {
505 elapsed_since_join < DAY && bundle_sender == &inviter
506 }
507 (RoomState::Joined, None) => false,
508 (RoomState::Left | RoomState::Invited | RoomState::Knocked | RoomState::Banned, _) => {
509 false
510 }
511 }
512 }
513}
514
515#[cfg(all(test, not(target_family = "wasm")))]
516mod test {
517 use matrix_sdk_test::{
518 async_test, event_factory::EventFactory, InvitedRoomBuilder, JoinedRoomBuilder,
519 };
520 #[cfg(not(feature = "experimental-encrypted-state-events"))]
521 use ruma::events::room::encrypted::OriginalSyncRoomEncryptedEvent;
522 use ruma::{event_id, room_id, user_id};
523 use serde_json::json;
524 use vodozemac::Curve25519PublicKey;
525 use wiremock::MockServer;
526
527 use super::*;
528 use crate::test_utils::{logged_in_client, mocks::MatrixMockServer};
529
530 #[async_test]
533 async fn test_disabled_backup_does_not_mark_room_key_as_downloaded() {
534 let room_id = room_id!("!DovneieKSTkdHKpIXy:morpheus.localhost");
535 let event_id = event_id!("$JbFHtZpEJiH8uaajZjPLz0QUZc1xtBR9rPGBOjF6WFM");
536 let session_id = "session_id";
537
538 let server = MockServer::start().await;
539 let client = logged_in_client(Some(server.uri())).await;
540 let weak_client = WeakClient::from_client(&client);
541
542 let event_content = json!({
543 "event_id": event_id,
544 "origin_server_ts": 1698579035927u64,
545 "sender": "@example2:morpheus.localhost",
546 "type": "m.room.encrypted",
547 "content": {
548 "algorithm": "m.megolm.v1.aes-sha2",
549 "ciphertext": "AwgAEpABhetEzzZzyYrxtEVUtlJnZtJcURBlQUQJ9irVeklCTs06LwgTMQj61PMUS4Vy\
550 YOX+PD67+hhU40/8olOww+Ud0m2afjMjC3wFX+4fFfSkoWPVHEmRVucfcdSF1RSB4EmK\
551 PIP4eo1X6x8kCIMewBvxl2sI9j4VNvDvAN7M3zkLJfFLOFHbBviI4FN7hSFHFeM739Zg\
552 iwxEs3hIkUXEiAfrobzaMEM/zY7SDrTdyffZndgJo7CZOVhoV6vuaOhmAy4X2t4UnbuV\
553 JGJjKfV57NAhp8W+9oT7ugwO",
554 "device_id": "KIUVQQSDTM",
555 "sender_key": "LvryVyoCjdONdBCi2vvoSbI34yTOx7YrCFACUEKoXnc",
556 "session_id": "64H7XKokIx0ASkYDHZKlT5zd/Zccz/cQspPNdvnNULA"
557 }
558 });
559
560 #[cfg(not(feature = "experimental-encrypted-state-events"))]
561 let event: Raw<OriginalSyncRoomEncryptedEvent> =
562 serde_json::from_value(event_content).expect("");
563
564 #[cfg(feature = "experimental-encrypted-state-events")]
565 let event: Raw<EncryptedEvent> = serde_json::from_value(event_content).expect("");
566
567 let state = Arc::new(Mutex::new(BackupDownloadTaskListenerState::new(weak_client)));
568 let download_request = RoomKeyDownloadRequest {
569 room_id: room_id.into(),
570 megolm_session_id: session_id.to_owned(),
571 event,
572 event_id: event_id.into(),
573 };
574
575 assert!(
576 !client.encryption().backups().are_enabled().await,
577 "Backups should not be enabled."
578 );
579
580 BackupDownloadTask::handle_download_request(state.clone(), download_request).await;
581
582 {
583 let state = state.lock().await;
584 assert!(
585 !state.downloaded_room_keys.contains(&(room_id.to_owned(), session_id.to_owned())),
586 "Backups are not enabled, we should not mark any room keys as downloaded."
587 )
588 }
589 }
590
591 #[async_test]
594 async fn test_should_accept_bundle() {
595 let server = MatrixMockServer::new().await;
596
597 let alice_user_id = user_id!("@alice:localhost");
598 let bob_user_id = user_id!("@bob:localhost");
599 let joined_room_id = room_id!("!joined:localhost");
600 let invited_rom_id = room_id!("!invited:localhost");
601
602 let client = server
603 .client_builder()
604 .logged_in_with_token("ABCD".to_owned(), alice_user_id.into(), "DEVICEID".into())
605 .build()
606 .await;
607
608 let event_factory = EventFactory::new().room(invited_rom_id);
609 let bob_member_event = event_factory.member(bob_user_id).into_raw();
610 let alice_member_event =
611 event_factory.member(bob_user_id).invited(alice_user_id).into_raw();
612
613 server
614 .mock_sync()
615 .ok_and_run(&client, |builder| {
616 builder.add_joined_room(JoinedRoomBuilder::new(joined_room_id)).add_invited_room(
617 InvitedRoomBuilder::new(invited_rom_id)
618 .add_state_event(bob_member_event)
619 .add_state_event(alice_member_event),
620 );
621 })
622 .await;
623
624 let room =
625 client.get_room(joined_room_id).expect("We should have access to our joined room now");
626
627 assert!(
628 room.invite_acceptance_details().is_none(),
629 "We shouldn't have any invite acceptance details if we didn't join the room on this Client"
630 );
631
632 let bundle_info = RoomKeyBundleInfo {
633 sender: bob_user_id.to_owned(),
634 sender_key: Curve25519PublicKey::from_bytes([0u8; 32]),
635 room_id: joined_room_id.to_owned(),
636 };
637
638 assert!(
639 !BundleReceiverTask::should_accept_bundle(&room, &bundle_info),
640 "We should not acceept a bundle if we did not join the room from this Client"
641 );
642
643 let invited_room =
644 client.get_room(invited_rom_id).expect("We should have access to our invited room now");
645
646 assert!(
647 !BundleReceiverTask::should_accept_bundle(&invited_room, &bundle_info),
648 "We should not accept a bundle if we didn't join the room."
649 );
650
651 server.mock_room_join(invited_rom_id).ok().mock_once().mount().await;
652
653 let room = client
654 .join_room_by_id(invited_rom_id)
655 .await
656 .expect("We should be able to join the invited room");
657
658 let details = room
659 .invite_acceptance_details()
660 .expect("We should have stored the invite acceptance details");
661 assert_eq!(details.inviter, bob_user_id, "We should have recorded that Bob has invited us");
662
663 assert!(
664 BundleReceiverTask::should_accept_bundle(&room, &bundle_info),
665 "We should accept a bundle if we just joined the room and did so from this very Client object"
666 );
667 }
668}