use std::{
cmp::max,
collections::{BTreeMap, BTreeSet},
fmt,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, RwLock as StdRwLock,
},
time::Duration,
};
use ruma::{
events::{
room::{encryption::RoomEncryptionEventContent, history_visibility::HistoryVisibility},
AnyMessageLikeEventContent,
},
serde::Raw,
DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId,
SecondsSinceUnixEpoch, TransactionId, UserId,
};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use tracing::{debug, error, info};
use vodozemac::{megolm::SessionConfig, Curve25519PublicKey};
pub use vodozemac::{
megolm::{GroupSession, GroupSessionPickle, MegolmMessage, SessionKey},
olm::IdentityKeys,
PickleError,
};
use super::SessionCreationError;
#[cfg(feature = "experimental-algorithms")]
use crate::types::events::room::encrypted::MegolmV2AesSha2Content;
use crate::{
types::{
events::{
room::encrypted::{
MegolmV1AesSha2Content, RoomEncryptedEventContent, RoomEventEncryptionScheme,
},
room_key::{MegolmV1AesSha2Content as MegolmV1AesSha2RoomKeyContent, RoomKeyContent},
room_key_withheld::{RoomKeyWithheldContent, WithheldCode},
},
EventEncryptionAlgorithm,
},
ReadOnlyDevice, ToDeviceRequest,
};
const ROTATION_PERIOD: Duration = Duration::from_millis(604800000);
const ROTATION_MESSAGES: u64 = 100;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ShareState {
NotShared,
SharedButChangedSenderKey,
Shared(u32),
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct EncryptionSettings {
pub algorithm: EventEncryptionAlgorithm,
pub rotation_period: Duration,
pub rotation_period_msgs: u64,
pub history_visibility: HistoryVisibility,
#[serde(default)]
pub only_allow_trusted_devices: bool,
}
impl Default for EncryptionSettings {
fn default() -> Self {
Self {
algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
rotation_period: ROTATION_PERIOD,
rotation_period_msgs: ROTATION_MESSAGES,
history_visibility: HistoryVisibility::Shared,
only_allow_trusted_devices: false,
}
}
}
impl EncryptionSettings {
pub fn new(
content: RoomEncryptionEventContent,
history_visibility: HistoryVisibility,
only_allow_trusted_devices: bool,
) -> Self {
let rotation_period: Duration =
content.rotation_period_ms.map_or(ROTATION_PERIOD, |r| Duration::from_millis(r.into()));
let rotation_period_msgs: u64 =
content.rotation_period_msgs.map_or(ROTATION_MESSAGES, Into::into);
Self {
algorithm: EventEncryptionAlgorithm::from(content.algorithm.as_str()),
rotation_period,
rotation_period_msgs,
history_visibility,
only_allow_trusted_devices,
}
}
}
#[derive(Clone)]
pub struct OutboundGroupSession {
inner: Arc<RwLock<GroupSession>>,
device_id: OwnedDeviceId,
account_identity_keys: Arc<IdentityKeys>,
session_id: Arc<str>,
room_id: OwnedRoomId,
pub(crate) creation_time: SecondsSinceUnixEpoch,
message_count: Arc<AtomicU64>,
shared: Arc<AtomicBool>,
invalidated: Arc<AtomicBool>,
settings: Arc<EncryptionSettings>,
pub(crate) shared_with_set:
Arc<StdRwLock<BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>>>,
#[allow(clippy::type_complexity)]
to_share_with_set:
Arc<StdRwLock<BTreeMap<OwnedTransactionId, (Arc<ToDeviceRequest>, ShareInfoSet)>>>,
}
pub type ShareInfoSet = BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum ShareInfo {
Shared(SharedWith),
Withheld(WithheldCode),
}
impl ShareInfo {
pub fn new_shared(sender_key: Curve25519PublicKey, message_index: u32) -> Self {
ShareInfo::Shared(SharedWith { sender_key, message_index })
}
pub fn new_withheld(code: WithheldCode) -> Self {
ShareInfo::Withheld(code)
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SharedWith {
pub sender_key: Curve25519PublicKey,
pub message_index: u32,
}
impl OutboundGroupSession {
pub(super) fn session_config(
algorithm: &EventEncryptionAlgorithm,
) -> Result<SessionConfig, SessionCreationError> {
match algorithm {
EventEncryptionAlgorithm::MegolmV1AesSha2 => Ok(SessionConfig::version_1()),
#[cfg(feature = "experimental-algorithms")]
EventEncryptionAlgorithm::MegolmV2AesSha2 => Ok(SessionConfig::version_2()),
_ => Err(SessionCreationError::Algorithm(algorithm.to_owned())),
}
}
pub fn new(
device_id: OwnedDeviceId,
identity_keys: Arc<IdentityKeys>,
room_id: &RoomId,
settings: EncryptionSettings,
) -> Result<Self, SessionCreationError> {
let config = Self::session_config(&settings.algorithm)?;
let session = GroupSession::new(config);
let session_id = session.session_id();
Ok(OutboundGroupSession {
inner: RwLock::new(session).into(),
room_id: room_id.into(),
device_id,
account_identity_keys: identity_keys,
session_id: session_id.into(),
creation_time: SecondsSinceUnixEpoch::now(),
message_count: Arc::new(AtomicU64::new(0)),
shared: Arc::new(AtomicBool::new(false)),
invalidated: Arc::new(AtomicBool::new(false)),
settings: Arc::new(settings),
shared_with_set: Default::default(),
to_share_with_set: Default::default(),
})
}
pub fn add_request(
&self,
request_id: OwnedTransactionId,
request: Arc<ToDeviceRequest>,
share_infos: ShareInfoSet,
) {
self.to_share_with_set.write().unwrap().insert(request_id, (request, share_infos));
}
pub fn withheld_code(&self, code: WithheldCode) -> RoomKeyWithheldContent {
RoomKeyWithheldContent::new(
self.settings().algorithm.to_owned(),
code,
self.room_id().to_owned(),
self.session_id().to_owned(),
self.sender_key().to_owned(),
(*self.device_id).to_owned(),
)
}
pub fn invalidate_session(&self) {
self.invalidated.store(true, Ordering::Relaxed)
}
pub fn settings(&self) -> &EncryptionSettings {
&self.settings
}
pub fn mark_request_as_sent(
&self,
request_id: &TransactionId,
) -> BTreeMap<OwnedUserId, BTreeSet<OwnedDeviceId>> {
let mut no_olm_devices = BTreeMap::new();
let removed = self.to_share_with_set.write().unwrap().remove(request_id);
if let Some((to_device, request)) = removed {
let recipients: BTreeMap<&UserId, BTreeSet<&DeviceId>> = request
.iter()
.map(|(u, d)| (u.as_ref(), d.keys().map(|d| d.as_ref()).collect()))
.collect();
info!(
?request_id,
?recipients,
?to_device.event_type,
"Marking to-device request carrying a room key or a withheld as sent"
);
for (user_id, info) in request {
let no_olms: BTreeSet<OwnedDeviceId> = info
.iter()
.filter(|(_, info)| matches!(info, ShareInfo::Withheld(WithheldCode::NoOlm)))
.map(|(d, _)| d.to_owned())
.collect();
no_olm_devices.insert(user_id.to_owned(), no_olms);
self.shared_with_set.write().unwrap().entry(user_id).or_default().extend(info);
}
if self.to_share_with_set.read().unwrap().is_empty() {
debug!(
session_id = self.session_id(),
room_id = ?self.room_id,
"All m.room_key and withheld to-device requests were sent out, marking \
session as shared.",
);
self.mark_as_shared();
}
} else {
let request_ids: Vec<String> =
self.to_share_with_set.read().unwrap().keys().map(|k| k.to_string()).collect();
error!(
all_request_ids = ?request_ids,
request_id = ?request_id,
"Marking to-device request carrying a room key as sent but no \
request found with the given id"
);
}
no_olm_devices
}
pub(crate) async fn encrypt_helper(&self, plaintext: String) -> MegolmMessage {
let mut session = self.inner.write().await;
self.message_count.fetch_add(1, Ordering::SeqCst);
session.encrypt(&plaintext)
}
pub async fn encrypt(
&self,
event_type: &str,
content: &Raw<AnyMessageLikeEventContent>,
) -> Raw<RoomEncryptedEventContent> {
#[derive(Serialize)]
struct Payload<'a> {
#[serde(rename = "type")]
event_type: &'a str,
content: &'a Raw<AnyMessageLikeEventContent>,
room_id: &'a RoomId,
}
let payload = Payload { event_type, content, room_id: &self.room_id };
let payload_json =
serde_json::to_string(&payload).expect("payload serialization never fails");
let relates_to = content
.get_field::<serde_json::Value>("m.relates_to")
.expect("serde_json::Value deserialization with valid JSON input never fails");
let ciphertext = self.encrypt_helper(payload_json).await;
let scheme: RoomEventEncryptionScheme = match self.settings.algorithm {
EventEncryptionAlgorithm::MegolmV1AesSha2 => MegolmV1AesSha2Content {
ciphertext,
sender_key: self.account_identity_keys.curve25519,
session_id: self.session_id().to_owned(),
device_id: (*self.device_id).to_owned(),
}
.into(),
#[cfg(feature = "experimental-algorithms")]
EventEncryptionAlgorithm::MegolmV2AesSha2 => {
MegolmV2AesSha2Content { ciphertext, session_id: self.session_id().to_owned() }
.into()
}
_ => unreachable!(
"An outbound group session is always using one of the supported algorithms"
),
};
let content = RoomEncryptedEventContent { scheme, relates_to, other: Default::default() };
Raw::new(&content).expect("m.room.encrypted event content can always be serialized")
}
fn elapsed(&self) -> bool {
let creation_time = Duration::from_secs(self.creation_time.get().into());
let now = Duration::from_secs(SecondsSinceUnixEpoch::now().get().into());
now.checked_sub(creation_time)
.map(|elapsed| elapsed >= max(self.settings.rotation_period, Duration::from_secs(3600)))
.unwrap_or(true)
}
pub fn expired(&self) -> bool {
let count = self.message_count.load(Ordering::SeqCst);
let rotation_period_msgs = self.settings.rotation_period_msgs.clamp(1, 10_000);
count >= rotation_period_msgs || self.elapsed()
}
pub fn invalidated(&self) -> bool {
self.invalidated.load(Ordering::Relaxed)
}
pub fn mark_as_shared(&self) {
self.shared.store(true, Ordering::Relaxed);
}
pub fn shared(&self) -> bool {
self.shared.load(Ordering::Relaxed)
}
pub async fn session_key(&self) -> SessionKey {
let session = self.inner.read().await;
session.session_key()
}
pub fn sender_key(&self) -> Curve25519PublicKey {
self.account_identity_keys.as_ref().curve25519.to_owned()
}
pub fn room_id(&self) -> &RoomId {
&self.room_id
}
pub fn session_id(&self) -> &str {
&self.session_id
}
pub async fn message_index(&self) -> u32 {
let session = self.inner.read().await;
session.message_index()
}
pub(crate) async fn as_content(&self) -> RoomKeyContent {
let session_key = self.session_key().await;
RoomKeyContent::MegolmV1AesSha2(
MegolmV1AesSha2RoomKeyContent::new(
self.room_id().to_owned(),
self.session_id().to_owned(),
session_key,
)
.into(),
)
}
pub(crate) fn is_shared_with(&self, device: &ReadOnlyDevice) -> ShareState {
let shared_state =
self.shared_with_set.read().unwrap().get(device.user_id()).and_then(|d| {
d.get(device.device_id()).map(|s| match s {
ShareInfo::Shared(s) => {
if device.curve25519_key() == Some(s.sender_key) {
ShareState::Shared(s.message_index)
} else {
ShareState::SharedButChangedSenderKey
}
}
ShareInfo::Withheld(_) => ShareState::NotShared,
})
});
if let Some(state) = shared_state {
state
} else {
let shared =
self.to_share_with_set.read().unwrap().values().find_map(|(_, share_info)| {
let d = share_info.get(device.user_id())?;
let info = d.get(device.device_id())?;
Some(match info {
ShareInfo::Shared(info) => {
if device.curve25519_key() == Some(info.sender_key) {
ShareState::Shared(info.message_index)
} else {
ShareState::SharedButChangedSenderKey
}
}
ShareInfo::Withheld(_) => ShareState::NotShared,
})
});
shared.unwrap_or(ShareState::NotShared)
}
}
pub(crate) fn is_withheld_to(&self, device: &ReadOnlyDevice, code: &WithheldCode) -> bool {
self.shared_with_set
.read()
.unwrap()
.get(device.user_id())
.and_then(|d| {
let info = d.get(device.device_id())?;
Some(matches!(info, ShareInfo::Withheld(c) if c == code))
})
.unwrap_or_else(|| {
self.to_share_with_set.read().unwrap().values().any(|(_, share_info)| {
share_info
.get(device.user_id())
.and_then(|d| d.get(device.device_id()))
.is_some_and(|info| matches!(info, ShareInfo::Withheld(c) if c == code))
})
})
}
#[cfg(test)]
pub fn mark_shared_with_from_index(
&self,
user_id: &UserId,
device_id: &DeviceId,
sender_key: Curve25519PublicKey,
index: u32,
) {
self.shared_with_set
.write()
.unwrap()
.entry(user_id.to_owned())
.or_default()
.insert(device_id.to_owned(), ShareInfo::new_shared(sender_key, index));
}
#[cfg(test)]
pub async fn mark_shared_with(
&self,
user_id: &UserId,
device_id: &DeviceId,
sender_key: Curve25519PublicKey,
) {
let share_info = ShareInfo::new_shared(sender_key, self.message_index().await);
self.shared_with_set
.write()
.unwrap()
.entry(user_id.to_owned())
.or_default()
.insert(device_id.to_owned(), share_info);
}
pub(crate) fn pending_requests(&self) -> Vec<Arc<ToDeviceRequest>> {
self.to_share_with_set.read().unwrap().values().map(|(req, _)| req.clone()).collect()
}
pub(crate) fn pending_request_ids(&self) -> Vec<OwnedTransactionId> {
self.to_share_with_set.read().unwrap().keys().cloned().collect()
}
pub fn from_pickle(
device_id: OwnedDeviceId,
identity_keys: Arc<IdentityKeys>,
pickle: PickledOutboundGroupSession,
) -> Result<Self, PickleError> {
let inner: GroupSession = pickle.pickle.into();
let session_id = inner.session_id();
Ok(Self {
inner: Arc::new(RwLock::new(inner)),
device_id,
account_identity_keys: identity_keys,
session_id: session_id.into(),
room_id: pickle.room_id,
creation_time: pickle.creation_time,
message_count: AtomicU64::from(pickle.message_count).into(),
shared: AtomicBool::from(pickle.shared).into(),
invalidated: AtomicBool::from(pickle.invalidated).into(),
settings: pickle.settings,
shared_with_set: Arc::new(StdRwLock::new(pickle.shared_with_set)),
to_share_with_set: Arc::new(StdRwLock::new(pickle.requests)),
})
}
pub async fn pickle(&self) -> PickledOutboundGroupSession {
let pickle = self.inner.read().await.pickle();
PickledOutboundGroupSession {
pickle,
room_id: self.room_id.clone(),
settings: self.settings.clone(),
creation_time: self.creation_time,
message_count: self.message_count.load(Ordering::SeqCst),
shared: self.shared(),
invalidated: self.invalidated(),
shared_with_set: self.shared_with_set.read().unwrap().clone(),
requests: self.to_share_with_set.read().unwrap().clone(),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct OutboundGroupSessionPickle(String);
impl From<String> for OutboundGroupSessionPickle {
fn from(p: String) -> Self {
Self(p)
}
}
#[cfg(not(tarpaulin_include))]
impl std::fmt::Debug for OutboundGroupSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OutboundGroupSession")
.field("session_id", &self.session_id)
.field("room_id", &self.room_id)
.field("creation_time", &self.creation_time)
.field("message_count", &self.message_count)
.finish()
}
}
#[derive(Deserialize, Serialize)]
#[allow(missing_debug_implementations)]
pub struct PickledOutboundGroupSession {
pub pickle: GroupSessionPickle,
pub settings: Arc<EncryptionSettings>,
pub room_id: OwnedRoomId,
pub creation_time: SecondsSinceUnixEpoch,
pub message_count: u64,
pub shared: bool,
pub invalidated: bool,
pub shared_with_set: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>,
pub requests: BTreeMap<OwnedTransactionId, (Arc<ToDeviceRequest>, ShareInfoSet)>,
}
#[cfg(test)]
mod tests {
use std::{sync::atomic::Ordering, time::Duration};
use matrix_sdk_test::async_test;
use ruma::{
device_id,
events::room::{
encryption::RoomEncryptionEventContent, history_visibility::HistoryVisibility,
message::RoomMessageEventContent,
},
room_id, uint, user_id, EventEncryptionAlgorithm,
};
use super::{EncryptionSettings, ROTATION_MESSAGES, ROTATION_PERIOD};
use crate::{Account, MegolmError};
#[test]
fn test_encryption_settings_conversion() {
let mut content =
RoomEncryptionEventContent::new(EventEncryptionAlgorithm::MegolmV1AesSha2);
let settings = EncryptionSettings::new(content.clone(), HistoryVisibility::Joined, false);
assert_eq!(settings.rotation_period, ROTATION_PERIOD);
assert_eq!(settings.rotation_period_msgs, ROTATION_MESSAGES);
content.rotation_period_ms = Some(uint!(3600));
content.rotation_period_msgs = Some(uint!(500));
let settings = EncryptionSettings::new(content, HistoryVisibility::Shared, false);
assert_eq!(settings.rotation_period, Duration::from_millis(3600));
assert_eq!(settings.rotation_period_msgs, 500);
}
#[async_test]
#[cfg(any(target_os = "linux", target_os = "macos", target_arch = "wasm32"))]
async fn test_expiration() -> Result<(), MegolmError> {
use ruma::{serde::Raw, SecondsSinceUnixEpoch};
let settings = EncryptionSettings { rotation_period_msgs: 1, ..Default::default() };
let account =
Account::with_device_id(user_id!("@alice:example.org"), device_id!("DEVICEID"))
.static_data;
let (session, _) = account
.create_group_session_pair(room_id!("!test_room:example.org"), settings)
.await
.unwrap();
assert!(!session.expired());
let _ = session
.encrypt(
"m.room.message",
&Raw::new(&RoomMessageEventContent::text_plain("Test message"))?.cast(),
)
.await;
assert!(session.expired());
let settings = EncryptionSettings {
rotation_period: Duration::from_millis(100),
..Default::default()
};
let (mut session, _) = account
.create_group_session_pair(room_id!("!test_room:example.org"), settings)
.await
.unwrap();
assert!(!session.expired());
let now = SecondsSinceUnixEpoch::now();
session.creation_time = SecondsSinceUnixEpoch(now.get() - uint!(3600));
assert!(session.expired());
let settings = EncryptionSettings { rotation_period_msgs: 0, ..Default::default() };
let (session, _) = account
.create_group_session_pair(room_id!("!test_room:example.org"), settings)
.await
.unwrap();
assert!(!session.expired());
let _ = session
.encrypt(
"m.room.message",
&Raw::new(&RoomMessageEventContent::text_plain("Test message"))?.cast(),
)
.await;
assert!(session.expired());
let settings = EncryptionSettings { rotation_period_msgs: 100_000, ..Default::default() };
let (session, _) = account
.create_group_session_pair(room_id!("!test_room:example.org"), settings)
.await
.unwrap();
assert!(!session.expired());
session.message_count.store(1000, Ordering::SeqCst);
assert!(!session.expired());
session.message_count.store(9999, Ordering::SeqCst);
assert!(!session.expired());
session.message_count.store(10_000, Ordering::SeqCst);
assert!(session.expired());
Ok(())
}
}