use std::borrow::Cow;
use ruma_common::{
OwnedTransactionId,
serde::{Base64, JsonObject},
};
use ruma_macros::EventContent;
use serde::{Deserialize, Serialize};
use serde_json::{Value as JsonValue, from_value as from_json_value};
use super::{
HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, ShortAuthenticationString,
};
use crate::relation::Reference;
#[derive(Clone, Debug, Deserialize, Serialize, EventContent)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
#[ruma_event(type = "m.key.verification.accept", kind = ToDevice)]
pub struct ToDeviceKeyVerificationAcceptEventContent {
pub transaction_id: OwnedTransactionId,
#[serde(flatten)]
pub method: AcceptMethod,
}
impl ToDeviceKeyVerificationAcceptEventContent {
pub fn new(transaction_id: OwnedTransactionId, method: AcceptMethod) -> Self {
Self { transaction_id, method }
}
}
#[derive(Clone, Debug, Deserialize, Serialize, EventContent)]
#[ruma_event(type = "m.key.verification.accept", kind = MessageLike)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
pub struct KeyVerificationAcceptEventContent {
#[serde(flatten)]
pub method: AcceptMethod,
#[serde(rename = "m.relates_to")]
pub relates_to: Reference,
}
impl KeyVerificationAcceptEventContent {
pub fn new(method: AcceptMethod, relates_to: Reference) -> Self {
Self { method, relates_to }
}
}
#[derive(Clone, Debug, Serialize)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
#[serde(untagged)]
pub enum AcceptMethod {
SasV1(SasV1Content),
#[doc(hidden)]
_Custom(_CustomAcceptMethodContent),
}
impl AcceptMethod {
pub fn data(&self) -> Cow<'_, JsonObject> {
fn serialize<T: Serialize>(obj: T) -> JsonObject {
match serde_json::to_value(obj).expect("accept method serialization to succeed") {
JsonValue::Object(mut obj) => {
obj.remove("method");
obj
}
_ => panic!("all accept method variants must serialize to objects"),
}
}
match self {
Self::SasV1(c) => Cow::Owned(serialize(c)),
Self::_Custom(c) => Cow::Borrowed(&c.data),
}
}
}
impl<'de> Deserialize<'de> for AcceptMethod {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let data = JsonObject::deserialize(deserializer)?;
Ok(match from_json_value(data.clone().into()) {
Ok(sas_v1_content) => AcceptMethod::SasV1(sas_v1_content),
Err(_) => AcceptMethod::_Custom(_CustomAcceptMethodContent { data }),
})
}
}
#[doc(hidden)]
#[derive(Clone, Debug, Serialize)]
pub struct _CustomAcceptMethodContent {
#[serde(flatten)]
data: JsonObject,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
pub struct SasV1Content {
pub key_agreement_protocol: KeyAgreementProtocol,
pub hash: HashAlgorithm,
pub message_authentication_code: MessageAuthenticationCode,
pub short_authentication_string: Vec<ShortAuthenticationString>,
pub commitment: Base64,
}
#[derive(Debug)]
#[allow(clippy::exhaustive_structs)]
pub struct SasV1ContentInit {
pub key_agreement_protocol: KeyAgreementProtocol,
pub hash: HashAlgorithm,
pub message_authentication_code: MessageAuthenticationCode,
pub short_authentication_string: Vec<ShortAuthenticationString>,
pub commitment: Base64,
}
impl From<SasV1ContentInit> for SasV1Content {
fn from(init: SasV1ContentInit) -> Self {
SasV1Content {
hash: init.hash,
key_agreement_protocol: init.key_agreement_protocol,
message_authentication_code: init.message_authentication_code,
short_authentication_string: init.short_authentication_string,
commitment: init.commitment,
}
}
}
#[cfg(test)]
mod tests {
use assert_matches2::{assert_let, assert_matches};
use ruma_common::{
canonical_json::assert_to_canonical_json_eq,
event_id,
serde::{Base64, Raw},
};
use serde_json::{Value as JsonValue, from_value as from_json_value, json};
use super::{
AcceptMethod, HashAlgorithm, KeyAgreementProtocol, KeyVerificationAcceptEventContent,
MessageAuthenticationCode, SasV1Content, ShortAuthenticationString,
ToDeviceKeyVerificationAcceptEventContent,
};
use crate::{ToDeviceEvent, relation::Reference};
#[test]
fn to_device_serialization() {
let key_verification_accept_content = ToDeviceKeyVerificationAcceptEventContent {
transaction_id: "456".into(),
method: AcceptMethod::SasV1(SasV1Content {
hash: HashAlgorithm::Sha256,
key_agreement_protocol: KeyAgreementProtocol::Curve25519,
message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256V2,
short_authentication_string: vec![ShortAuthenticationString::Decimal],
commitment: Base64::new(b"hello".to_vec()),
}),
};
assert_to_canonical_json_eq!(
key_verification_accept_content,
json!({
"transaction_id": "456",
"commitment": "aGVsbG8",
"key_agreement_protocol": "curve25519",
"hash": "sha256",
"message_authentication_code": "hkdf-hmac-sha256.v2",
"short_authentication_string": ["decimal"],
}),
);
}
#[test]
fn in_room_serialization() {
let event_id = event_id!("$1598361704261elfgc:localhost");
let key_verification_accept_content = KeyVerificationAcceptEventContent {
relates_to: Reference { event_id: event_id.to_owned() },
method: AcceptMethod::SasV1(SasV1Content {
hash: HashAlgorithm::Sha256,
key_agreement_protocol: KeyAgreementProtocol::Curve25519,
message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256V2,
short_authentication_string: vec![ShortAuthenticationString::Decimal],
commitment: Base64::new(b"hello".to_vec()),
}),
};
assert_to_canonical_json_eq!(
key_verification_accept_content,
json!({
"commitment": "aGVsbG8",
"key_agreement_protocol": "curve25519",
"hash": "sha256",
"message_authentication_code": "hkdf-hmac-sha256.v2",
"short_authentication_string": ["decimal"],
"m.relates_to": {
"rel_type": "m.reference",
"event_id": event_id,
},
}),
);
}
#[test]
fn to_device_deserialization() {
let json = json!({
"transaction_id": "456",
"commitment": "aGVsbG8",
"hash": "sha256",
"key_agreement_protocol": "curve25519",
"message_authentication_code": "hkdf-hmac-sha256.v2",
"short_authentication_string": ["decimal"]
});
let content = from_json_value::<ToDeviceKeyVerificationAcceptEventContent>(json).unwrap();
assert_eq!(content.transaction_id, "456");
assert_matches!(content.method, AcceptMethod::SasV1(sas));
assert_eq!(sas.commitment.encode(), "aGVsbG8");
assert_eq!(sas.hash, HashAlgorithm::Sha256);
assert_eq!(sas.key_agreement_protocol, KeyAgreementProtocol::Curve25519);
assert_eq!(sas.message_authentication_code, MessageAuthenticationCode::HkdfHmacSha256V2);
assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
let json = json!({
"content": {
"commitment": "aGVsbG8",
"transaction_id": "456",
"key_agreement_protocol": "curve25519",
"hash": "sha256",
"message_authentication_code": "hkdf-hmac-sha256.v2",
"short_authentication_string": ["decimal"]
},
"type": "m.key.verification.accept",
"sender": "@example:localhost",
});
let ev = from_json_value::<ToDeviceEvent<ToDeviceKeyVerificationAcceptEventContent>>(json)
.unwrap();
assert_eq!(ev.content.transaction_id, "456");
assert_eq!(ev.sender, "@example:localhost");
assert_matches!(ev.content.method, AcceptMethod::SasV1(sas));
assert_eq!(sas.commitment.encode(), "aGVsbG8");
assert_eq!(sas.hash, HashAlgorithm::Sha256);
assert_eq!(sas.key_agreement_protocol, KeyAgreementProtocol::Curve25519);
assert_eq!(sas.message_authentication_code, MessageAuthenticationCode::HkdfHmacSha256V2);
assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
}
#[test]
fn in_room_deserialization() {
let json = json!({
"commitment": "aGVsbG8",
"hash": "sha256",
"key_agreement_protocol": "curve25519",
"message_authentication_code": "hkdf-hmac-sha256.v2",
"short_authentication_string": ["decimal"],
"m.relates_to": {
"rel_type": "m.reference",
"event_id": "$1598361704261elfgc:localhost",
}
});
let content = from_json_value::<KeyVerificationAcceptEventContent>(json).unwrap();
assert_eq!(content.relates_to.event_id, "$1598361704261elfgc:localhost");
assert_matches!(content.method, AcceptMethod::SasV1(sas));
assert_eq!(sas.commitment.encode(), "aGVsbG8");
assert_eq!(sas.hash, HashAlgorithm::Sha256);
assert_eq!(sas.key_agreement_protocol, KeyAgreementProtocol::Curve25519);
assert_eq!(sas.message_authentication_code, MessageAuthenticationCode::HkdfHmacSha256V2);
assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
}
#[test]
fn in_room_serialization_roundtrip() {
let event_id = event_id!("$1598361704261elfgc:localhost");
let content = KeyVerificationAcceptEventContent {
relates_to: Reference { event_id: event_id.to_owned() },
method: AcceptMethod::SasV1(SasV1Content {
hash: HashAlgorithm::Sha256,
key_agreement_protocol: KeyAgreementProtocol::Curve25519,
message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256V2,
short_authentication_string: vec![ShortAuthenticationString::Decimal],
commitment: Base64::new(b"hello".to_vec()),
}),
};
let json_content = Raw::new(&content).unwrap();
let deser_content = json_content.deserialize().unwrap();
assert_matches!(deser_content.method, AcceptMethod::SasV1(_));
assert_eq!(deser_content.relates_to.event_id, event_id);
}
#[test]
fn custom_to_device_serialization_roundtrip() {
let json = json!({
"transaction_id": "456",
"test": "field",
});
let content =
from_json_value::<ToDeviceKeyVerificationAcceptEventContent>(json.clone()).unwrap();
assert_eq!(content.transaction_id, "456");
let data = &*content.method.data();
assert_eq!(data.len(), 1);
assert_let!(Some(JsonValue::String(value)) = data.get("test"));
assert_eq!(value, "field");
assert_to_canonical_json_eq!(content, json);
}
}