matrix_sdk_crypto/verification/
cache.rs1use std::{collections::BTreeMap, sync::Arc};
16
17use as_variant::as_variant;
18use matrix_sdk_common::locks::RwLock as StdRwLock;
19use ruma::{DeviceId, OwnedTransactionId, OwnedUserId, TransactionId, UserId};
20#[cfg(feature = "qrcode")]
21use tracing::debug;
22use tracing::{trace, warn};
23
24use super::{event_enums::OutgoingContent, FlowId, Sas, Verification};
25use crate::types::requests::{
26 OutgoingRequest, OutgoingVerificationRequest, RoomMessageRequest, ToDeviceRequest,
27};
28#[cfg(feature = "qrcode")]
29use crate::QrVerification;
30
31#[derive(Clone, Debug, Default)]
32pub struct VerificationCache {
33 inner: Arc<VerificationCacheInner>,
34}
35
36#[cfg(not(feature = "test-send-sync"))]
38unsafe impl Sync for VerificationCache {}
39
40#[cfg(feature = "test-send-sync")]
41#[test]
42fn test_send_sync_for_room() {
44 fn assert_send_sync<
45 T: matrix_sdk_common::SendOutsideWasm + matrix_sdk_common::SyncOutsideWasm,
46 >() {
47 }
48
49 assert_send_sync::<VerificationCache>();
50}
51
52#[derive(Debug, Default)]
53struct VerificationCacheInner {
54 verification: StdRwLock<BTreeMap<OwnedUserId, BTreeMap<String, Verification>>>,
55 outgoing_requests: StdRwLock<BTreeMap<OwnedTransactionId, OutgoingRequest>>,
56 flow_ids_waiting_for_response: StdRwLock<BTreeMap<OwnedTransactionId, (OwnedUserId, FlowId)>>,
57}
58
59#[derive(Debug)]
60pub struct RequestInfo {
61 pub flow_id: FlowId,
62 pub request_id: OwnedTransactionId,
63}
64
65impl VerificationCache {
66 pub fn new() -> Self {
67 Self::default()
68 }
69
70 #[cfg(test)]
71 #[allow(dead_code)]
72 pub fn is_empty(&self) -> bool {
73 self.inner.verification.read().values().all(|m| m.is_empty())
74 }
75
76 pub fn insert(&self, verification: impl Into<Verification>) {
80 let verification = verification.into();
81
82 let mut verification_write_guard = self.inner.verification.write();
83 let user_verifications =
84 verification_write_guard.entry(verification.other_user().to_owned()).or_default();
85
86 for old_verification in user_verifications.values() {
90 if !old_verification.is_cancelled() {
91 warn!(
92 user_id = ?verification.other_user(),
93 old_flow_id = old_verification.flow_id(),
94 new_flow_id = verification.flow_id(),
95 "Received a new verification whilst another one with \
96 the same user is ongoing. Cancelling both verifications"
97 );
98
99 if let Some(r) = old_verification.cancel() {
100 self.add_request(r.into())
101 }
102
103 if let Some(r) = verification.cancel() {
104 self.add_request(r.into())
105 }
106 }
107 }
108
109 user_verifications.insert(verification.flow_id().to_owned(), verification);
113 }
114
115 pub fn insert_sas(&self, sas: Sas) {
116 self.insert(sas);
117 }
118
119 pub fn replace_sas(&self, sas: Sas) {
120 let verification: Verification = sas.into();
121 self.replace(verification);
122 }
123
124 #[cfg(feature = "qrcode")]
125 pub fn insert_qr(&self, qr: QrVerification) {
126 debug!(
127 user_id = ?qr.other_user_id(),
128 flow_id = qr.flow_id().as_str(),
129 "Inserting new QR verification"
130 );
131 self.insert(qr)
132 }
133
134 #[cfg(feature = "qrcode")]
135 pub fn replace_qr(&self, qr: QrVerification) {
136 debug!(
137 user_id = ?qr.other_user_id(),
138 flow_id = qr.flow_id().as_str(),
139 "Replacing existing QR verification"
140 );
141 let verification: Verification = qr.into();
142 self.replace(verification);
143 }
144
145 #[cfg(feature = "qrcode")]
146 pub fn get_qr(&self, sender: &UserId, flow_id: &str) -> Option<Box<QrVerification>> {
147 self.get(sender, flow_id).and_then(as_variant!(Verification::QrV1))
148 }
149
150 pub fn replace(&self, verification: Verification) {
151 self.inner
152 .verification
153 .write()
154 .entry(verification.other_user().to_owned())
155 .or_default()
156 .insert(verification.flow_id().to_owned(), verification.clone());
157 }
158
159 pub fn get(&self, sender: &UserId, flow_id: &str) -> Option<Verification> {
160 self.inner.verification.read().get(sender)?.get(flow_id).cloned()
161 }
162
163 pub fn outgoing_requests(&self) -> Vec<OutgoingRequest> {
164 self.inner.outgoing_requests.read().values().cloned().collect()
165 }
166
167 pub fn garbage_collect(&self) -> Vec<OutgoingVerificationRequest> {
168 let verification = &mut self.inner.verification.write();
169
170 for user_verification in verification.values_mut() {
171 user_verification.retain(|_, s| !(s.is_done() || s.is_cancelled()));
172 }
173
174 verification.retain(|_, m| !m.is_empty());
175
176 verification
177 .values()
178 .flat_map(BTreeMap::values)
179 .filter_map(|s| as_variant!(s, Verification::SasV1)?.cancel_if_timed_out())
180 .collect()
181 }
182
183 pub fn get_sas(&self, user_id: &UserId, flow_id: &str) -> Option<Box<Sas>> {
184 self.get(user_id, flow_id).and_then(as_variant!(Verification::SasV1))
185 }
186
187 pub fn add_request(&self, request: OutgoingRequest) {
188 trace!("Adding an outgoing request {:?}", request);
189 self.inner.outgoing_requests.write().insert(request.request_id.clone(), request);
190 }
191
192 pub fn add_verification_request(&self, request: OutgoingVerificationRequest) {
193 let request = OutgoingRequest {
194 request_id: request.request_id().to_owned(),
195 request: Arc::new(request.into()),
196 };
197 self.add_request(request);
198 }
199
200 pub fn queue_up_content(
201 &self,
202 recipient: &UserId,
203 recipient_device: &DeviceId,
204 content: OutgoingContent,
205 request_info: Option<RequestInfo>,
206 ) {
207 let request_id = if let Some(request_info) = request_info {
208 trace!(
209 ?recipient,
210 ?request_info,
211 "Storing the request info, waiting for the request to be marked as sent"
212 );
213
214 self.inner.flow_ids_waiting_for_response.write().insert(
215 request_info.request_id.to_owned(),
216 (recipient.to_owned(), request_info.flow_id),
217 );
218 request_info.request_id
219 } else {
220 TransactionId::new()
221 };
222
223 match content {
224 OutgoingContent::ToDevice(c) => {
225 let request = ToDeviceRequest::with_id(
226 recipient,
227 recipient_device.to_owned(),
228 &c,
229 request_id,
230 );
231 let request_id = request.txn_id.clone();
232
233 let request = OutgoingRequest {
234 request_id: request_id.clone(),
235 request: Arc::new(request.into()),
236 };
237
238 self.inner.outgoing_requests.write().insert(request_id, request);
239 }
240
241 OutgoingContent::Room(r, c) => {
242 let request = OutgoingRequest {
243 request: Arc::new(
244 RoomMessageRequest { room_id: r, txn_id: request_id.clone(), content: c }
245 .into(),
246 ),
247 request_id: request_id.clone(),
248 };
249
250 self.inner.outgoing_requests.write().insert(request_id, request);
251 }
252 }
253 }
254
255 pub fn mark_request_as_sent(&self, request_id: &TransactionId) {
256 if let Some(request_id) = self.inner.outgoing_requests.write().remove(request_id) {
257 trace!(?request_id, "Marking a verification HTTP request as sent");
258 }
259
260 if let Some((user_id, flow_id)) =
261 self.inner.flow_ids_waiting_for_response.read().get(request_id)
262 {
263 if let Some(verification) = self.get(user_id, flow_id.as_str()) {
264 match verification {
265 Verification::SasV1(s) => s.mark_request_as_sent(request_id),
266 #[cfg(feature = "qrcode")]
267 Verification::QrV1(_) => (),
268 }
269 }
270 }
271 }
272}