1use std::{collections::BTreeMap, sync::Arc, time::Duration};
16
17use futures_core::Stream;
18use futures_util::{StreamExt, pin_mut};
19use matrix_sdk_base::crypto::store::types::{RoomKeyBundleInfo, RoomPendingKeyBundleDetails};
20#[cfg(feature = "experimental-encrypted-state-events")]
21use matrix_sdk_base::crypto::types::events::room::encrypted::{
22 EncryptedEvent, RoomEventEncryptionScheme,
23};
24use matrix_sdk_common::failures_cache::FailuresCache;
25#[cfg(not(feature = "experimental-encrypted-state-events"))]
26use ruma::events::room::encrypted::{EncryptedEventScheme, OriginalSyncRoomEncryptedEvent};
27#[cfg(feature = "experimental-encrypted-state-events")]
28use ruma::serde::JsonCastable;
29use ruma::{OwnedEventId, OwnedRoomId, serde::Raw};
30use tokio::sync::{Mutex, mpsc};
31use tracing::{debug, info, instrument, trace, warn};
32
33use crate::{
34 Client, Room,
35 client::WeakClient,
36 encryption::backups::UploadState,
37 executor::{JoinHandle, spawn},
38 room::shared_room_history,
39};
40
41type DownloadCache = FailuresCache<RoomKeyInfo>;
43
44#[derive(Default)]
45pub(crate) struct ClientTasks {
46 pub(crate) upload_room_keys: Option<BackupUploadingTask>,
47 pub(crate) download_room_keys: Option<BackupDownloadTask>,
48 pub(crate) update_recovery_state_after_backup: Option<JoinHandle<()>>,
49 pub(crate) receive_historic_room_key_bundles: Option<BundleReceiverTask>,
50 pub(crate) setup_e2ee: Option<JoinHandle<()>>,
51}
52
53pub(crate) struct BackupUploadingTask {
54 sender: mpsc::UnboundedSender<()>,
55 #[allow(dead_code)]
56 join_handle: JoinHandle<()>,
57}
58
59impl Drop for BackupUploadingTask {
60 fn drop(&mut self) {
61 #[cfg(not(target_family = "wasm"))]
62 self.join_handle.abort();
63 }
64}
65
66impl BackupUploadingTask {
67 pub(crate) fn new(client: WeakClient) -> Self {
68 let (sender, receiver) = mpsc::unbounded_channel();
69
70 let join_handle = spawn(async move {
71 Self::listen(client, receiver).await;
72 });
73
74 Self { sender, join_handle }
75 }
76
77 pub(crate) fn trigger_upload(&self) {
78 let _ = self.sender.send(());
79 }
80
81 pub(crate) async fn listen(client: WeakClient, mut receiver: mpsc::UnboundedReceiver<()>) {
82 while receiver.recv().await.is_some() {
83 if let Some(client) = client.get() {
84 let upload_progress = &client.inner.e2ee.backup_state.upload_progress;
85
86 if let Err(e) = client.encryption().backups().backup_room_keys().await {
87 upload_progress.set(UploadState::Error);
88 warn!("Error backing up room keys {e:?}");
89 }
93
94 upload_progress.set(UploadState::Idle);
95 } else {
96 trace!("Client got dropped, shutting down the task");
97 break;
98 }
99 }
100 }
101}
102
103#[derive(Debug)]
106struct RoomKeyDownloadRequest {
107 room_id: OwnedRoomId,
109
110 event_id: OwnedEventId,
112
113 #[cfg(not(feature = "experimental-encrypted-state-events"))]
115 event: Raw<OriginalSyncRoomEncryptedEvent>,
116
117 #[cfg(feature = "experimental-encrypted-state-events")]
119 event: Raw<EncryptedEvent>,
120
121 megolm_session_id: String,
123}
124
125impl RoomKeyDownloadRequest {
126 pub fn to_room_key_info(&self) -> RoomKeyInfo {
127 (self.room_id.clone(), self.megolm_session_id.clone())
128 }
129}
130
131pub type RoomKeyInfo = (OwnedRoomId, String);
132
133pub(crate) struct BackupDownloadTask {
134 sender: mpsc::UnboundedSender<RoomKeyDownloadRequest>,
135 #[allow(dead_code)]
136 join_handle: JoinHandle<()>,
137}
138
139impl Drop for BackupDownloadTask {
140 fn drop(&mut self) {
141 #[cfg(not(target_family = "wasm"))]
142 self.join_handle.abort();
143 }
144}
145
146impl BackupDownloadTask {
147 #[cfg(not(test))]
148 const DOWNLOAD_DELAY_MILLIS: u64 = 100;
149
150 pub(crate) fn new(client: WeakClient) -> Self {
151 let (sender, receiver) = mpsc::unbounded_channel();
152
153 let join_handle = spawn(async move {
154 Self::listen(client, receiver).await;
155 });
156
157 Self { sender, join_handle }
158 }
159
160 #[cfg(not(feature = "experimental-encrypted-state-events"))]
166 pub(crate) fn trigger_download_for_utd_event(
167 &self,
168 room_id: OwnedRoomId,
169 event: Raw<OriginalSyncRoomEncryptedEvent>,
170 ) {
171 if let Ok(deserialized_event) = event.deserialize()
172 && let EncryptedEventScheme::MegolmV1AesSha2(c) = deserialized_event.content.scheme
173 {
174 let _ = self.sender.send(RoomKeyDownloadRequest {
175 room_id,
176 event_id: deserialized_event.event_id,
177 event,
178 megolm_session_id: c.session_id,
179 });
180 }
181 }
182
183 #[cfg(feature = "experimental-encrypted-state-events")]
189 pub(crate) fn trigger_download_for_utd_event<T: JsonCastable<EncryptedEvent>>(
190 &self,
191 room_id: OwnedRoomId,
192 event: Raw<T>,
193 ) {
194 if let Ok(deserialized_event) = event.deserialize_as::<EncryptedEvent>() {
195 if let RoomEventEncryptionScheme::MegolmV1AesSha2(c) = deserialized_event.content.scheme
196 {
197 let _ = self.sender.send(RoomKeyDownloadRequest {
198 room_id,
199 event_id: deserialized_event.event_id,
200 event: event.cast(),
201 megolm_session_id: c.session_id,
202 });
203 }
204 }
205 }
206
207 async fn listen(
216 client: WeakClient,
217 mut receiver: mpsc::UnboundedReceiver<RoomKeyDownloadRequest>,
218 ) {
219 let state = Arc::new(Mutex::new(BackupDownloadTaskListenerState::new(client)));
220
221 while let Some(room_key_download_request) = receiver.recv().await {
222 let mut state_guard = state.lock().await;
223
224 if state_guard.client.strong_count() == 0 {
225 trace!("Client got dropped, shutting down the task");
226 break;
227 }
228
229 let event_id = &room_key_download_request.event_id;
232 if !state_guard.active_tasks.contains_key(event_id) {
233 let event_id = event_id.to_owned();
234 let task =
235 spawn(Self::handle_download_request(state.clone(), room_key_download_request));
236 state_guard.active_tasks.insert(event_id, task);
237 }
238 }
239 }
240
241 async fn handle_download_request(
246 state: Arc<Mutex<BackupDownloadTaskListenerState>>,
247 download_request: RoomKeyDownloadRequest,
248 ) {
249 #[cfg(not(test))]
251 crate::sleep::sleep(Duration::from_millis(Self::DOWNLOAD_DELAY_MILLIS)).await;
252
253 let client = {
256 let mut state = state.lock().await;
257
258 let Some(client) = state.client.get() else {
259 return;
262 };
263
264 if !state.should_download(&client, &download_request).await {
266 state.active_tasks.remove(&download_request.event_id);
269 return;
270 }
271
272 state.downloaded_room_keys.insert(download_request.to_room_key_info());
275
276 client
277 };
278
279 let result = client
281 .encryption()
282 .backups()
283 .download_room_key(&download_request.room_id, &download_request.megolm_session_id)
284 .await;
285
286 {
288 let mut state = state.lock().await;
289 let room_key_info = download_request.to_room_key_info();
290
291 match result {
292 Ok(true) => {
293 state.failures_cache.remove(std::iter::once(&room_key_info))
296 }
297 Ok(false) => {
298 state.downloaded_room_keys.remove(std::iter::once(&room_key_info));
301 }
302 Err(_) => {
303 state.downloaded_room_keys.remove(std::iter::once(&room_key_info));
307 state.failures_cache.insert(room_key_info);
308 }
309 }
310
311 state.active_tasks.remove(&download_request.event_id);
312 }
313 }
314}
315
316struct BackupDownloadTaskListenerState {
318 client: WeakClient,
321
322 failures_cache: FailuresCache<RoomKeyInfo>,
324
325 active_tasks: BTreeMap<OwnedEventId, JoinHandle<()>>,
327
328 downloaded_room_keys: DownloadCache,
335}
336
337impl BackupDownloadTaskListenerState {
338 pub fn new(client: WeakClient) -> Self {
345 Self {
346 client,
347 failures_cache: FailuresCache::with_settings(Duration::from_secs(60 * 60 * 24), 60),
348 active_tasks: Default::default(),
349 downloaded_room_keys: DownloadCache::with_settings(
350 Duration::from_secs(60 * 60 * 24),
351 60,
352 ),
353 }
354 }
355
356 pub async fn should_download(
366 &self,
367 client: &Client,
368 download_request: &RoomKeyDownloadRequest,
369 ) -> bool {
370 let machine_guard = client.olm_machine().await;
372 let Some(machine) = machine_guard.as_ref() else {
373 return false;
374 };
375
376 if !client.encryption().backups().are_enabled().await {
378 debug!(
379 ?download_request,
380 "Not performing backup download because backups are not enabled"
381 );
382
383 return false;
384 }
385
386 if machine
391 .is_room_key_available(
392 #[cfg(not(feature = "experimental-encrypted-state-events"))]
393 download_request.event.cast_ref(),
394 #[cfg(feature = "experimental-encrypted-state-events")]
395 &download_request.event,
396 &download_request.room_id,
397 )
398 .await
399 .unwrap_or(false)
400 {
401 debug!(
402 ?download_request,
403 "Not performing backup download because key became available while we were sleeping"
404 );
405 return false;
406 }
407
408 let room_key_info = download_request.to_room_key_info();
411 if self.downloaded_room_keys.contains(&room_key_info) {
412 debug!(
413 ?download_request,
414 "Not performing backup download because this room key has already been downloaded recently"
415 );
416 return false;
417 }
418
419 if self.failures_cache.contains(&room_key_info) {
421 debug!(
422 ?download_request,
423 "Not performing backup download because this room key failed to download recently"
424 );
425 return false;
426 }
427
428 debug!(?download_request, "Performing backup download");
429 true
430 }
431}
432
433pub(crate) struct BundleReceiverTask {
434 _startup_handle: JoinHandle<()>,
435 _listen_handle: JoinHandle<()>,
436}
437
438impl BundleReceiverTask {
439 pub async fn new(client: &Client) -> Self {
440 let stream = client.encryption().historic_room_key_stream().await.expect("E2EE tasks should only be initialized once we have logged in and have access to an OlmMachine");
441 let weak_client = WeakClient::from_client(client);
442 Self {
443 _listen_handle: spawn(Self::listen_task(weak_client.clone(), stream)),
444 _startup_handle: spawn(Self::startup_task(weak_client)),
445 }
446 }
447
448 async fn listen_task(client: WeakClient, stream: impl Stream<Item = RoomKeyBundleInfo>) {
449 pin_mut!(stream);
450
451 while let Some(bundle_info) = stream.next().await {
456 let Some(client) = client.get() else {
457 break;
460 };
461
462 let Some(room) = client.get_room(&bundle_info.room_id) else {
463 warn!(room_id = %bundle_info.room_id, "Received a historic room key bundle for an unknown room");
464 continue;
465 };
466
467 Self::handle_bundle(&room, &bundle_info).await;
468 }
469 }
470
471 #[tracing::instrument(skip_all)]
476 async fn startup_task(client: WeakClient) {
477 tracing::debug!("Checking for unimported stored room key bundles...");
478
479 let Some(client) = client.get() else {
480 return;
482 };
483
484 let olm_machine = client.olm_machine().await;
485 let Some(olm_machine) = olm_machine.as_ref() else {
486 tracing::warn!("Skipping startup bundle checks because the Olm machine is unavailable");
491 return;
492 };
493
494 let room_details = match olm_machine.store().get_all_rooms_pending_key_bundles().await {
495 Ok(room_details) => room_details,
496 Err(e) => {
497 tracing::warn!("Error while fetching rooms pending key bundles: {e:?}");
498 return;
499 }
500 };
501
502 let (valid, invalid): (Vec<_>, Vec<_>) = room_details.iter().partition(|details| {
505 shared_room_history::should_process_room_pending_key_bundle_details(details)
506 });
507
508 tracing::debug!(
509 "Found {} valid and {} invalid rooms that are still pending key bundles",
510 valid.len(),
511 invalid.len(),
512 );
513
514 for RoomPendingKeyBundleDetails { room_id, inviter, .. } in valid {
518 let Some(room) = client.get_room(room_id) else {
519 tracing::trace!(?room_id, "Room not available in state store, skipping...");
521 continue;
522 };
523 let bundle =
524 match olm_machine.store().get_received_room_key_bundle_data(room_id, inviter).await
525 {
526 Ok(Some(bundle)) => bundle,
527 Ok(None) => {
528 tracing::trace!(?room_id, "No bundle available, skipping...");
531 continue;
532 }
533 Err(err) => {
534 tracing::warn!(
535 ?room_id,
536 "Failed to fetch received room key bundle data: {err:?}"
537 );
538 continue;
539 }
540 };
541 Self::handle_bundle(&room, &(&bundle).into()).await;
542 }
543
544 for RoomPendingKeyBundleDetails { room_id, .. } in &invalid {
547 tracing::trace!(?room_id, "Clearing pending flag for room");
548 if let Err(e) = olm_machine.store().clear_room_pending_key_bundle(room_id).await {
549 tracing::warn!("Error clearing room pending key bundle: {e:?}");
550 }
551 }
552 }
553
554 #[instrument(skip(room), fields(room_id = %room.room_id()))]
572 async fn handle_bundle(room: &Room, bundle_info: &RoomKeyBundleInfo) {
573 if shared_room_history::should_accept_key_bundle(room, bundle_info).await {
574 info!(room_id = %room.room_id(), "Accepting a late key bundle.");
575
576 if let Err(e) =
577 shared_room_history::maybe_accept_key_bundle(room, &bundle_info.sender).await
578 {
579 warn!("Couldn't accept a late room key bundle {e:?}");
580 }
581 } else {
582 info!("Refusing to accept a historic room key bundle.");
583 }
584 }
585
586 #[cfg(any(feature = "testing", test))]
587 pub(crate) fn abort(&self) {
588 self._startup_handle.abort();
589 self._listen_handle.abort();
590 }
591}
592
593#[cfg(all(test, not(target_family = "wasm")))]
594mod test {
595 use matrix_sdk_test::async_test;
596 #[cfg(not(feature = "experimental-encrypted-state-events"))]
597 use ruma::events::room::encrypted::OriginalSyncRoomEncryptedEvent;
598 use ruma::{event_id, room_id};
599 use serde_json::json;
600 use wiremock::MockServer;
601
602 use super::*;
603 use crate::test_utils::logged_in_client;
604
605 #[async_test]
608 async fn test_disabled_backup_does_not_mark_room_key_as_downloaded() {
609 let room_id = room_id!("!DovneieKSTkdHKpIXy:morpheus.localhost");
610 let event_id = event_id!("$JbFHtZpEJiH8uaajZjPLz0QUZc1xtBR9rPGBOjF6WFM");
611 let session_id = "session_id";
612
613 let server = MockServer::start().await;
614 let client = logged_in_client(Some(server.uri())).await;
615 let weak_client = WeakClient::from_client(&client);
616
617 let event_content = json!({
618 "event_id": event_id,
619 "origin_server_ts": 1698579035927u64,
620 "sender": "@example2:morpheus.localhost",
621 "type": "m.room.encrypted",
622 "content": {
623 "algorithm": "m.megolm.v1.aes-sha2",
624 "ciphertext": "AwgAEpABhetEzzZzyYrxtEVUtlJnZtJcURBlQUQJ9irVeklCTs06LwgTMQj61PMUS4Vy\
625 YOX+PD67+hhU40/8olOww+Ud0m2afjMjC3wFX+4fFfSkoWPVHEmRVucfcdSF1RSB4EmK\
626 PIP4eo1X6x8kCIMewBvxl2sI9j4VNvDvAN7M3zkLJfFLOFHbBviI4FN7hSFHFeM739Zg\
627 iwxEs3hIkUXEiAfrobzaMEM/zY7SDrTdyffZndgJo7CZOVhoV6vuaOhmAy4X2t4UnbuV\
628 JGJjKfV57NAhp8W+9oT7ugwO",
629 "device_id": "KIUVQQSDTM",
630 "sender_key": "LvryVyoCjdONdBCi2vvoSbI34yTOx7YrCFACUEKoXnc",
631 "session_id": "64H7XKokIx0ASkYDHZKlT5zd/Zccz/cQspPNdvnNULA"
632 }
633 });
634
635 #[cfg(not(feature = "experimental-encrypted-state-events"))]
636 let event: Raw<OriginalSyncRoomEncryptedEvent> =
637 serde_json::from_value(event_content).expect("");
638
639 #[cfg(feature = "experimental-encrypted-state-events")]
640 let event: Raw<EncryptedEvent> = serde_json::from_value(event_content).expect("");
641
642 let state = Arc::new(Mutex::new(BackupDownloadTaskListenerState::new(weak_client)));
643 let download_request = RoomKeyDownloadRequest {
644 room_id: room_id.into(),
645 megolm_session_id: session_id.to_owned(),
646 event,
647 event_id: event_id.into(),
648 };
649
650 assert!(
651 !client.encryption().backups().are_enabled().await,
652 "Backups should not be enabled."
653 );
654
655 BackupDownloadTask::handle_download_request(state.clone(), download_request).await;
656
657 {
658 let state = state.lock().await;
659 assert!(
660 !state.downloaded_room_keys.contains(&(room_id.to_owned(), session_id.to_owned())),
661 "Backups are not enabled, we should not mark any room keys as downloaded."
662 )
663 }
664 }
665}