use std::{collections::BTreeSet, sync::Arc};
use futures_core::Stream;
use futures_util::pin_mut;
use imbl::Vector;
use itertools::{Either, Itertools as _};
use matrix_sdk::{
Room,
encryption::backups::BackupState,
event_cache::RedecryptorReport,
executor::{JoinHandle, spawn},
};
use tokio_stream::{StreamExt as _, wrappers::errors::BroadcastStreamRecvError};
use crate::timeline::{TimelineController, TimelineItem};
#[derive(Debug)]
pub(in crate::timeline) struct CryptoDropHandles {
redecryption_report_join_handle: JoinHandle<()>,
room_key_backup_enabled_join_handle: JoinHandle<()>,
encryption_changes_handle: JoinHandle<()>,
}
impl Drop for CryptoDropHandles {
fn drop(&mut self) {
self.redecryption_report_join_handle.abort();
self.room_key_backup_enabled_join_handle.abort();
self.encryption_changes_handle.abort();
}
}
pub(super) fn compute_redecryption_candidates(
timeline_items: &Vector<Arc<TimelineItem>>,
) -> (BTreeSet<String>, BTreeSet<String>) {
timeline_items
.iter()
.filter_map(|event| {
event.as_event().and_then(|e| {
let session_id = e.encryption_info().and_then(|info| info.session_id());
let session_id = if let Some(session_id) = session_id {
Some(session_id)
} else {
event.as_event().and_then(|e| {
e.content.as_unable_to_decrypt().and_then(|utd| utd.session_id())
})
};
session_id.map(|id| id.to_owned()).zip(Some(e))
})
})
.partition_map(|(session_id, event)| {
if event.content.is_unable_to_decrypt() {
Either::Left(session_id)
} else {
Either::Right(session_id)
}
})
}
async fn redecryption_report_task(timeline_controller: TimelineController) {
let client = timeline_controller.room().client();
let stream = client.event_cache().subscribe_to_decryption_reports();
pin_mut!(stream);
while let Some(report) = stream.next().await {
match report {
Ok(RedecryptorReport::ResolvedUtds { events, .. }) => {
let state = timeline_controller.state.read().await;
if let Some(utd_hook) = &state.meta.unable_to_decrypt_hook {
for event_id in events {
utd_hook.on_late_decrypt(&event_id).await;
}
}
}
Ok(RedecryptorReport::Lagging) | Err(_) => {
timeline_controller.retry_event_decryption(None).await;
}
}
}
}
async fn backup_states_task<S>(backup_states_stream: S, timeline_controller: TimelineController)
where
S: Stream<Item = Result<BackupState, BroadcastStreamRecvError>>,
{
pin_mut!(backup_states_stream);
while let Some(update) = backup_states_stream.next().await {
match update {
Ok(BackupState::Enabled) | Err(_) => {
timeline_controller.retry_event_decryption(None).await;
}
Ok(
BackupState::Unknown
| BackupState::Creating
| BackupState::Resuming
| BackupState::Disabling
| BackupState::Downloading
| BackupState::Enabling,
) => (),
}
}
}
pub(in crate::timeline) async fn spawn_crypto_tasks(
room: Room,
controller: TimelineController,
) -> CryptoDropHandles {
let client = room.client();
let room_key_backup_enabled_join_handle =
spawn(backup_states_task(client.encryption().backups().state_stream(), controller.clone()));
let redecryption_report_join_handle = spawn(redecryption_report_task(controller.clone()));
CryptoDropHandles {
redecryption_report_join_handle,
room_key_backup_enabled_join_handle,
encryption_changes_handle: spawn(async move {
controller.handle_encryption_state_changes().await
}),
}
}
#[cfg(test)]
mod tests {
use std::{collections::BTreeMap, sync::Arc, time::SystemTime};
use imbl::vector;
use matrix_sdk::deserialized_responses::{AlgorithmInfo, EncryptionInfo, VerificationState};
use matrix_sdk_base::crypto::types::events::UtdCause;
use ruma::{
MilliSecondsSinceUnixEpoch, OwnedTransactionId,
events::room::{
encrypted::{
EncryptedEventScheme, MegolmV1AesSha2Content, MegolmV1AesSha2ContentInit,
RoomEncryptedEventContent,
},
message::RoomMessageEventContent,
},
owned_device_id, owned_event_id, owned_user_id,
};
use crate::timeline::{
EncryptedMessage, EventSendState, EventTimelineItem, MsgLikeContent,
ReactionsByKeyBySender, TimelineDetails, TimelineItem, TimelineItemContent,
TimelineItemKind, TimelineUniqueId, VirtualTimelineItem,
controller::decryption_retry_task::compute_redecryption_candidates,
event_item::{
EventTimelineItemKind, LocalEventTimelineItem, RemoteEventOrigin,
RemoteEventTimelineItem,
},
};
#[test]
fn test_non_events_are_not_retried() {
let timeline = vector![TimelineItem::read_marker(), date_divider()];
let answer = compute_redecryption_candidates(&timeline);
assert!(answer.0.is_empty());
assert!(answer.1.is_empty());
}
#[test]
fn test_non_remote_events_are_not_retried() {
let timeline = vector![local_event()];
let answer = compute_redecryption_candidates(&timeline);
assert!(answer.0.is_empty());
assert!(answer.1.is_empty());
}
#[test]
fn test_utds_are_retried() {
let timeline = vector![utd_event("session1")];
let answer = compute_redecryption_candidates(&timeline);
assert_eq!(answer.0.first().map(|s| s.as_str()), Some("session1"));
assert!(answer.1.is_empty());
}
#[test]
fn test_remote_decrypted_info_is_refetched() {
let timeline = vector![decrypted_event("session1")];
let answer = compute_redecryption_candidates(&timeline);
assert!(answer.0.is_empty());
assert_eq!(answer.1.first().map(|s| s.as_str()), Some("session1"));
}
#[test]
fn test_only_required_sessions_are_retried() {
let timeline = vector![
TimelineItem::read_marker(),
utd_event("session1"),
utd_event("session1"),
date_divider(),
utd_event("session2"),
decrypted_event("session1"),
decrypted_event("session1"),
decrypted_event("session2"),
local_event(),
];
let answer = compute_redecryption_candidates(&timeline);
assert!(answer.0.contains("session1"));
assert!(answer.0.contains("session2"));
assert!(answer.1.contains("session1"));
assert!(answer.1.contains("session2"));
}
fn date_divider() -> Arc<TimelineItem> {
TimelineItem::new(
TimelineItemKind::Virtual(VirtualTimelineItem::DateDivider(timestamp())),
TimelineUniqueId("datething".to_owned()),
)
}
fn local_event() -> Arc<TimelineItem> {
let event_kind = EventTimelineItemKind::Local(LocalEventTimelineItem {
send_state: EventSendState::NotSentYet { progress: None },
transaction_id: OwnedTransactionId::from("trans"),
send_handle: None,
});
TimelineItem::new(
TimelineItemKind::Event(EventTimelineItem::new(
owned_user_id!("@u:s.to"),
TimelineDetails::Pending,
timestamp(),
TimelineItemContent::MsgLike(MsgLikeContent::redacted()),
event_kind,
true,
)),
TimelineUniqueId("local".to_owned()),
)
}
fn utd_event(session_id: &str) -> Arc<TimelineItem> {
let event_kind = EventTimelineItemKind::Remote(RemoteEventTimelineItem {
event_id: owned_event_id!("$local"),
transaction_id: None,
read_receipts: Default::default(),
is_own: false,
is_highlighted: false,
encryption_info: None,
original_json: None,
latest_edit_json: None,
origin: RemoteEventOrigin::Sync,
});
TimelineItem::new(
TimelineItemKind::Event(EventTimelineItem::new(
owned_user_id!("@u:s.to"),
TimelineDetails::Pending,
timestamp(),
TimelineItemContent::MsgLike(MsgLikeContent::unable_to_decrypt(
EncryptedMessage::from_content(
RoomEncryptedEventContent::new(
EncryptedEventScheme::MegolmV1AesSha2(MegolmV1AesSha2Content::from(
MegolmV1AesSha2ContentInit {
ciphertext: "cyf".to_owned(),
sender_key: "sendk".to_owned(),
device_id: owned_device_id!("DEV"),
session_id: session_id.to_owned(),
},
)),
None,
),
UtdCause::Unknown,
),
)),
event_kind,
true,
)),
TimelineUniqueId("local".to_owned()),
)
}
fn decrypted_event(session_id: &str) -> Arc<TimelineItem> {
let event_kind = EventTimelineItemKind::Remote(RemoteEventTimelineItem {
event_id: owned_event_id!("$local"),
transaction_id: None,
read_receipts: Default::default(),
is_own: false,
is_highlighted: false,
encryption_info: Some(Arc::new(EncryptionInfo {
sender: owned_user_id!("@u:s.co"),
sender_device: None,
algorithm_info: AlgorithmInfo::MegolmV1AesSha2 {
curve25519_key: "".to_owned(),
sender_claimed_keys: BTreeMap::new(),
session_id: Some(session_id.to_owned()),
},
verification_state: VerificationState::Verified,
})),
original_json: None,
latest_edit_json: None,
origin: RemoteEventOrigin::Sync,
});
let content = RoomMessageEventContent::text_plain("hi");
TimelineItem::new(
TimelineItemKind::Event(EventTimelineItem::new(
owned_user_id!("@u:s.to"),
TimelineDetails::Pending,
timestamp(),
TimelineItemContent::message(
content.msgtype,
content.mentions,
ReactionsByKeyBySender::default(),
None,
None,
None,
),
event_kind,
true,
)),
TimelineUniqueId("local".to_owned()),
)
}
fn timestamp() -> MilliSecondsSinceUnixEpoch {
MilliSecondsSinceUnixEpoch::from_system_time(SystemTime::UNIX_EPOCH).unwrap()
}
}