1use std::borrow::Cow;
6
7use ruma_common::{
8 OwnedTransactionId,
9 serde::{Base64, JsonObject},
10};
11use ruma_macros::EventContent;
12use serde::{Deserialize, Serialize};
13use serde_json::{Value as JsonValue, from_value as from_json_value};
14
15use super::{
16 HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, ShortAuthenticationString,
17};
18use crate::relation::Reference;
19
20#[derive(Clone, Debug, Deserialize, Serialize, EventContent)]
24#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
25#[ruma_event(type = "m.key.verification.accept", kind = ToDevice)]
26pub struct ToDeviceKeyVerificationAcceptEventContent {
27 pub transaction_id: OwnedTransactionId,
31
32 #[serde(flatten)]
34 pub method: AcceptMethod,
35}
36
37impl ToDeviceKeyVerificationAcceptEventContent {
38 pub fn new(transaction_id: OwnedTransactionId, method: AcceptMethod) -> Self {
41 Self { transaction_id, method }
42 }
43}
44
45#[derive(Clone, Debug, Deserialize, Serialize, EventContent)]
49#[ruma_event(type = "m.key.verification.accept", kind = MessageLike)]
50#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
51pub struct KeyVerificationAcceptEventContent {
52 #[serde(flatten)]
54 pub method: AcceptMethod,
55
56 #[serde(rename = "m.relates_to")]
58 pub relates_to: Reference,
59}
60
61impl KeyVerificationAcceptEventContent {
62 pub fn new(method: AcceptMethod, relates_to: Reference) -> Self {
65 Self { method, relates_to }
66 }
67}
68
69#[derive(Clone, Debug, Serialize)]
71#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
72#[serde(untagged)]
73pub enum AcceptMethod {
74 SasV1(SasV1Content),
76
77 #[doc(hidden)]
79 _Custom(_CustomAcceptMethodContent),
80}
81
82impl AcceptMethod {
83 pub fn data(&self) -> Cow<'_, JsonObject> {
88 fn serialize<T: Serialize>(obj: T) -> JsonObject {
89 match serde_json::to_value(obj).expect("accept method serialization to succeed") {
90 JsonValue::Object(mut obj) => {
91 obj.remove("method");
92 obj
93 }
94 _ => panic!("all accept method variants must serialize to objects"),
95 }
96 }
97
98 match self {
99 Self::SasV1(c) => Cow::Owned(serialize(c)),
100 Self::_Custom(c) => Cow::Borrowed(&c.data),
101 }
102 }
103}
104
105impl<'de> Deserialize<'de> for AcceptMethod {
106 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
107 where
108 D: serde::Deserializer<'de>,
109 {
110 let data = JsonObject::deserialize(deserializer)?;
111
112 Ok(match from_json_value(data.clone().into()) {
113 Ok(sas_v1_content) => AcceptMethod::SasV1(sas_v1_content),
114 Err(_) => AcceptMethod::_Custom(_CustomAcceptMethodContent { data }),
115 })
116 }
117}
118
119#[doc(hidden)]
121#[derive(Clone, Debug, Serialize)]
122pub struct _CustomAcceptMethodContent {
123 #[serde(flatten)]
125 data: JsonObject,
126}
127
128#[derive(Clone, Debug, Deserialize, Serialize)]
130#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
131pub struct SasV1Content {
132 pub key_agreement_protocol: KeyAgreementProtocol,
135
136 pub hash: HashAlgorithm,
139
140 pub message_authentication_code: MessageAuthenticationCode,
143
144 pub short_authentication_string: Vec<ShortAuthenticationString>,
150
151 pub commitment: Base64,
155}
156
157#[derive(Debug)]
159#[allow(clippy::exhaustive_structs)]
160pub struct SasV1ContentInit {
161 pub key_agreement_protocol: KeyAgreementProtocol,
164
165 pub hash: HashAlgorithm,
168
169 pub message_authentication_code: MessageAuthenticationCode,
171
172 pub short_authentication_string: Vec<ShortAuthenticationString>,
178
179 pub commitment: Base64,
183}
184
185impl From<SasV1ContentInit> for SasV1Content {
186 fn from(init: SasV1ContentInit) -> Self {
188 SasV1Content {
189 hash: init.hash,
190 key_agreement_protocol: init.key_agreement_protocol,
191 message_authentication_code: init.message_authentication_code,
192 short_authentication_string: init.short_authentication_string,
193 commitment: init.commitment,
194 }
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use assert_matches2::{assert_let, assert_matches};
201 use ruma_common::{
202 canonical_json::assert_to_canonical_json_eq,
203 event_id,
204 serde::{Base64, Raw},
205 };
206 use serde_json::{Value as JsonValue, from_value as from_json_value, json};
207
208 use super::{
209 AcceptMethod, HashAlgorithm, KeyAgreementProtocol, KeyVerificationAcceptEventContent,
210 MessageAuthenticationCode, SasV1Content, ShortAuthenticationString,
211 ToDeviceKeyVerificationAcceptEventContent,
212 };
213 use crate::{ToDeviceEvent, relation::Reference};
214
215 #[test]
216 fn to_device_serialization() {
217 let key_verification_accept_content = ToDeviceKeyVerificationAcceptEventContent {
218 transaction_id: "456".into(),
219 method: AcceptMethod::SasV1(SasV1Content {
220 hash: HashAlgorithm::Sha256,
221 key_agreement_protocol: KeyAgreementProtocol::Curve25519,
222 message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256V2,
223 short_authentication_string: vec![ShortAuthenticationString::Decimal],
224 commitment: Base64::new(b"hello".to_vec()),
225 }),
226 };
227
228 assert_to_canonical_json_eq!(
229 key_verification_accept_content,
230 json!({
231 "transaction_id": "456",
232 "commitment": "aGVsbG8",
233 "key_agreement_protocol": "curve25519",
234 "hash": "sha256",
235 "message_authentication_code": "hkdf-hmac-sha256.v2",
236 "short_authentication_string": ["decimal"],
237 }),
238 );
239 }
240
241 #[test]
242 fn in_room_serialization() {
243 let event_id = event_id!("$1598361704261elfgc:localhost");
244
245 let key_verification_accept_content = KeyVerificationAcceptEventContent {
246 relates_to: Reference { event_id: event_id.to_owned() },
247 method: AcceptMethod::SasV1(SasV1Content {
248 hash: HashAlgorithm::Sha256,
249 key_agreement_protocol: KeyAgreementProtocol::Curve25519,
250 message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256V2,
251 short_authentication_string: vec![ShortAuthenticationString::Decimal],
252 commitment: Base64::new(b"hello".to_vec()),
253 }),
254 };
255
256 assert_to_canonical_json_eq!(
257 key_verification_accept_content,
258 json!({
259 "commitment": "aGVsbG8",
260 "key_agreement_protocol": "curve25519",
261 "hash": "sha256",
262 "message_authentication_code": "hkdf-hmac-sha256.v2",
263 "short_authentication_string": ["decimal"],
264 "m.relates_to": {
265 "rel_type": "m.reference",
266 "event_id": event_id,
267 },
268 }),
269 );
270 }
271
272 #[test]
273 fn to_device_deserialization() {
274 let json = json!({
275 "transaction_id": "456",
276 "commitment": "aGVsbG8",
277 "hash": "sha256",
278 "key_agreement_protocol": "curve25519",
279 "message_authentication_code": "hkdf-hmac-sha256.v2",
280 "short_authentication_string": ["decimal"]
281 });
282
283 let content = from_json_value::<ToDeviceKeyVerificationAcceptEventContent>(json).unwrap();
285 assert_eq!(content.transaction_id, "456");
286
287 assert_matches!(content.method, AcceptMethod::SasV1(sas));
288 assert_eq!(sas.commitment.encode(), "aGVsbG8");
289 assert_eq!(sas.hash, HashAlgorithm::Sha256);
290 assert_eq!(sas.key_agreement_protocol, KeyAgreementProtocol::Curve25519);
291 assert_eq!(sas.message_authentication_code, MessageAuthenticationCode::HkdfHmacSha256V2);
292 assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
293
294 let json = json!({
295 "content": {
296 "commitment": "aGVsbG8",
297 "transaction_id": "456",
298 "key_agreement_protocol": "curve25519",
299 "hash": "sha256",
300 "message_authentication_code": "hkdf-hmac-sha256.v2",
301 "short_authentication_string": ["decimal"]
302 },
303 "type": "m.key.verification.accept",
304 "sender": "@example:localhost",
305 });
306
307 let ev = from_json_value::<ToDeviceEvent<ToDeviceKeyVerificationAcceptEventContent>>(json)
308 .unwrap();
309 assert_eq!(ev.content.transaction_id, "456");
310 assert_eq!(ev.sender, "@example:localhost");
311
312 assert_matches!(ev.content.method, AcceptMethod::SasV1(sas));
313 assert_eq!(sas.commitment.encode(), "aGVsbG8");
314 assert_eq!(sas.hash, HashAlgorithm::Sha256);
315 assert_eq!(sas.key_agreement_protocol, KeyAgreementProtocol::Curve25519);
316 assert_eq!(sas.message_authentication_code, MessageAuthenticationCode::HkdfHmacSha256V2);
317 assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
318 }
319
320 #[test]
321 fn in_room_deserialization() {
322 let json = json!({
323 "commitment": "aGVsbG8",
324 "hash": "sha256",
325 "key_agreement_protocol": "curve25519",
326 "message_authentication_code": "hkdf-hmac-sha256.v2",
327 "short_authentication_string": ["decimal"],
328 "m.relates_to": {
329 "rel_type": "m.reference",
330 "event_id": "$1598361704261elfgc:localhost",
331 }
332 });
333
334 let content = from_json_value::<KeyVerificationAcceptEventContent>(json).unwrap();
336 assert_eq!(content.relates_to.event_id, "$1598361704261elfgc:localhost");
337
338 assert_matches!(content.method, AcceptMethod::SasV1(sas));
339 assert_eq!(sas.commitment.encode(), "aGVsbG8");
340 assert_eq!(sas.hash, HashAlgorithm::Sha256);
341 assert_eq!(sas.key_agreement_protocol, KeyAgreementProtocol::Curve25519);
342 assert_eq!(sas.message_authentication_code, MessageAuthenticationCode::HkdfHmacSha256V2);
343 assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
344 }
345
346 #[test]
347 fn in_room_serialization_roundtrip() {
348 let event_id = event_id!("$1598361704261elfgc:localhost");
349
350 let content = KeyVerificationAcceptEventContent {
351 relates_to: Reference { event_id: event_id.to_owned() },
352 method: AcceptMethod::SasV1(SasV1Content {
353 hash: HashAlgorithm::Sha256,
354 key_agreement_protocol: KeyAgreementProtocol::Curve25519,
355 message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256V2,
356 short_authentication_string: vec![ShortAuthenticationString::Decimal],
357 commitment: Base64::new(b"hello".to_vec()),
358 }),
359 };
360
361 let json_content = Raw::new(&content).unwrap();
362 let deser_content = json_content.deserialize().unwrap();
363
364 assert_matches!(deser_content.method, AcceptMethod::SasV1(_));
365 assert_eq!(deser_content.relates_to.event_id, event_id);
366 }
367
368 #[test]
369 fn custom_to_device_serialization_roundtrip() {
370 let json = json!({
371 "transaction_id": "456",
372 "test": "field",
373 });
374
375 let content =
376 from_json_value::<ToDeviceKeyVerificationAcceptEventContent>(json.clone()).unwrap();
377
378 assert_eq!(content.transaction_id, "456");
379 let data = &*content.method.data();
380 assert_eq!(data.len(), 1);
381 assert_let!(Some(JsonValue::String(value)) = data.get("test"));
382 assert_eq!(value, "field");
383
384 assert_to_canonical_json_eq!(content, json);
385 }
386}