use std::{fmt, sync::Arc};
use ruma::{serde::Raw, SecondsSinceUnixEpoch};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::Mutex;
use tracing::{debug, Span};
use vodozemac::{
olm::{DecryptionError, OlmMessage, Session as InnerSession, SessionConfig, SessionPickle},
Curve25519PublicKey,
};
#[cfg(feature = "experimental-algorithms")]
use crate::types::events::room::encrypted::OlmV2Curve25519AesSha2Content;
use crate::{
error::{EventError, OlmResult, SessionUnpickleError},
types::{
events::{
olm_v1::{DecryptedOlmV1Event, OlmV1Keys},
room::encrypted::{OlmV1Curve25519AesSha2Content, ToDeviceEncryptedEventContent},
EventType,
},
DeviceKeys, EventEncryptionAlgorithm,
},
DeviceData,
};
#[derive(Clone)]
pub struct Session {
pub inner: Arc<Mutex<InnerSession>>,
pub session_id: Arc<str>,
pub sender_key: Curve25519PublicKey,
pub our_device_keys: DeviceKeys,
pub created_using_fallback_key: bool,
pub creation_time: SecondsSinceUnixEpoch,
pub last_use_time: SecondsSinceUnixEpoch,
}
#[cfg(not(tarpaulin_include))]
impl fmt::Debug for Session {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Session")
.field("session_id", &self.session_id())
.field("sender_key", &self.sender_key)
.finish()
}
}
impl Session {
pub async fn decrypt(&mut self, message: &OlmMessage) -> Result<String, DecryptionError> {
let mut inner = self.inner.lock().await;
Span::current().record("session_id", inner.session_id());
let plaintext = inner.decrypt(message)?;
debug!(session=?inner, "Decrypted an Olm message");
let plaintext = String::from_utf8_lossy(&plaintext).to_string();
self.last_use_time = SecondsSinceUnixEpoch::now();
Ok(plaintext)
}
pub fn sender_key(&self) -> Curve25519PublicKey {
self.sender_key
}
pub async fn session_config(&self) -> SessionConfig {
self.inner.lock().await.session_config()
}
#[allow(clippy::unused_async)] pub async fn algorithm(&self) -> EventEncryptionAlgorithm {
#[cfg(feature = "experimental-algorithms")]
if self.session_config().await.version() == 2 {
EventEncryptionAlgorithm::OlmV2Curve25519AesSha2
} else {
EventEncryptionAlgorithm::OlmV1Curve25519AesSha2
}
#[cfg(not(feature = "experimental-algorithms"))]
EventEncryptionAlgorithm::OlmV1Curve25519AesSha2
}
pub(crate) async fn encrypt_helper(&mut self, plaintext: &str) -> OlmMessage {
let mut session = self.inner.lock().await;
let message = session.encrypt(plaintext);
self.last_use_time = SecondsSinceUnixEpoch::now();
debug!(?session, "Successfully encrypted an event");
message
}
pub async fn encrypt(
&mut self,
recipient_device: &DeviceData,
event_type: &str,
content: impl Serialize,
message_id: Option<String>,
) -> OlmResult<Raw<ToDeviceEncryptedEventContent>> {
#[derive(Debug)]
struct Content<'a> {
event_type: &'a str,
content: Raw<Value>,
}
impl EventType for Content<'_> {
const EVENT_TYPE: &'static str = "";
fn event_type(&self) -> &str {
self.event_type
}
}
impl Serialize for Content<'_> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.content.serialize(serializer)
}
}
let plaintext = {
let content = serde_json::to_value(content)?;
let content = Content { event_type, content: Raw::new(&content)? };
let recipient_signing_key =
recipient_device.ed25519_key().ok_or(EventError::MissingSigningKey)?;
let content = DecryptedOlmV1Event {
sender: self.our_device_keys.user_id.clone(),
recipient: recipient_device.user_id().into(),
keys: OlmV1Keys {
ed25519: self
.our_device_keys
.ed25519_key()
.expect("Our own device should have an Ed25519 public key"),
},
recipient_keys: OlmV1Keys { ed25519: recipient_signing_key },
sender_device_keys: Some(self.our_device_keys.clone()),
content,
};
serde_json::to_string(&content)?
};
let ciphertext = self.encrypt_helper(&plaintext).await;
let content = self.build_encrypted_event(ciphertext, message_id).await?;
let content = Raw::new(&content)?;
Ok(content)
}
pub(crate) async fn build_encrypted_event(
&self,
ciphertext: OlmMessage,
message_id: Option<String>,
) -> OlmResult<ToDeviceEncryptedEventContent> {
let content = match self.algorithm().await {
EventEncryptionAlgorithm::OlmV1Curve25519AesSha2 => OlmV1Curve25519AesSha2Content {
ciphertext,
recipient_key: self.sender_key,
sender_key: self
.our_device_keys
.curve25519_key()
.expect("Device doesn't have curve25519 key"),
message_id,
}
.into(),
#[cfg(feature = "experimental-algorithms")]
EventEncryptionAlgorithm::OlmV2Curve25519AesSha2 => OlmV2Curve25519AesSha2Content {
ciphertext,
sender_key: self
.our_device_keys
.curve25519_key()
.expect("Device doesn't have curve25519 key"),
message_id,
}
.into(),
_ => unreachable!(),
};
Ok(content)
}
pub fn session_id(&self) -> &str {
&self.session_id
}
pub async fn pickle(&self) -> PickledSession {
let pickle = self.inner.lock().await.pickle();
PickledSession {
pickle,
sender_key: self.sender_key,
created_using_fallback_key: self.created_using_fallback_key,
creation_time: self.creation_time,
last_use_time: self.last_use_time,
}
}
pub fn from_pickle(
our_device_keys: DeviceKeys,
pickle: PickledSession,
) -> Result<Self, SessionUnpickleError> {
if our_device_keys.curve25519_key().is_none() {
return Err(SessionUnpickleError::MissingIdentityKey);
}
if our_device_keys.ed25519_key().is_none() {
return Err(SessionUnpickleError::MissingSigningKey);
}
let session: vodozemac::olm::Session = pickle.pickle.into();
let session_id = session.session_id();
Ok(Session {
inner: Arc::new(Mutex::new(session)),
session_id: session_id.into(),
created_using_fallback_key: pickle.created_using_fallback_key,
sender_key: pickle.sender_key,
our_device_keys,
creation_time: pickle.creation_time,
last_use_time: pickle.last_use_time,
})
}
}
impl PartialEq for Session {
fn eq(&self, other: &Self) -> bool {
self.session_id() == other.session_id()
}
}
#[derive(Serialize, Deserialize)]
#[allow(missing_debug_implementations)]
pub struct PickledSession {
pub pickle: SessionPickle,
pub sender_key: Curve25519PublicKey,
#[serde(default)]
pub created_using_fallback_key: bool,
pub creation_time: SecondsSinceUnixEpoch,
pub last_use_time: SecondsSinceUnixEpoch,
}
#[cfg(test)]
mod tests {
use assert_matches2::assert_let;
use matrix_sdk_test::async_test;
use ruma::{device_id, user_id};
use serde_json::{self, Value};
use vodozemac::olm::{OlmMessage, SessionConfig};
use crate::{
identities::DeviceData,
olm::Account,
types::events::{
dummy::DummyEventContent, olm_v1::DecryptedOlmV1Event,
room::encrypted::ToDeviceEncryptedEventContent,
},
};
#[async_test]
async fn test_encryption_and_decryption() {
use ruma::events::dummy::ToDeviceDummyEventContent;
let alice =
Account::with_device_id(user_id!("@alice:localhost"), device_id!("ALICEDEVICE"));
let mut bob = Account::with_device_id(user_id!("@bob:localhost"), device_id!("BOBDEVICE"));
bob.generate_one_time_keys(1);
let one_time_key = *bob.one_time_keys().values().next().unwrap();
let sender_key = bob.identity_keys().curve25519;
let mut alice_session = alice.create_outbound_session_helper(
SessionConfig::default(),
sender_key,
one_time_key,
false,
alice.device_keys(),
);
let alice_device = DeviceData::from_account(&alice);
let message = alice_session
.encrypt(&alice_device, "m.dummy", ToDeviceDummyEventContent::new(), None)
.await
.unwrap()
.deserialize()
.unwrap();
#[cfg(feature = "experimental-algorithms")]
assert_let!(ToDeviceEncryptedEventContent::OlmV2Curve25519AesSha2(content) = message);
#[cfg(not(feature = "experimental-algorithms"))]
assert_let!(ToDeviceEncryptedEventContent::OlmV1Curve25519AesSha2(content) = message);
let OlmMessage::PreKey(prekey) = content.ciphertext else {
panic!("Wrong Olm message type");
};
let bob_session_result = bob
.create_inbound_session(
alice_device.curve25519_key().unwrap(),
bob.device_keys(),
&prekey,
)
.unwrap();
let plaintext: Value = serde_json::from_str(&bob_session_result.plaintext).unwrap();
assert_eq!(plaintext["sender_device_keys"]["user_id"].as_str(), Some("@alice:localhost"));
let event: DecryptedOlmV1Event<DummyEventContent> =
serde_json::from_str(&bob_session_result.plaintext).unwrap();
assert_eq!(event.sender_device_keys.unwrap(), alice.device_keys());
}
}