1use std::{
16 collections::{BTreeMap, BTreeSet},
17 sync::Arc,
18};
19
20use futures_core::Stream;
21use futures_util::pin_mut;
22use imbl::Vector;
23use itertools::{Either, Itertools as _};
24use matrix_sdk::{
25 Client, Room,
26 crypto::store::types::RoomKeyInfo,
27 deserialized_responses::TimelineEventKind as SdkTimelineEventKind,
28 encryption::backups::BackupState,
29 event_handler::EventHandlerHandle,
30 executor::{JoinHandle, spawn},
31};
32use tokio::sync::{
33 RwLock,
34 mpsc::{self, Receiver, Sender},
35};
36use tokio_stream::{StreamExt as _, wrappers::errors::BroadcastStreamRecvError};
37use tracing::{Instrument as _, debug, error, field, info, info_span, warn};
38
39use crate::timeline::{
40 EncryptedMessage, EventTimelineItem, TimelineController, TimelineItem, TimelineItemKind,
41 controller::{TimelineSettings, TimelineState},
42 event_item::EventTimelineItemKind,
43 to_device::{handle_forwarded_room_key_event, handle_room_key_event},
44 traits::{Decryptor, RoomDataProvider},
45};
46
47#[derive(Debug)]
50pub(in crate::timeline) struct CryptoDropHandles {
51 client: Client,
52 event_handler_handles: Vec<EventHandlerHandle>,
53 room_key_from_backups_join_handle: JoinHandle<()>,
54 room_keys_received_join_handle: JoinHandle<()>,
55 room_key_backup_enabled_join_handle: JoinHandle<()>,
56 encryption_changes_handle: JoinHandle<()>,
57}
58
59impl Drop for CryptoDropHandles {
60 fn drop(&mut self) {
61 for handle in self.event_handler_handles.drain(..) {
62 self.client.remove_event_handler(handle);
63 }
64
65 self.room_key_from_backups_join_handle.abort();
66 self.room_keys_received_join_handle.abort();
67 self.room_key_backup_enabled_join_handle.abort();
68 self.encryption_changes_handle.abort();
69 }
70}
71
72async fn room_keys_from_backups_task<S>(stream: S, timeline_controller: TimelineController)
74where
75 S: Stream<Item = Result<BTreeMap<String, BTreeSet<String>>, BroadcastStreamRecvError>>,
76{
77 pin_mut!(stream);
78
79 while let Some(update) = stream.next().await {
80 match update {
81 Ok(info) => {
82 let mut session_ids = BTreeSet::new();
83
84 for set in info.into_values() {
85 session_ids.extend(set);
86 }
87
88 timeline_controller.retry_event_decryption(Some(session_ids)).await;
89 }
90 Err(_) => timeline_controller.retry_event_decryption(None).await,
92 }
93 }
94}
95
96async fn backup_states_task<S>(backup_states_stream: S, timeline_controller: TimelineController)
98where
99 S: Stream<Item = Result<BackupState, BroadcastStreamRecvError>>,
100{
101 pin_mut!(backup_states_stream);
102
103 while let Some(update) = backup_states_stream.next().await {
104 match update {
105 Ok(BackupState::Enabled) | Err(_) => {
114 timeline_controller.retry_event_decryption(None).await;
115 }
116 Ok(
119 BackupState::Unknown
120 | BackupState::Creating
121 | BackupState::Resuming
122 | BackupState::Disabling
123 | BackupState::Downloading
124 | BackupState::Enabling,
125 ) => (),
126 }
127 }
128}
129
130async fn room_key_received_task<S>(
132 room_keys_received_stream: S,
133 timeline_controller: TimelineController,
134) where
135 S: Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>>,
136{
137 pin_mut!(room_keys_received_stream);
138
139 let room_id = timeline_controller.room().room_id();
140
141 while let Some(room_keys) = room_keys_received_stream.next().await {
142 let session_ids = match room_keys {
143 Ok(room_keys) => {
144 let session_ids: BTreeSet<String> = room_keys
145 .into_iter()
146 .filter(|info| info.room_id == room_id)
147 .map(|info| info.session_id)
148 .collect();
149
150 Some(session_ids)
151 }
152 Err(BroadcastStreamRecvError::Lagged(missed_updates)) => {
153 warn!(
156 missed_updates,
157 "The room keys stream has lagged, retrying to decrypt the whole timeline"
158 );
159
160 None
161 }
162 };
163
164 timeline_controller.retry_event_decryption(session_ids).await;
165 }
166}
167
168pub(in crate::timeline) async fn spawn_crypto_tasks(
171 room: Room,
172 controller: TimelineController,
173) -> CryptoDropHandles {
174 let room_key_handle = room
175 .client()
176 .add_event_handler(handle_room_key_event(controller.clone(), room.room_id().to_owned()));
177
178 let client = room.client();
179 let forwarded_room_key_handle = client.add_event_handler(handle_forwarded_room_key_event(
180 controller.clone(),
181 room.room_id().to_owned(),
182 ));
183
184 let event_handlers = vec![room_key_handle, forwarded_room_key_handle];
185
186 let room_key_from_backups_join_handle = spawn(room_keys_from_backups_task(
190 client.encryption().backups().room_keys_for_room_stream(controller.room().room_id()),
191 controller.clone(),
192 ));
193
194 let room_key_backup_enabled_join_handle =
195 spawn(backup_states_task(client.encryption().backups().state_stream(), controller.clone()));
196
197 let room_keys_received_join_handle = {
203 spawn(room_key_received_task(
204 client.encryption().room_keys_received_stream().await.expect(
205 "We should be logged in by now, so we should have access to an `OlmMachine` \
206 to be able to listen to this stream",
207 ),
208 controller.clone(),
209 ))
210 };
211
212 CryptoDropHandles {
213 client,
214 event_handler_handles: event_handlers,
215 room_key_from_backups_join_handle,
216 room_keys_received_join_handle,
217 room_key_backup_enabled_join_handle,
218 encryption_changes_handle: spawn(async move {
219 controller.handle_encryption_state_changes().await
220 }),
221 }
222}
223
224#[derive(Clone, Debug)]
234pub struct DecryptionRetryTask<P: RoomDataProvider, D: Decryptor> {
235 sender: Sender<DecryptionRetryRequest<D>>,
240
241 _task_handle: Arc<JoinHandle<()>>,
246
247 _phantom: std::marker::PhantomData<P>,
248}
249
250const CHANNEL_BUFFER_SIZE: usize = 100;
254
255impl<P: RoomDataProvider, D: Decryptor> DecryptionRetryTask<P, D> {
256 pub(crate) fn new(state: Arc<RwLock<TimelineState<P>>>, room_data_provider: P) -> Self {
257 let (sender, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
259
260 let handle = spawn(decryption_task(state, room_data_provider, receiver));
263
264 Self { sender, _task_handle: Arc::new(handle), _phantom: Default::default() }
266 }
267
268 pub(crate) async fn decrypt(
271 &self,
272 decryptor: D,
273 session_ids: Option<BTreeSet<String>>,
274 settings: TimelineSettings,
275 ) {
276 let res =
277 self.sender.send(DecryptionRetryRequest { decryptor, session_ids, settings }).await;
278
279 if let Err(error) = res {
280 error!("Failed to send decryption retry request: {error}");
281 }
282 }
283}
284
285struct DecryptionRetryRequest<D: Decryptor> {
288 decryptor: D,
289 session_ids: Option<BTreeSet<String>>,
290 settings: TimelineSettings,
291}
292
293async fn decryption_task<P: RoomDataProvider, D: Decryptor>(
297 state: Arc<RwLock<TimelineState<P>>>,
298 room_data_provider: P,
299 mut receiver: Receiver<DecryptionRetryRequest<D>>,
300) {
301 debug!("Decryption task starting.");
302
303 while let Some(request) = receiver.recv().await {
304 let should_retry = |session_id: &str| {
305 if let Some(session_ids) = &request.session_ids {
306 session_ids.contains(session_id)
307 } else {
308 true
309 }
310 };
311
312 let mut state = state.write().await;
316 let (retry_decryption_indices, retry_info_indices) =
317 compute_event_indices_to_retry_decryption(&state.items, should_retry);
318
319 if !retry_info_indices.is_empty() {
321 debug!("Retrying fetching encryption info");
322 retry_fetch_encryption_info(&mut state, retry_info_indices, &room_data_provider).await;
323 }
324
325 if !retry_decryption_indices.is_empty() {
327 debug!("Retrying decryption");
328 decrypt_by_index(
329 &mut state,
330 &request.settings,
331 &room_data_provider,
332 request.decryptor,
333 should_retry,
334 retry_decryption_indices,
335 )
336 .await
337 }
338 }
339
340 debug!("Decryption task stopping.");
341}
342
343fn compute_event_indices_to_retry_decryption(
351 items: &Vector<Arc<TimelineItem>>,
352 should_retry: impl Fn(&str) -> bool,
353) -> (Vec<usize>, Vec<usize>) {
354 use Either::{Left, Right};
355
356 let should_retry_event = |event: &EventTimelineItem| {
358 let session_id = if let Some(encrypted_message) = event.content().as_unable_to_decrypt() {
359 encrypted_message.session_id()
361 } else {
362 event.as_remote().and_then(|remote| remote.encryption_info.as_ref()?.session_id())
365 };
366
367 if let Some(session_id) = session_id {
368 should_retry(session_id)
370 } else {
371 false
373 }
374 };
375
376 items
377 .iter()
378 .enumerate()
379 .filter_map(|(idx, item)| {
380 item.as_event().filter(|e| should_retry_event(e)).map(|event| (idx, event))
381 })
382 .partition_map(
384 |(idx, event)| {
385 if event.content().is_unable_to_decrypt() { Left(idx) } else { Right(idx) }
386 },
387 )
388}
389
390pub(super) async fn retry_fetch_encryption_info<P: RoomDataProvider>(
393 state: &mut TimelineState<P>,
394 retry_indices: Vec<usize>,
395 room_data_provider: &P,
396) {
397 for idx in retry_indices {
398 let old_item = state.items.get(idx);
399 if let Some(new_item) = make_replacement_for(room_data_provider, old_item).await {
400 state.items.replace(idx, new_item);
401 }
402 }
403}
404
405async fn make_replacement_for<P: RoomDataProvider>(
409 room_data_provider: &P,
410 item: Option<&Arc<TimelineItem>>,
411) -> Option<Arc<TimelineItem>> {
412 let item = item?;
413 let event = item.as_event()?;
414 let remote = event.as_remote()?;
415 let session_id = remote.encryption_info.as_ref()?.session_id()?;
416
417 let new_encryption_info =
418 room_data_provider.get_encryption_info(session_id, &event.sender).await;
419 let mut new_remote = remote.clone();
420 new_remote.encryption_info = new_encryption_info;
421 let new_item = item.with_kind(TimelineItemKind::Event(
422 event.with_kind(EventTimelineItemKind::Remote(new_remote)),
423 ));
424
425 Some(new_item)
426}
427
428async fn decrypt_by_index<P: RoomDataProvider, D: Decryptor>(
431 state: &mut TimelineState<P>,
432 settings: &TimelineSettings,
433 room_data_provider: &P,
434 decryptor: D,
435 should_retry: impl Fn(&str) -> bool,
436 retry_indices: Vec<usize>,
437) {
438 let push_ctx = room_data_provider.push_context().await;
439 let push_ctx = push_ctx.as_ref();
440 let unable_to_decrypt_hook = state.meta.unable_to_decrypt_hook.clone();
441
442 let retry_one = |item: Arc<TimelineItem>| {
443 let decryptor = decryptor.clone();
444 let should_retry = &should_retry;
445 let unable_to_decrypt_hook = unable_to_decrypt_hook.clone();
446 async move {
447 let event_item = item.as_event()?;
448
449 let session_id = match event_item.content().as_unable_to_decrypt()? {
450 EncryptedMessage::MegolmV1AesSha2 { session_id, .. }
451 if should_retry(session_id) =>
452 {
453 session_id
454 }
455 EncryptedMessage::MegolmV1AesSha2 { .. }
456 | EncryptedMessage::OlmV1Curve25519AesSha2 { .. }
457 | EncryptedMessage::Unknown => return None,
458 };
459
460 tracing::Span::current().record("session_id", session_id);
461
462 let Some(remote_event) = event_item.as_remote() else {
463 error!("Key for unable-to-decrypt timeline item is not an event ID");
464 return None;
465 };
466
467 tracing::Span::current().record("event_id", field::debug(&remote_event.event_id));
468
469 let Some(original_json) = &remote_event.original_json else {
470 error!("UTD item must contain original JSON");
471 return None;
472 };
473
474 match decryptor.decrypt_event_impl(original_json, push_ctx).await {
475 Ok(event) => {
476 if let SdkTimelineEventKind::UnableToDecrypt { utd_info, .. } = event.kind {
477 info!(
478 "Failed to decrypt event after receiving room key: {:?}",
479 utd_info.reason
480 );
481 None
482 } else {
483 if let Some(hook) = unable_to_decrypt_hook {
485 hook.on_late_decrypt(&remote_event.event_id).await;
486 }
487
488 Some(event)
489 }
490 }
491 Err(e) => {
492 info!("Failed to decrypt event after receiving room key: {e}");
493 None
494 }
495 }
496 }
497 .instrument(info_span!(
498 "retry_one",
499 session_id = field::Empty,
500 event_id = field::Empty
501 ))
502 };
503
504 state.retry_event_decryption(retry_one, retry_indices, room_data_provider, settings).await;
505}
506
507#[cfg(test)]
508mod tests {
509 use std::{collections::BTreeMap, sync::Arc, time::SystemTime};
510
511 use imbl::vector;
512 use matrix_sdk::{
513 crypto::types::events::UtdCause,
514 deserialized_responses::{AlgorithmInfo, EncryptionInfo, VerificationState},
515 };
516 use ruma::{
517 MilliSecondsSinceUnixEpoch, OwnedTransactionId,
518 events::room::{
519 encrypted::{
520 EncryptedEventScheme, MegolmV1AesSha2Content, MegolmV1AesSha2ContentInit,
521 RoomEncryptedEventContent,
522 },
523 message::RoomMessageEventContent,
524 },
525 owned_device_id, owned_event_id, owned_user_id,
526 };
527
528 use crate::timeline::{
529 EncryptedMessage, EventSendState, EventTimelineItem, MsgLikeContent,
530 ReactionsByKeyBySender, TimelineDetails, TimelineItem, TimelineItemContent,
531 TimelineItemKind, TimelineUniqueId, VirtualTimelineItem,
532 controller::decryption_retry_task::compute_event_indices_to_retry_decryption,
533 event_item::{
534 EventTimelineItemKind, LocalEventTimelineItem, RemoteEventOrigin,
535 RemoteEventTimelineItem,
536 },
537 };
538
539 #[test]
540 fn test_non_events_are_not_retried() {
541 let timeline = vector![TimelineItem::read_marker(), date_divider()];
543 let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
545 assert!(answer.0.is_empty());
547 assert!(answer.1.is_empty());
548 }
549
550 #[test]
551 fn test_non_remote_events_are_not_retried() {
552 let timeline = vector![local_event()];
554 let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
556 assert!(answer.0.is_empty());
558 assert!(answer.1.is_empty());
559 }
560
561 #[test]
562 fn test_utds_are_retried() {
563 let timeline = vector![utd_event("session1")];
565 let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
567 assert_eq!(answer.0, vec![0]);
569 assert!(answer.1.is_empty());
570 }
571
572 #[test]
573 fn test_remote_decrypted_info_is_refetched() {
574 let timeline = vector![decrypted_event("session1")];
576 let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
578 assert!(answer.0.is_empty());
580 assert_eq!(answer.1, vec![0]);
581 }
582
583 #[test]
584 fn test_only_required_sessions_are_retried() {
585 fn retry(s: &str) -> bool {
588 s == "session1"
589 }
590
591 let timeline = vector![
594 TimelineItem::read_marker(),
595 utd_event("session1"),
596 utd_event("session1"),
597 date_divider(),
598 utd_event("session2"),
599 decrypted_event("session1"),
600 decrypted_event("session1"),
601 decrypted_event("session2"),
602 local_event(),
603 ];
604
605 let answer = compute_event_indices_to_retry_decryption(&timeline, retry);
607
608 assert_eq!(answer.0, vec![1, 2]);
610 assert_eq!(answer.1, vec![5, 6]);
611 }
612
613 fn always_retry(_: &str) -> bool {
614 true
615 }
616
617 fn date_divider() -> Arc<TimelineItem> {
618 TimelineItem::new(
619 TimelineItemKind::Virtual(VirtualTimelineItem::DateDivider(timestamp())),
620 TimelineUniqueId("datething".to_owned()),
621 )
622 }
623
624 fn local_event() -> Arc<TimelineItem> {
625 let event_kind = EventTimelineItemKind::Local(LocalEventTimelineItem {
626 send_state: EventSendState::NotSentYet { progress: None },
627 transaction_id: OwnedTransactionId::from("trans"),
628 send_handle: None,
629 });
630
631 TimelineItem::new(
632 TimelineItemKind::Event(EventTimelineItem::new(
633 owned_user_id!("@u:s.to"),
634 TimelineDetails::Pending,
635 timestamp(),
636 TimelineItemContent::MsgLike(MsgLikeContent::redacted()),
637 event_kind,
638 true,
639 )),
640 TimelineUniqueId("local".to_owned()),
641 )
642 }
643
644 fn utd_event(session_id: &str) -> Arc<TimelineItem> {
645 let event_kind = EventTimelineItemKind::Remote(RemoteEventTimelineItem {
646 event_id: owned_event_id!("$local"),
647 transaction_id: None,
648 read_receipts: Default::default(),
649 is_own: false,
650 is_highlighted: false,
651 encryption_info: None,
652 original_json: None,
653 latest_edit_json: None,
654 origin: RemoteEventOrigin::Sync,
655 });
656
657 TimelineItem::new(
658 TimelineItemKind::Event(EventTimelineItem::new(
659 owned_user_id!("@u:s.to"),
660 TimelineDetails::Pending,
661 timestamp(),
662 TimelineItemContent::MsgLike(MsgLikeContent::unable_to_decrypt(
663 EncryptedMessage::from_content(
664 RoomEncryptedEventContent::new(
665 EncryptedEventScheme::MegolmV1AesSha2(MegolmV1AesSha2Content::from(
666 MegolmV1AesSha2ContentInit {
667 ciphertext: "cyf".to_owned(),
668 sender_key: "sendk".to_owned(),
669 device_id: owned_device_id!("DEV"),
670 session_id: session_id.to_owned(),
671 },
672 )),
673 None,
674 ),
675 UtdCause::Unknown,
676 ),
677 )),
678 event_kind,
679 true,
680 )),
681 TimelineUniqueId("local".to_owned()),
682 )
683 }
684
685 fn decrypted_event(session_id: &str) -> Arc<TimelineItem> {
686 let event_kind = EventTimelineItemKind::Remote(RemoteEventTimelineItem {
687 event_id: owned_event_id!("$local"),
688 transaction_id: None,
689 read_receipts: Default::default(),
690 is_own: false,
691 is_highlighted: false,
692 encryption_info: Some(Arc::new(EncryptionInfo {
693 sender: owned_user_id!("@u:s.co"),
694 sender_device: None,
695 algorithm_info: AlgorithmInfo::MegolmV1AesSha2 {
696 curve25519_key: "".to_owned(),
697 sender_claimed_keys: BTreeMap::new(),
698 session_id: Some(session_id.to_owned()),
699 },
700 verification_state: VerificationState::Verified,
701 })),
702 original_json: None,
703 latest_edit_json: None,
704 origin: RemoteEventOrigin::Sync,
705 });
706
707 let content = RoomMessageEventContent::text_plain("hi");
708
709 TimelineItem::new(
710 TimelineItemKind::Event(EventTimelineItem::new(
711 owned_user_id!("@u:s.to"),
712 TimelineDetails::Pending,
713 timestamp(),
714 TimelineItemContent::message(
715 content.msgtype,
716 content.mentions,
717 ReactionsByKeyBySender::default(),
718 None,
719 None,
720 None,
721 ),
722 event_kind,
723 true,
724 )),
725 TimelineUniqueId("local".to_owned()),
726 )
727 }
728
729 fn timestamp() -> MilliSecondsSinceUnixEpoch {
730 MilliSecondsSinceUnixEpoch::from_system_time(SystemTime::UNIX_EPOCH).unwrap()
731 }
732}