use std::{
collections::{BTreeMap, BTreeSet},
pin::Pin,
sync::Weak,
};
use as_variant::as_variant;
use futures_core::Stream;
use futures_util::{StreamExt, future::join_all, pin_mut};
#[cfg(doc)]
use matrix_sdk_base::{BaseClient, crypto::OlmMachine};
use matrix_sdk_base::{
crypto::{
store::types::{RoomKeyInfo, RoomKeyWithheldInfo},
types::events::room::encrypted::EncryptedEvent,
},
deserialized_responses::{DecryptedRoomEvent, TimelineEvent, TimelineEventKind},
event_cache::store::EventCacheStoreLockState,
locks::Mutex,
task_monitor::BackgroundTaskHandle,
timer,
};
#[cfg(doc)]
use matrix_sdk_common::deserialized_responses::EncryptionInfo;
use ruma::{
OwnedEventId, OwnedRoomId, RoomId,
events::{AnySyncTimelineEvent, room::encrypted::OriginalSyncRoomEncryptedEvent},
push::Action,
serde::Raw,
};
use tokio::sync::{
broadcast::{self, Sender},
mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
};
use tokio_stream::wrappers::{
BroadcastStream, UnboundedReceiverStream, errors::BroadcastStreamRecvError,
};
use tracing::{info, instrument, trace, warn};
#[cfg(doc)]
use super::RoomEventCache;
use super::{
EventCache, EventCacheError, EventCacheInner, EventsOrigin, RoomEventCacheGenericUpdate,
RoomEventCacheUpdate, TimelineVectorDiffs,
caches::room::{PostProcessingOrigin, RoomEventCacheLinkedChunkUpdate},
};
use crate::{Client, Result, Room, encryption::backups::BackupState, room::PushContext};
type SessionId<'a> = &'a str;
type OwnedSessionId = String;
type EventIdAndUtd = (OwnedEventId, Raw<AnySyncTimelineEvent>);
type EventIdAndEvent = (OwnedEventId, DecryptedRoomEvent);
pub(in crate::event_cache) type ResolvedUtd =
(OwnedEventId, DecryptedRoomEvent, Option<Vec<Action>>);
#[derive(Debug, Clone)]
pub struct DecryptionRetryRequest {
pub room_id: OwnedRoomId,
pub utd_session_ids: BTreeSet<OwnedSessionId>,
pub refresh_info_session_ids: BTreeSet<OwnedSessionId>,
}
#[derive(Debug, Clone)]
pub enum RedecryptorReport {
ResolvedUtds {
room_id: OwnedRoomId,
events: BTreeSet<OwnedEventId>,
},
Lagging,
BackupAvailable,
}
pub(super) struct RedecryptorChannels {
utd_reporter: Sender<RedecryptorReport>,
pub(super) decryption_request_sender: UnboundedSender<DecryptionRetryRequest>,
pub(super) decryption_request_receiver:
Mutex<Option<UnboundedReceiver<DecryptionRetryRequest>>>,
}
impl RedecryptorChannels {
pub(super) fn new() -> Self {
let (utd_reporter, _) = broadcast::channel(100);
let (decryption_request_sender, decryption_request_receiver) = unbounded_channel();
Self {
utd_reporter,
decryption_request_sender,
decryption_request_receiver: Mutex::new(Some(decryption_request_receiver)),
}
}
}
fn filter_timeline_event_to_utd(
event: TimelineEvent,
) -> Option<(OwnedEventId, Raw<AnySyncTimelineEvent>)> {
let event_id = event.event_id();
let event = as_variant!(event.kind, TimelineEventKind::UnableToDecrypt { event, .. } => event);
event_id.zip(event)
}
fn filter_timeline_event_to_decrypted(
event: TimelineEvent,
) -> Option<(OwnedEventId, DecryptedRoomEvent)> {
let event_id = event.event_id();
let event = as_variant!(event.kind, TimelineEventKind::Decrypted(event) => event);
event_id.zip(event)
}
impl EventCache {
async fn get_utds(
&self,
room_id: &RoomId,
session_id: SessionId<'_>,
) -> Result<Vec<EventIdAndUtd>, EventCacheError> {
let events = match self.inner.store.lock().await? {
EventCacheStoreLockState::Clean(guard) | EventCacheStoreLockState::Dirty(guard) => {
guard.get_room_events(room_id, Some("m.room.encrypted"), Some(session_id)).await?
}
};
Ok(events.into_iter().filter_map(filter_timeline_event_to_utd).collect())
}
async fn get_utds_from_memory(&self) -> BTreeMap<OwnedRoomId, Vec<EventIdAndUtd>> {
let mut utds = BTreeMap::new();
for (room_id, caches) in self.inner.by_room.read().await.iter() {
let room_utds: Vec<_> = caches
.all_events()
.await
.into_iter()
.flatten()
.filter_map(filter_timeline_event_to_utd)
.collect();
utds.insert(room_id.to_owned(), room_utds);
}
utds
}
async fn get_decrypted_events(
&self,
room_id: &RoomId,
session_id: SessionId<'_>,
) -> Result<Vec<EventIdAndEvent>, EventCacheError> {
let events = match self.inner.store.lock().await? {
EventCacheStoreLockState::Clean(guard) | EventCacheStoreLockState::Dirty(guard) => {
guard.get_room_events(room_id, None, Some(session_id)).await?
}
};
Ok(events.into_iter().filter_map(filter_timeline_event_to_decrypted).collect())
}
async fn get_decrypted_events_from_memory(
&self,
) -> BTreeMap<OwnedRoomId, Vec<EventIdAndEvent>> {
let mut decrypted_events = BTreeMap::new();
for (room_id, caches) in self.inner.by_room.read().await.iter() {
let room_utds: Vec<_> = caches
.all_events()
.await
.into_iter()
.flatten()
.filter_map(filter_timeline_event_to_decrypted)
.collect();
decrypted_events.insert(room_id.to_owned(), room_utds);
}
decrypted_events
}
#[instrument(skip_all, fields(room_id))]
async fn on_resolved_utds(
&self,
room_id: &RoomId,
events: Vec<ResolvedUtd>,
) -> Result<(), EventCacheError> {
if events.is_empty() {
trace!("No events were redecrypted or updated, nothing to replace");
return Ok(());
}
timer!("Resolving UTDs");
let (room_cache, _drop_handles) = self.for_room(room_id).await?;
let event_ids: BTreeSet<_> =
events.iter().cloned().map(|(event_id, _, _)| event_id).collect();
let (pinned_cache, ef_caches) = {
let mut state = room_cache.state().write().await?;
let pinned_cache = state.pinned_event_cache().cloned();
let ef_caches: Vec<_> = state.event_focused_caches().cloned().collect();
let mut new_events = Vec::with_capacity(events.len());
for (event_id, decrypted, actions) in &events {
if let Some((location, mut target_event)) = state.find_event(event_id).await? {
target_event.kind = TimelineEventKind::Decrypted(decrypted.clone());
if let Some(actions) = actions {
target_event.set_push_actions(actions.clone());
}
state.replace_event_at(location, target_event.clone()).await?;
new_events.push(target_event);
}
}
let receipt_event = None;
state
.post_process_new_events(
new_events,
PostProcessingOrigin::Redecryption,
receipt_event,
)
.await?;
room_cache.update_sender().send(
RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs {
diffs: state.room_linked_chunk_mut().updates_as_vector_diffs(),
origin: EventsOrigin::Cache,
}),
Some(RoomEventCacheGenericUpdate { room_id: room_id.to_owned() }),
);
(pinned_cache, ef_caches)
};
if let Some(pinned_cache) = pinned_cache {
pinned_cache.replace_utds(&events).await?;
}
join_all(ef_caches.iter().map(|cache| cache.replace_utds(&events))).await;
let report =
RedecryptorReport::ResolvedUtds { room_id: room_id.to_owned(), events: event_ids };
let _ = self.inner.redecryption_channels.utd_reporter.send(report);
Ok(())
}
async fn decrypt_event(
&self,
room_id: &RoomId,
room: Option<&Room>,
push_context: Option<&PushContext>,
event: &Raw<EncryptedEvent>,
) -> Option<(DecryptedRoomEvent, Option<Vec<Action>>)> {
if let Some(room) = room {
match room
.decrypt_event(
event.cast_ref_unchecked::<OriginalSyncRoomEncryptedEvent>(),
push_context,
)
.await
{
Ok(maybe_decrypted) => {
let actions = maybe_decrypted.push_actions().map(|a| a.to_vec());
if let TimelineEventKind::Decrypted(decrypted) = maybe_decrypted.kind {
Some((decrypted, actions))
} else {
warn!(
"Failed to redecrypt an event despite receiving a room key or request to redecrypt"
);
None
}
}
Err(e) => {
warn!(
"Failed to redecrypt an event despite receiving a room key or request to redecrypt {e:?}"
);
None
}
}
} else {
let client = self.inner.client().ok()?;
let machine = client.olm_machine().await;
let machine = machine.as_ref()?;
match machine.decrypt_room_event(event, room_id, client.decryption_settings()).await {
Ok(decrypted) => Some((decrypted, None)),
Err(e) => {
warn!(
"Failed to redecrypt an event despite receiving a room key or a request to redecrypt {e:?}"
);
None
}
}
}
}
#[instrument(skip_all, fields(room_id, session_id))]
async fn retry_decryption(
&self,
room_id: &RoomId,
session_id: SessionId<'_>,
) -> Result<(), EventCacheError> {
let events = self.get_utds(room_id, session_id).await?;
self.retry_decryption_for_events(room_id, events).await
}
#[instrument(skip_all, fields(updates.linked_chunk_id))]
async fn retry_decryption_for_event_cache_updates(
&self,
updates: RoomEventCacheLinkedChunkUpdate,
) -> Result<(), EventCacheError> {
let room_id = updates.linked_chunk_id.room_id();
let events: Vec<_> = updates
.updates
.into_iter()
.flat_map(|updates| updates.into_items())
.filter_map(filter_timeline_event_to_utd)
.collect();
self.retry_decryption_for_events(room_id, events).await
}
async fn retry_decryption_for_in_memory_events(&self) {
let utds = self.get_utds_from_memory().await;
for (room_id, utds) in utds.into_iter() {
if let Err(e) = self.retry_decryption_for_events(&room_id, utds).await {
warn!(%room_id, "Failed to redecrypt in-memory events {e:?}");
}
}
}
#[instrument(skip_all, fields(room_id, session_id))]
async fn retry_decryption_for_events(
&self,
room_id: &RoomId,
events: Vec<EventIdAndUtd>,
) -> Result<(), EventCacheError> {
trace!("Retrying to decrypt");
if events.is_empty() {
trace!("No relevant events found.");
return Ok(());
}
let room = self.inner.client().ok().and_then(|client| client.get_room(room_id));
let push_context =
if let Some(room) = &room { room.push_context().await.ok().flatten() } else { None };
let mut decrypted_events = Vec::with_capacity(events.len());
for (event_id, event) in events {
if let Some((decrypted, actions)) = self
.decrypt_event(
room_id,
room.as_ref(),
push_context.as_ref(),
event.cast_ref_unchecked(),
)
.await
{
decrypted_events.push((event_id, decrypted, actions));
}
}
let event_ids: BTreeSet<_> =
decrypted_events.iter().map(|(event_id, _, _)| event_id).collect();
if !event_ids.is_empty() {
trace!(?event_ids, "Successfully redecrypted events");
}
self.on_resolved_utds(room_id, decrypted_events).await?;
Ok(())
}
async fn update_encryption_info_for_events(
&self,
room: &Room,
events: Vec<EventIdAndEvent>,
) -> Result<(), EventCacheError> {
let mut updated_events = Vec::with_capacity(events.len());
for (event_id, mut event) in events {
if let Some(session_id) = event.encryption_info.session_id() {
let new_encryption_info =
room.get_encryption_info(session_id, &event.encryption_info.sender).await;
if let Some(new_encryption_info) = new_encryption_info
&& event.encryption_info != new_encryption_info
{
event.encryption_info = new_encryption_info;
updated_events.push((event_id, event, None));
}
}
}
let event_ids: BTreeSet<_> =
updated_events.iter().map(|(event_id, _, _)| event_id).collect();
if !event_ids.is_empty() {
trace!(?event_ids, "Replacing the encryption info of some events");
}
self.on_resolved_utds(room.room_id(), updated_events).await
}
#[instrument(skip_all, fields(room_id, session_id))]
async fn update_encryption_info(
&self,
room_id: &RoomId,
session_id: SessionId<'_>,
) -> Result<(), EventCacheError> {
trace!("Updating encryption info");
let Ok(client) = self.inner.client() else {
return Ok(());
};
let Some(room) = client.get_room(room_id) else {
return Ok(());
};
let events = self.get_decrypted_events(room_id, session_id).await?;
if events.is_empty() {
trace!("No relevant events found.");
return Ok(());
}
self.update_encryption_info_for_events(&room, events).await
}
async fn retry_update_encryption_info_for_in_memory_events(&self) {
let decrypted_events = self.get_decrypted_events_from_memory().await;
for (room_id, events) in decrypted_events.into_iter() {
let Some(room) = self.inner.client().ok().and_then(|c| c.get_room(&room_id)) else {
continue;
};
if let Err(e) = self.update_encryption_info_for_events(&room, events).await {
warn!(
%room_id,
"Failed to replace the encryption info for in-memory events {e:?}"
);
}
}
}
async fn retry_in_memory_events(&self) {
self.retry_decryption_for_in_memory_events().await;
self.retry_update_encryption_info_for_in_memory_events().await;
}
pub fn request_decryption(&self, request: DecryptionRetryRequest) {
let _ =
self.inner.redecryption_channels.decryption_request_sender.send(request).inspect_err(
|_| warn!("Requesting a decryption while the redecryption task has been shut down"),
);
}
pub fn subscribe_to_decryption_reports(
&self,
) -> impl Stream<Item = Result<RedecryptorReport, BroadcastStreamRecvError>> {
BroadcastStream::new(self.inner.redecryption_channels.utd_reporter.subscribe())
}
}
#[inline(always)]
fn upgrade_event_cache(cache: &Weak<EventCacheInner>) -> Option<EventCache> {
cache.upgrade().map(|inner| EventCache { inner })
}
async fn send_report_and_retry_memory_events(
cache: &Weak<EventCacheInner>,
report: RedecryptorReport,
) -> Result<(), ()> {
let Some(cache) = upgrade_event_cache(cache) else {
return Err(());
};
cache.retry_in_memory_events().await;
let _ = cache.inner.redecryption_channels.utd_reporter.send(report);
Ok(())
}
pub(crate) struct Redecryptor {
_task: BackgroundTaskHandle,
}
impl Redecryptor {
pub(super) fn new(
client: &Client,
cache: Weak<EventCacheInner>,
receiver: UnboundedReceiver<DecryptionRetryRequest>,
linked_chunk_update_sender: &Sender<RoomEventCacheLinkedChunkUpdate>,
) -> Self {
let linked_chunk_stream = BroadcastStream::new(linked_chunk_update_sender.subscribe());
let backup_state_stream = client.encryption().backups().state_stream();
let task = client
.task_monitor()
.spawn_infinite_task("event_cache::redecryptor", async {
let request_redecryption_stream = UnboundedReceiverStream::new(receiver);
Self::listen_for_room_keys_task(
cache,
request_redecryption_stream,
linked_chunk_stream,
backup_state_stream,
)
.await;
})
.abort_on_drop();
Self { _task: task }
}
async fn subscribe_to_room_key_stream(
cache: &Weak<EventCacheInner>,
) -> Option<(
impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>>,
impl Stream<Item = Vec<RoomKeyWithheldInfo>>,
)> {
let event_cache = cache.upgrade()?;
let client = event_cache.client().ok()?;
let machine = client.olm_machine().await;
machine.as_ref().map(|m| {
(m.store().room_keys_received_stream(), m.store().room_keys_withheld_received_stream())
})
}
async fn redecryption_loop(
cache: &Weak<EventCacheInner>,
decryption_request_stream: &mut Pin<&mut impl Stream<Item = DecryptionRetryRequest>>,
events_stream: &mut Pin<
&mut impl Stream<Item = Result<RoomEventCacheLinkedChunkUpdate, BroadcastStreamRecvError>>,
>,
backup_state_stream: &mut Pin<
&mut impl Stream<Item = Result<BackupState, BroadcastStreamRecvError>>,
>,
) -> bool {
let Some((room_key_stream, withheld_stream)) =
Self::subscribe_to_room_key_stream(cache).await
else {
return false;
};
pin_mut!(room_key_stream);
pin_mut!(withheld_stream);
loop {
tokio::select! {
Some(request) = decryption_request_stream.next() => {
let Some(cache) = upgrade_event_cache(cache) else {
break false;
};
trace!(?request, "Received a redecryption request");
for session_id in request.utd_session_ids {
let _ = cache
.retry_decryption(&request.room_id, &session_id)
.await
.inspect_err(|e| warn!("Error redecrypting after an explicit request was received {e:?}"));
}
for session_id in request.refresh_info_session_ids {
let _ = cache.update_encryption_info(&request.room_id, &session_id).await.inspect_err(|e|
warn!(
room_id = %request.room_id,
session_id = session_id,
"Unable to update the encryption info {e:?}",
));
}
}
room_keys = room_key_stream.next() => {
match room_keys {
Some(Ok(room_keys)) => {
let Some(cache) = upgrade_event_cache(cache) else {
break false;
};
trace!(?room_keys, "Received new room keys");
for key in &room_keys {
let _ = cache
.retry_decryption(&key.room_id, &key.session_id)
.await
.inspect_err(|e| warn!("Error redecrypting {e:?}"));
}
for key in room_keys {
let _ = cache.update_encryption_info(&key.room_id, &key.session_id).await.inspect_err(|e|
warn!(
room_id = %key.room_id,
session_id = key.session_id,
"Unable to update the encryption info {e:?}",
));
}
},
Some(Err(_)) => {
warn!("The room key stream lagged, reporting the lag to our listeners");
if send_report_and_retry_memory_events(cache, RedecryptorReport::Lagging).await.is_err() {
break false;
}
},
None => {
break true;
}
}
}
withheld_info = withheld_stream.next() => {
match withheld_info {
Some(infos) => {
let Some(cache) = upgrade_event_cache(cache) else {
break false;
};
trace!(?infos, "Received new withheld infos");
for RoomKeyWithheldInfo { room_id, session_id, .. } in &infos {
let _ = cache.update_encryption_info(room_id, session_id).await.inspect_err(|e|
warn!(
room_id = %room_id,
session_id = session_id,
"Unable to update the encryption info {e:?}",
));
}
}
None => break true,
}
}
Some(event_updates) = events_stream.next() => {
match event_updates {
Ok(updates) => {
let Some(cache) = upgrade_event_cache(cache) else {
break false;
};
let linked_chunk_id = updates.linked_chunk_id.to_owned();
let _ = cache.retry_decryption_for_event_cache_updates(updates).await.inspect_err(|e|
warn!(
%linked_chunk_id,
"Unable to handle UTDs from event cache updates {e:?}",
)
);
}
Err(_) => {
if send_report_and_retry_memory_events(cache, RedecryptorReport::Lagging).await.is_err() {
break false;
}
}
}
}
Some(backup_state_update) = backup_state_stream.next() => {
match backup_state_update {
Ok(state) => {
match state {
BackupState::Unknown |
BackupState::Creating |
BackupState::Enabling |
BackupState::Resuming |
BackupState::Downloading |
BackupState::Disabling =>{
}
BackupState::Enabled => {
if send_report_and_retry_memory_events(cache, RedecryptorReport::BackupAvailable).await.is_err() {
break false;
}
}
}
}
Err(_) => {
if send_report_and_retry_memory_events(cache, RedecryptorReport::Lagging).await.is_err() {
break false;
}
}
}
}
else => break false,
}
}
}
async fn listen_for_room_keys_task(
cache: Weak<EventCacheInner>,
decryption_request_stream: UnboundedReceiverStream<DecryptionRetryRequest>,
events_stream: BroadcastStream<RoomEventCacheLinkedChunkUpdate>,
backup_state_stream: impl Stream<Item = Result<BackupState, BroadcastStreamRecvError>>,
) {
pin_mut!(decryption_request_stream);
pin_mut!(events_stream);
pin_mut!(backup_state_stream);
while Self::redecryption_loop(
&cache,
&mut decryption_request_stream,
&mut events_stream,
&mut backup_state_stream,
)
.await
{
info!("Regenerating the re-decryption streams");
if send_report_and_retry_memory_events(&cache, RedecryptorReport::Lagging)
.await
.is_err()
{
break;
}
}
info!("Shutting down the event cache redecryptor");
}
}
#[cfg(not(target_family = "wasm"))]
#[cfg(test)]
mod tests {
use std::{
collections::BTreeSet,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
time::Duration,
};
use assert_matches2::assert_matches;
use async_trait::async_trait;
use eyeball_im::VectorDiff;
use matrix_sdk_base::{
cross_process_lock::CrossProcessLockGeneration,
crypto::types::events::{ToDeviceEvent, room::encrypted::ToDeviceEncryptedEventContent},
deserialized_responses::{TimelineEventKind, VerificationState},
event_cache::{
Event, Gap,
store::{EventCacheStore, EventCacheStoreError, MemoryStore},
},
linked_chunk::{
ChunkIdentifier, ChunkIdentifierGenerator, ChunkMetadata, LinkedChunkId, Position,
RawChunk, Update,
},
locks::Mutex,
sleep::sleep,
store::StoreConfig,
timeout::timeout,
};
use matrix_sdk_common::cross_process_lock::CrossProcessLockConfig;
use matrix_sdk_test::{JoinedRoomBuilder, async_test, event_factory::EventFactory};
use ruma::{
EventId, OwnedEventId, RoomId, RoomVersionId, device_id, event_id,
events::{AnySyncTimelineEvent, relation::RelationType},
room_id,
serde::Raw,
user_id,
};
use serde_json::json;
use tokio::sync::oneshot::{self, Sender};
use tracing::{Instrument, info};
use crate::{
Client, assert_let_timeout,
encryption::EncryptionSettings,
event_cache::{
DecryptionRetryRequest, RoomEventCacheGenericUpdate, RoomEventCacheUpdate,
TimelineVectorDiffs,
},
test_utils::mocks::MatrixMockServer,
};
#[derive(Debug, Clone)]
struct DelayingStore {
memory_store: MemoryStore,
delaying: Arc<AtomicBool>,
foo: Arc<Mutex<Option<Sender<()>>>>,
}
impl DelayingStore {
fn new() -> Self {
Self {
memory_store: MemoryStore::new(),
delaying: AtomicBool::new(true).into(),
foo: Arc::new(Mutex::new(None)),
}
}
async fn stop_delaying(&self) {
let (sender, receiver) = oneshot::channel();
{
*self.foo.lock() = Some(sender);
}
self.delaying.store(false, Ordering::SeqCst);
receiver.await.expect("We should be able to receive a response")
}
}
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl EventCacheStore for DelayingStore {
type Error = EventCacheStoreError;
async fn try_take_leased_lock(
&self,
lease_duration_ms: u32,
key: &str,
holder: &str,
) -> Result<Option<CrossProcessLockGeneration>, Self::Error> {
self.memory_store.try_take_leased_lock(lease_duration_ms, key, holder).await
}
async fn handle_linked_chunk_updates(
&self,
linked_chunk_id: LinkedChunkId<'_>,
updates: Vec<Update<Event, Gap>>,
) -> Result<(), Self::Error> {
while self.delaying.load(Ordering::SeqCst) {
sleep(Duration::from_millis(10)).await;
}
let sender = self.foo.lock().take();
let ret = self.memory_store.handle_linked_chunk_updates(linked_chunk_id, updates).await;
if let Some(sender) = sender {
sender.send(()).expect("We should be able to notify the other side that we're done with the storage operation");
}
ret
}
async fn load_all_chunks(
&self,
linked_chunk_id: LinkedChunkId<'_>,
) -> Result<Vec<RawChunk<Event, Gap>>, Self::Error> {
self.memory_store.load_all_chunks(linked_chunk_id).await
}
async fn load_all_chunks_metadata(
&self,
linked_chunk_id: LinkedChunkId<'_>,
) -> Result<Vec<ChunkMetadata>, Self::Error> {
self.memory_store.load_all_chunks_metadata(linked_chunk_id).await
}
async fn load_last_chunk(
&self,
linked_chunk_id: LinkedChunkId<'_>,
) -> Result<(Option<RawChunk<Event, Gap>>, ChunkIdentifierGenerator), Self::Error> {
self.memory_store.load_last_chunk(linked_chunk_id).await
}
async fn load_previous_chunk(
&self,
linked_chunk_id: LinkedChunkId<'_>,
before_chunk_identifier: ChunkIdentifier,
) -> Result<Option<RawChunk<Event, Gap>>, Self::Error> {
self.memory_store.load_previous_chunk(linked_chunk_id, before_chunk_identifier).await
}
async fn clear_all_linked_chunks(&self) -> Result<(), Self::Error> {
self.memory_store.clear_all_linked_chunks().await
}
async fn filter_duplicated_events(
&self,
linked_chunk_id: LinkedChunkId<'_>,
events: Vec<OwnedEventId>,
) -> Result<Vec<(OwnedEventId, Position)>, Self::Error> {
self.memory_store.filter_duplicated_events(linked_chunk_id, events).await
}
async fn find_event(
&self,
room_id: &RoomId,
event_id: &EventId,
) -> Result<Option<Event>, Self::Error> {
self.memory_store.find_event(room_id, event_id).await
}
async fn find_event_relations(
&self,
room_id: &RoomId,
event_id: &EventId,
filters: Option<&[RelationType]>,
) -> Result<Vec<(Event, Option<Position>)>, Self::Error> {
self.memory_store.find_event_relations(room_id, event_id, filters).await
}
async fn get_room_events(
&self,
room_id: &RoomId,
event_type: Option<&str>,
session_id: Option<&str>,
) -> Result<Vec<Event>, Self::Error> {
self.memory_store.get_room_events(room_id, event_type, session_id).await
}
async fn save_event(&self, room_id: &RoomId, event: Event) -> Result<(), Self::Error> {
self.memory_store.save_event(room_id, event).await
}
async fn optimize(&self) -> Result<(), Self::Error> {
self.memory_store.optimize().await
}
async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
self.memory_store.get_size().await
}
}
async fn set_up_clients(
room_id: &RoomId,
alice_enables_cross_signing: bool,
use_delayed_store: bool,
) -> (Client, Client, MatrixMockServer, Option<DelayingStore>) {
let alice_span = tracing::info_span!("alice");
let bob_span = tracing::info_span!("bob");
let alice_user_id = user_id!("@alice:localhost");
let alice_device_id = device_id!("ALICEDEVICE");
let bob_user_id = user_id!("@bob:localhost");
let bob_device_id = device_id!("BOBDEVICE");
let matrix_mock_server = MatrixMockServer::new().await;
matrix_mock_server.mock_crypto_endpoints_preset().await;
let encryption_settings = EncryptionSettings {
auto_enable_cross_signing: alice_enables_cross_signing,
..Default::default()
};
let alice = matrix_mock_server
.client_builder_for_crypto_end_to_end(alice_user_id, alice_device_id)
.on_builder(|builder| {
builder
.with_enable_share_history_on_invite(true)
.with_encryption_settings(encryption_settings)
})
.build()
.instrument(alice_span.clone())
.await;
let encryption_settings =
EncryptionSettings { auto_enable_cross_signing: true, ..Default::default() };
let (store_config, store) = if use_delayed_store {
let store = DelayingStore::new();
(
StoreConfig::new(CrossProcessLockConfig::multi_process(
"delayed_store_event_cache_test",
))
.event_cache_store(store.clone()),
Some(store),
)
} else {
(
StoreConfig::new(CrossProcessLockConfig::multi_process(
"normal_store_event_cache_test",
)),
None,
)
};
let bob = matrix_mock_server
.client_builder_for_crypto_end_to_end(bob_user_id, bob_device_id)
.on_builder(|builder| {
builder
.with_enable_share_history_on_invite(true)
.with_encryption_settings(encryption_settings)
.store_config(store_config)
})
.build()
.instrument(bob_span.clone())
.await;
bob.event_cache().subscribe().expect("Bob should be able to enable the event cache");
matrix_mock_server.exchange_e2ee_identities(&alice, &bob).await;
let event_factory = EventFactory::new().room(room_id).sender(alice_user_id);
let room_builder = JoinedRoomBuilder::new(room_id)
.add_state_event(event_factory.create(alice_user_id, RoomVersionId::V1))
.add_state_event(event_factory.room_encryption());
matrix_mock_server
.mock_sync()
.ok_and_run(&alice, |builder| {
builder.add_joined_room(room_builder.clone());
})
.instrument(alice_span)
.await;
matrix_mock_server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_joined_room(room_builder);
})
.instrument(bob_span)
.await;
(alice, bob, matrix_mock_server, store)
}
async fn prepare_room(
matrix_mock_server: &MatrixMockServer,
event_factory: &EventFactory,
alice: &Client,
bob: &Client,
room_id: &RoomId,
) -> (Raw<AnySyncTimelineEvent>, Raw<ToDeviceEvent<ToDeviceEncryptedEventContent>>) {
let alice_user_id = alice.user_id().unwrap();
let bob_user_id = bob.user_id().unwrap();
let alice_member_event = event_factory.member(alice_user_id).into_raw();
let bob_member_event = event_factory.member(bob_user_id).into_raw();
let room = alice
.get_room(room_id)
.expect("Alice should have access to the room now that we synced");
let event_type = "m.room.message";
let content = json!({"body": "It's a secret to everybody", "msgtype": "m.text"});
let event_id = event_id!("$some_id");
let (event_receiver, mock) =
matrix_mock_server.mock_room_send().ok_with_capture(event_id, alice_user_id);
let (_guard, room_key) = matrix_mock_server.mock_capture_put_to_device(alice_user_id).await;
{
let _guard = mock.mock_once().mount_as_scoped().await;
matrix_mock_server
.mock_get_members()
.ok(vec![alice_member_event.clone(), bob_member_event.clone()])
.mock_once()
.mount()
.await;
room.send_raw(event_type, content)
.await
.expect("We should be able to send an initial message");
};
let event = event_receiver.await.expect("Alice should have sent the event by now");
let room_key = room_key.await;
(event, room_key)
}
#[async_test]
async fn test_redecryptor() {
let room_id = room_id!("!test:localhost");
let event_factory = EventFactory::new().room(room_id);
let (alice, bob, matrix_mock_server, _) = set_up_clients(room_id, true, false).await;
let (event, room_key) =
prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
let event_cache = bob.event_cache();
let (room_cache, _) = event_cache
.for_room(room_id)
.await
.expect("We should be able to get to the event cache for a specific room");
let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
let mut generic_stream = event_cache.subscribe_to_room_generic_updates();
bob.inner
.base_client
.regenerate_olm(None)
.await
.expect("We should be able to regenerate the Olm machine");
matrix_mock_server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
})
.await;
assert_let_timeout!(
Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
subscriber.recv()
);
assert_eq!(diffs.len(), 1);
assert_matches!(&diffs[0], VectorDiff::Append { values });
assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
assert_let_timeout!(
Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
);
assert_eq!(expected_room_id, room_id);
assert!(generic_stream.is_empty());
matrix_mock_server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_to_device_event(
room_key
.deserialize_as()
.expect("We should be able to deserialize the room key"),
);
})
.await;
assert_let_timeout!(
Duration::from_secs(1),
Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
subscriber.recv()
);
assert_eq!(diffs.len(), 1);
assert_matches!(&diffs[0], VectorDiff::Set { index, value });
assert_eq!(*index, 0);
assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
assert_let_timeout!(
Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
);
assert_eq!(expected_room_id, room_id);
assert!(generic_stream.is_empty());
}
#[async_test]
async fn test_redecryptor_updating_encryption_info() {
let bob_span = tracing::info_span!("bob");
let room_id = room_id!("!test:localhost");
let event_factory = EventFactory::new().room(room_id);
let (alice, bob, matrix_mock_server, _) = set_up_clients(room_id, false, false).await;
let (event, room_key) =
prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
let event_cache = bob.event_cache();
let (room_cache, _) = event_cache
.for_room(room_id)
.instrument(bob_span.clone())
.await
.expect("We should be able to get to the event cache for a specific room");
let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
let mut generic_stream = event_cache.subscribe_to_room_generic_updates();
matrix_mock_server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
})
.instrument(bob_span.clone())
.await;
assert_let_timeout!(
Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
subscriber.recv()
);
assert_eq!(diffs.len(), 1);
assert_matches!(&diffs[0], VectorDiff::Append { values });
assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
assert_let_timeout!(
Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
);
assert_eq!(expected_room_id, room_id);
assert!(generic_stream.is_empty());
matrix_mock_server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_to_device_event(
room_key
.deserialize_as()
.expect("We should be able to deserialize the room key"),
);
})
.instrument(bob_span.clone())
.await;
assert_let_timeout!(
Duration::from_secs(1),
Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
subscriber.recv()
);
assert_eq!(diffs.len(), 1);
assert_matches!(&diffs[0], VectorDiff::Set { index: 0, value });
assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
let encryption_info = value.encryption_info().unwrap();
assert_matches!(&encryption_info.verification_state, VerificationState::Unverified(_));
assert_let_timeout!(
Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
);
assert_eq!(expected_room_id, room_id);
assert!(generic_stream.is_empty());
let session_id = encryption_info.session_id().unwrap().to_owned();
let alice_user_id = alice.user_id().unwrap();
alice
.encryption()
.bootstrap_cross_signing(None)
.await
.expect("Alice should be able to create the cross-signing keys");
bob.update_tracked_users_for_testing([alice_user_id]).instrument(bob_span.clone()).await;
matrix_mock_server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_change_device(alice_user_id);
})
.instrument(bob_span.clone())
.await;
bob.event_cache().request_decryption(DecryptionRetryRequest {
room_id: room_id.into(),
utd_session_ids: BTreeSet::new(),
refresh_info_session_ids: BTreeSet::from([session_id]),
});
assert_let_timeout!(
Duration::from_secs(1),
Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
subscriber.recv()
);
assert_eq!(diffs.len(), 1);
assert_matches!(&diffs[0], VectorDiff::Set { index: 0, value });
assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
let encryption_info = value.encryption_info().unwrap();
assert_matches!(
&encryption_info.verification_state,
VerificationState::Unverified(_),
"The event should now know about the identity but still be unverified"
);
assert_let_timeout!(
Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
);
assert_eq!(expected_room_id, room_id);
assert!(generic_stream.is_empty());
}
#[async_test]
async fn test_event_is_redecrypted_even_if_key_arrives_while_event_processing() {
let room_id = room_id!("!test:localhost");
let event_factory = EventFactory::new().room(room_id);
let (alice, bob, matrix_mock_server, delayed_store) =
set_up_clients(room_id, true, true).await;
let delayed_store = delayed_store.unwrap();
let (event, room_key) =
prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
let event_cache = bob.event_cache();
let (room_cache, _) = event_cache
.for_room(room_id)
.await
.expect("We should be able to get to the event cache for a specific room");
let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
let mut generic_stream = event_cache.subscribe_to_room_generic_updates();
matrix_mock_server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
})
.await;
matrix_mock_server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_to_device_event(
room_key
.deserialize_as()
.expect("We should be able to deserialize the room key"),
);
})
.await;
info!("Stopping the delay");
delayed_store.stop_delaying().await;
assert_let_timeout!(
Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
subscriber.recv()
);
assert_eq!(diffs.len(), 1);
assert_matches!(&diffs[0], VectorDiff::Append { values });
assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
assert_let_timeout!(
Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
);
assert_eq!(expected_room_id, room_id);
assert_let_timeout!(
Duration::from_secs(1),
Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
subscriber.recv()
);
assert_eq!(diffs.len(), 1);
assert_matches!(&diffs[0], VectorDiff::Set { index, value });
assert_eq!(*index, 0);
assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
assert_let_timeout!(
Ok(RoomEventCacheGenericUpdate { room_id: expected_room_id }) = generic_stream.recv()
);
assert_eq!(expected_room_id, room_id);
assert!(generic_stream.is_empty());
}
#[async_test]
async fn test_redecryptor_no_deadlock_with_event_focused_cache_pagination() {
use crate::{
event_cache::EventFocusThreadMode,
test_utils::mocks::{RoomContextResponseTemplate, RoomMessagesResponseTemplate},
};
let room_id = room_id!("!test:localhost");
let f = EventFactory::new().room(room_id);
let (alice, bob, server, _) = set_up_clients(room_id, true, false).await;
let (encrypted_event, room_key) = prepare_room(&server, &f, &alice, &bob, room_id).await;
let event_cache = bob.event_cache();
let (room_cache, _drop_handles) = event_cache
.for_room(room_id)
.await
.expect("Bob should have an event cache for the room");
let (_initial_events, mut subscriber) = room_cache.subscribe().await.unwrap();
server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_joined_room(
JoinedRoomBuilder::new(room_id).add_timeline_event(encrypted_event),
);
})
.await;
assert_let_timeout!(
Ok(RoomEventCacheUpdate::UpdateTimelineEvents(TimelineVectorDiffs { diffs, .. })) =
subscriber.recv()
);
assert_eq!(diffs.len(), 1);
assert_matches!(&diffs[0], VectorDiff::Append { values });
assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
let focused_event_id = event_id!("$focused");
let bob_user_id = bob.user_id().unwrap();
server
.mock_room_event_context()
.expect_any_access_token()
.ok(RoomContextResponseTemplate::new(
f.text_msg("focused msg")
.sender(bob_user_id)
.event_id(focused_event_id)
.into_event(),
)
.start("back-token"))
.mock_once()
.mount()
.await;
let event_focused_cache = room_cache
.get_or_create_event_focused_cache(
focused_event_id.to_owned(),
20,
EventFocusThreadMode::Automatic,
)
.await
.unwrap();
server
.mock_room_messages()
.expect_any_access_token()
.ok(RoomMessagesResponseTemplate::default().with_delay(Duration::from_secs(5)))
.mock_once()
.mount()
.await;
let event_focused_cache_clone = event_focused_cache.clone();
let pagination_task = tokio::spawn(async move {
let _ = event_focused_cache_clone.paginate_backwards(20).await;
});
sleep(Duration::from_millis(200)).await;
server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_to_device_event(
room_key
.deserialize_as()
.expect("We should be able to deserialize the room key"),
);
})
.await;
sleep(Duration::from_secs(1)).await;
let (_events, _subscriber) = timeout(room_cache.subscribe(), Duration::from_millis(100))
.await
.expect("subscribing shouldn't timeout")
.expect("subscribing should succeed");
pagination_task.abort();
}
}