use std::{borrow::Cow, fmt};
use as_variant::as_variant;
use ruma_common::{
OwnedDeviceId, OwnedTransactionId,
serde::{Base64, JsonObject},
};
use ruma_macros::EventContent;
use serde::{Deserialize, Deserializer, Serialize, de};
use serde_json::{Value as JsonValue, from_value as from_json_value};
use super::{
HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, ShortAuthenticationString,
};
use crate::relation::Reference;
#[derive(Clone, Debug, Deserialize, Serialize, EventContent)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
#[ruma_event(type = "m.key.verification.start", kind = ToDevice)]
pub struct ToDeviceKeyVerificationStartEventContent {
pub from_device: OwnedDeviceId,
pub transaction_id: OwnedTransactionId,
#[serde(flatten)]
pub method: StartMethod,
}
impl ToDeviceKeyVerificationStartEventContent {
pub fn new(
from_device: OwnedDeviceId,
transaction_id: OwnedTransactionId,
method: StartMethod,
) -> Self {
Self { from_device, transaction_id, method }
}
}
#[derive(Clone, Debug, Deserialize, Serialize, EventContent)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
#[ruma_event(type = "m.key.verification.start", kind = MessageLike)]
pub struct KeyVerificationStartEventContent {
pub from_device: OwnedDeviceId,
#[serde(flatten)]
pub method: StartMethod,
#[serde(rename = "m.relates_to")]
pub relates_to: Reference,
}
impl KeyVerificationStartEventContent {
pub fn new(from_device: OwnedDeviceId, method: StartMethod, relates_to: Reference) -> Self {
Self { from_device, method, relates_to }
}
}
#[derive(Clone, Debug, Serialize)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
#[serde(untagged)]
pub enum StartMethod {
SasV1(SasV1Content),
ReciprocateV1(ReciprocateV1Content),
#[doc(hidden)]
_Custom(_CustomStartMethodContent),
}
impl StartMethod {
pub fn method(&self) -> &str {
match self {
Self::SasV1(_) => "m.sas.v1",
Self::ReciprocateV1(_) => "m.reciprocate.v1",
Self::_Custom(c) => &c.method,
}
}
pub fn data(&self) -> Cow<'_, JsonObject> {
fn serialize<T: Serialize>(obj: T) -> JsonObject {
match serde_json::to_value(obj).expect("start method serialization to succeed") {
JsonValue::Object(mut obj) => {
obj.remove("method");
obj
}
_ => panic!("all start method variants must serialize to objects"),
}
}
match self {
Self::SasV1(c) => Cow::Owned(serialize(c)),
Self::ReciprocateV1(c) => Cow::Owned(serialize(c)),
Self::_Custom(c) => Cow::Borrowed(&c.data),
}
}
}
impl<'de> Deserialize<'de> for StartMethod {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let mut data = JsonObject::deserialize(deserializer)?;
let method = data
.get("method")
.and_then(|value| as_variant!(value, JsonValue::String))
.ok_or_else(|| de::Error::missing_field("method"))?;
match method.as_ref() {
"m.sas.v1" => from_json_value(data.into()).map(Self::SasV1),
"m.reciprocate.v1" => from_json_value(data.into()).map(Self::ReciprocateV1),
_ => {
let method = as_variant!(
data.remove("method")
.expect("we already checked that the method field is present"),
JsonValue::String
)
.expect("we already checked that the method is a string");
Ok(Self::_Custom(_CustomStartMethodContent { method, data }))
}
}
.map_err(de::Error::custom)
}
}
#[doc(hidden)]
#[derive(Clone, Debug, Serialize)]
pub struct _CustomStartMethodContent {
method: String,
#[serde(flatten)]
data: JsonObject,
}
#[derive(Clone, Deserialize, Serialize)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
#[serde(rename = "m.reciprocate.v1", tag = "method")]
pub struct ReciprocateV1Content {
pub secret: Base64,
}
impl ReciprocateV1Content {
pub fn new(secret: Base64) -> Self {
Self { secret }
}
}
impl fmt::Debug for ReciprocateV1Content {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ReciprocateV1Content").finish_non_exhaustive()
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
#[serde(rename = "m.sas.v1", tag = "method")]
pub struct SasV1Content {
pub key_agreement_protocols: Vec<KeyAgreementProtocol>,
pub hashes: Vec<HashAlgorithm>,
pub message_authentication_codes: Vec<MessageAuthenticationCode>,
pub short_authentication_string: Vec<ShortAuthenticationString>,
}
#[derive(Debug)]
#[allow(clippy::exhaustive_structs)]
pub struct SasV1ContentInit {
pub key_agreement_protocols: Vec<KeyAgreementProtocol>,
pub hashes: Vec<HashAlgorithm>,
pub message_authentication_codes: Vec<MessageAuthenticationCode>,
pub short_authentication_string: Vec<ShortAuthenticationString>,
}
impl From<SasV1ContentInit> for SasV1Content {
fn from(init: SasV1ContentInit) -> Self {
Self {
key_agreement_protocols: init.key_agreement_protocols,
hashes: init.hashes,
message_authentication_codes: init.message_authentication_codes,
short_authentication_string: init.short_authentication_string,
}
}
}
#[cfg(test)]
mod tests {
use assert_matches2::{assert_let, assert_matches};
use ruma_common::{canonical_json::assert_to_canonical_json_eq, event_id, serde::Base64};
use serde_json::{Value as JsonValue, from_value as from_json_value, json};
use super::{
HashAlgorithm, KeyAgreementProtocol, KeyVerificationStartEventContent,
MessageAuthenticationCode, ReciprocateV1Content, SasV1ContentInit,
ShortAuthenticationString, StartMethod, ToDeviceKeyVerificationStartEventContent,
};
use crate::{ToDeviceEvent, relation::Reference};
#[test]
fn to_device_serialization() {
let key_verification_start_content = ToDeviceKeyVerificationStartEventContent {
from_device: "123".into(),
transaction_id: "456".into(),
method: StartMethod::SasV1(
SasV1ContentInit {
hashes: vec![HashAlgorithm::Sha256],
key_agreement_protocols: vec![KeyAgreementProtocol::Curve25519],
message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256V2],
short_authentication_string: vec![ShortAuthenticationString::Decimal],
}
.into(),
),
};
assert_to_canonical_json_eq!(
key_verification_start_content,
json!({
"from_device": "123",
"transaction_id": "456",
"method": "m.sas.v1",
"key_agreement_protocols": ["curve25519"],
"hashes": ["sha256"],
"message_authentication_codes": ["hkdf-hmac-sha256.v2"],
"short_authentication_string": ["decimal"],
}),
);
let secret = Base64::new(b"This is a secret to everybody".to_vec());
let key_verification_start_content = ToDeviceKeyVerificationStartEventContent {
from_device: "123".into(),
transaction_id: "456".into(),
method: StartMethod::ReciprocateV1(ReciprocateV1Content::new(secret.clone())),
};
assert_to_canonical_json_eq!(
key_verification_start_content,
json!({
"from_device": "123",
"method": "m.reciprocate.v1",
"secret": secret,
"transaction_id": "456",
}),
);
}
#[test]
fn in_room_serialization() {
let event_id = event_id!("$1598361704261elfgc:localhost");
let key_verification_start_content = KeyVerificationStartEventContent {
from_device: "123".into(),
relates_to: Reference { event_id: event_id.to_owned() },
method: StartMethod::SasV1(
SasV1ContentInit {
hashes: vec![HashAlgorithm::Sha256],
key_agreement_protocols: vec![KeyAgreementProtocol::Curve25519],
message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256V2],
short_authentication_string: vec![ShortAuthenticationString::Decimal],
}
.into(),
),
};
assert_to_canonical_json_eq!(
key_verification_start_content,
json!({
"from_device": "123",
"method": "m.sas.v1",
"key_agreement_protocols": ["curve25519"],
"hashes": ["sha256"],
"message_authentication_codes": ["hkdf-hmac-sha256.v2"],
"short_authentication_string": ["decimal"],
"m.relates_to": {
"rel_type": "m.reference",
"event_id": event_id,
},
}),
);
let secret = Base64::new(b"This is a secret to everybody".to_vec());
let key_verification_start_content = KeyVerificationStartEventContent {
from_device: "123".into(),
relates_to: Reference { event_id: event_id.to_owned() },
method: StartMethod::ReciprocateV1(ReciprocateV1Content::new(secret.clone())),
};
assert_to_canonical_json_eq!(
key_verification_start_content,
json!({
"from_device": "123",
"method": "m.reciprocate.v1",
"secret": secret,
"m.relates_to": {
"rel_type": "m.reference",
"event_id": event_id,
},
}),
);
}
#[test]
fn to_device_deserialization() {
let json = json!({
"from_device": "123",
"transaction_id": "456",
"method": "m.sas.v1",
"hashes": ["sha256"],
"key_agreement_protocols": ["curve25519"],
"message_authentication_codes": ["hkdf-hmac-sha256.v2"],
"short_authentication_string": ["decimal"]
});
let content = from_json_value::<ToDeviceKeyVerificationStartEventContent>(json).unwrap();
assert_eq!(content.from_device, "123");
assert_eq!(content.transaction_id, "456");
assert_matches!(content.method, StartMethod::SasV1(sas));
assert_eq!(sas.hashes, vec![HashAlgorithm::Sha256]);
assert_eq!(sas.key_agreement_protocols, vec![KeyAgreementProtocol::Curve25519]);
assert_eq!(
sas.message_authentication_codes,
vec![MessageAuthenticationCode::HkdfHmacSha256V2]
);
assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
let json = json!({
"content": {
"from_device": "123",
"transaction_id": "456",
"method": "m.sas.v1",
"key_agreement_protocols": ["curve25519"],
"hashes": ["sha256"],
"message_authentication_codes": ["hkdf-hmac-sha256.v2"],
"short_authentication_string": ["decimal"]
},
"type": "m.key.verification.start",
"sender": "@example:localhost",
});
let ev = from_json_value::<ToDeviceEvent<ToDeviceKeyVerificationStartEventContent>>(json)
.unwrap();
assert_eq!(ev.sender, "@example:localhost");
assert_eq!(ev.content.from_device, "123");
assert_eq!(ev.content.transaction_id, "456");
assert_matches!(ev.content.method, StartMethod::SasV1(sas));
assert_eq!(sas.hashes, vec![HashAlgorithm::Sha256]);
assert_eq!(sas.key_agreement_protocols, vec![KeyAgreementProtocol::Curve25519]);
assert_eq!(
sas.message_authentication_codes,
vec![MessageAuthenticationCode::HkdfHmacSha256V2]
);
assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
let json = json!({
"content": {
"from_device": "123",
"method": "m.reciprocate.v1",
"secret": "c2VjcmV0Cg",
"transaction_id": "456",
},
"type": "m.key.verification.start",
"sender": "@example:localhost",
});
let ev = from_json_value::<ToDeviceEvent<ToDeviceKeyVerificationStartEventContent>>(json)
.unwrap();
assert_eq!(ev.sender, "@example:localhost");
assert_eq!(ev.content.from_device, "123");
assert_eq!(ev.content.transaction_id, "456");
assert_matches!(ev.content.method, StartMethod::ReciprocateV1(reciprocate));
assert_eq!(reciprocate.secret.encode(), "c2VjcmV0Cg");
}
#[test]
fn in_room_deserialization() {
let json = json!({
"from_device": "123",
"method": "m.sas.v1",
"hashes": ["sha256"],
"key_agreement_protocols": ["curve25519"],
"message_authentication_codes": ["hkdf-hmac-sha256.v2"],
"short_authentication_string": ["decimal"],
"m.relates_to": {
"rel_type": "m.reference",
"event_id": "$1598361704261elfgc:localhost",
}
});
let content = from_json_value::<KeyVerificationStartEventContent>(json).unwrap();
assert_eq!(content.from_device, "123");
assert_eq!(content.relates_to.event_id, "$1598361704261elfgc:localhost");
assert_matches!(content.method, StartMethod::SasV1(sas));
assert_eq!(sas.hashes, vec![HashAlgorithm::Sha256]);
assert_eq!(sas.key_agreement_protocols, vec![KeyAgreementProtocol::Curve25519]);
assert_eq!(
sas.message_authentication_codes,
vec![MessageAuthenticationCode::HkdfHmacSha256V2]
);
assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
let json = json!({
"from_device": "123",
"method": "m.reciprocate.v1",
"secret": "c2VjcmV0Cg",
"m.relates_to": {
"rel_type": "m.reference",
"event_id": "$1598361704261elfgc:localhost",
}
});
let content = from_json_value::<KeyVerificationStartEventContent>(json).unwrap();
assert_eq!(content.from_device, "123");
assert_eq!(content.relates_to.event_id, "$1598361704261elfgc:localhost");
assert_matches!(content.method, StartMethod::ReciprocateV1(reciprocate));
assert_eq!(reciprocate.secret.encode(), "c2VjcmV0Cg");
}
#[test]
fn custom_to_device_serialization_roundtrip() {
let json = json!({
"from_device": "123",
"transaction_id": "456",
"method": "m.sas.custom",
"test": "field",
});
let content =
from_json_value::<ToDeviceKeyVerificationStartEventContent>(json.clone()).unwrap();
assert_eq!(content.from_device, "123");
assert_eq!(content.transaction_id, "456");
assert_eq!(content.method.method(), "m.sas.custom");
let data = &*content.method.data();
assert_eq!(data.len(), 1);
assert_let!(Some(JsonValue::String(value)) = data.get("test"));
assert_eq!(value, "field");
assert_to_canonical_json_eq!(content, json);
}
}