use ruma_identifiers::DeviceId;
use serde::{ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer};
use serde_json::{from_value, Value};
use super::{
HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, ShortAuthenticationString,
VerificationMethod,
};
use crate::{Event, EventType, InvalidInput, TryFromRaw};
#[derive(Clone, Debug, PartialEq)]
pub struct StartEvent {
pub content: StartEventContent,
}
#[derive(Clone, Debug, PartialEq)]
pub enum StartEventContent {
MSasV1(MSasV1Content),
#[doc(hidden)]
__Nonexhaustive,
}
impl TryFromRaw for StartEvent {
type Raw = raw::StartEvent;
type Err = &'static str;
fn try_from_raw(raw: raw::StartEvent) -> Result<Self, Self::Err> {
StartEventContent::try_from_raw(raw.content).map(|content| Self { content })
}
}
impl Serialize for StartEvent {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("StartEvent", 2)?;
state.serialize_field("content", &self.content)?;
state.serialize_field("type", &self.event_type())?;
state.end()
}
}
impl_event!(
StartEvent,
StartEventContent,
EventType::KeyVerificationStart
);
impl TryFromRaw for StartEventContent {
type Raw = raw::StartEventContent;
type Err = &'static str;
fn try_from_raw(raw: raw::StartEventContent) -> Result<Self, Self::Err> {
match raw {
raw::StartEventContent::MSasV1(content) => {
if !content
.key_agreement_protocols
.contains(&KeyAgreementProtocol::Curve25519)
{
return Err(
"`key_agreement_protocols` must contain at least `KeyAgreementProtocol::Curve25519`"
);
}
if !content.hashes.contains(&HashAlgorithm::Sha256) {
return Err("`hashes` must contain at least `HashAlgorithm::Sha256`");
}
if !content
.message_authentication_codes
.contains(&MessageAuthenticationCode::HkdfHmacSha256)
{
return Err(
"`message_authentication_codes` must contain at least `MessageAuthenticationCode::HkdfHmacSha256`"
);
}
if !content
.short_authentication_string
.contains(&ShortAuthenticationString::Decimal)
{
return Err(
"`short_authentication_string` must contain at least `ShortAuthenticationString::Decimal`",
);
}
Ok(StartEventContent::MSasV1(content))
}
raw::StartEventContent::__Nonexhaustive => {
panic!("__Nonexhaustive enum variant is not intended for use.");
}
}
}
}
impl Serialize for StartEventContent {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match *self {
StartEventContent::MSasV1(ref content) => content.serialize(serializer),
_ => panic!("Attempted to serialize __Nonexhaustive variant."),
}
}
}
pub(crate) mod raw {
use super::*;
#[derive(Clone, Debug, Deserialize, PartialEq)]
pub struct StartEvent {
pub content: StartEventContent,
}
#[derive(Clone, Debug, PartialEq)]
pub enum StartEventContent {
MSasV1(MSasV1Content),
#[doc(hidden)]
__Nonexhaustive,
}
impl<'de> Deserialize<'de> for StartEventContent {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
use serde::de::Error as _;
let value: Value = Deserialize::deserialize(deserializer)?;
let method_value = match value.get("method") {
Some(value) => value.clone(),
None => return Err(D::Error::missing_field("method")),
};
let method = match from_value::<VerificationMethod>(method_value) {
Ok(method) => method,
Err(error) => return Err(D::Error::custom(error.to_string())),
};
match method {
VerificationMethod::MSasV1 => {
let content = match from_value::<MSasV1Content>(value) {
Ok(content) => content,
Err(error) => return Err(D::Error::custom(error.to_string())),
};
Ok(StartEventContent::MSasV1(content))
}
VerificationMethod::__Nonexhaustive => Err(D::Error::custom(
"Attempted to deserialize __Nonexhaustive variant.",
)),
}
}
}
}
#[derive(Clone, Debug, PartialEq, Deserialize)]
pub struct MSasV1Content {
pub(crate) from_device: DeviceId,
pub(crate) transaction_id: String,
pub(crate) key_agreement_protocols: Vec<KeyAgreementProtocol>,
pub(crate) hashes: Vec<HashAlgorithm>,
pub(crate) message_authentication_codes: Vec<MessageAuthenticationCode>,
pub(crate) short_authentication_string: Vec<ShortAuthenticationString>,
}
#[derive(Clone, Debug, PartialEq, Deserialize)]
pub struct MSasV1ContentOptions {
pub from_device: DeviceId,
pub transaction_id: String,
pub key_agreement_protocols: Vec<KeyAgreementProtocol>,
pub hashes: Vec<HashAlgorithm>,
pub message_authentication_codes: Vec<MessageAuthenticationCode>,
pub short_authentication_string: Vec<ShortAuthenticationString>,
}
impl MSasV1Content {
pub fn new(options: MSasV1ContentOptions) -> Result<Self, InvalidInput> {
if !options
.key_agreement_protocols
.contains(&KeyAgreementProtocol::Curve25519)
{
return Err(InvalidInput("`key_agreement_protocols` must contain at least `KeyAgreementProtocol::Curve25519`".to_string()));
}
if !options.hashes.contains(&HashAlgorithm::Sha256) {
return Err(InvalidInput(
"`hashes` must contain at least `HashAlgorithm::Sha256`".to_string(),
));
}
if !options
.message_authentication_codes
.contains(&MessageAuthenticationCode::HkdfHmacSha256)
{
return Err(InvalidInput("`message_authentication_codes` must contain at least `MessageAuthenticationCode::HkdfHmacSha256`".to_string()));
}
if !options
.short_authentication_string
.contains(&ShortAuthenticationString::Decimal)
{
return Err(InvalidInput("`short_authentication_string` must contain at least `ShortAuthenticationString::Decimal`".to_string()));
}
Ok(Self {
from_device: options.from_device,
transaction_id: options.transaction_id,
key_agreement_protocols: options.key_agreement_protocols,
hashes: options.hashes,
message_authentication_codes: options.message_authentication_codes,
short_authentication_string: options.short_authentication_string,
})
}
}
impl Serialize for MSasV1Content {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("MSasV1Content", 2)?;
state.serialize_field("from_device", &self.from_device)?;
state.serialize_field("transaction_id", &self.transaction_id)?;
state.serialize_field("method", "m.sas.v1")?;
state.serialize_field("key_agreement_protocols", &self.key_agreement_protocols)?;
state.serialize_field("hashes", &self.hashes)?;
state.serialize_field(
"message_authentication_codes",
&self.message_authentication_codes,
)?;
state.serialize_field(
"short_authentication_string",
&self.short_authentication_string,
)?;
state.end()
}
}
#[cfg(test)]
mod tests {
use serde_json::to_string;
use super::{
HashAlgorithm, KeyAgreementProtocol, MSasV1Content, MSasV1ContentOptions,
MessageAuthenticationCode, ShortAuthenticationString, StartEvent, StartEventContent,
};
use crate::EventResult;
#[test]
fn invalid_m_sas_v1_content_missing_required_key_agreement_protocols() {
let error = MSasV1Content::new(MSasV1ContentOptions {
from_device: "123".to_string(),
transaction_id: "456".to_string(),
hashes: vec![HashAlgorithm::Sha256],
key_agreement_protocols: vec![],
message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256],
short_authentication_string: vec![ShortAuthenticationString::Decimal],
})
.err()
.unwrap();
assert!(error.to_string().contains("key_agreement_protocols"));
}
#[test]
fn invalid_m_sas_v1_content_missing_required_hashes() {
let error = MSasV1Content::new(MSasV1ContentOptions {
from_device: "123".to_string(),
transaction_id: "456".to_string(),
hashes: vec![],
key_agreement_protocols: vec![KeyAgreementProtocol::Curve25519],
message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256],
short_authentication_string: vec![ShortAuthenticationString::Decimal],
})
.err()
.unwrap();
assert!(error.to_string().contains("hashes"));
}
#[test]
fn invalid_m_sas_v1_content_missing_required_message_authentication_codes() {
let error = MSasV1Content::new(MSasV1ContentOptions {
from_device: "123".to_string(),
transaction_id: "456".to_string(),
hashes: vec![HashAlgorithm::Sha256],
key_agreement_protocols: vec![KeyAgreementProtocol::Curve25519],
message_authentication_codes: vec![],
short_authentication_string: vec![ShortAuthenticationString::Decimal],
})
.err()
.unwrap();
assert!(error.to_string().contains("message_authentication_codes"));
}
#[test]
fn invalid_m_sas_v1_content_missing_required_short_authentication_string() {
let error = MSasV1Content::new(MSasV1ContentOptions {
from_device: "123".to_string(),
transaction_id: "456".to_string(),
hashes: vec![HashAlgorithm::Sha256],
key_agreement_protocols: vec![KeyAgreementProtocol::Curve25519],
message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256],
short_authentication_string: vec![],
})
.err()
.unwrap();
assert!(error.to_string().contains("short_authentication_string"));
}
#[test]
fn serialization() {
let key_verification_start_content = StartEventContent::MSasV1(
MSasV1Content::new(MSasV1ContentOptions {
from_device: "123".to_string(),
transaction_id: "456".to_string(),
hashes: vec![HashAlgorithm::Sha256],
key_agreement_protocols: vec![KeyAgreementProtocol::Curve25519],
message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256],
short_authentication_string: vec![ShortAuthenticationString::Decimal],
})
.unwrap(),
);
let key_verification_start = StartEvent {
content: key_verification_start_content,
};
assert_eq!(
to_string(&key_verification_start).unwrap(),
r#"{"content":{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":["curve25519"],"hashes":["sha256"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]},"type":"m.key.verification.start"}"#
);
}
#[test]
fn deserialization() {
let key_verification_start_content = StartEventContent::MSasV1(
MSasV1Content::new(MSasV1ContentOptions {
from_device: "123".to_string(),
transaction_id: "456".to_string(),
hashes: vec![HashAlgorithm::Sha256],
key_agreement_protocols: vec![KeyAgreementProtocol::Curve25519],
message_authentication_codes: vec![MessageAuthenticationCode::HkdfHmacSha256],
short_authentication_string: vec![ShortAuthenticationString::Decimal],
})
.unwrap(),
);
assert_eq!(
serde_json::from_str::<EventResult<StartEventContent>>(
r#"{"from_device":"123","transaction_id":"456","method":"m.sas.v1","hashes":["sha256"],"key_agreement_protocols":["curve25519"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]}"#
)
.unwrap()
.into_result()
.unwrap(),
key_verification_start_content
);
let key_verification_start = StartEvent {
content: key_verification_start_content,
};
assert_eq!(
serde_json::from_str::<EventResult<StartEvent>>(
r#"{"content":{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":["curve25519"],"hashes":["sha256"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]},"type":"m.key.verification.start"}"#
)
.unwrap()
.into_result()
.unwrap(),
key_verification_start
)
}
#[test]
fn deserialization_failure() {
assert!(serde_json::from_str::<EventResult<StartEventContent>>("{").is_err());
}
#[test]
fn deserialization_structure_mismatch() {
let error =
serde_json::from_str::<EventResult<StartEventContent>>(r#"{"from_device":"123"}"#)
.unwrap()
.into_result()
.unwrap_err();
assert!(error.message().contains("missing field"));
assert!(error.is_deserialization());
}
#[test]
fn deserialization_validation_missing_required_key_agreement_protocols() {
let error =
serde_json::from_str::<EventResult<StartEventContent>>(
r#"{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":[],"hashes":["sha256"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]}"#
)
.unwrap()
.into_result()
.unwrap_err();
assert!(error.message().contains("key_agreement_protocols"));
assert!(error.is_validation());
}
#[test]
fn deserialization_validation_missing_required_hashes() {
let error =
serde_json::from_str::<EventResult<StartEventContent>>(
r#"{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":["curve25519"],"hashes":[],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]}"#
)
.unwrap()
.into_result()
.unwrap_err();
assert!(error.message().contains("hashes"));
assert!(error.is_validation());
}
#[test]
fn deserialization_validation_missing_required_message_authentication_codes() {
let error =
serde_json::from_str::<EventResult<StartEventContent>>(
r#"{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":["curve25519"],"hashes":["sha256"],"message_authentication_codes":[],"short_authentication_string":["decimal"]}"#
)
.unwrap()
.into_result()
.unwrap_err();
assert!(error.message().contains("message_authentication_codes"));
assert!(error.is_validation());
}
#[test]
fn deserialization_validation_missing_required_short_authentication_string() {
let error =
serde_json::from_str::<EventResult<StartEventContent>>(
r#"{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":["curve25519"],"hashes":["sha256"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":[]}"#
)
.unwrap()
.into_result()
.unwrap_err();
assert!(error.message().contains("short_authentication_string"));
assert!(error.is_validation());
}
#[test]
fn deserialization_of_event_validates_content() {
let error =
serde_json::from_str::<EventResult<StartEvent>>(
r#"{"content":{"from_device":"123","transaction_id":"456","method":"m.sas.v1","key_agreement_protocols":[],"hashes":["sha256"],"message_authentication_codes":["hkdf-hmac-sha256"],"short_authentication_string":["decimal"]},"type":"m.key.verification.start"}"#
)
.unwrap()
.into_result()
.unwrap_err();
assert!(error.message().contains("key_agreement_protocols"));
assert!(error.is_validation());
}
}