1use std::{collections::BTreeSet, pin::Pin, sync::Weak};
114
115use as_variant::as_variant;
116use futures_core::Stream;
117use futures_util::{StreamExt, pin_mut};
118#[cfg(doc)]
119use matrix_sdk_base::{BaseClient, crypto::OlmMachine};
120use matrix_sdk_base::{
121 crypto::{
122 store::types::{RoomKeyInfo, RoomKeyWithheldInfo},
123 types::events::room::encrypted::EncryptedEvent,
124 },
125 deserialized_responses::{DecryptedRoomEvent, TimelineEvent, TimelineEventKind},
126 event_cache::store::EventCacheStoreLockState,
127 locks::Mutex,
128 timer,
129};
130#[cfg(doc)]
131use matrix_sdk_common::deserialized_responses::EncryptionInfo;
132use matrix_sdk_common::executor::{AbortOnDrop, JoinHandleExt, spawn};
133use ruma::{
134 OwnedEventId, OwnedRoomId, RoomId,
135 events::{AnySyncTimelineEvent, room::encrypted::OriginalSyncRoomEncryptedEvent},
136 push::Action,
137 serde::Raw,
138};
139use tokio::sync::{
140 broadcast::{self, Sender},
141 mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
142};
143use tokio_stream::wrappers::{
144 BroadcastStream, UnboundedReceiverStream, errors::BroadcastStreamRecvError,
145};
146use tracing::{info, instrument, trace, warn};
147
148#[cfg(doc)]
149use super::RoomEventCache;
150use super::{EventCache, EventCacheError, EventCacheInner, EventsOrigin, RoomEventCacheUpdate};
151use crate::{Room, event_cache::RoomEventCacheLinkedChunkUpdate, room::PushContext};
152
153type SessionId<'a> = &'a str;
154type OwnedSessionId = String;
155
156type EventIdAndUtd = (OwnedEventId, Raw<AnySyncTimelineEvent>);
157type EventIdAndEvent = (OwnedEventId, DecryptedRoomEvent);
158type ResolvedUtd = (OwnedEventId, DecryptedRoomEvent, Option<Vec<Action>>);
159
160#[derive(Debug, Clone)]
163pub struct DecryptionRetryRequest {
164 pub room_id: OwnedRoomId,
166 pub utd_session_ids: BTreeSet<OwnedSessionId>,
168 pub refresh_info_session_ids: BTreeSet<OwnedSessionId>,
171}
172
173#[derive(Debug, Clone)]
175pub enum RedecryptorReport {
176 ResolvedUtds {
178 room_id: OwnedRoomId,
180 events: BTreeSet<OwnedEventId>,
182 },
183 Lagging,
186}
187
188pub(super) struct RedecryptorChannels {
189 utd_reporter: Sender<RedecryptorReport>,
190 pub(super) decryption_request_sender: UnboundedSender<DecryptionRetryRequest>,
191 pub(super) decryption_request_receiver:
192 Mutex<Option<UnboundedReceiver<DecryptionRetryRequest>>>,
193}
194
195impl RedecryptorChannels {
196 pub(super) fn new() -> Self {
197 let (utd_reporter, _) = broadcast::channel(100);
198 let (decryption_request_sender, decryption_request_receiver) = unbounded_channel();
199
200 Self {
201 utd_reporter,
202 decryption_request_sender,
203 decryption_request_receiver: Mutex::new(Some(decryption_request_receiver)),
204 }
205 }
206}
207
208fn filter_timeline_event_to_utd(
213 event: TimelineEvent,
214) -> Option<(OwnedEventId, Raw<AnySyncTimelineEvent>)> {
215 let event_id = event.event_id();
216
217 let event = as_variant!(event.kind, TimelineEventKind::UnableToDecrypt { event, .. } => event);
220 event_id.zip(event)
223}
224
225impl EventCache {
226 async fn get_utds(
234 &self,
235 room_id: &RoomId,
236 session_id: SessionId<'_>,
237 ) -> Result<Vec<EventIdAndUtd>, EventCacheError> {
238 let events = match self.inner.store.lock().await? {
239 EventCacheStoreLockState::Clean(guard) | EventCacheStoreLockState::Dirty(guard) => {
244 guard.get_room_events(room_id, Some("m.room.encrypted"), Some(session_id)).await?
245 }
246 };
247
248 Ok(events.into_iter().filter_map(filter_timeline_event_to_utd).collect())
249 }
250
251 async fn get_decrypted_events(
252 &self,
253 room_id: &RoomId,
254 session_id: SessionId<'_>,
255 ) -> Result<Vec<EventIdAndEvent>, EventCacheError> {
256 let filter = |event: TimelineEvent| {
257 let event_id = event.event_id();
258
259 let event = as_variant!(event.kind, TimelineEventKind::Decrypted(event) => event);
260 event_id.zip(event)
263 };
264
265 let events = match self.inner.store.lock().await? {
266 EventCacheStoreLockState::Clean(guard) | EventCacheStoreLockState::Dirty(guard) => {
271 guard.get_room_events(room_id, None, Some(session_id)).await?
272 }
273 };
274
275 Ok(events.into_iter().filter_map(filter).collect())
276 }
277
278 #[instrument(skip_all, fields(room_id))]
290 async fn on_resolved_utds(
291 &self,
292 room_id: &RoomId,
293 events: Vec<ResolvedUtd>,
294 ) -> Result<(), EventCacheError> {
295 if events.is_empty() {
296 trace!("No events were redecrypted or updated, nothing to replace");
297 return Ok(());
298 }
299
300 timer!("Resolving UTDs");
301
302 let (room_cache, _drop_handles) = self.for_room(room_id).await?;
305 let mut state = room_cache.inner.state.write().await?;
306
307 let event_ids: BTreeSet<_> =
308 events.iter().cloned().map(|(event_id, _, _)| event_id).collect();
309 let mut new_events = Vec::with_capacity(events.len());
310
311 for (event_id, decrypted, actions) in events {
312 if let Some((location, mut target_event)) = state.find_event(&event_id).await? {
316 target_event.kind = TimelineEventKind::Decrypted(decrypted);
317
318 if let Some(actions) = actions {
319 target_event.set_push_actions(actions);
320 }
321
322 state.replace_event_at(location, target_event.clone()).await?;
325 new_events.push(target_event);
326 }
327 }
328
329 state.post_process_new_events(new_events, false).await?;
330
331 let diffs = state.room_linked_chunk().updates_as_vector_diffs();
334
335 let _ = room_cache.inner.update_sender.send(RoomEventCacheUpdate::UpdateTimelineEvents {
336 diffs,
337 origin: EventsOrigin::Cache,
338 });
339
340 let report =
345 RedecryptorReport::ResolvedUtds { room_id: room_id.to_owned(), events: event_ids };
346 let _ = self.inner.redecryption_channels.utd_reporter.send(report);
347
348 Ok(())
349 }
350
351 async fn decrypt_event(
353 &self,
354 room_id: &RoomId,
355 room: Option<&Room>,
356 push_context: Option<&PushContext>,
357 event: &Raw<EncryptedEvent>,
358 ) -> Option<(DecryptedRoomEvent, Option<Vec<Action>>)> {
359 if let Some(room) = room {
360 match room
361 .decrypt_event(
362 event.cast_ref_unchecked::<OriginalSyncRoomEncryptedEvent>(),
363 push_context,
364 )
365 .await
366 {
367 Ok(maybe_decrypted) => {
368 let actions = maybe_decrypted.push_actions().map(|a| a.to_vec());
369
370 if let TimelineEventKind::Decrypted(decrypted) = maybe_decrypted.kind {
371 Some((decrypted, actions))
372 } else {
373 warn!(
374 "Failed to redecrypt an event despite receiving a room key or request to redecrypt"
375 );
376 None
377 }
378 }
379 Err(e) => {
380 warn!(
381 "Failed to redecrypt an event despite receiving a room key or request to redecrypt {e:?}"
382 );
383 None
384 }
385 }
386 } else {
387 let client = self.inner.client().ok()?;
388 let machine = client.olm_machine().await;
389 let machine = machine.as_ref()?;
390
391 match machine.decrypt_room_event(event, room_id, client.decryption_settings()).await {
392 Ok(decrypted) => Some((decrypted, None)),
393 Err(e) => {
394 warn!(
395 "Failed to redecrypt an event despite receiving a room key or a request to redecrypt {e:?}"
396 );
397 None
398 }
399 }
400 }
401 }
402
403 #[instrument(skip_all, fields(room_id, session_id))]
406 async fn retry_decryption(
407 &self,
408 room_id: &RoomId,
409 session_id: SessionId<'_>,
410 ) -> Result<(), EventCacheError> {
411 let events = self.get_utds(room_id, session_id).await?;
413 self.retry_decryption_for_events(room_id, events).await
414 }
415
416 #[instrument(skip_all, fields(updates.linked_chunk_id))]
418 async fn retry_decryption_for_event_cache_updates(
419 &self,
420 updates: RoomEventCacheLinkedChunkUpdate,
421 ) -> Result<(), EventCacheError> {
422 let room_id = updates.linked_chunk_id.room_id();
423 let events: Vec<_> = updates
424 .updates
425 .into_iter()
426 .flat_map(|updates| updates.into_items())
427 .filter_map(filter_timeline_event_to_utd)
428 .collect();
429
430 self.retry_decryption_for_events(room_id, events).await
431 }
432
433 #[instrument(skip_all, fields(room_id, session_id))]
435 async fn retry_decryption_for_events(
436 &self,
437 room_id: &RoomId,
438 events: Vec<EventIdAndUtd>,
439 ) -> Result<(), EventCacheError> {
440 trace!("Retrying to decrypt");
441
442 if events.is_empty() {
443 trace!("No relevant events found.");
444 return Ok(());
445 }
446
447 let room = self.inner.client().ok().and_then(|client| client.get_room(room_id));
448 let push_context =
449 if let Some(room) = &room { room.push_context().await.ok().flatten() } else { None };
450
451 let mut decrypted_events = Vec::with_capacity(events.len());
453
454 for (event_id, event) in events {
455 if let Some((decrypted, actions)) = self
458 .decrypt_event(
459 room_id,
460 room.as_ref(),
461 push_context.as_ref(),
462 event.cast_ref_unchecked(),
463 )
464 .await
465 {
466 decrypted_events.push((event_id, decrypted, actions));
467 }
468 }
469
470 let event_ids: BTreeSet<_> =
471 decrypted_events.iter().map(|(event_id, _, _)| event_id).collect();
472
473 if !event_ids.is_empty() {
474 trace!(?event_ids, "Successfully redecrypted events");
475 }
476
477 self.on_resolved_utds(room_id, decrypted_events).await?;
480
481 Ok(())
482 }
483
484 #[instrument(skip_all, fields(room_id, session_id))]
485 async fn update_encryption_info(
486 &self,
487 room_id: &RoomId,
488 session_id: SessionId<'_>,
489 ) -> Result<(), EventCacheError> {
490 trace!("Updating encryption info");
491
492 let Ok(client) = self.inner.client() else {
493 return Ok(());
494 };
495
496 let Some(room) = client.get_room(room_id) else {
497 return Ok(());
498 };
499
500 let events = self.get_decrypted_events(room_id, session_id).await?;
502
503 if events.is_empty() {
504 trace!("No relevant events found.");
505 return Ok(());
506 }
507
508 let mut updated_events = Vec::with_capacity(events.len());
510
511 for (event_id, mut event) in events {
512 let new_encryption_info =
513 room.get_encryption_info(session_id, &event.encryption_info.sender).await;
514
515 if let Some(new_encryption_info) = new_encryption_info
517 && event.encryption_info != new_encryption_info
518 {
519 event.encryption_info = new_encryption_info;
520 updated_events.push((event_id, event, None));
521 }
522 }
523
524 let event_ids: BTreeSet<_> =
525 updated_events.iter().map(|(event_id, _, _)| event_id).collect();
526
527 if !event_ids.is_empty() {
528 trace!(?event_ids, "Replacing the encryption info of some events");
529 }
530
531 self.on_resolved_utds(room_id, updated_events).await?;
532
533 Ok(())
534 }
535
536 pub fn request_decryption(&self, request: DecryptionRetryRequest) {
577 let _ =
578 self.inner.redecryption_channels.decryption_request_sender.send(request).inspect_err(
579 |_| warn!("Requesting a decryption while the redecryption task has been shut down"),
580 );
581 }
582
583 pub fn subscribe_to_decryption_reports(
626 &self,
627 ) -> impl Stream<Item = Result<RedecryptorReport, BroadcastStreamRecvError>> {
628 BroadcastStream::new(self.inner.redecryption_channels.utd_reporter.subscribe())
629 }
630}
631
632pub(crate) struct Redecryptor {
639 _task: AbortOnDrop<()>,
640}
641
642impl Redecryptor {
643 pub(super) fn new(
648 cache: Weak<EventCacheInner>,
649 receiver: UnboundedReceiver<DecryptionRetryRequest>,
650 linked_chunk_update_sender: &Sender<RoomEventCacheLinkedChunkUpdate>,
651 ) -> Self {
652 let linked_chunk_stream = BroadcastStream::new(linked_chunk_update_sender.subscribe());
653
654 let task = spawn(async {
655 let request_redecryption_stream = UnboundedReceiverStream::new(receiver);
656
657 Self::listen_for_room_keys_task(
658 cache,
659 request_redecryption_stream,
660 linked_chunk_stream,
661 )
662 .await;
663 })
664 .abort_on_drop();
665
666 Self { _task: task }
667 }
668
669 async fn subscribe_to_room_key_stream(
674 cache: &Weak<EventCacheInner>,
675 ) -> Option<(
676 impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>>,
677 impl Stream<Item = Vec<RoomKeyWithheldInfo>>,
678 )> {
679 let event_cache = cache.upgrade()?;
680 let client = event_cache.client().ok()?;
681 let machine = client.olm_machine().await;
682
683 machine.as_ref().map(|m| {
684 (m.store().room_keys_received_stream(), m.store().room_keys_withheld_received_stream())
685 })
686 }
687
688 #[inline(always)]
689 fn upgrade_event_cache(cache: &Weak<EventCacheInner>) -> Option<EventCache> {
690 cache.upgrade().map(|inner| EventCache { inner })
691 }
692
693 async fn redecryption_loop(
694 cache: &Weak<EventCacheInner>,
695 decryption_request_stream: &mut Pin<&mut impl Stream<Item = DecryptionRetryRequest>>,
696 events_stream: &mut Pin<
697 &mut impl Stream<Item = Result<RoomEventCacheLinkedChunkUpdate, BroadcastStreamRecvError>>,
698 >,
699 ) -> bool {
700 let Some((room_key_stream, withheld_stream)) =
701 Self::subscribe_to_room_key_stream(cache).await
702 else {
703 return false;
704 };
705
706 pin_mut!(room_key_stream);
707 pin_mut!(withheld_stream);
708
709 loop {
710 tokio::select! {
711 Some(request) = decryption_request_stream.next() => {
714 let Some(cache) = Self::upgrade_event_cache(cache) else {
715 break false;
716 };
717
718 trace!(?request, "Received a redecryption request");
719
720 for session_id in request.utd_session_ids {
721 let _ = cache
722 .retry_decryption(&request.room_id, &session_id)
723 .await
724 .inspect_err(|e| warn!("Error redecrypting after an explicit request was received {e:?}"));
725 }
726
727 for session_id in request.refresh_info_session_ids {
728 let _ = cache.update_encryption_info(&request.room_id, &session_id).await.inspect_err(|e|
729 warn!(
730 room_id = %request.room_id,
731 session_id = session_id,
732 "Unable to update the encryption info {e:?}",
733 ));
734 }
735 }
736 room_keys = room_key_stream.next() => {
739 match room_keys {
740 Some(Ok(room_keys)) => {
741 let Some(cache) = Self::upgrade_event_cache(cache) else {
745 break false;
746 };
747
748 trace!(?room_keys, "Received new room keys");
749
750 for key in &room_keys {
751 let _ = cache
752 .retry_decryption(&key.room_id, &key.session_id)
753 .await
754 .inspect_err(|e| warn!("Error redecrypting {e:?}"));
755 }
756
757 for key in room_keys {
758 let _ = cache.update_encryption_info(&key.room_id, &key.session_id).await.inspect_err(|e|
759 warn!(
760 room_id = %key.room_id,
761 session_id = key.session_id,
762 "Unable to update the encryption info {e:?}",
763 ));
764 }
765 },
766 Some(Err(_)) => {
767 let Some(cache) = Self::upgrade_event_cache(cache) else {
774 break false;
775 };
776
777 let message = RedecryptorReport::Lagging;
778 let _ = cache.inner.redecryption_channels.utd_reporter.send(message);
779 },
780 None => {
783 break true
784 }
785 }
786 }
787 withheld_info = withheld_stream.next() => {
788 match withheld_info {
789 Some(infos) => {
790 let Some(cache) = Self::upgrade_event_cache(cache) else {
791 break false;
792 };
793
794 trace!(?infos, "Received new withheld infos");
795
796 for RoomKeyWithheldInfo { room_id, session_id, .. } in &infos {
797 let _ = cache.update_encryption_info(room_id, session_id).await.inspect_err(|e|
798 warn!(
799 room_id = %room_id,
800 session_id = session_id,
801 "Unable to update the encryption info {e:?}",
802 ));
803 }
804 }
805 None => break true
808 }
809 }
810 Some(event_updates) = events_stream.next() => {
814 match event_updates {
815 Ok(updates) => {
816 let Some(cache) = Self::upgrade_event_cache(cache) else {
817 break false;
818 };
819
820 let linked_chunk_id = updates.linked_chunk_id.to_owned();
821
822 let _ = cache.retry_decryption_for_event_cache_updates(updates).await.inspect_err(|e|
823 warn!(
824 %linked_chunk_id,
825 "Unable to handle UTDs from event cache updates {e:?}",
826 )
827 );
828 }
829 Err(_) => {
830 let Some(cache) = Self::upgrade_event_cache(cache) else {
831 break false;
832 };
833
834 let message = RedecryptorReport::Lagging;
835 let _ = cache.inner.redecryption_channels.utd_reporter.send(message);
836 }
837 }
838 }
839 else => break false,
840 }
841 }
842 }
843
844 async fn listen_for_room_keys_task(
845 cache: Weak<EventCacheInner>,
846 decryption_request_stream: UnboundedReceiverStream<DecryptionRetryRequest>,
847 events_stream: BroadcastStream<RoomEventCacheLinkedChunkUpdate>,
848 ) {
849 pin_mut!(decryption_request_stream);
853 pin_mut!(events_stream);
854
855 while Self::redecryption_loop(&cache, &mut decryption_request_stream, &mut events_stream)
856 .await
857 {
858 info!("Regenerating the re-decryption streams");
859
860 let Some(cache) = Self::upgrade_event_cache(&cache) else {
861 break;
862 };
863
864 let message = RedecryptorReport::Lagging;
867 let _ = cache.inner.redecryption_channels.utd_reporter.send(message);
868 }
869
870 info!("Shutting down the event cache redecryptor");
871 }
872}
873
874#[cfg(not(target_family = "wasm"))]
875#[cfg(test)]
876mod tests {
877 use std::{
878 collections::BTreeSet,
879 sync::{
880 Arc,
881 atomic::{AtomicBool, Ordering},
882 },
883 time::Duration,
884 };
885
886 use assert_matches2::assert_matches;
887 use async_trait::async_trait;
888 use eyeball_im::VectorDiff;
889 use matrix_sdk_base::{
890 cross_process_lock::CrossProcessLockGeneration,
891 crypto::types::events::{ToDeviceEvent, room::encrypted::ToDeviceEncryptedEventContent},
892 deserialized_responses::{TimelineEventKind, VerificationState},
893 event_cache::{
894 Event, Gap,
895 store::{EventCacheStore, EventCacheStoreError, MemoryStore},
896 },
897 linked_chunk::{
898 ChunkIdentifier, ChunkIdentifierGenerator, ChunkMetadata, LinkedChunkId, Position,
899 RawChunk, Update,
900 },
901 locks::Mutex,
902 sleep::sleep,
903 store::StoreConfig,
904 };
905 use matrix_sdk_test::{
906 JoinedRoomBuilder, StateTestEvent, async_test, event_factory::EventFactory,
907 };
908 use ruma::{
909 EventId, OwnedEventId, RoomId, device_id, event_id,
910 events::{AnySyncTimelineEvent, relation::RelationType},
911 room_id,
912 serde::Raw,
913 user_id,
914 };
915 use serde_json::json;
916 use tokio::sync::oneshot::{self, Sender};
917 use tracing::{Instrument, info};
918
919 use crate::{
920 Client, assert_let_timeout,
921 encryption::EncryptionSettings,
922 event_cache::{DecryptionRetryRequest, RoomEventCacheUpdate},
923 test_utils::mocks::MatrixMockServer,
924 };
925
926 #[derive(Debug, Clone)]
931 struct DelayingStore {
932 memory_store: MemoryStore,
933 delaying: Arc<AtomicBool>,
934 foo: Arc<Mutex<Option<Sender<()>>>>,
935 }
936
937 impl DelayingStore {
938 fn new() -> Self {
939 Self {
940 memory_store: MemoryStore::new(),
941 delaying: AtomicBool::new(true).into(),
942 foo: Arc::new(Mutex::new(None)),
943 }
944 }
945
946 async fn stop_delaying(&self) {
947 let (sender, receiver) = oneshot::channel();
948
949 {
950 *self.foo.lock() = Some(sender);
951 }
952
953 self.delaying.store(false, Ordering::SeqCst);
954
955 receiver.await.expect("We should be able to receive a response")
956 }
957 }
958
959 #[cfg_attr(target_family = "wasm", async_trait(?Send))]
960 #[cfg_attr(not(target_family = "wasm"), async_trait)]
961 impl EventCacheStore for DelayingStore {
962 type Error = EventCacheStoreError;
963
964 async fn try_take_leased_lock(
965 &self,
966 lease_duration_ms: u32,
967 key: &str,
968 holder: &str,
969 ) -> Result<Option<CrossProcessLockGeneration>, Self::Error> {
970 self.memory_store.try_take_leased_lock(lease_duration_ms, key, holder).await
971 }
972
973 async fn handle_linked_chunk_updates(
974 &self,
975 linked_chunk_id: LinkedChunkId<'_>,
976 updates: Vec<Update<Event, Gap>>,
977 ) -> Result<(), Self::Error> {
978 while self.delaying.load(Ordering::SeqCst) {
984 sleep(Duration::from_millis(10)).await;
985 }
986
987 let sender = self.foo.lock().take();
988 let ret = self.memory_store.handle_linked_chunk_updates(linked_chunk_id, updates).await;
989
990 if let Some(sender) = sender {
991 sender.send(()).expect("We should be able to notify the other side that we're done with the storage operation");
992 }
993
994 ret
995 }
996
997 async fn load_all_chunks(
998 &self,
999 linked_chunk_id: LinkedChunkId<'_>,
1000 ) -> Result<Vec<RawChunk<Event, Gap>>, Self::Error> {
1001 self.memory_store.load_all_chunks(linked_chunk_id).await
1002 }
1003
1004 async fn load_all_chunks_metadata(
1005 &self,
1006 linked_chunk_id: LinkedChunkId<'_>,
1007 ) -> Result<Vec<ChunkMetadata>, Self::Error> {
1008 self.memory_store.load_all_chunks_metadata(linked_chunk_id).await
1009 }
1010
1011 async fn load_last_chunk(
1012 &self,
1013 linked_chunk_id: LinkedChunkId<'_>,
1014 ) -> Result<(Option<RawChunk<Event, Gap>>, ChunkIdentifierGenerator), Self::Error> {
1015 self.memory_store.load_last_chunk(linked_chunk_id).await
1016 }
1017
1018 async fn load_previous_chunk(
1019 &self,
1020 linked_chunk_id: LinkedChunkId<'_>,
1021 before_chunk_identifier: ChunkIdentifier,
1022 ) -> Result<Option<RawChunk<Event, Gap>>, Self::Error> {
1023 self.memory_store.load_previous_chunk(linked_chunk_id, before_chunk_identifier).await
1024 }
1025
1026 async fn clear_all_linked_chunks(&self) -> Result<(), Self::Error> {
1027 self.memory_store.clear_all_linked_chunks().await
1028 }
1029
1030 async fn filter_duplicated_events(
1031 &self,
1032 linked_chunk_id: LinkedChunkId<'_>,
1033 events: Vec<OwnedEventId>,
1034 ) -> Result<Vec<(OwnedEventId, Position)>, Self::Error> {
1035 self.memory_store.filter_duplicated_events(linked_chunk_id, events).await
1036 }
1037
1038 async fn find_event(
1039 &self,
1040 room_id: &RoomId,
1041 event_id: &EventId,
1042 ) -> Result<Option<Event>, Self::Error> {
1043 self.memory_store.find_event(room_id, event_id).await
1044 }
1045
1046 async fn find_event_relations(
1047 &self,
1048 room_id: &RoomId,
1049 event_id: &EventId,
1050 filters: Option<&[RelationType]>,
1051 ) -> Result<Vec<(Event, Option<Position>)>, Self::Error> {
1052 self.memory_store.find_event_relations(room_id, event_id, filters).await
1053 }
1054
1055 async fn get_room_events(
1056 &self,
1057 room_id: &RoomId,
1058 event_type: Option<&str>,
1059 session_id: Option<&str>,
1060 ) -> Result<Vec<Event>, Self::Error> {
1061 self.memory_store.get_room_events(room_id, event_type, session_id).await
1062 }
1063
1064 async fn save_event(&self, room_id: &RoomId, event: Event) -> Result<(), Self::Error> {
1065 self.memory_store.save_event(room_id, event).await
1066 }
1067
1068 async fn optimize(&self) -> Result<(), Self::Error> {
1069 self.memory_store.optimize().await
1070 }
1071
1072 async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
1073 self.memory_store.get_size().await
1074 }
1075 }
1076
1077 async fn set_up_clients(
1078 room_id: &RoomId,
1079 alice_enables_cross_signing: bool,
1080 use_delayed_store: bool,
1081 ) -> (Client, Client, MatrixMockServer, Option<DelayingStore>) {
1082 let alice_span = tracing::info_span!("alice");
1083 let bob_span = tracing::info_span!("bob");
1084
1085 let alice_user_id = user_id!("@alice:localhost");
1086 let alice_device_id = device_id!("ALICEDEVICE");
1087 let bob_user_id = user_id!("@bob:localhost");
1088 let bob_device_id = device_id!("BOBDEVICE");
1089
1090 let matrix_mock_server = MatrixMockServer::new().await;
1091 matrix_mock_server.mock_crypto_endpoints_preset().await;
1092
1093 let encryption_settings = EncryptionSettings {
1094 auto_enable_cross_signing: alice_enables_cross_signing,
1095 ..Default::default()
1096 };
1097
1098 let alice = matrix_mock_server
1101 .client_builder_for_crypto_end_to_end(alice_user_id, alice_device_id)
1102 .on_builder(|builder| {
1103 builder
1104 .with_enable_share_history_on_invite(true)
1105 .with_encryption_settings(encryption_settings)
1106 })
1107 .build()
1108 .instrument(alice_span.clone())
1109 .await;
1110
1111 let encryption_settings =
1112 EncryptionSettings { auto_enable_cross_signing: true, ..Default::default() };
1113
1114 let (store_config, store) = if use_delayed_store {
1115 let store = DelayingStore::new();
1116
1117 (
1118 StoreConfig::new("delayed_store_event_cache_test".into())
1119 .event_cache_store(store.clone()),
1120 Some(store),
1121 )
1122 } else {
1123 (StoreConfig::new("normal_store_event_cache_test".into()), None)
1124 };
1125
1126 let bob = matrix_mock_server
1127 .client_builder_for_crypto_end_to_end(bob_user_id, bob_device_id)
1128 .on_builder(|builder| {
1129 builder
1130 .with_enable_share_history_on_invite(true)
1131 .with_encryption_settings(encryption_settings)
1132 .store_config(store_config)
1133 })
1134 .build()
1135 .instrument(bob_span.clone())
1136 .await;
1137
1138 bob.event_cache().subscribe().expect("Bob should be able to enable the event cache");
1139
1140 matrix_mock_server.exchange_e2ee_identities(&alice, &bob).await;
1142
1143 let room_builder = JoinedRoomBuilder::new(room_id)
1145 .add_state_event(StateTestEvent::Create)
1146 .add_state_event(StateTestEvent::Encryption);
1147
1148 matrix_mock_server
1149 .mock_sync()
1150 .ok_and_run(&alice, |builder| {
1151 builder.add_joined_room(room_builder.clone());
1152 })
1153 .instrument(alice_span)
1154 .await;
1155
1156 matrix_mock_server
1157 .mock_sync()
1158 .ok_and_run(&bob, |builder| {
1159 builder.add_joined_room(room_builder);
1160 })
1161 .instrument(bob_span)
1162 .await;
1163
1164 (alice, bob, matrix_mock_server, store)
1165 }
1166
1167 async fn prepare_room(
1168 matrix_mock_server: &MatrixMockServer,
1169 event_factory: &EventFactory,
1170 alice: &Client,
1171 bob: &Client,
1172 room_id: &RoomId,
1173 ) -> (Raw<AnySyncTimelineEvent>, Raw<ToDeviceEvent<ToDeviceEncryptedEventContent>>) {
1174 let alice_user_id = alice.user_id().unwrap();
1175 let bob_user_id = bob.user_id().unwrap();
1176
1177 let alice_member_event = event_factory.member(alice_user_id).into_raw();
1178 let bob_member_event = event_factory.member(bob_user_id).into_raw();
1179
1180 let room = alice
1181 .get_room(room_id)
1182 .expect("Alice should have access to the room now that we synced");
1183
1184 let event_type = "m.room.message";
1189 let content = json!({"body": "It's a secret to everybody", "msgtype": "m.text"});
1190
1191 let event_id = event_id!("$some_id");
1192 let (event_receiver, mock) =
1193 matrix_mock_server.mock_room_send().ok_with_capture(event_id, alice_user_id);
1194 let (_guard, room_key) = matrix_mock_server.mock_capture_put_to_device(alice_user_id).await;
1195
1196 {
1197 let _guard = mock.mock_once().mount_as_scoped().await;
1198
1199 matrix_mock_server
1200 .mock_get_members()
1201 .ok(vec![alice_member_event.clone(), bob_member_event.clone()])
1202 .mock_once()
1203 .mount()
1204 .await;
1205
1206 room.send_raw(event_type, content)
1207 .await
1208 .expect("We should be able to send an initial message");
1209 };
1210
1211 let event = event_receiver.await.expect("Alice should have sent the event by now");
1213 let room_key = room_key.await;
1214
1215 (event, room_key)
1216 }
1217
1218 #[async_test]
1219 async fn test_redecryptor() {
1220 let room_id = room_id!("!test:localhost");
1221
1222 let event_factory = EventFactory::new().room(room_id);
1223 let (alice, bob, matrix_mock_server, _) = set_up_clients(room_id, true, false).await;
1224
1225 let (event, room_key) =
1226 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1227
1228 let (room_cache, _) = bob
1231 .event_cache()
1232 .for_room(room_id)
1233 .await
1234 .expect("We should be able to get to the event cache for a specific room");
1235
1236 let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
1237
1238 bob.inner
1241 .base_client
1242 .regenerate_olm(None)
1243 .await
1244 .expect("We should be able to regenerate the Olm machine");
1245
1246 matrix_mock_server
1248 .mock_sync()
1249 .ok_and_run(&bob, |builder| {
1250 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1251 })
1252 .await;
1253
1254 assert_let_timeout!(
1257 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1258 );
1259
1260 assert_eq!(diffs.len(), 1);
1263 assert_matches!(&diffs[0], VectorDiff::Append { values });
1264 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1265
1266 matrix_mock_server
1268 .mock_sync()
1269 .ok_and_run(&bob, |builder| {
1270 builder.add_to_device_event(
1271 room_key
1272 .deserialize_as()
1273 .expect("We should be able to deserialize the room key"),
1274 );
1275 })
1276 .await;
1277
1278 assert_let_timeout!(
1280 Duration::from_secs(1),
1281 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1282 );
1283
1284 assert_eq!(diffs.len(), 1);
1286 assert_matches!(&diffs[0], VectorDiff::Set { index, value });
1287 assert_eq!(*index, 0);
1288 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1289 }
1290
1291 #[async_test]
1292 async fn test_redecryptor_updating_encryption_info() {
1293 let bob_span = tracing::info_span!("bob");
1294
1295 let room_id = room_id!("!test:localhost");
1296
1297 let event_factory = EventFactory::new().room(room_id);
1298 let (alice, bob, matrix_mock_server, _) = set_up_clients(room_id, false, false).await;
1299
1300 let (event, room_key) =
1301 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1302
1303 let (room_cache, _) = bob
1306 .event_cache()
1307 .for_room(room_id)
1308 .instrument(bob_span.clone())
1309 .await
1310 .expect("We should be able to get to the event cache for a specific room");
1311
1312 let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
1313
1314 matrix_mock_server
1316 .mock_sync()
1317 .ok_and_run(&bob, |builder| {
1318 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1319 })
1320 .instrument(bob_span.clone())
1321 .await;
1322
1323 assert_let_timeout!(
1326 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1327 );
1328
1329 assert_eq!(diffs.len(), 1);
1332 assert_matches!(&diffs[0], VectorDiff::Append { values });
1333 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1334
1335 matrix_mock_server
1337 .mock_sync()
1338 .ok_and_run(&bob, |builder| {
1339 builder.add_to_device_event(
1340 room_key
1341 .deserialize_as()
1342 .expect("We should be able to deserialize the room key"),
1343 );
1344 })
1345 .instrument(bob_span.clone())
1346 .await;
1347
1348 assert_let_timeout!(
1350 Duration::from_secs(1),
1351 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1352 );
1353
1354 assert_eq!(diffs.len(), 1);
1356 assert_matches!(&diffs[0], VectorDiff::Set { index: 0, value });
1357 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1358
1359 let encryption_info = value.encryption_info().unwrap();
1360 assert_matches!(&encryption_info.verification_state, VerificationState::Unverified(_));
1361 let session_id = encryption_info.session_id().unwrap().to_owned();
1362
1363 let alice_user_id = alice.user_id().unwrap();
1364
1365 alice
1367 .encryption()
1368 .bootstrap_cross_signing(None)
1369 .await
1370 .expect("Alice should be able to create the cross-signing keys");
1371
1372 bob.update_tracked_users_for_testing([alice_user_id]).instrument(bob_span.clone()).await;
1373 matrix_mock_server
1374 .mock_sync()
1375 .ok_and_run(&bob, |builder| {
1376 builder.add_change_device(alice_user_id);
1377 })
1378 .instrument(bob_span.clone())
1379 .await;
1380
1381 bob.event_cache().request_decryption(DecryptionRetryRequest {
1382 room_id: room_id.into(),
1383 utd_session_ids: BTreeSet::new(),
1384 refresh_info_session_ids: BTreeSet::from([session_id]),
1385 });
1386
1387 assert_let_timeout!(
1390 Duration::from_secs(1),
1391 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1392 );
1393
1394 assert_eq!(diffs.len(), 1);
1395 assert_matches!(&diffs[0], VectorDiff::Set { index: 0, value });
1396 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1397 let encryption_info = value.encryption_info().unwrap();
1398
1399 assert_matches!(
1400 &encryption_info.verification_state,
1401 VerificationState::Unverified(_),
1402 "The event should now know about the identity but still be unverified"
1403 );
1404 }
1405
1406 #[async_test]
1407 async fn test_event_is_redecrypted_even_if_key_arrives_while_event_processing() {
1408 let room_id = room_id!("!test:localhost");
1409
1410 let event_factory = EventFactory::new().room(room_id);
1411 let (alice, bob, matrix_mock_server, delayed_store) =
1412 set_up_clients(room_id, true, true).await;
1413
1414 let delayed_store = delayed_store.unwrap();
1415
1416 let (event, room_key) =
1417 prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
1418
1419 let (room_cache, _) = bob
1421 .event_cache()
1422 .for_room(room_id)
1423 .await
1424 .expect("We should be able to get to the event cache for a specific room");
1425
1426 let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
1427
1428 matrix_mock_server
1430 .mock_sync()
1431 .ok_and_run(&bob, |builder| {
1432 builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
1433 })
1434 .await;
1435
1436 matrix_mock_server
1438 .mock_sync()
1439 .ok_and_run(&bob, |builder| {
1440 builder.add_to_device_event(
1441 room_key
1442 .deserialize_as()
1443 .expect("We should be able to deserialize the room key"),
1444 );
1445 })
1446 .await;
1447
1448 info!("Stopping the delay");
1449 delayed_store.stop_delaying().await;
1450
1451 assert_let_timeout!(
1458 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1459 );
1460
1461 assert_eq!(diffs.len(), 1);
1464 assert_matches!(&diffs[0], VectorDiff::Append { values });
1465 assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
1466
1467 assert_let_timeout!(
1469 Duration::from_secs(1),
1470 Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
1471 );
1472
1473 assert_eq!(diffs.len(), 1);
1475 assert_matches!(&diffs[0], VectorDiff::Set { index, value });
1476 assert_eq!(*index, 0);
1477 assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
1478 }
1479}