use std::{collections::BTreeSet, pin::Pin, sync::Weak};
use as_variant::as_variant;
use futures_core::Stream;
use futures_util::{StreamExt, pin_mut};
#[cfg(doc)]
use matrix_sdk_base::{BaseClient, crypto::OlmMachine};
use matrix_sdk_base::{
crypto::{
store::types::{RoomKeyInfo, RoomKeyWithheldInfo},
types::events::room::encrypted::EncryptedEvent,
},
deserialized_responses::{DecryptedRoomEvent, TimelineEvent, TimelineEventKind},
event_cache::store::EventCacheStoreLockState,
locks::Mutex,
timer,
};
#[cfg(doc)]
use matrix_sdk_common::deserialized_responses::EncryptionInfo;
use matrix_sdk_common::executor::{AbortOnDrop, JoinHandleExt, spawn};
use ruma::{
OwnedEventId, OwnedRoomId, RoomId,
events::{AnySyncTimelineEvent, room::encrypted::OriginalSyncRoomEncryptedEvent},
push::Action,
serde::Raw,
};
use tokio::sync::{
broadcast::{self, Sender},
mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
};
use tokio_stream::wrappers::{
BroadcastStream, UnboundedReceiverStream, errors::BroadcastStreamRecvError,
};
use tracing::{info, instrument, trace, warn};
#[cfg(doc)]
use super::RoomEventCache;
use super::{EventCache, EventCacheError, EventCacheInner, EventsOrigin, RoomEventCacheUpdate};
use crate::{Room, event_cache::RoomEventCacheLinkedChunkUpdate, room::PushContext};
type SessionId<'a> = &'a str;
type OwnedSessionId = String;
type EventIdAndUtd = (OwnedEventId, Raw<AnySyncTimelineEvent>);
type EventIdAndEvent = (OwnedEventId, DecryptedRoomEvent);
type ResolvedUtd = (OwnedEventId, DecryptedRoomEvent, Option<Vec<Action>>);
#[derive(Debug, Clone)]
pub struct DecryptionRetryRequest {
pub room_id: OwnedRoomId,
pub utd_session_ids: BTreeSet<OwnedSessionId>,
pub refresh_info_session_ids: BTreeSet<OwnedSessionId>,
}
#[derive(Debug, Clone)]
pub enum RedecryptorReport {
ResolvedUtds {
room_id: OwnedRoomId,
events: BTreeSet<OwnedEventId>,
},
Lagging,
}
pub(super) struct RedecryptorChannels {
utd_reporter: Sender<RedecryptorReport>,
pub(super) decryption_request_sender: UnboundedSender<DecryptionRetryRequest>,
pub(super) decryption_request_receiver:
Mutex<Option<UnboundedReceiver<DecryptionRetryRequest>>>,
}
impl RedecryptorChannels {
pub(super) fn new() -> Self {
let (utd_reporter, _) = broadcast::channel(100);
let (decryption_request_sender, decryption_request_receiver) = unbounded_channel();
Self {
utd_reporter,
decryption_request_sender,
decryption_request_receiver: Mutex::new(Some(decryption_request_receiver)),
}
}
}
fn filter_timeline_event_to_utd(
event: TimelineEvent,
) -> Option<(OwnedEventId, Raw<AnySyncTimelineEvent>)> {
let event_id = event.event_id();
let event = as_variant!(event.kind, TimelineEventKind::UnableToDecrypt { event, .. } => event);
event_id.zip(event)
}
impl EventCache {
async fn get_utds(
&self,
room_id: &RoomId,
session_id: SessionId<'_>,
) -> Result<Vec<EventIdAndUtd>, EventCacheError> {
let events = match self.inner.store.lock().await? {
EventCacheStoreLockState::Clean(guard) | EventCacheStoreLockState::Dirty(guard) => {
guard.get_room_events(room_id, Some("m.room.encrypted"), Some(session_id)).await?
}
};
Ok(events.into_iter().filter_map(filter_timeline_event_to_utd).collect())
}
async fn get_decrypted_events(
&self,
room_id: &RoomId,
session_id: SessionId<'_>,
) -> Result<Vec<EventIdAndEvent>, EventCacheError> {
let filter = |event: TimelineEvent| {
let event_id = event.event_id();
let event = as_variant!(event.kind, TimelineEventKind::Decrypted(event) => event);
event_id.zip(event)
};
let events = match self.inner.store.lock().await? {
EventCacheStoreLockState::Clean(guard) | EventCacheStoreLockState::Dirty(guard) => {
guard.get_room_events(room_id, None, Some(session_id)).await?
}
};
Ok(events.into_iter().filter_map(filter).collect())
}
#[instrument(skip_all, fields(room_id))]
async fn on_resolved_utds(
&self,
room_id: &RoomId,
events: Vec<ResolvedUtd>,
) -> Result<(), EventCacheError> {
if events.is_empty() {
trace!("No events were redecrypted or updated, nothing to replace");
return Ok(());
}
timer!("Resolving UTDs");
let (room_cache, _drop_handles) = self.for_room(room_id).await?;
let mut state = room_cache.inner.state.write().await?;
let event_ids: BTreeSet<_> =
events.iter().cloned().map(|(event_id, _, _)| event_id).collect();
let mut new_events = Vec::with_capacity(events.len());
for (event_id, decrypted, actions) in events {
if let Some((location, mut target_event)) = state.find_event(&event_id).await? {
target_event.kind = TimelineEventKind::Decrypted(decrypted);
if let Some(actions) = actions {
target_event.set_push_actions(actions);
}
state.replace_event_at(location, target_event.clone()).await?;
new_events.push(target_event);
}
}
state.post_process_new_events(new_events, false).await?;
let diffs = state.room_linked_chunk().updates_as_vector_diffs();
let _ = room_cache.inner.update_sender.send(RoomEventCacheUpdate::UpdateTimelineEvents {
diffs,
origin: EventsOrigin::Cache,
});
let report =
RedecryptorReport::ResolvedUtds { room_id: room_id.to_owned(), events: event_ids };
let _ = self.inner.redecryption_channels.utd_reporter.send(report);
Ok(())
}
async fn decrypt_event(
&self,
room_id: &RoomId,
room: Option<&Room>,
push_context: Option<&PushContext>,
event: &Raw<EncryptedEvent>,
) -> Option<(DecryptedRoomEvent, Option<Vec<Action>>)> {
if let Some(room) = room {
match room
.decrypt_event(
event.cast_ref_unchecked::<OriginalSyncRoomEncryptedEvent>(),
push_context,
)
.await
{
Ok(maybe_decrypted) => {
let actions = maybe_decrypted.push_actions().map(|a| a.to_vec());
if let TimelineEventKind::Decrypted(decrypted) = maybe_decrypted.kind {
Some((decrypted, actions))
} else {
warn!(
"Failed to redecrypt an event despite receiving a room key or request to redecrypt"
);
None
}
}
Err(e) => {
warn!(
"Failed to redecrypt an event despite receiving a room key or request to redecrypt {e:?}"
);
None
}
}
} else {
let client = self.inner.client().ok()?;
let machine = client.olm_machine().await;
let machine = machine.as_ref()?;
match machine.decrypt_room_event(event, room_id, client.decryption_settings()).await {
Ok(decrypted) => Some((decrypted, None)),
Err(e) => {
warn!(
"Failed to redecrypt an event despite receiving a room key or a request to redecrypt {e:?}"
);
None
}
}
}
}
#[instrument(skip_all, fields(room_id, session_id))]
async fn retry_decryption(
&self,
room_id: &RoomId,
session_id: SessionId<'_>,
) -> Result<(), EventCacheError> {
let events = self.get_utds(room_id, session_id).await?;
self.retry_decryption_for_events(room_id, events).await
}
#[instrument(skip_all, fields(updates.linked_chunk_id))]
async fn retry_decryption_for_event_cache_updates(
&self,
updates: RoomEventCacheLinkedChunkUpdate,
) -> Result<(), EventCacheError> {
let room_id = updates.linked_chunk_id.room_id();
let events: Vec<_> = updates
.updates
.into_iter()
.flat_map(|updates| updates.into_items())
.filter_map(filter_timeline_event_to_utd)
.collect();
self.retry_decryption_for_events(room_id, events).await
}
#[instrument(skip_all, fields(room_id, session_id))]
async fn retry_decryption_for_events(
&self,
room_id: &RoomId,
events: Vec<EventIdAndUtd>,
) -> Result<(), EventCacheError> {
trace!("Retrying to decrypt");
if events.is_empty() {
trace!("No relevant events found.");
return Ok(());
}
let room = self.inner.client().ok().and_then(|client| client.get_room(room_id));
let push_context =
if let Some(room) = &room { room.push_context().await.ok().flatten() } else { None };
let mut decrypted_events = Vec::with_capacity(events.len());
for (event_id, event) in events {
if let Some((decrypted, actions)) = self
.decrypt_event(
room_id,
room.as_ref(),
push_context.as_ref(),
event.cast_ref_unchecked(),
)
.await
{
decrypted_events.push((event_id, decrypted, actions));
}
}
let event_ids: BTreeSet<_> =
decrypted_events.iter().map(|(event_id, _, _)| event_id).collect();
if !event_ids.is_empty() {
trace!(?event_ids, "Successfully redecrypted events");
}
self.on_resolved_utds(room_id, decrypted_events).await?;
Ok(())
}
#[instrument(skip_all, fields(room_id, session_id))]
async fn update_encryption_info(
&self,
room_id: &RoomId,
session_id: SessionId<'_>,
) -> Result<(), EventCacheError> {
trace!("Updating encryption info");
let Ok(client) = self.inner.client() else {
return Ok(());
};
let Some(room) = client.get_room(room_id) else {
return Ok(());
};
let events = self.get_decrypted_events(room_id, session_id).await?;
if events.is_empty() {
trace!("No relevant events found.");
return Ok(());
}
let mut updated_events = Vec::with_capacity(events.len());
for (event_id, mut event) in events {
let new_encryption_info =
room.get_encryption_info(session_id, &event.encryption_info.sender).await;
if let Some(new_encryption_info) = new_encryption_info
&& event.encryption_info != new_encryption_info
{
event.encryption_info = new_encryption_info;
updated_events.push((event_id, event, None));
}
}
let event_ids: BTreeSet<_> =
updated_events.iter().map(|(event_id, _, _)| event_id).collect();
if !event_ids.is_empty() {
trace!(?event_ids, "Replacing the encryption info of some events");
}
self.on_resolved_utds(room_id, updated_events).await?;
Ok(())
}
pub fn request_decryption(&self, request: DecryptionRetryRequest) {
let _ =
self.inner.redecryption_channels.decryption_request_sender.send(request).inspect_err(
|_| warn!("Requesting a decryption while the redecryption task has been shut down"),
);
}
pub fn subscribe_to_decryption_reports(
&self,
) -> impl Stream<Item = Result<RedecryptorReport, BroadcastStreamRecvError>> {
BroadcastStream::new(self.inner.redecryption_channels.utd_reporter.subscribe())
}
}
pub(crate) struct Redecryptor {
_task: AbortOnDrop<()>,
}
impl Redecryptor {
pub(super) fn new(
cache: Weak<EventCacheInner>,
receiver: UnboundedReceiver<DecryptionRetryRequest>,
linked_chunk_update_sender: &Sender<RoomEventCacheLinkedChunkUpdate>,
) -> Self {
let linked_chunk_stream = BroadcastStream::new(linked_chunk_update_sender.subscribe());
let task = spawn(async {
let request_redecryption_stream = UnboundedReceiverStream::new(receiver);
Self::listen_for_room_keys_task(
cache,
request_redecryption_stream,
linked_chunk_stream,
)
.await;
})
.abort_on_drop();
Self { _task: task }
}
async fn subscribe_to_room_key_stream(
cache: &Weak<EventCacheInner>,
) -> Option<(
impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>>,
impl Stream<Item = Vec<RoomKeyWithheldInfo>>,
)> {
let event_cache = cache.upgrade()?;
let client = event_cache.client().ok()?;
let machine = client.olm_machine().await;
machine.as_ref().map(|m| {
(m.store().room_keys_received_stream(), m.store().room_keys_withheld_received_stream())
})
}
#[inline(always)]
fn upgrade_event_cache(cache: &Weak<EventCacheInner>) -> Option<EventCache> {
cache.upgrade().map(|inner| EventCache { inner })
}
async fn redecryption_loop(
cache: &Weak<EventCacheInner>,
decryption_request_stream: &mut Pin<&mut impl Stream<Item = DecryptionRetryRequest>>,
events_stream: &mut Pin<
&mut impl Stream<Item = Result<RoomEventCacheLinkedChunkUpdate, BroadcastStreamRecvError>>,
>,
) -> bool {
let Some((room_key_stream, withheld_stream)) =
Self::subscribe_to_room_key_stream(cache).await
else {
return false;
};
pin_mut!(room_key_stream);
pin_mut!(withheld_stream);
loop {
tokio::select! {
Some(request) = decryption_request_stream.next() => {
let Some(cache) = Self::upgrade_event_cache(cache) else {
break false;
};
trace!(?request, "Received a redecryption request");
for session_id in request.utd_session_ids {
let _ = cache
.retry_decryption(&request.room_id, &session_id)
.await
.inspect_err(|e| warn!("Error redecrypting after an explicit request was received {e:?}"));
}
for session_id in request.refresh_info_session_ids {
let _ = cache.update_encryption_info(&request.room_id, &session_id).await.inspect_err(|e|
warn!(
room_id = %request.room_id,
session_id = session_id,
"Unable to update the encryption info {e:?}",
));
}
}
room_keys = room_key_stream.next() => {
match room_keys {
Some(Ok(room_keys)) => {
let Some(cache) = Self::upgrade_event_cache(cache) else {
break false;
};
trace!(?room_keys, "Received new room keys");
for key in &room_keys {
let _ = cache
.retry_decryption(&key.room_id, &key.session_id)
.await
.inspect_err(|e| warn!("Error redecrypting {e:?}"));
}
for key in room_keys {
let _ = cache.update_encryption_info(&key.room_id, &key.session_id).await.inspect_err(|e|
warn!(
room_id = %key.room_id,
session_id = key.session_id,
"Unable to update the encryption info {e:?}",
));
}
},
Some(Err(_)) => {
let Some(cache) = Self::upgrade_event_cache(cache) else {
break false;
};
let message = RedecryptorReport::Lagging;
let _ = cache.inner.redecryption_channels.utd_reporter.send(message);
},
None => {
break true
}
}
}
withheld_info = withheld_stream.next() => {
match withheld_info {
Some(infos) => {
let Some(cache) = Self::upgrade_event_cache(cache) else {
break false;
};
trace!(?infos, "Received new withheld infos");
for RoomKeyWithheldInfo { room_id, session_id, .. } in &infos {
let _ = cache.update_encryption_info(room_id, session_id).await.inspect_err(|e|
warn!(
room_id = %room_id,
session_id = session_id,
"Unable to update the encryption info {e:?}",
));
}
}
None => break true
}
}
Some(event_updates) = events_stream.next() => {
match event_updates {
Ok(updates) => {
let Some(cache) = Self::upgrade_event_cache(cache) else {
break false;
};
let linked_chunk_id = updates.linked_chunk_id.to_owned();
let _ = cache.retry_decryption_for_event_cache_updates(updates).await.inspect_err(|e|
warn!(
%linked_chunk_id,
"Unable to handle UTDs from event cache updates {e:?}",
)
);
}
Err(_) => {
let Some(cache) = Self::upgrade_event_cache(cache) else {
break false;
};
let message = RedecryptorReport::Lagging;
let _ = cache.inner.redecryption_channels.utd_reporter.send(message);
}
}
}
else => break false,
}
}
}
async fn listen_for_room_keys_task(
cache: Weak<EventCacheInner>,
decryption_request_stream: UnboundedReceiverStream<DecryptionRetryRequest>,
events_stream: BroadcastStream<RoomEventCacheLinkedChunkUpdate>,
) {
pin_mut!(decryption_request_stream);
pin_mut!(events_stream);
while Self::redecryption_loop(&cache, &mut decryption_request_stream, &mut events_stream)
.await
{
info!("Regenerating the re-decryption streams");
let Some(cache) = Self::upgrade_event_cache(&cache) else {
break;
};
let message = RedecryptorReport::Lagging;
let _ = cache.inner.redecryption_channels.utd_reporter.send(message);
}
info!("Shutting down the event cache redecryptor");
}
}
#[cfg(not(target_family = "wasm"))]
#[cfg(test)]
mod tests {
use std::{
collections::BTreeSet,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
time::Duration,
};
use assert_matches2::assert_matches;
use async_trait::async_trait;
use eyeball_im::VectorDiff;
use matrix_sdk_base::{
cross_process_lock::CrossProcessLockGeneration,
crypto::types::events::{ToDeviceEvent, room::encrypted::ToDeviceEncryptedEventContent},
deserialized_responses::{TimelineEventKind, VerificationState},
event_cache::{
Event, Gap,
store::{EventCacheStore, EventCacheStoreError, MemoryStore},
},
linked_chunk::{
ChunkIdentifier, ChunkIdentifierGenerator, ChunkMetadata, LinkedChunkId, Position,
RawChunk, Update,
},
locks::Mutex,
sleep::sleep,
store::StoreConfig,
};
use matrix_sdk_test::{
JoinedRoomBuilder, StateTestEvent, async_test, event_factory::EventFactory,
};
use ruma::{
EventId, OwnedEventId, RoomId, device_id, event_id,
events::{AnySyncTimelineEvent, relation::RelationType},
room_id,
serde::Raw,
user_id,
};
use serde_json::json;
use tokio::sync::oneshot::{self, Sender};
use tracing::{Instrument, info};
use crate::{
Client, assert_let_timeout,
encryption::EncryptionSettings,
event_cache::{DecryptionRetryRequest, RoomEventCacheUpdate},
test_utils::mocks::MatrixMockServer,
};
#[derive(Debug, Clone)]
struct DelayingStore {
memory_store: MemoryStore,
delaying: Arc<AtomicBool>,
foo: Arc<Mutex<Option<Sender<()>>>>,
}
impl DelayingStore {
fn new() -> Self {
Self {
memory_store: MemoryStore::new(),
delaying: AtomicBool::new(true).into(),
foo: Arc::new(Mutex::new(None)),
}
}
async fn stop_delaying(&self) {
let (sender, receiver) = oneshot::channel();
{
*self.foo.lock() = Some(sender);
}
self.delaying.store(false, Ordering::SeqCst);
receiver.await.expect("We should be able to receive a response")
}
}
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
#[cfg_attr(not(target_family = "wasm"), async_trait)]
impl EventCacheStore for DelayingStore {
type Error = EventCacheStoreError;
async fn try_take_leased_lock(
&self,
lease_duration_ms: u32,
key: &str,
holder: &str,
) -> Result<Option<CrossProcessLockGeneration>, Self::Error> {
self.memory_store.try_take_leased_lock(lease_duration_ms, key, holder).await
}
async fn handle_linked_chunk_updates(
&self,
linked_chunk_id: LinkedChunkId<'_>,
updates: Vec<Update<Event, Gap>>,
) -> Result<(), Self::Error> {
while self.delaying.load(Ordering::SeqCst) {
sleep(Duration::from_millis(10)).await;
}
let sender = self.foo.lock().take();
let ret = self.memory_store.handle_linked_chunk_updates(linked_chunk_id, updates).await;
if let Some(sender) = sender {
sender.send(()).expect("We should be able to notify the other side that we're done with the storage operation");
}
ret
}
async fn load_all_chunks(
&self,
linked_chunk_id: LinkedChunkId<'_>,
) -> Result<Vec<RawChunk<Event, Gap>>, Self::Error> {
self.memory_store.load_all_chunks(linked_chunk_id).await
}
async fn load_all_chunks_metadata(
&self,
linked_chunk_id: LinkedChunkId<'_>,
) -> Result<Vec<ChunkMetadata>, Self::Error> {
self.memory_store.load_all_chunks_metadata(linked_chunk_id).await
}
async fn load_last_chunk(
&self,
linked_chunk_id: LinkedChunkId<'_>,
) -> Result<(Option<RawChunk<Event, Gap>>, ChunkIdentifierGenerator), Self::Error> {
self.memory_store.load_last_chunk(linked_chunk_id).await
}
async fn load_previous_chunk(
&self,
linked_chunk_id: LinkedChunkId<'_>,
before_chunk_identifier: ChunkIdentifier,
) -> Result<Option<RawChunk<Event, Gap>>, Self::Error> {
self.memory_store.load_previous_chunk(linked_chunk_id, before_chunk_identifier).await
}
async fn clear_all_linked_chunks(&self) -> Result<(), Self::Error> {
self.memory_store.clear_all_linked_chunks().await
}
async fn filter_duplicated_events(
&self,
linked_chunk_id: LinkedChunkId<'_>,
events: Vec<OwnedEventId>,
) -> Result<Vec<(OwnedEventId, Position)>, Self::Error> {
self.memory_store.filter_duplicated_events(linked_chunk_id, events).await
}
async fn find_event(
&self,
room_id: &RoomId,
event_id: &EventId,
) -> Result<Option<Event>, Self::Error> {
self.memory_store.find_event(room_id, event_id).await
}
async fn find_event_relations(
&self,
room_id: &RoomId,
event_id: &EventId,
filters: Option<&[RelationType]>,
) -> Result<Vec<(Event, Option<Position>)>, Self::Error> {
self.memory_store.find_event_relations(room_id, event_id, filters).await
}
async fn get_room_events(
&self,
room_id: &RoomId,
event_type: Option<&str>,
session_id: Option<&str>,
) -> Result<Vec<Event>, Self::Error> {
self.memory_store.get_room_events(room_id, event_type, session_id).await
}
async fn save_event(&self, room_id: &RoomId, event: Event) -> Result<(), Self::Error> {
self.memory_store.save_event(room_id, event).await
}
async fn optimize(&self) -> Result<(), Self::Error> {
self.memory_store.optimize().await
}
async fn get_size(&self) -> Result<Option<usize>, Self::Error> {
self.memory_store.get_size().await
}
}
async fn set_up_clients(
room_id: &RoomId,
alice_enables_cross_signing: bool,
use_delayed_store: bool,
) -> (Client, Client, MatrixMockServer, Option<DelayingStore>) {
let alice_span = tracing::info_span!("alice");
let bob_span = tracing::info_span!("bob");
let alice_user_id = user_id!("@alice:localhost");
let alice_device_id = device_id!("ALICEDEVICE");
let bob_user_id = user_id!("@bob:localhost");
let bob_device_id = device_id!("BOBDEVICE");
let matrix_mock_server = MatrixMockServer::new().await;
matrix_mock_server.mock_crypto_endpoints_preset().await;
let encryption_settings = EncryptionSettings {
auto_enable_cross_signing: alice_enables_cross_signing,
..Default::default()
};
let alice = matrix_mock_server
.client_builder_for_crypto_end_to_end(alice_user_id, alice_device_id)
.on_builder(|builder| {
builder
.with_enable_share_history_on_invite(true)
.with_encryption_settings(encryption_settings)
})
.build()
.instrument(alice_span.clone())
.await;
let encryption_settings =
EncryptionSettings { auto_enable_cross_signing: true, ..Default::default() };
let (store_config, store) = if use_delayed_store {
let store = DelayingStore::new();
(
StoreConfig::new("delayed_store_event_cache_test".into())
.event_cache_store(store.clone()),
Some(store),
)
} else {
(StoreConfig::new("normal_store_event_cache_test".into()), None)
};
let bob = matrix_mock_server
.client_builder_for_crypto_end_to_end(bob_user_id, bob_device_id)
.on_builder(|builder| {
builder
.with_enable_share_history_on_invite(true)
.with_encryption_settings(encryption_settings)
.store_config(store_config)
})
.build()
.instrument(bob_span.clone())
.await;
bob.event_cache().subscribe().expect("Bob should be able to enable the event cache");
matrix_mock_server.exchange_e2ee_identities(&alice, &bob).await;
let room_builder = JoinedRoomBuilder::new(room_id)
.add_state_event(StateTestEvent::Create)
.add_state_event(StateTestEvent::Encryption);
matrix_mock_server
.mock_sync()
.ok_and_run(&alice, |builder| {
builder.add_joined_room(room_builder.clone());
})
.instrument(alice_span)
.await;
matrix_mock_server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_joined_room(room_builder);
})
.instrument(bob_span)
.await;
(alice, bob, matrix_mock_server, store)
}
async fn prepare_room(
matrix_mock_server: &MatrixMockServer,
event_factory: &EventFactory,
alice: &Client,
bob: &Client,
room_id: &RoomId,
) -> (Raw<AnySyncTimelineEvent>, Raw<ToDeviceEvent<ToDeviceEncryptedEventContent>>) {
let alice_user_id = alice.user_id().unwrap();
let bob_user_id = bob.user_id().unwrap();
let alice_member_event = event_factory.member(alice_user_id).into_raw();
let bob_member_event = event_factory.member(bob_user_id).into_raw();
let room = alice
.get_room(room_id)
.expect("Alice should have access to the room now that we synced");
let event_type = "m.room.message";
let content = json!({"body": "It's a secret to everybody", "msgtype": "m.text"});
let event_id = event_id!("$some_id");
let (event_receiver, mock) =
matrix_mock_server.mock_room_send().ok_with_capture(event_id, alice_user_id);
let (_guard, room_key) = matrix_mock_server.mock_capture_put_to_device(alice_user_id).await;
{
let _guard = mock.mock_once().mount_as_scoped().await;
matrix_mock_server
.mock_get_members()
.ok(vec![alice_member_event.clone(), bob_member_event.clone()])
.mock_once()
.mount()
.await;
room.send_raw(event_type, content)
.await
.expect("We should be able to send an initial message");
};
let event = event_receiver.await.expect("Alice should have sent the event by now");
let room_key = room_key.await;
(event, room_key)
}
#[async_test]
async fn test_redecryptor() {
let room_id = room_id!("!test:localhost");
let event_factory = EventFactory::new().room(room_id);
let (alice, bob, matrix_mock_server, _) = set_up_clients(room_id, true, false).await;
let (event, room_key) =
prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
let (room_cache, _) = bob
.event_cache()
.for_room(room_id)
.await
.expect("We should be able to get to the event cache for a specific room");
let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
bob.inner
.base_client
.regenerate_olm(None)
.await
.expect("We should be able to regenerate the Olm machine");
matrix_mock_server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
})
.await;
assert_let_timeout!(
Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
);
assert_eq!(diffs.len(), 1);
assert_matches!(&diffs[0], VectorDiff::Append { values });
assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
matrix_mock_server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_to_device_event(
room_key
.deserialize_as()
.expect("We should be able to deserialize the room key"),
);
})
.await;
assert_let_timeout!(
Duration::from_secs(1),
Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
);
assert_eq!(diffs.len(), 1);
assert_matches!(&diffs[0], VectorDiff::Set { index, value });
assert_eq!(*index, 0);
assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
}
#[async_test]
async fn test_redecryptor_updating_encryption_info() {
let bob_span = tracing::info_span!("bob");
let room_id = room_id!("!test:localhost");
let event_factory = EventFactory::new().room(room_id);
let (alice, bob, matrix_mock_server, _) = set_up_clients(room_id, false, false).await;
let (event, room_key) =
prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
let (room_cache, _) = bob
.event_cache()
.for_room(room_id)
.instrument(bob_span.clone())
.await
.expect("We should be able to get to the event cache for a specific room");
let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
matrix_mock_server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
})
.instrument(bob_span.clone())
.await;
assert_let_timeout!(
Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
);
assert_eq!(diffs.len(), 1);
assert_matches!(&diffs[0], VectorDiff::Append { values });
assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
matrix_mock_server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_to_device_event(
room_key
.deserialize_as()
.expect("We should be able to deserialize the room key"),
);
})
.instrument(bob_span.clone())
.await;
assert_let_timeout!(
Duration::from_secs(1),
Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
);
assert_eq!(diffs.len(), 1);
assert_matches!(&diffs[0], VectorDiff::Set { index: 0, value });
assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
let encryption_info = value.encryption_info().unwrap();
assert_matches!(&encryption_info.verification_state, VerificationState::Unverified(_));
let session_id = encryption_info.session_id().unwrap().to_owned();
let alice_user_id = alice.user_id().unwrap();
alice
.encryption()
.bootstrap_cross_signing(None)
.await
.expect("Alice should be able to create the cross-signing keys");
bob.update_tracked_users_for_testing([alice_user_id]).instrument(bob_span.clone()).await;
matrix_mock_server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_change_device(alice_user_id);
})
.instrument(bob_span.clone())
.await;
bob.event_cache().request_decryption(DecryptionRetryRequest {
room_id: room_id.into(),
utd_session_ids: BTreeSet::new(),
refresh_info_session_ids: BTreeSet::from([session_id]),
});
assert_let_timeout!(
Duration::from_secs(1),
Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
);
assert_eq!(diffs.len(), 1);
assert_matches!(&diffs[0], VectorDiff::Set { index: 0, value });
assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
let encryption_info = value.encryption_info().unwrap();
assert_matches!(
&encryption_info.verification_state,
VerificationState::Unverified(_),
"The event should now know about the identity but still be unverified"
);
}
#[async_test]
async fn test_event_is_redecrypted_even_if_key_arrives_while_event_processing() {
let room_id = room_id!("!test:localhost");
let event_factory = EventFactory::new().room(room_id);
let (alice, bob, matrix_mock_server, delayed_store) =
set_up_clients(room_id, true, true).await;
let delayed_store = delayed_store.unwrap();
let (event, room_key) =
prepare_room(&matrix_mock_server, &event_factory, &alice, &bob, room_id).await;
let (room_cache, _) = bob
.event_cache()
.for_room(room_id)
.await
.expect("We should be able to get to the event cache for a specific room");
let (_, mut subscriber) = room_cache.subscribe().await.unwrap();
matrix_mock_server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_joined_room(JoinedRoomBuilder::new(room_id).add_timeline_event(event));
})
.await;
matrix_mock_server
.mock_sync()
.ok_and_run(&bob, |builder| {
builder.add_to_device_event(
room_key
.deserialize_as()
.expect("We should be able to deserialize the room key"),
);
})
.await;
info!("Stopping the delay");
delayed_store.stop_delaying().await;
assert_let_timeout!(
Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
);
assert_eq!(diffs.len(), 1);
assert_matches!(&diffs[0], VectorDiff::Append { values });
assert_matches!(&values[0].kind, TimelineEventKind::UnableToDecrypt { .. });
assert_let_timeout!(
Duration::from_secs(1),
Ok(RoomEventCacheUpdate::UpdateTimelineEvents { diffs, .. }) = subscriber.recv()
);
assert_eq!(diffs.len(), 1);
assert_matches!(&diffs[0], VectorDiff::Set { index, value });
assert_eq!(*index, 0);
assert_matches!(&value.kind, TimelineEventKind::Decrypted { .. });
}
}