Skip to main content

ruma_events/key/verification/
accept.rs

1//! Types for the [`m.key.verification.accept`] event.
2//!
3//! [`m.key.verification.accept`]: https://spec.matrix.org/v1.18/client-server-api/#mkeyverificationaccept
4
5use std::borrow::Cow;
6
7use ruma_common::{
8    OwnedTransactionId,
9    serde::{Base64, JsonObject},
10};
11use ruma_macros::EventContent;
12use serde::{Deserialize, Serialize};
13use serde_json::{Value as JsonValue, from_value as from_json_value};
14
15use super::{
16    HashAlgorithm, KeyAgreementProtocol, MessageAuthenticationCode, ShortAuthenticationString,
17};
18use crate::relation::Reference;
19
20/// The content of a to-device `m.key.verification.accept` event.
21///
22/// Accepts a previously sent `m.key.verification.start` message.
23#[derive(Clone, Debug, Deserialize, Serialize, EventContent)]
24#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
25#[ruma_event(type = "m.key.verification.accept", kind = ToDevice)]
26pub struct ToDeviceKeyVerificationAcceptEventContent {
27    /// An opaque identifier for the verification process.
28    ///
29    /// Must be the same as the one used for the `m.key.verification.start` message.
30    pub transaction_id: OwnedTransactionId,
31
32    /// The method specific content.
33    #[serde(flatten)]
34    pub method: AcceptMethod,
35}
36
37impl ToDeviceKeyVerificationAcceptEventContent {
38    /// Creates a new `ToDeviceKeyVerificationAcceptEventContent` with the given transaction ID and
39    /// method-specific content.
40    pub fn new(transaction_id: OwnedTransactionId, method: AcceptMethod) -> Self {
41        Self { transaction_id, method }
42    }
43}
44
45/// The content of a in-room `m.key.verification.accept` event.
46///
47/// Accepts a previously sent `m.key.verification.start` message.
48#[derive(Clone, Debug, Deserialize, Serialize, EventContent)]
49#[ruma_event(type = "m.key.verification.accept", kind = MessageLike)]
50#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
51pub struct KeyVerificationAcceptEventContent {
52    /// The method specific content.
53    #[serde(flatten)]
54    pub method: AcceptMethod,
55
56    /// Information about the related event.
57    #[serde(rename = "m.relates_to")]
58    pub relates_to: Reference,
59}
60
61impl KeyVerificationAcceptEventContent {
62    /// Creates a new `KeyVerificationAcceptEventContent` with the given method-specific
63    /// content and reference.
64    pub fn new(method: AcceptMethod, relates_to: Reference) -> Self {
65        Self { method, relates_to }
66    }
67}
68
69/// An enum representing the different method specific `m.key.verification.accept` content.
70#[derive(Clone, Debug, Serialize)]
71#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
72#[serde(untagged)]
73pub enum AcceptMethod {
74    /// The `m.sas.v1` verification method.
75    SasV1(SasV1Content),
76
77    /// Any unknown accept method.
78    #[doc(hidden)]
79    _Custom(_CustomAcceptMethodContent),
80}
81
82impl AcceptMethod {
83    /// The data of this `AcceptMethod`.
84    ///
85    /// Prefer to use the public variants of `AcceptMethod` where possible; this method is meant to
86    /// be used for custom methods only.
87    pub fn data(&self) -> Cow<'_, JsonObject> {
88        fn serialize<T: Serialize>(obj: T) -> JsonObject {
89            match serde_json::to_value(obj).expect("accept method serialization to succeed") {
90                JsonValue::Object(mut obj) => {
91                    obj.remove("method");
92                    obj
93                }
94                _ => panic!("all accept method variants must serialize to objects"),
95            }
96        }
97
98        match self {
99            Self::SasV1(c) => Cow::Owned(serialize(c)),
100            Self::_Custom(c) => Cow::Borrowed(&c.data),
101        }
102    }
103}
104
105impl<'de> Deserialize<'de> for AcceptMethod {
106    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
107    where
108        D: serde::Deserializer<'de>,
109    {
110        let data = JsonObject::deserialize(deserializer)?;
111
112        Ok(match from_json_value(data.clone().into()) {
113            Ok(sas_v1_content) => AcceptMethod::SasV1(sas_v1_content),
114            Err(_) => AcceptMethod::_Custom(_CustomAcceptMethodContent { data }),
115        })
116    }
117}
118
119/// Method specific content of a unknown key verification method.
120#[doc(hidden)]
121#[derive(Clone, Debug, Serialize)]
122pub struct _CustomAcceptMethodContent {
123    /// The additional fields that the method contains.
124    #[serde(flatten)]
125    data: JsonObject,
126}
127
128/// The payload of an `m.key.verification.accept` event using the `m.sas.v1` method.
129#[derive(Clone, Debug, Deserialize, Serialize)]
130#[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
131pub struct SasV1Content {
132    /// The key agreement protocol the device is choosing to use, out of the
133    /// options in the `m.key.verification.start` message.
134    pub key_agreement_protocol: KeyAgreementProtocol,
135
136    /// The hash method the device is choosing to use, out of the options in the
137    /// `m.key.verification.start` message.
138    pub hash: HashAlgorithm,
139
140    /// The message authentication code the device is choosing to use, out of
141    /// the options in the `m.key.verification.start` message.
142    pub message_authentication_code: MessageAuthenticationCode,
143
144    /// The SAS methods both devices involved in the verification process
145    /// understand.
146    ///
147    /// Must be a subset of the options in the `m.key.verification.start`
148    /// message.
149    pub short_authentication_string: Vec<ShortAuthenticationString>,
150
151    /// The hash (encoded as unpadded base64) of the concatenation of the
152    /// device's ephemeral public key (encoded as unpadded base64) and the
153    /// canonical JSON representation of the `m.key.verification.start` message.
154    pub commitment: Base64,
155}
156
157/// Mandatory initial set of fields for creating an accept `SasV1Content`.
158#[derive(Debug)]
159#[allow(clippy::exhaustive_structs)]
160pub struct SasV1ContentInit {
161    /// The key agreement protocol the device is choosing to use, out of the
162    /// options in the `m.key.verification.start` message.
163    pub key_agreement_protocol: KeyAgreementProtocol,
164
165    /// The hash method the device is choosing to use, out of the options in the
166    /// `m.key.verification.start` message.
167    pub hash: HashAlgorithm,
168
169    /// The message authentication codes that the accepting device understands.
170    pub message_authentication_code: MessageAuthenticationCode,
171
172    /// The SAS methods both devices involved in the verification process
173    /// understand.
174    ///
175    /// Must be a subset of the options in the `m.key.verification.start`
176    /// message.
177    pub short_authentication_string: Vec<ShortAuthenticationString>,
178
179    /// The hash (encoded as unpadded base64) of the concatenation of the
180    /// device's ephemeral public key (encoded as unpadded base64) and the
181    /// canonical JSON representation of the `m.key.verification.start` message.
182    pub commitment: Base64,
183}
184
185impl From<SasV1ContentInit> for SasV1Content {
186    /// Creates a new `SasV1Content` from the given init struct.
187    fn from(init: SasV1ContentInit) -> Self {
188        SasV1Content {
189            hash: init.hash,
190            key_agreement_protocol: init.key_agreement_protocol,
191            message_authentication_code: init.message_authentication_code,
192            short_authentication_string: init.short_authentication_string,
193            commitment: init.commitment,
194        }
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use assert_matches2::{assert_let, assert_matches};
201    use ruma_common::{
202        canonical_json::assert_to_canonical_json_eq,
203        event_id,
204        serde::{Base64, Raw},
205    };
206    use serde_json::{Value as JsonValue, from_value as from_json_value, json};
207
208    use super::{
209        AcceptMethod, HashAlgorithm, KeyAgreementProtocol, KeyVerificationAcceptEventContent,
210        MessageAuthenticationCode, SasV1Content, ShortAuthenticationString,
211        ToDeviceKeyVerificationAcceptEventContent,
212    };
213    use crate::{ToDeviceEvent, relation::Reference};
214
215    #[test]
216    fn to_device_serialization() {
217        let key_verification_accept_content = ToDeviceKeyVerificationAcceptEventContent {
218            transaction_id: "456".into(),
219            method: AcceptMethod::SasV1(SasV1Content {
220                hash: HashAlgorithm::Sha256,
221                key_agreement_protocol: KeyAgreementProtocol::Curve25519,
222                message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256V2,
223                short_authentication_string: vec![ShortAuthenticationString::Decimal],
224                commitment: Base64::new(b"hello".to_vec()),
225            }),
226        };
227
228        assert_to_canonical_json_eq!(
229            key_verification_accept_content,
230            json!({
231                "transaction_id": "456",
232                "commitment": "aGVsbG8",
233                "key_agreement_protocol": "curve25519",
234                "hash": "sha256",
235                "message_authentication_code": "hkdf-hmac-sha256.v2",
236                "short_authentication_string": ["decimal"],
237            }),
238        );
239    }
240
241    #[test]
242    fn in_room_serialization() {
243        let event_id = event_id!("$1598361704261elfgc:localhost");
244
245        let key_verification_accept_content = KeyVerificationAcceptEventContent {
246            relates_to: Reference { event_id: event_id.to_owned() },
247            method: AcceptMethod::SasV1(SasV1Content {
248                hash: HashAlgorithm::Sha256,
249                key_agreement_protocol: KeyAgreementProtocol::Curve25519,
250                message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256V2,
251                short_authentication_string: vec![ShortAuthenticationString::Decimal],
252                commitment: Base64::new(b"hello".to_vec()),
253            }),
254        };
255
256        assert_to_canonical_json_eq!(
257            key_verification_accept_content,
258            json!({
259                "commitment": "aGVsbG8",
260                "key_agreement_protocol": "curve25519",
261                "hash": "sha256",
262                "message_authentication_code": "hkdf-hmac-sha256.v2",
263                "short_authentication_string": ["decimal"],
264                "m.relates_to": {
265                    "rel_type": "m.reference",
266                    "event_id": event_id,
267                },
268            }),
269        );
270    }
271
272    #[test]
273    fn to_device_deserialization() {
274        let json = json!({
275            "transaction_id": "456",
276            "commitment": "aGVsbG8",
277            "hash": "sha256",
278            "key_agreement_protocol": "curve25519",
279            "message_authentication_code": "hkdf-hmac-sha256.v2",
280            "short_authentication_string": ["decimal"]
281        });
282
283        // Deserialize the content struct separately to verify `TryFromRaw` is implemented for it.
284        let content = from_json_value::<ToDeviceKeyVerificationAcceptEventContent>(json).unwrap();
285        assert_eq!(content.transaction_id, "456");
286
287        assert_matches!(content.method, AcceptMethod::SasV1(sas));
288        assert_eq!(sas.commitment.encode(), "aGVsbG8");
289        assert_eq!(sas.hash, HashAlgorithm::Sha256);
290        assert_eq!(sas.key_agreement_protocol, KeyAgreementProtocol::Curve25519);
291        assert_eq!(sas.message_authentication_code, MessageAuthenticationCode::HkdfHmacSha256V2);
292        assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
293
294        let json = json!({
295            "content": {
296                "commitment": "aGVsbG8",
297                "transaction_id": "456",
298                "key_agreement_protocol": "curve25519",
299                "hash": "sha256",
300                "message_authentication_code": "hkdf-hmac-sha256.v2",
301                "short_authentication_string": ["decimal"]
302            },
303            "type": "m.key.verification.accept",
304            "sender": "@example:localhost",
305        });
306
307        let ev = from_json_value::<ToDeviceEvent<ToDeviceKeyVerificationAcceptEventContent>>(json)
308            .unwrap();
309        assert_eq!(ev.content.transaction_id, "456");
310        assert_eq!(ev.sender, "@example:localhost");
311
312        assert_matches!(ev.content.method, AcceptMethod::SasV1(sas));
313        assert_eq!(sas.commitment.encode(), "aGVsbG8");
314        assert_eq!(sas.hash, HashAlgorithm::Sha256);
315        assert_eq!(sas.key_agreement_protocol, KeyAgreementProtocol::Curve25519);
316        assert_eq!(sas.message_authentication_code, MessageAuthenticationCode::HkdfHmacSha256V2);
317        assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
318    }
319
320    #[test]
321    fn in_room_deserialization() {
322        let json = json!({
323            "commitment": "aGVsbG8",
324            "hash": "sha256",
325            "key_agreement_protocol": "curve25519",
326            "message_authentication_code": "hkdf-hmac-sha256.v2",
327            "short_authentication_string": ["decimal"],
328            "m.relates_to": {
329                "rel_type": "m.reference",
330                "event_id": "$1598361704261elfgc:localhost",
331            }
332        });
333
334        // Deserialize the content struct separately to verify `TryFromRaw` is implemented for it.
335        let content = from_json_value::<KeyVerificationAcceptEventContent>(json).unwrap();
336        assert_eq!(content.relates_to.event_id, "$1598361704261elfgc:localhost");
337
338        assert_matches!(content.method, AcceptMethod::SasV1(sas));
339        assert_eq!(sas.commitment.encode(), "aGVsbG8");
340        assert_eq!(sas.hash, HashAlgorithm::Sha256);
341        assert_eq!(sas.key_agreement_protocol, KeyAgreementProtocol::Curve25519);
342        assert_eq!(sas.message_authentication_code, MessageAuthenticationCode::HkdfHmacSha256V2);
343        assert_eq!(sas.short_authentication_string, vec![ShortAuthenticationString::Decimal]);
344    }
345
346    #[test]
347    fn in_room_serialization_roundtrip() {
348        let event_id = event_id!("$1598361704261elfgc:localhost");
349
350        let content = KeyVerificationAcceptEventContent {
351            relates_to: Reference { event_id: event_id.to_owned() },
352            method: AcceptMethod::SasV1(SasV1Content {
353                hash: HashAlgorithm::Sha256,
354                key_agreement_protocol: KeyAgreementProtocol::Curve25519,
355                message_authentication_code: MessageAuthenticationCode::HkdfHmacSha256V2,
356                short_authentication_string: vec![ShortAuthenticationString::Decimal],
357                commitment: Base64::new(b"hello".to_vec()),
358            }),
359        };
360
361        let json_content = Raw::new(&content).unwrap();
362        let deser_content = json_content.deserialize().unwrap();
363
364        assert_matches!(deser_content.method, AcceptMethod::SasV1(_));
365        assert_eq!(deser_content.relates_to.event_id, event_id);
366    }
367
368    #[test]
369    fn custom_to_device_serialization_roundtrip() {
370        let json = json!({
371            "transaction_id": "456",
372            "test": "field",
373        });
374
375        let content =
376            from_json_value::<ToDeviceKeyVerificationAcceptEventContent>(json.clone()).unwrap();
377
378        assert_eq!(content.transaction_id, "456");
379        let data = &*content.method.data();
380        assert_eq!(data.len(), 1);
381        assert_let!(Some(JsonValue::String(value)) = data.get("test"));
382        assert_eq!(value, "field");
383
384        assert_to_canonical_json_eq!(content, json);
385    }
386}