matrix_sdk_ui/timeline/controller/
decryption_retry_task.rs

1// Copyright 2025 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{collections::BTreeSet, sync::Arc};
16
17use imbl::Vector;
18use itertools::{Either, Itertools as _};
19use matrix_sdk::{
20    deserialized_responses::TimelineEventKind as SdkTimelineEventKind, executor::JoinHandle,
21};
22use tokio::sync::{
23    mpsc::{self, Receiver, Sender},
24    RwLock,
25};
26use tracing::{debug, error, field, info, info_span, Instrument as _};
27
28use crate::timeline::{
29    controller::{TimelineSettings, TimelineState},
30    event_item::EventTimelineItemKind,
31    traits::{Decryptor, RoomDataProvider},
32    EncryptedMessage, EventTimelineItem, TimelineItem, TimelineItemKind,
33};
34
35/// Holds a long-running task that is used to retry decryption of items in the
36/// timeline when new information about a session is received.
37///
38/// Creating an instance with [`DecryptionRetryTask::new`] creates the async
39/// task, and a channel that is used to communicate with it.
40///
41/// The underlying async task will stop soon after the [`DecryptionRetryTask`]
42/// is dropped, because it waits for the channel to close, which happens when we
43/// drop the sending side.
44#[derive(Clone, Debug)]
45pub struct DecryptionRetryTask<D: Decryptor> {
46    /// The sending side of the channel that we have open to the long-running
47    /// async task. Every time we want to retry decrypting some events, we
48    /// send a [`DecryptionRetryRequest`] along this channel. Users of this
49    /// struct call [`DecryptionRetryTask::decrypt`] to do this.
50    sender: Sender<DecryptionRetryRequest<D>>,
51
52    /// The join handle of the task. We don't actually use this, since the task
53    /// will end soon after we are dropped, because when `sender` is dropped the
54    /// task will see that the channel closed, but we hold on to the handle to
55    /// indicate that we own the task.
56    _task_handle: Arc<JoinHandle<()>>,
57}
58
59/// How many concurrent retry requests we will queue before blocking when
60/// attempting to queue another. We don't normally expect more than one or two
61/// will be queued at a time, so blocking should be a rare occurrence.
62const CHANNEL_BUFFER_SIZE: usize = 100;
63
64impl<D: Decryptor> DecryptionRetryTask<D> {
65    pub(crate) fn new<P: RoomDataProvider>(
66        state: Arc<RwLock<TimelineState>>,
67        room_data_provider: P,
68    ) -> Self {
69        // We will send decryption requests down this channel to the long-running task
70        let (sender, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
71
72        // Spawn the long-running task, providing the receiver so we can listen for
73        // decryption requests
74        let handle =
75            matrix_sdk::executor::spawn(decryption_task(state, room_data_provider, receiver));
76
77        // Keep hold of the sender so we can send off decryption requests to the task.
78        Self { sender, _task_handle: Arc::new(handle) }
79    }
80
81    /// Use the supplied decryptor to attempt redecryption of the events
82    /// associated with the supplied session IDs.
83    pub(crate) async fn decrypt(
84        &self,
85        decryptor: D,
86        session_ids: Option<BTreeSet<String>>,
87        settings: TimelineSettings,
88    ) {
89        let res =
90            self.sender.send(DecryptionRetryRequest { decryptor, session_ids, settings }).await;
91
92        if let Err(error) = res {
93            error!("Failed to send decryption retry request: {error}");
94        }
95    }
96}
97
98/// The information sent across the channel to the long-running task requesting
99/// that the supplied set of sessions be retried.
100struct DecryptionRetryRequest<D: Decryptor> {
101    decryptor: D,
102    session_ids: Option<BTreeSet<String>>,
103    settings: TimelineSettings,
104}
105
106/// Long-running task that waits for decryption requests to come through the
107/// supplied channel `receiver` and act on them. Stops when the channel is
108/// closed, i.e. when the sender side is dropped.
109async fn decryption_task<D: Decryptor>(
110    state: Arc<RwLock<TimelineState>>,
111    room_data_provider: impl RoomDataProvider,
112    mut receiver: Receiver<DecryptionRetryRequest<D>>,
113) {
114    debug!("Decryption task starting.");
115
116    while let Some(request) = receiver.recv().await {
117        let should_retry = |session_id: &str| {
118            if let Some(session_ids) = &request.session_ids {
119                session_ids.contains(session_id)
120            } else {
121                true
122            }
123        };
124
125        // Find the indices of events that are in the supplied sessions, distinguishing
126        // between UTDs which we need to decrypt, and already-decrypted events where we
127        // only need to re-fetch encryption info.
128        let mut state = state.write().await;
129        let (retry_decryption_indices, retry_info_indices) =
130            compute_event_indices_to_retry_decryption(&state.items, should_retry);
131
132        // Retry fetching encryption info for events that are already decrypted
133        if !retry_info_indices.is_empty() {
134            debug!("Retrying fetching encryption info");
135            retry_fetch_encryption_info(&mut state, retry_info_indices, &room_data_provider).await;
136        }
137
138        // Retry decrypting any unable-to-decrypt messages
139        if !retry_decryption_indices.is_empty() {
140            debug!("Retrying decryption");
141            decrypt_by_index(
142                &mut state,
143                &request.settings,
144                &room_data_provider,
145                request.decryptor,
146                should_retry,
147                retry_decryption_indices,
148            )
149            .await
150        }
151    }
152
153    debug!("Decryption task stopping.");
154}
155
156/// Decide which events should be retried, either for re-decryption, or, if they
157/// are already decrypted, for re-checking their encryption info.
158///
159/// Returns a tuple `(retry_decryption_indices, retry_info_indices)` where
160/// `retry_decryption_indices` is a list of the indices of UTDs to try
161/// decrypting, and retry_info_indices is a list of the indices of
162/// already-decrypted events whose encryption info we can re-fetch.
163fn compute_event_indices_to_retry_decryption(
164    items: &Vector<Arc<TimelineItem>>,
165    should_retry: impl Fn(&str) -> bool,
166) -> (Vec<usize>, Vec<usize>) {
167    use Either::{Left, Right};
168
169    // We retry an event if its session ID should be retried
170    let should_retry_event = |event: &EventTimelineItem| {
171        let session_id = if let Some(encrypted_message) = event.content().as_unable_to_decrypt() {
172            // UTDs carry their session ID inside the content
173            encrypted_message.session_id()
174        } else {
175            // Non-UTDs only have a session ID if they are remote and have it in the
176            // EncryptionInfo
177            event.as_remote().and_then(|remote| remote.encryption_info.as_ref()?.session_id())
178        };
179
180        if let Some(session_id) = session_id {
181            // Should we retry this session ID?
182            should_retry(session_id)
183        } else {
184            // No session ID: don't retry this event
185            false
186        }
187    };
188
189    items
190        .iter()
191        .enumerate()
192        .filter_map(|(idx, item)| {
193            item.as_event().filter(|e| should_retry_event(e)).map(|event| (idx, event))
194        })
195        // Break the result into 2 lists: (utds, decrypted)
196        .partition_map(
197            |(idx, event)| {
198                if event.content().is_unable_to_decrypt() {
199                    Left(idx)
200                } else {
201                    Right(idx)
202                }
203            },
204        )
205}
206
207/// Try to fetch [`EncryptionInfo`] for the events with the supplied
208/// indices, and update them where we succeed.
209pub(super) async fn retry_fetch_encryption_info<P: RoomDataProvider>(
210    state: &mut TimelineState,
211    retry_indices: Vec<usize>,
212    room_data_provider: &P,
213) {
214    for idx in retry_indices {
215        let old_item = state.items.get(idx);
216        if let Some(new_item) = make_replacement_for(room_data_provider, old_item).await {
217            state.items.replace(idx, new_item);
218        }
219    }
220}
221
222/// Create a replacement TimelineItem for the supplied one, with new
223/// [`EncryptionInfo`] from the supplied `room_data_provider`. Returns None if
224/// the supplied item is not a remote event, or if it doesn't have a session ID.
225async fn make_replacement_for<P: RoomDataProvider>(
226    room_data_provider: &P,
227    item: Option<&Arc<TimelineItem>>,
228) -> Option<Arc<TimelineItem>> {
229    let item = item?;
230    let event = item.as_event()?;
231    let remote = event.as_remote()?;
232    let session_id = remote.encryption_info.as_ref()?.session_id()?;
233
234    let new_encryption_info =
235        room_data_provider.get_encryption_info(session_id, &event.sender).await;
236    let mut new_remote = remote.clone();
237    new_remote.encryption_info = new_encryption_info;
238    let new_item = item.with_kind(TimelineItemKind::Event(
239        event.with_kind(EventTimelineItemKind::Remote(new_remote)),
240    ));
241
242    Some(new_item)
243}
244
245/// Attempt decryption of the events encrypted with the session IDs in the
246/// supplied decryption `request`.
247async fn decrypt_by_index<D: Decryptor>(
248    state: &mut TimelineState,
249    settings: &TimelineSettings,
250    room_data_provider: &impl RoomDataProvider,
251    decryptor: D,
252    should_retry: impl Fn(&str) -> bool,
253    retry_indices: Vec<usize>,
254) {
255    let push_ctx = room_data_provider.push_context().await;
256    let push_ctx = push_ctx.as_ref();
257    let unable_to_decrypt_hook = state.meta.unable_to_decrypt_hook.clone();
258
259    let retry_one = |item: Arc<TimelineItem>| {
260        let decryptor = decryptor.clone();
261        let should_retry = &should_retry;
262        let unable_to_decrypt_hook = unable_to_decrypt_hook.clone();
263        async move {
264            let event_item = item.as_event()?;
265
266            let session_id = match event_item.content().as_unable_to_decrypt()? {
267                EncryptedMessage::MegolmV1AesSha2 { session_id, .. }
268                    if should_retry(session_id) =>
269                {
270                    session_id
271                }
272                EncryptedMessage::MegolmV1AesSha2 { .. }
273                | EncryptedMessage::OlmV1Curve25519AesSha2 { .. }
274                | EncryptedMessage::Unknown => return None,
275            };
276
277            tracing::Span::current().record("session_id", session_id);
278
279            let Some(remote_event) = event_item.as_remote() else {
280                error!("Key for unable-to-decrypt timeline item is not an event ID");
281                return None;
282            };
283
284            tracing::Span::current().record("event_id", field::debug(&remote_event.event_id));
285
286            let Some(original_json) = &remote_event.original_json else {
287                error!("UTD item must contain original JSON");
288                return None;
289            };
290
291            match decryptor.decrypt_event_impl(original_json, push_ctx).await {
292                Ok(event) => {
293                    if let SdkTimelineEventKind::UnableToDecrypt { utd_info, .. } = event.kind {
294                        info!(
295                            "Failed to decrypt event after receiving room key: {:?}",
296                            utd_info.reason
297                        );
298                        None
299                    } else {
300                        // Notify observers that we managed to eventually decrypt an event.
301                        if let Some(hook) = unable_to_decrypt_hook {
302                            hook.on_late_decrypt(&remote_event.event_id).await;
303                        }
304
305                        Some(event)
306                    }
307                }
308                Err(e) => {
309                    info!("Failed to decrypt event after receiving room key: {e}");
310                    None
311                }
312            }
313        }
314        .instrument(info_span!(
315            "retry_one",
316            session_id = field::Empty,
317            event_id = field::Empty
318        ))
319    };
320
321    state.retry_event_decryption(retry_one, retry_indices, room_data_provider, settings).await;
322}
323
324#[cfg(test)]
325mod tests {
326    use std::{collections::BTreeMap, sync::Arc, time::SystemTime};
327
328    use imbl::vector;
329    use matrix_sdk::{
330        crypto::types::events::UtdCause,
331        deserialized_responses::{AlgorithmInfo, EncryptionInfo, VerificationState},
332    };
333    use ruma::{
334        events::room::{
335            encrypted::{
336                EncryptedEventScheme, MegolmV1AesSha2Content, MegolmV1AesSha2ContentInit,
337                RoomEncryptedEventContent,
338            },
339            message::RoomMessageEventContent,
340        },
341        owned_device_id, owned_event_id, owned_user_id, MilliSecondsSinceUnixEpoch,
342        OwnedTransactionId,
343    };
344
345    use crate::timeline::{
346        controller::decryption_retry_task::compute_event_indices_to_retry_decryption,
347        event_item::{
348            EventTimelineItemKind, LocalEventTimelineItem, RemoteEventOrigin,
349            RemoteEventTimelineItem,
350        },
351        EncryptedMessage, EventSendState, EventTimelineItem, MsgLikeContent,
352        ReactionsByKeyBySender, TimelineDetails, TimelineItem, TimelineItemContent,
353        TimelineItemKind, TimelineUniqueId, VirtualTimelineItem,
354    };
355
356    #[test]
357    fn test_non_events_are_not_retried() {
358        // Given a timeline with only non-events
359        let timeline = vector![TimelineItem::read_marker(), date_divider()];
360        // When we ask what to retry
361        let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
362        // Then we retry nothing
363        assert!(answer.0.is_empty());
364        assert!(answer.1.is_empty());
365    }
366
367    #[test]
368    fn test_non_remote_events_are_not_retried() {
369        // Given a timeline with only local events
370        let timeline = vector![local_event()];
371        // When we ask what to retry
372        let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
373        // Then we retry nothing
374        assert!(answer.0.is_empty());
375        assert!(answer.1.is_empty());
376    }
377
378    #[test]
379    fn test_utds_are_retried() {
380        // Given a timeline with a UTD
381        let timeline = vector![utd_event("session1")];
382        // When we ask what to retry
383        let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
384        // Then we retry decrypting it, and don't refetch any encryption info
385        assert_eq!(answer.0, vec![0]);
386        assert!(answer.1.is_empty());
387    }
388
389    #[test]
390    fn test_remote_decrypted_info_is_refetched() {
391        // Given a timeline with a decrypted event
392        let timeline = vector![decrypted_event("session1")];
393        // When we ask what to retry
394        let answer = compute_event_indices_to_retry_decryption(&timeline, always_retry);
395        // Then we don't need to decrypt anything, but we do refetch the encryption info
396        assert!(answer.0.is_empty());
397        assert_eq!(answer.1, vec![0]);
398    }
399
400    #[test]
401    fn test_only_required_sessions_are_retried() {
402        // Given we want to retry everything in session1 only
403
404        fn retry(s: &str) -> bool {
405            s == "session1"
406        }
407
408        // And we have a timeline containing non-events, local events, UTDs and
409        // decrypted events
410        let timeline = vector![
411            TimelineItem::read_marker(),
412            utd_event("session1"),
413            utd_event("session1"),
414            date_divider(),
415            utd_event("session2"),
416            decrypted_event("session1"),
417            decrypted_event("session1"),
418            decrypted_event("session2"),
419            local_event(),
420        ];
421
422        // When we ask what to retry
423        let answer = compute_event_indices_to_retry_decryption(&timeline, retry);
424
425        // Then we re-decrypt the UTDs, and refetch the decrypted events' info
426        assert_eq!(answer.0, vec![1, 2]);
427        assert_eq!(answer.1, vec![5, 6]);
428    }
429
430    fn always_retry(_: &str) -> bool {
431        true
432    }
433
434    fn date_divider() -> Arc<TimelineItem> {
435        TimelineItem::new(
436            TimelineItemKind::Virtual(VirtualTimelineItem::DateDivider(timestamp())),
437            TimelineUniqueId("datething".to_owned()),
438        )
439    }
440
441    fn local_event() -> Arc<TimelineItem> {
442        let event_kind = EventTimelineItemKind::Local(LocalEventTimelineItem {
443            send_state: EventSendState::NotSentYet,
444            transaction_id: OwnedTransactionId::from("trans"),
445            send_handle: None,
446        });
447
448        TimelineItem::new(
449            TimelineItemKind::Event(EventTimelineItem::new(
450                owned_user_id!("@u:s.to"),
451                TimelineDetails::Pending,
452                timestamp(),
453                TimelineItemContent::MsgLike(MsgLikeContent::redacted()),
454                event_kind,
455                true,
456            )),
457            TimelineUniqueId("local".to_owned()),
458        )
459    }
460
461    fn utd_event(session_id: &str) -> Arc<TimelineItem> {
462        let event_kind = EventTimelineItemKind::Remote(RemoteEventTimelineItem {
463            event_id: owned_event_id!("$local"),
464            transaction_id: None,
465            read_receipts: Default::default(),
466            is_own: false,
467            is_highlighted: false,
468            encryption_info: None,
469            original_json: None,
470            latest_edit_json: None,
471            origin: RemoteEventOrigin::Sync,
472        });
473
474        TimelineItem::new(
475            TimelineItemKind::Event(EventTimelineItem::new(
476                owned_user_id!("@u:s.to"),
477                TimelineDetails::Pending,
478                timestamp(),
479                TimelineItemContent::MsgLike(MsgLikeContent::unable_to_decrypt(
480                    EncryptedMessage::from_content(
481                        RoomEncryptedEventContent::new(
482                            EncryptedEventScheme::MegolmV1AesSha2(MegolmV1AesSha2Content::from(
483                                MegolmV1AesSha2ContentInit {
484                                    ciphertext: "cyf".to_owned(),
485                                    sender_key: "sendk".to_owned(),
486                                    device_id: owned_device_id!("DEV"),
487                                    session_id: session_id.to_owned(),
488                                },
489                            )),
490                            None,
491                        ),
492                        UtdCause::Unknown,
493                    ),
494                )),
495                event_kind,
496                true,
497            )),
498            TimelineUniqueId("local".to_owned()),
499        )
500    }
501
502    fn decrypted_event(session_id: &str) -> Arc<TimelineItem> {
503        let event_kind = EventTimelineItemKind::Remote(RemoteEventTimelineItem {
504            event_id: owned_event_id!("$local"),
505            transaction_id: None,
506            read_receipts: Default::default(),
507            is_own: false,
508            is_highlighted: false,
509            encryption_info: Some(Arc::new(EncryptionInfo {
510                sender: owned_user_id!("@u:s.co"),
511                sender_device: None,
512                algorithm_info: AlgorithmInfo::MegolmV1AesSha2 {
513                    curve25519_key: "".to_owned(),
514                    sender_claimed_keys: BTreeMap::new(),
515                    session_id: Some(session_id.to_owned()),
516                },
517                verification_state: VerificationState::Verified,
518            })),
519            original_json: None,
520            latest_edit_json: None,
521            origin: RemoteEventOrigin::Sync,
522        });
523
524        let content = RoomMessageEventContent::text_plain("hi");
525
526        TimelineItem::new(
527            TimelineItemKind::Event(EventTimelineItem::new(
528                owned_user_id!("@u:s.to"),
529                TimelineDetails::Pending,
530                timestamp(),
531                TimelineItemContent::message(
532                    content.msgtype,
533                    content.mentions,
534                    ReactionsByKeyBySender::default(),
535                    None,
536                    None,
537                    None,
538                ),
539                event_kind,
540                true,
541            )),
542            TimelineUniqueId("local".to_owned()),
543        )
544    }
545
546    fn timestamp() -> MilliSecondsSinceUnixEpoch {
547        MilliSecondsSinceUnixEpoch::from_system_time(SystemTime::UNIX_EPOCH).unwrap()
548    }
549}