use std::borrow::Cow;
use as_variant::as_variant;
use ruma_common::{
EventEncryptionAlgorithm, OwnedRoomId,
serde::{Base64, JsonObject, from_raw_json_value},
};
use ruma_macros::{EventContent, StringEnum};
use serde::{Deserialize, Serialize, de};
use serde_json::{Value as JsonValue, value::RawValue as RawJsonValue};
use crate::PrivOwnedStr;
#[derive(Clone, Debug, Serialize, EventContent)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
#[ruma_event(type = "m.room_key.withheld", kind = ToDevice)]
pub struct ToDeviceRoomKeyWithheldEventContent {
pub algorithm: EventEncryptionAlgorithm,
#[serde(flatten)]
pub code: RoomKeyWithheldCodeInfo,
#[serde(skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
pub sender_key: Base64,
}
impl ToDeviceRoomKeyWithheldEventContent {
pub fn new(
algorithm: EventEncryptionAlgorithm,
code: RoomKeyWithheldCodeInfo,
sender_key: Base64,
) -> Self {
Self { algorithm, code, reason: None, sender_key }
}
}
impl<'de> Deserialize<'de> for ToDeviceRoomKeyWithheldEventContent {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: de::Deserializer<'de>,
{
#[derive(Deserialize)]
struct ToDeviceRoomKeyWithheldEventContentDeHelper {
algorithm: EventEncryptionAlgorithm,
reason: Option<String>,
sender_key: Base64,
}
let json = Box::<RawJsonValue>::deserialize(deserializer)?;
let ToDeviceRoomKeyWithheldEventContentDeHelper { algorithm, reason, sender_key } =
from_raw_json_value(&json)?;
let code = from_raw_json_value(&json)?;
Ok(Self { algorithm, code, reason, sender_key })
}
}
#[derive(Debug, Clone, Serialize)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
#[serde(tag = "code")]
pub enum RoomKeyWithheldCodeInfo {
#[serde(rename = "m.blacklisted")]
Blacklisted(Box<RoomKeyWithheldSessionData>),
#[serde(rename = "m.unverified")]
Unverified(Box<RoomKeyWithheldSessionData>),
#[serde(rename = "m.unauthorised")]
Unauthorized(Box<RoomKeyWithheldSessionData>),
#[serde(rename = "m.unavailable")]
Unavailable(Box<RoomKeyWithheldSessionData>),
#[serde(rename = "m.no_olm")]
NoOlm,
#[doc(hidden)]
#[serde(untagged)]
_Custom(Box<CustomRoomKeyWithheldCodeInfo>),
}
impl RoomKeyWithheldCodeInfo {
pub fn code(&self) -> RoomKeyWithheldCode {
match self {
Self::Blacklisted(_) => RoomKeyWithheldCode::Blacklisted,
Self::Unverified(_) => RoomKeyWithheldCode::Unverified,
Self::Unauthorized(_) => RoomKeyWithheldCode::Unauthorized,
Self::Unavailable(_) => RoomKeyWithheldCode::Unavailable,
Self::NoOlm => RoomKeyWithheldCode::NoOlm,
Self::_Custom(info) => info.code.as_str().into(),
}
}
}
impl<'de> Deserialize<'de> for RoomKeyWithheldCodeInfo {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: de::Deserializer<'de>,
{
#[derive(Debug, Deserialize)]
struct ExtractCode<'a> {
#[serde(borrow)]
code: Cow<'a, str>,
}
let json = Box::<RawJsonValue>::deserialize(deserializer)?;
let ExtractCode { code } = from_raw_json_value(&json)?;
Ok(match code.as_ref() {
"m.blacklisted" => Self::Blacklisted(from_raw_json_value(&json)?),
"m.unverified" => Self::Unverified(from_raw_json_value(&json)?),
"m.unauthorised" => Self::Unauthorized(from_raw_json_value(&json)?),
"m.unavailable" => Self::Unavailable(from_raw_json_value(&json)?),
"m.no_olm" => Self::NoOlm,
_ => {
let mut data = from_raw_json_value::<JsonObject, _>(&json)?;
data.remove("algorithm");
data.remove("sender_key");
data.remove("reason");
let code = as_variant!(
data.remove("code").expect("we already checked that the code field is present"),
JsonValue::String
)
.expect("we already checked that the code is a string");
Self::_Custom(CustomRoomKeyWithheldCodeInfo { code, data }.into())
}
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
pub struct RoomKeyWithheldSessionData {
pub room_id: OwnedRoomId,
pub session_id: String,
}
impl RoomKeyWithheldSessionData {
pub fn new(room_id: OwnedRoomId, session_id: String) -> Self {
Self { room_id, session_id }
}
}
#[doc(hidden)]
#[derive(Clone, Debug, Serialize)]
pub struct CustomRoomKeyWithheldCodeInfo {
code: String,
#[serde(flatten)]
data: JsonObject,
}
#[doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/doc/string_enum.md"))]
#[derive(Clone, StringEnum)]
#[ruma_enum(rename_all(prefix = "m.", rule = "snake_case"))]
#[non_exhaustive]
pub enum RoomKeyWithheldCode {
Blacklisted,
Unverified,
Unauthorized,
Unavailable,
NoOlm,
#[doc(hidden)]
_Custom(PrivOwnedStr),
}
#[cfg(test)]
mod tests {
use assert_matches2::assert_matches;
use ruma_common::{
EventEncryptionAlgorithm, canonical_json::assert_to_canonical_json_eq, owned_room_id,
serde::Base64,
};
use serde_json::{from_value as from_json_value, json};
use super::{
RoomKeyWithheldCodeInfo, RoomKeyWithheldSessionData, ToDeviceRoomKeyWithheldEventContent,
};
const PUBLIC_KEY: &[u8] = b"key";
const BASE64_ENCODED_PUBLIC_KEY: &str = "a2V5";
#[test]
fn serialization_no_olm() {
let content = ToDeviceRoomKeyWithheldEventContent::new(
EventEncryptionAlgorithm::MegolmV1AesSha2,
RoomKeyWithheldCodeInfo::NoOlm,
Base64::new(PUBLIC_KEY.to_owned()),
);
assert_to_canonical_json_eq!(
content,
json!({
"algorithm": "m.megolm.v1.aes-sha2",
"code": "m.no_olm",
"sender_key": BASE64_ENCODED_PUBLIC_KEY,
})
);
}
#[test]
fn serialization_blacklisted() {
let room_id = owned_room_id!("!roomid:localhost");
let content = ToDeviceRoomKeyWithheldEventContent::new(
EventEncryptionAlgorithm::MegolmV1AesSha2,
RoomKeyWithheldCodeInfo::Blacklisted(
RoomKeyWithheldSessionData::new(room_id.clone(), "unique_id".to_owned()).into(),
),
Base64::new(PUBLIC_KEY.to_owned()),
);
assert_to_canonical_json_eq!(
content,
json!({
"algorithm": "m.megolm.v1.aes-sha2",
"code": "m.blacklisted",
"sender_key": BASE64_ENCODED_PUBLIC_KEY,
"room_id": room_id,
"session_id": "unique_id",
})
);
}
#[test]
fn deserialization_no_olm() {
let json = json!({
"algorithm": "m.megolm.v1.aes-sha2",
"code": "m.no_olm",
"sender_key": BASE64_ENCODED_PUBLIC_KEY,
"reason": "Could not find an olm session",
});
let content = from_json_value::<ToDeviceRoomKeyWithheldEventContent>(json).unwrap();
assert_eq!(content.algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
assert_eq!(content.sender_key, Base64::new(PUBLIC_KEY.to_owned()));
assert_eq!(content.reason.as_deref(), Some("Could not find an olm session"));
assert_matches!(content.code, RoomKeyWithheldCodeInfo::NoOlm);
}
#[test]
fn deserialization_blacklisted() {
let room_id = owned_room_id!("!roomid:localhost");
let json = json!({
"algorithm": "m.megolm.v1.aes-sha2",
"code": "m.blacklisted",
"sender_key": BASE64_ENCODED_PUBLIC_KEY,
"room_id": room_id,
"session_id": "unique_id",
});
let content = from_json_value::<ToDeviceRoomKeyWithheldEventContent>(json).unwrap();
assert_eq!(content.algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
assert_eq!(content.sender_key, Base64::new(PUBLIC_KEY.to_owned()));
assert_eq!(content.reason, None);
assert_matches!(content.code, RoomKeyWithheldCodeInfo::Blacklisted(session_data));
assert_eq!(session_data.room_id, room_id);
assert_eq!(session_data.session_id, "unique_id");
}
#[test]
fn custom_room_key_withheld_code_info_round_trip() {
let room_id = owned_room_id!("!roomid:localhost");
let json = json!({
"algorithm": "m.megolm.v1.aes-sha2",
"code": "dev.ruma.custom_code",
"sender_key": BASE64_ENCODED_PUBLIC_KEY,
"room_id": room_id,
"key": "value",
});
let content = from_json_value::<ToDeviceRoomKeyWithheldEventContent>(json.clone()).unwrap();
assert_eq!(content.code.code().as_str(), "dev.ruma.custom_code");
assert_to_canonical_json_eq!(content, json);
}
}