matrix_sdk_crypto/verification/
cache.rs

1// Copyright 2021 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{collections::BTreeMap, sync::Arc};
16
17use as_variant::as_variant;
18use matrix_sdk_common::locks::RwLock as StdRwLock;
19use ruma::{DeviceId, OwnedTransactionId, OwnedUserId, TransactionId, UserId};
20#[cfg(feature = "qrcode")]
21use tracing::debug;
22use tracing::{trace, warn};
23
24use super::{event_enums::OutgoingContent, FlowId, Sas, Verification};
25use crate::types::requests::{
26    OutgoingRequest, OutgoingVerificationRequest, RoomMessageRequest, ToDeviceRequest,
27};
28#[cfg(feature = "qrcode")]
29use crate::QrVerification;
30
31#[derive(Clone, Debug, Default)]
32pub struct VerificationCache {
33    inner: Arc<VerificationCacheInner>,
34}
35
36// See https://github.com/matrix-org/matrix-rust-sdk/pull/3749#issuecomment-2312939823.
37#[cfg(not(feature = "test-send-sync"))]
38unsafe impl Sync for VerificationCache {}
39
40#[cfg(feature = "test-send-sync")]
41#[test]
42// See https://github.com/matrix-org/matrix-rust-sdk/pull/3749#issuecomment-2312939823.
43fn test_send_sync_for_room() {
44    fn assert_send_sync<
45        T: matrix_sdk_common::SendOutsideWasm + matrix_sdk_common::SyncOutsideWasm,
46    >() {
47    }
48
49    assert_send_sync::<VerificationCache>();
50}
51
52#[derive(Debug, Default)]
53struct VerificationCacheInner {
54    verification: StdRwLock<BTreeMap<OwnedUserId, BTreeMap<String, Verification>>>,
55    outgoing_requests: StdRwLock<BTreeMap<OwnedTransactionId, OutgoingRequest>>,
56    flow_ids_waiting_for_response: StdRwLock<BTreeMap<OwnedTransactionId, (OwnedUserId, FlowId)>>,
57}
58
59#[derive(Debug)]
60pub struct RequestInfo {
61    pub flow_id: FlowId,
62    pub request_id: OwnedTransactionId,
63}
64
65impl VerificationCache {
66    pub fn new() -> Self {
67        Self::default()
68    }
69
70    #[cfg(test)]
71    #[allow(dead_code)]
72    pub fn is_empty(&self) -> bool {
73        self.inner.verification.read().values().all(|m| m.is_empty())
74    }
75
76    /// Add a new `Verification` object to the cache, this will cancel any
77    /// duplicates we have going on, including the newly inserted one, with a
78    /// given user.
79    pub fn insert(&self, verification: impl Into<Verification>) {
80        let verification = verification.into();
81
82        let mut verification_write_guard = self.inner.verification.write();
83        let user_verifications =
84            verification_write_guard.entry(verification.other_user().to_owned()).or_default();
85
86        // Cancel all the old verifications as well as the new one we have for
87        // this user if someone tries to have two verifications going on at
88        // once.
89        for old_verification in user_verifications.values() {
90            if !old_verification.is_cancelled() {
91                warn!(
92                    user_id = ?verification.other_user(),
93                    old_flow_id = old_verification.flow_id(),
94                    new_flow_id = verification.flow_id(),
95                    "Received a new verification whilst another one with \
96                    the same user is ongoing. Cancelling both verifications"
97                );
98
99                if let Some(r) = old_verification.cancel() {
100                    self.add_request(r.into())
101                }
102
103                if let Some(r) = verification.cancel() {
104                    self.add_request(r.into())
105                }
106            }
107        }
108
109        // We still want to add the new verification, in case users want to
110        // inspect the verification object a matching `m.key.verification.start`
111        // produced.
112        user_verifications.insert(verification.flow_id().to_owned(), verification);
113    }
114
115    pub fn insert_sas(&self, sas: Sas) {
116        self.insert(sas);
117    }
118
119    pub fn replace_sas(&self, sas: Sas) {
120        let verification: Verification = sas.into();
121        self.replace(verification);
122    }
123
124    #[cfg(feature = "qrcode")]
125    pub fn insert_qr(&self, qr: QrVerification) {
126        debug!(
127            user_id = ?qr.other_user_id(),
128            flow_id = qr.flow_id().as_str(),
129            "Inserting new QR verification"
130        );
131        self.insert(qr)
132    }
133
134    #[cfg(feature = "qrcode")]
135    pub fn replace_qr(&self, qr: QrVerification) {
136        debug!(
137            user_id = ?qr.other_user_id(),
138            flow_id = qr.flow_id().as_str(),
139            "Replacing existing QR verification"
140        );
141        let verification: Verification = qr.into();
142        self.replace(verification);
143    }
144
145    #[cfg(feature = "qrcode")]
146    pub fn get_qr(&self, sender: &UserId, flow_id: &str) -> Option<Box<QrVerification>> {
147        self.get(sender, flow_id).and_then(as_variant!(Verification::QrV1))
148    }
149
150    pub fn replace(&self, verification: Verification) {
151        self.inner
152            .verification
153            .write()
154            .entry(verification.other_user().to_owned())
155            .or_default()
156            .insert(verification.flow_id().to_owned(), verification.clone());
157    }
158
159    pub fn get(&self, sender: &UserId, flow_id: &str) -> Option<Verification> {
160        self.inner.verification.read().get(sender)?.get(flow_id).cloned()
161    }
162
163    pub fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
164        self.inner.outgoing_requests.read().values().cloned().collect()
165    }
166
167    pub fn garbage_collect(&self) -> Vec<OutgoingVerificationRequest> {
168        let verification = &mut self.inner.verification.write();
169
170        for user_verification in verification.values_mut() {
171            user_verification.retain(|_, s| !(s.is_done() || s.is_cancelled()));
172        }
173
174        verification.retain(|_, m| !m.is_empty());
175
176        verification
177            .values()
178            .flat_map(BTreeMap::values)
179            .filter_map(|s| as_variant!(s, Verification::SasV1)?.cancel_if_timed_out())
180            .collect()
181    }
182
183    pub fn get_sas(&self, user_id: &UserId, flow_id: &str) -> Option<Box<Sas>> {
184        self.get(user_id, flow_id).and_then(as_variant!(Verification::SasV1))
185    }
186
187    pub fn add_request(&self, request: OutgoingRequest) {
188        trace!("Adding an outgoing request {:?}", request);
189        self.inner.outgoing_requests.write().insert(request.request_id.clone(), request);
190    }
191
192    pub fn add_verification_request(&self, request: OutgoingVerificationRequest) {
193        let request = OutgoingRequest {
194            request_id: request.request_id().to_owned(),
195            request: Arc::new(request.into()),
196        };
197        self.add_request(request);
198    }
199
200    pub fn queue_up_content(
201        &self,
202        recipient: &UserId,
203        recipient_device: &DeviceId,
204        content: OutgoingContent,
205        request_info: Option<RequestInfo>,
206    ) {
207        let request_id = if let Some(request_info) = request_info {
208            trace!(
209                ?recipient,
210                ?request_info,
211                "Storing the request info, waiting for the request to be marked as sent"
212            );
213
214            self.inner.flow_ids_waiting_for_response.write().insert(
215                request_info.request_id.to_owned(),
216                (recipient.to_owned(), request_info.flow_id),
217            );
218            request_info.request_id
219        } else {
220            TransactionId::new()
221        };
222
223        match content {
224            OutgoingContent::ToDevice(c) => {
225                let request = ToDeviceRequest::with_id(
226                    recipient,
227                    recipient_device.to_owned(),
228                    &c,
229                    request_id,
230                );
231                let request_id = request.txn_id.clone();
232
233                let request = OutgoingRequest {
234                    request_id: request_id.clone(),
235                    request: Arc::new(request.into()),
236                };
237
238                self.inner.outgoing_requests.write().insert(request_id, request);
239            }
240
241            OutgoingContent::Room(r, c) => {
242                let request = OutgoingRequest {
243                    request: Arc::new(
244                        RoomMessageRequest { room_id: r, txn_id: request_id.clone(), content: c }
245                            .into(),
246                    ),
247                    request_id: request_id.clone(),
248                };
249
250                self.inner.outgoing_requests.write().insert(request_id, request);
251            }
252        }
253    }
254
255    pub fn mark_request_as_sent(&self, request_id: &TransactionId) {
256        if let Some(request_id) = self.inner.outgoing_requests.write().remove(request_id) {
257            trace!(?request_id, "Marking a verification HTTP request as sent");
258        }
259
260        if let Some((user_id, flow_id)) =
261            self.inner.flow_ids_waiting_for_response.read().get(request_id)
262        {
263            if let Some(verification) = self.get(user_id, flow_id.as_str()) {
264                match verification {
265                    Verification::SasV1(s) => s.mark_request_as_sent(request_id),
266                    #[cfg(feature = "qrcode")]
267                    Verification::QrV1(_) => (),
268                }
269            }
270        }
271    }
272}