1use std::{collections::BTreeSet, sync::Arc};
16
17use imbl::Vector;
18use itertools::{Either, Itertools as _};
19use matrix_sdk::{
20 deserialized_responses::TimelineEventKind as SdkTimelineEventKind, executor::JoinHandle,
21};
22use tokio::sync::{
23 mpsc::{self, Receiver, Sender},
24 RwLock,
25};
26use tracing::{debug, error, field, info, info_span, Instrument as _};
27
28use crate::timeline::{
29 controller::{TimelineSettings, TimelineState},
30 event_item::EventTimelineItemKind,
31 traits::{Decryptor, RoomDataProvider},
32 EncryptedMessage, EventTimelineItem, TimelineItem, TimelineItemKind,
33};
34
35#[derive(Clone, Debug)]
45pub struct DecryptionRetryTask<D: Decryptor> {
46 sender: Sender<DecryptionRetryRequest<D>>,
51
52 _task_handle: Arc<JoinHandle<()>>,
57}
58
59const CHANNEL_BUFFER_SIZE: usize = 100;
63
64impl<D: Decryptor> DecryptionRetryTask<D> {
65 pub(crate) fn new<P: RoomDataProvider>(
66 state: Arc<RwLock<TimelineState>>,
67 room_data_provider: P,
68 ) -> Self {
69 let (sender, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
71
72 let handle =
75 matrix_sdk::executor::spawn(decryption_task(state, room_data_provider, receiver));
76
77 Self { sender, _task_handle: Arc::new(handle) }
79 }
80
81 pub(crate) async fn decrypt(
84 &self,
85 decryptor: D,
86 session_ids: Option<BTreeSet<String>>,
87 settings: TimelineSettings,
88 ) {
89 let res =
90 self.sender.send(DecryptionRetryRequest { decryptor, session_ids, settings }).await;
91
92 if let Err(error) = res {
93 error!("Failed to send decryption retry request: {error}");
94 }
95 }
96}
97
98struct DecryptionRetryRequest<D: Decryptor> {
101 decryptor: D,
102 session_ids: Option<BTreeSet<String>>,
103 settings: TimelineSettings,
104}
105
106async fn decryption_task<D: Decryptor>(
110 state: Arc<RwLock<TimelineState>>,
111 room_data_provider: impl RoomDataProvider,
112 mut receiver: Receiver<DecryptionRetryRequest<D>>,
113) {
114 debug!("Decryption task starting.");
115
116 while let Some(request) = receiver.recv().await {
117 let should_retry = |session_id: &str| {
118 if let Some(session_ids) = &request.session_ids {
119 session_ids.contains(session_id)
120 } else {
121 true
122 }
123 };
124
125 let mut state = state.write().await;
129 let (retry_decryption_indices, retry_info_indices) =
130 compute_event_indices_to_retry_decryption(&state.items, should_retry);
131
132 if !retry_info_indices.is_empty() {
134 debug!("Retrying fetching encryption info");
135 retry_fetch_encryption_info(&mut state, retry_info_indices, &room_data_provider).await;
136 }
137
138 if !retry_decryption_indices.is_empty() {
140 debug!("Retrying decryption");
141 decrypt_by_index(
142 &mut state,
143 &request.settings,
144 &room_data_provider,
145 request.decryptor,
146 should_retry,
147 retry_decryption_indices,
148 )
149 .await
150 }
151 }
152
153 debug!("Decryption task stopping.");
154}
155
156fn compute_event_indices_to_retry_decryption(
164 items: &Vector<Arc<TimelineItem>>,
165 should_retry: impl Fn(&str) -> bool,
166) -> (Vec<usize>, Vec<usize>) {
167 use Either::{Left, Right};
168
169 let should_retry_event = |event: &EventTimelineItem| {
171 let session_id = if let Some(encrypted_message) = event.content().as_unable_to_decrypt() {
172 encrypted_message.session_id()
174 } else {
175 event.as_remote().and_then(|remote| remote.encryption_info.as_ref()?.session_id())
178 };
179
180 if let Some(session_id) = session_id {
181 should_retry(session_id)
183 } else {
184 false
186 }
187 };
188
189 items
190 .iter()
191 .enumerate()
192 .filter_map(|(idx, item)| {
193 item.as_event().filter(|e| should_retry_event(e)).map(|event| (idx, event))
194 })
195 .partition_map(
197 |(idx, event)| {
198 if event.content().is_unable_to_decrypt() {
199 Left(idx)
200 } else {
201 Right(idx)
202 }
203 },
204 )
205}
206
207pub(super) async fn retry_fetch_encryption_info<P: RoomDataProvider>(
210 state: &mut TimelineState,
211 retry_indices: Vec<usize>,
212 room_data_provider: &P,
213) {
214 for idx in retry_indices {
215 let old_item = state.items.get(idx);
216 if let Some(new_item) = make_replacement_for(room_data_provider, old_item).await {
217 state.items.replace(idx, new_item);
218 }
219 }
220}
221
222async fn make_replacement_for<P: RoomDataProvider>(
226 room_data_provider: &P,
227 item: Option<&Arc<TimelineItem>>,
228) -> Option<Arc<TimelineItem>> {
229 let item = item?;
230 let event = item.as_event()?;
231 let remote = event.as_remote()?;
232 let session_id = remote.encryption_info.as_ref()?.session_id()?;
233
234 let new_encryption_info =
235 room_data_provider.get_encryption_info(session_id, &event.sender).await;
236 let mut new_remote = remote.clone();
237 new_remote.encryption_info = new_encryption_info;
238 let new_item = item.with_kind(TimelineItemKind::Event(
239 event.with_kind(EventTimelineItemKind::Remote(new_remote)),
240 ));
241
242 Some(new_item)
243}
244
245async fn decrypt_by_index<D: Decryptor>(
248 state: &mut TimelineState,
249 settings: &TimelineSettings,
250 room_data_provider: &impl RoomDataProvider,
251 decryptor: D,
252 should_retry: impl Fn(&str) -> bool,
253 retry_indices: Vec<usize>,
254) {
255 let push_ctx = room_data_provider.push_context().await;
256 let push_ctx = push_ctx.as_ref();
257 let unable_to_decrypt_hook = state.meta.unable_to_decrypt_hook.clone();
258
259 let retry_one = |item: Arc<TimelineItem>| {
260 let decryptor = decryptor.clone();
261 let should_retry = &should_retry;
262 let unable_to_decrypt_hook = unable_to_decrypt_hook.clone();
263 async move {
264 let event_item = item.as_event()?;
265
266 let session_id = match event_item.content().as_unable_to_decrypt()? {
267 EncryptedMessage::MegolmV1AesSha2 { session_id, .. }
268 if should_retry(session_id) =>
269 {
270 session_id
271 }
272 EncryptedMessage::MegolmV1AesSha2 { .. }
273 | EncryptedMessage::OlmV1Curve25519AesSha2 { .. }
274 | EncryptedMessage::Unknown => return None,
275 };
276
277 tracing::Span::current().record("session_id", session_id);
278
279 let Some(remote_event) = event_item.as_remote() else {
280 error!("Key for unable-to-decrypt timeline item is not an event ID");
281 return None;
282 };
283
284 tracing::Span::current().record("event_id", field::debug(&remote_event.event_id));
285
286 let Some(original_json) = &remote_event.original_json else {
287 error!("UTD item must contain original JSON");
288 return None;
289 };
290
291 match decryptor.decrypt_event_impl(original_json, push_ctx).await {
292 Ok(event) => {
293 if let SdkTimelineEventKind::UnableToDecrypt { utd_info, .. } = event.kind {
294 info!(
295 "Failed to decrypt event after receiving room key: {:?}",
296 utd_info.reason
297 );
298 None
299 } else {
300 if let Some(hook) = unable_to_decrypt_hook {
302 hook.on_late_decrypt(&remote_event.event_id).await;
303 }
304
305 Some(event)
306 }
307 }
308 Err(e) => {
309 info!("Failed to decrypt event after receiving room key: {e}");
310 None
311 }
312 }
313 }
314 .instrument(info_span!(
315 "retry_one",
316 session_id = field::Empty,
317 event_id = field::Empty
318 ))
319 };
320
321 state.retry_event_decryption(retry_one, retry_indices, room_data_provider, settings).await;
322}
323
324#[cfg(test)]
325mod tests {
326 use std::{collections::BTreeMap, sync::Arc, time::SystemTime};
327
328 use imbl::vector;
329 use matrix_sdk::{
330 crypto::types::events::UtdCause,
331 deserialized_responses::{AlgorithmInfo, EncryptionInfo, VerificationState},
332 };
333 use ruma::{
334 events::room::{
335 encrypted::{
336 EncryptedEventScheme, MegolmV1AesSha2Content, MegolmV1AesSha2ContentInit,
337 RoomEncryptedEventContent,
338 },
339 message::RoomMessageEventContent,
340 },
341 owned_device_id, owned_event_id, owned_user_id, MilliSecondsSinceUnixEpoch,
342 OwnedTransactionId,
343 };
344
345 use crate::timeline::{
346 controller::decryption_retry_task::compute_event_indices_to_retry_decryption,
347 event_item::{
348 EventTimelineItemKind, LocalEventTimelineItem, RemoteEventOrigin,
349 RemoteEventTimelineItem,
350 },
351 EncryptedMessage, EventSendState, EventTimelineItem, MsgLikeContent,
352 ReactionsByKeyBySender, TimelineDetails, TimelineItem, TimelineItemContent,
353 TimelineItemKind, TimelineUniqueId, VirtualTimelineItem,
354 };
355
356 #[test]
357 fn test_non_events_are_not_retried() {
358 let timeline = vector![TimelineItem::read_marker(), date_divider()];
360 let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
362 assert!(answer.0.is_empty());
364 assert!(answer.1.is_empty());
365 }
366
367 #[test]
368 fn test_non_remote_events_are_not_retried() {
369 let timeline = vector![local_event()];
371 let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
373 assert!(answer.0.is_empty());
375 assert!(answer.1.is_empty());
376 }
377
378 #[test]
379 fn test_utds_are_retried() {
380 let timeline = vector![utd_event("session1")];
382 let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
384 assert_eq!(answer.0, vec![0]);
386 assert!(answer.1.is_empty());
387 }
388
389 #[test]
390 fn test_remote_decrypted_info_is_refetched() {
391 let timeline = vector![decrypted_event("session1")];
393 let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
395 assert!(answer.0.is_empty());
397 assert_eq!(answer.1, vec![0]);
398 }
399
400 #[test]
401 fn test_only_required_sessions_are_retried() {
402 fn retry(s: &str) -> bool {
405 s == "session1"
406 }
407
408 let timeline = vector![
411 TimelineItem::read_marker(),
412 utd_event("session1"),
413 utd_event("session1"),
414 date_divider(),
415 utd_event("session2"),
416 decrypted_event("session1"),
417 decrypted_event("session1"),
418 decrypted_event("session2"),
419 local_event(),
420 ];
421
422 let answer = compute_event_indices_to_retry_decryption(&timeline, retry);
424
425 assert_eq!(answer.0, vec![1, 2]);
427 assert_eq!(answer.1, vec![5, 6]);
428 }
429
430 fn always_retry(_: &str) -> bool {
431 true
432 }
433
434 fn date_divider() -> Arc<TimelineItem> {
435 TimelineItem::new(
436 TimelineItemKind::Virtual(VirtualTimelineItem::DateDivider(timestamp())),
437 TimelineUniqueId("datething".to_owned()),
438 )
439 }
440
441 fn local_event() -> Arc<TimelineItem> {
442 let event_kind = EventTimelineItemKind::Local(LocalEventTimelineItem {
443 send_state: EventSendState::NotSentYet,
444 transaction_id: OwnedTransactionId::from("trans"),
445 send_handle: None,
446 });
447
448 TimelineItem::new(
449 TimelineItemKind::Event(EventTimelineItem::new(
450 owned_user_id!("@u:s.to"),
451 TimelineDetails::Pending,
452 timestamp(),
453 TimelineItemContent::MsgLike(MsgLikeContent::redacted()),
454 event_kind,
455 true,
456 )),
457 TimelineUniqueId("local".to_owned()),
458 )
459 }
460
461 fn utd_event(session_id: &str) -> Arc<TimelineItem> {
462 let event_kind = EventTimelineItemKind::Remote(RemoteEventTimelineItem {
463 event_id: owned_event_id!("$local"),
464 transaction_id: None,
465 read_receipts: Default::default(),
466 is_own: false,
467 is_highlighted: false,
468 encryption_info: None,
469 original_json: None,
470 latest_edit_json: None,
471 origin: RemoteEventOrigin::Sync,
472 });
473
474 TimelineItem::new(
475 TimelineItemKind::Event(EventTimelineItem::new(
476 owned_user_id!("@u:s.to"),
477 TimelineDetails::Pending,
478 timestamp(),
479 TimelineItemContent::MsgLike(MsgLikeContent::unable_to_decrypt(
480 EncryptedMessage::from_content(
481 RoomEncryptedEventContent::new(
482 EncryptedEventScheme::MegolmV1AesSha2(MegolmV1AesSha2Content::from(
483 MegolmV1AesSha2ContentInit {
484 ciphertext: "cyf".to_owned(),
485 sender_key: "sendk".to_owned(),
486 device_id: owned_device_id!("DEV"),
487 session_id: session_id.to_owned(),
488 },
489 )),
490 None,
491 ),
492 UtdCause::Unknown,
493 ),
494 )),
495 event_kind,
496 true,
497 )),
498 TimelineUniqueId("local".to_owned()),
499 )
500 }
501
502 fn decrypted_event(session_id: &str) -> Arc<TimelineItem> {
503 let event_kind = EventTimelineItemKind::Remote(RemoteEventTimelineItem {
504 event_id: owned_event_id!("$local"),
505 transaction_id: None,
506 read_receipts: Default::default(),
507 is_own: false,
508 is_highlighted: false,
509 encryption_info: Some(Arc::new(EncryptionInfo {
510 sender: owned_user_id!("@u:s.co"),
511 sender_device: None,
512 algorithm_info: AlgorithmInfo::MegolmV1AesSha2 {
513 curve25519_key: "".to_owned(),
514 sender_claimed_keys: BTreeMap::new(),
515 session_id: Some(session_id.to_owned()),
516 },
517 verification_state: VerificationState::Verified,
518 })),
519 original_json: None,
520 latest_edit_json: None,
521 origin: RemoteEventOrigin::Sync,
522 });
523
524 let content = RoomMessageEventContent::text_plain("hi");
525
526 TimelineItem::new(
527 TimelineItemKind::Event(EventTimelineItem::new(
528 owned_user_id!("@u:s.to"),
529 TimelineDetails::Pending,
530 timestamp(),
531 TimelineItemContent::message(
532 content.msgtype,
533 content.mentions,
534 ReactionsByKeyBySender::default(),
535 None,
536 None,
537 None,
538 ),
539 event_kind,
540 true,
541 )),
542 TimelineUniqueId("local".to_owned()),
543 )
544 }
545
546 fn timestamp() -> MilliSecondsSinceUnixEpoch {
547 MilliSecondsSinceUnixEpoch::from_system_time(SystemTime::UNIX_EPOCH).unwrap()
548 }
549}