matrix_sdk_crypto/session_manager/group_sessions/
mod.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod share_strategy;
16
17use std::{
18    collections::{BTreeMap, BTreeSet},
19    fmt::Debug,
20    iter,
21    iter::zip,
22    sync::Arc,
23};
24
25use futures_util::future::join_all;
26use itertools::Itertools;
27use matrix_sdk_common::{
28    deserialized_responses::WithheldCode, executor::spawn, locks::RwLock as StdRwLock,
29};
30use ruma::{
31    events::{AnyMessageLikeEventContent, AnyToDeviceEventContent, ToDeviceEventType},
32    serde::Raw,
33    to_device::DeviceIdOrAllDevices,
34    DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId,
35    UserId,
36};
37use serde::Serialize;
38pub(crate) use share_strategy::CollectRecipientsResult;
39pub use share_strategy::CollectStrategy;
40use tracing::{debug, error, info, instrument, trace, warn, Instrument};
41
42use crate::{
43    error::{EventError, MegolmResult, OlmResult},
44    identities::device::MaybeEncryptedRoomKey,
45    olm::{
46        InboundGroupSession, OutboundGroupSession, SenderData, SenderDataFinder, Session,
47        ShareInfo, ShareState,
48    },
49    store::{Changes, CryptoStoreWrapper, Result as StoreResult, Store},
50    types::{
51        events::{
52            room::encrypted::{RoomEncryptedEventContent, ToDeviceEncryptedEventContent},
53            room_key_bundle::RoomKeyBundleContent,
54            EventType,
55        },
56        requests::ToDeviceRequest,
57    },
58    Device, DeviceData, EncryptionSettings, OlmError,
59};
60
61#[derive(Clone, Debug)]
62pub(crate) struct GroupSessionCache {
63    store: Store,
64    sessions: Arc<StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>>,
65    /// A map from the request id to the group session that the request belongs
66    /// to. Used to mark requests belonging to the session as shared.
67    sessions_being_shared: Arc<StdRwLock<BTreeMap<OwnedTransactionId, OutboundGroupSession>>>,
68}
69
70impl GroupSessionCache {
71    pub(crate) fn new(store: Store) -> Self {
72        Self { store, sessions: Default::default(), sessions_being_shared: Default::default() }
73    }
74
75    pub(crate) fn insert(&self, session: OutboundGroupSession) {
76        self.sessions.write().insert(session.room_id().to_owned(), session);
77    }
78
79    /// Either get a session for the given room from the cache or load it from
80    /// the store.
81    ///
82    /// # Arguments
83    ///
84    /// * `room_id` - The id of the room this session is used for.
85    pub async fn get_or_load(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
86        // Get the cached session, if there isn't one load one from the store
87        // and put it in the cache.
88        if let Some(s) = self.sessions.read().get(room_id) {
89            return Some(s.clone());
90        }
91
92        match self.store.get_outbound_group_session(room_id).await {
93            Ok(Some(s)) => {
94                {
95                    let mut sessions_being_shared = self.sessions_being_shared.write();
96                    for request_id in s.pending_request_ids() {
97                        sessions_being_shared.insert(request_id, s.clone());
98                    }
99                }
100
101                self.sessions.write().insert(room_id.to_owned(), s.clone());
102
103                Some(s)
104            }
105            Ok(None) => None,
106            Err(e) => {
107                error!("Couldn't restore an outbound group session: {e:?}");
108                None
109            }
110        }
111    }
112
113    /// Get an outbound group session for a room, if one exists.
114    ///
115    /// # Arguments
116    ///
117    /// * `room_id` - The id of the room for which we should get the outbound
118    ///   group session.
119    fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
120        self.sessions.read().get(room_id).cloned()
121    }
122
123    /// Returns whether any session is withheld with the given device and code.
124    fn has_session_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
125        self.sessions.read().values().any(|s| s.sharing_view().is_withheld_to(device, code))
126    }
127
128    fn remove_from_being_shared(&self, id: &TransactionId) -> Option<OutboundGroupSession> {
129        self.sessions_being_shared.write().remove(id)
130    }
131
132    fn mark_as_being_shared(&self, id: OwnedTransactionId, session: OutboundGroupSession) {
133        self.sessions_being_shared.write().insert(id, session);
134    }
135}
136
137#[derive(Debug, Clone)]
138pub(crate) struct GroupSessionManager {
139    /// Store for the encryption keys.
140    /// Persists all the encryption keys so a client can resume the session
141    /// without the need to create new keys.
142    store: Store,
143    /// The currently active outbound group sessions.
144    sessions: GroupSessionCache,
145}
146
147impl GroupSessionManager {
148    const MAX_TO_DEVICE_MESSAGES: usize = 250;
149
150    pub fn new(store: Store) -> Self {
151        Self { store: store.clone(), sessions: GroupSessionCache::new(store) }
152    }
153
154    pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
155        if let Some(s) = self.sessions.get(room_id) {
156            s.invalidate_session();
157
158            let mut changes = Changes::default();
159            changes.outbound_group_sessions.push(s.clone());
160            self.store.save_changes(changes).await?;
161
162            Ok(true)
163        } else {
164            Ok(false)
165        }
166    }
167
168    pub async fn mark_request_as_sent(&self, request_id: &TransactionId) -> StoreResult<()> {
169        let Some(session) = self.sessions.remove_from_being_shared(request_id) else {
170            return Ok(());
171        };
172
173        let no_olm = session.mark_request_as_sent(request_id);
174
175        let mut changes = Changes::default();
176
177        for (user_id, devices) in &no_olm {
178            for device_id in devices {
179                let device = self.store.get_device(user_id, device_id).await;
180
181                if let Ok(Some(device)) = device {
182                    device.mark_withheld_code_as_sent();
183                    changes.devices.changed.push(device.inner.clone());
184                } else {
185                    error!(
186                        ?request_id,
187                        "Marking to-device no olm as sent but device not found, might \
188                            have been deleted?"
189                    );
190                }
191            }
192        }
193
194        changes.outbound_group_sessions.push(session.clone());
195        self.store.save_changes(changes).await
196    }
197
198    #[cfg(test)]
199    pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
200        self.sessions.get(room_id)
201    }
202
203    pub async fn encrypt(
204        &self,
205        room_id: &RoomId,
206        event_type: &str,
207        content: &Raw<AnyMessageLikeEventContent>,
208    ) -> MegolmResult<Raw<RoomEncryptedEventContent>> {
209        let session =
210            self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
211
212        assert!(!session.expired(), "Session expired");
213
214        let content = session.encrypt(event_type, content).await;
215
216        let mut changes = Changes::default();
217        changes.outbound_group_sessions.push(session);
218        self.store.save_changes(changes).await?;
219
220        Ok(content)
221    }
222
223    /// Create a new outbound group session.
224    ///
225    /// This also creates a matching inbound group session.
226    pub async fn create_outbound_group_session(
227        &self,
228        room_id: &RoomId,
229        settings: EncryptionSettings,
230        own_sender_data: SenderData,
231    ) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> {
232        let (outbound, inbound) = self
233            .store
234            .static_account()
235            .create_group_session_pair(room_id, settings, own_sender_data)
236            .await
237            .map_err(|_| EventError::UnsupportedAlgorithm)?;
238
239        self.sessions.insert(outbound.clone());
240        Ok((outbound, inbound))
241    }
242
243    pub async fn get_or_create_outbound_session(
244        &self,
245        room_id: &RoomId,
246        settings: EncryptionSettings,
247        own_sender_data: SenderData,
248    ) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> {
249        let outbound_session = self.sessions.get_or_load(room_id).await;
250
251        // If there is no session or the session has expired or is invalid,
252        // create a new one.
253        if let Some(s) = outbound_session {
254            if s.expired() || s.invalidated() {
255                self.create_outbound_group_session(room_id, settings, own_sender_data)
256                    .await
257                    .map(|(o, i)| (o, i.into()))
258            } else {
259                Ok((s, None))
260            }
261        } else {
262            self.create_outbound_group_session(room_id, settings, own_sender_data)
263                .await
264                .map(|(o, i)| (o, i.into()))
265        }
266    }
267
268    /// Encrypt the given group session key for the given devices and create
269    /// to-device requests that sends the encrypted content to them.
270    ///
271    /// See also [`encrypt_content_for_devices`] which is similar
272    /// but is not specific to group sessions, and does not return the
273    /// [`ShareInfo`] data.
274    async fn encrypt_session_for(
275        store: Arc<CryptoStoreWrapper>,
276        group_session: OutboundGroupSession,
277        devices: Vec<DeviceData>,
278    ) -> OlmResult<(
279        EncryptForDevicesResult,
280        BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>,
281    )> {
282        // Use a named type instead of a tuple with rather long type name
283        pub struct DeviceResult {
284            device: DeviceData,
285            maybe_encrypted_room_key: MaybeEncryptedRoomKey,
286        }
287
288        let mut result_builder = EncryptForDevicesResultBuilder::default();
289        let mut share_infos = BTreeMap::new();
290
291        // XXX is there a way to do this that doesn't involve cloning the
292        // `Arc<CryptoStoreWrapper>` for each device?
293        let encrypt = |store: Arc<CryptoStoreWrapper>,
294                       device: DeviceData,
295                       session: OutboundGroupSession| async move {
296            let encryption_result = device.maybe_encrypt_room_key(store.as_ref(), session).await?;
297
298            Ok::<_, OlmError>(DeviceResult { device, maybe_encrypted_room_key: encryption_result })
299        };
300
301        let tasks: Vec<_> = devices
302            .iter()
303            .map(|d| spawn(encrypt(store.clone(), d.clone(), group_session.clone())))
304            .collect();
305
306        let results = join_all(tasks).await;
307
308        for result in results {
309            let result = result.expect("Encryption task panicked")?;
310
311            match result.maybe_encrypted_room_key {
312                MaybeEncryptedRoomKey::Encrypted { used_session, share_info, message } => {
313                    result_builder.on_successful_encryption(&result.device, *used_session, message);
314
315                    let user_id = result.device.user_id().to_owned();
316                    let device_id = result.device.device_id().to_owned();
317                    share_infos
318                        .entry(user_id)
319                        .or_insert_with(BTreeMap::new)
320                        .insert(device_id, *share_info);
321                }
322                MaybeEncryptedRoomKey::MissingSession => {
323                    result_builder.on_missing_session(result.device);
324                }
325            }
326        }
327
328        Ok((result_builder.into_result(), share_infos))
329    }
330
331    /// Given a list of user and an outbound session, return the list of users
332    /// and their devices that this session should be shared with.
333    ///
334    /// Returns information indicating whether the session needs to be rotated
335    /// and the list of users/devices that should receive or not the session
336    /// (with withheld reason).
337    #[instrument(skip_all)]
338    pub async fn collect_session_recipients(
339        &self,
340        users: impl Iterator<Item = &UserId>,
341        settings: &EncryptionSettings,
342        outbound: &OutboundGroupSession,
343    ) -> OlmResult<CollectRecipientsResult> {
344        share_strategy::collect_session_recipients(&self.store, users, settings, outbound).await
345    }
346
347    async fn encrypt_request(
348        store: Arc<CryptoStoreWrapper>,
349        chunk: Vec<DeviceData>,
350        outbound: OutboundGroupSession,
351        sessions: GroupSessionCache,
352    ) -> OlmResult<(Vec<Session>, Vec<(DeviceData, WithheldCode)>)> {
353        let (result, share_infos) =
354            Self::encrypt_session_for(store, outbound.clone(), chunk).await?;
355
356        if let Some(request) = result.to_device_request {
357            let id = request.txn_id.clone();
358            outbound.add_request(id.clone(), request.into(), share_infos);
359            sessions.mark_as_being_shared(id, outbound.clone());
360        }
361
362        Ok((result.updated_olm_sessions, result.no_olm_devices))
363    }
364
365    pub(crate) fn session_cache(&self) -> GroupSessionCache {
366        self.sessions.clone()
367    }
368
369    async fn maybe_rotate_group_session(
370        &self,
371        should_rotate: bool,
372        room_id: &RoomId,
373        outbound: OutboundGroupSession,
374        encryption_settings: EncryptionSettings,
375        changes: &mut Changes,
376        own_device: Option<Device>,
377    ) -> OlmResult<OutboundGroupSession> {
378        Ok(if should_rotate {
379            let old_session_id = outbound.session_id();
380
381            let (outbound, mut inbound) = self
382                .create_outbound_group_session(room_id, encryption_settings, SenderData::unknown())
383                .await?;
384
385            // Use our own device info to populate the SenderData that validates the
386            // InboundGroupSession that we create as a pair to the OutboundGroupSession we
387            // are sending out.
388            let own_sender_data = if let Some(device) = own_device {
389                SenderDataFinder::find_using_device_data(
390                    &self.store,
391                    device.inner.clone(),
392                    &inbound,
393                )
394                .await?
395            } else {
396                error!("Unable to find our own device!");
397                SenderData::unknown()
398            };
399            inbound.sender_data = own_sender_data;
400
401            changes.outbound_group_sessions.push(outbound.clone());
402            changes.inbound_group_sessions.push(inbound);
403
404            debug!(
405                old_session_id = old_session_id,
406                session_id = outbound.session_id(),
407                "A user or device has left the room since we last sent a \
408                message, or the encryption settings have changed. Rotating the \
409                room key.",
410            );
411
412            outbound
413        } else {
414            outbound
415        })
416    }
417
418    async fn encrypt_for_devices(
419        &self,
420        recipient_devices: Vec<DeviceData>,
421        group_session: &OutboundGroupSession,
422        changes: &mut Changes,
423    ) -> OlmResult<Vec<(DeviceData, WithheldCode)>> {
424        // If we have some recipients, log them here.
425        if !recipient_devices.is_empty() {
426            let recipients = recipient_list_to_users_and_devices(&recipient_devices);
427
428            // If there are new recipients we need to persist the outbound group
429            // session as the to-device requests are persisted with the session.
430            changes.outbound_group_sessions = vec![group_session.clone()];
431
432            let message_index = group_session.message_index().await;
433
434            info!(
435                ?recipients,
436                message_index,
437                room_id = ?group_session.room_id(),
438                session_id = group_session.session_id(),
439                "Trying to encrypt a room key",
440            );
441        }
442
443        // Chunk the recipients out so each to-device request will contain a
444        // limited amount of to-device messages.
445        //
446        // Create concurrent tasks for each chunk of recipients.
447        let tasks: Vec<_> = recipient_devices
448            .chunks(Self::MAX_TO_DEVICE_MESSAGES)
449            .map(|chunk| {
450                spawn(Self::encrypt_request(
451                    self.store.crypto_store(),
452                    chunk.to_vec(),
453                    group_session.clone(),
454                    self.sessions.clone(),
455                ))
456            })
457            .collect();
458
459        let mut withheld_devices = Vec::new();
460
461        // Wait for all the tasks to finish up and queue up the Olm session that
462        // was used to encrypt the room key to be persisted again. This is
463        // needed because each encryption step will mutate the Olm session,
464        // ratcheting its state forward.
465        for result in join_all(tasks).await {
466            let result = result.expect("Encryption task panicked");
467
468            let (used_sessions, failed_no_olm) = result?;
469
470            changes.sessions.extend(used_sessions);
471            withheld_devices.extend(failed_no_olm);
472        }
473
474        Ok(withheld_devices)
475    }
476
477    fn is_withheld_to(
478        &self,
479        group_session: &OutboundGroupSession,
480        device: &DeviceData,
481        code: &WithheldCode,
482    ) -> bool {
483        // The `m.no_olm` withheld code is special because it is supposed to be sent
484        // only once for a given device. The `Device` remembers the flag if we
485        // already sent a `m.no_olm` to this particular device so let's check
486        // that first.
487        //
488        // Keep in mind that any outbound group session might want to send this code to
489        // the device. So we need to check if any of our outbound group sessions
490        // is attempting to send the code to the device.
491        //
492        // This still has a slight race where some other thread might remove the
493        // outbound group session while a third is marking the device as having
494        // received the code.
495        //
496        // Since nothing terrible happens if we do end up sending the withheld code
497        // twice, and removing the race requires us to lock the store because the
498        // `OutboundGroupSession` and the `Device` both interact with the flag we'll
499        // leave it be.
500        if code == &WithheldCode::NoOlm {
501            device.was_withheld_code_sent() || self.sessions.has_session_withheld_to(device, code)
502        } else {
503            group_session.sharing_view().is_withheld_to(device, code)
504        }
505    }
506
507    fn handle_withheld_devices(
508        &self,
509        group_session: &OutboundGroupSession,
510        withheld_devices: Vec<(DeviceData, WithheldCode)>,
511    ) -> OlmResult<()> {
512        // Convert a withheld code for the group session into a to-device event content.
513        let to_content = |code| {
514            let content = group_session.withheld_code(code);
515            Raw::new(&content).expect("We can always serialize a withheld content info").cast()
516        };
517
518        // Helper to convert a chunk of device and withheld code pairs into a to-device
519        // request and it's accompanying share info.
520        let chunk_to_request = |chunk| {
521            let mut messages = BTreeMap::new();
522            let mut share_infos = BTreeMap::new();
523
524            for (device, code) in chunk {
525                let device: DeviceData = device;
526                let code: WithheldCode = code;
527
528                let user_id = device.user_id().to_owned();
529                let device_id = device.device_id().to_owned();
530
531                let share_info = ShareInfo::new_withheld(code.to_owned());
532                let content = to_content(code);
533
534                messages
535                    .entry(user_id.to_owned())
536                    .or_insert_with(BTreeMap::new)
537                    .insert(DeviceIdOrAllDevices::DeviceId(device_id.to_owned()), content);
538
539                share_infos
540                    .entry(user_id)
541                    .or_insert_with(BTreeMap::new)
542                    .insert(device_id, share_info);
543            }
544
545            let txn_id = TransactionId::new();
546
547            let request = ToDeviceRequest {
548                event_type: ToDeviceEventType::from("m.room_key.withheld"),
549                txn_id,
550                messages,
551            };
552
553            (request, share_infos)
554        };
555
556        let result: Vec<_> = withheld_devices
557            .into_iter()
558            .filter(|(device, code)| !self.is_withheld_to(group_session, device, code))
559            .chunks(Self::MAX_TO_DEVICE_MESSAGES)
560            .into_iter()
561            .map(chunk_to_request)
562            .collect();
563
564        for (request, share_info) in result {
565            if !request.messages.is_empty() {
566                let txn_id = request.txn_id.to_owned();
567                group_session.add_request(txn_id.to_owned(), request.into(), share_info);
568
569                self.sessions.mark_as_being_shared(txn_id, group_session.clone());
570            }
571        }
572
573        Ok(())
574    }
575
576    fn log_room_key_sharing_result(requests: &[Arc<ToDeviceRequest>]) {
577        for request in requests {
578            let message_list = Self::to_device_request_to_log_list(request);
579            info!(
580                request_id = ?request.txn_id,
581                ?message_list,
582                "Created batch of to-device messages of type {}",
583                request.event_type
584            );
585        }
586    }
587
588    /// Given a to-device request, build a recipient map suitable for logging.
589    ///
590    /// Returns a list of triples of (message_id, user id, device_id).
591    fn to_device_request_to_log_list(
592        request: &Arc<ToDeviceRequest>,
593    ) -> Vec<(String, String, String)> {
594        #[derive(serde::Deserialize)]
595        struct ContentStub<'a> {
596            #[serde(borrow, default, rename = "org.matrix.msgid")]
597            message_id: Option<&'a str>,
598        }
599
600        let mut result: Vec<(String, String, String)> = Vec::new();
601
602        for (user_id, device_map) in &request.messages {
603            for (device, content) in device_map {
604                let message_id: Option<&str> = content
605                    .deserialize_as::<ContentStub<'_>>()
606                    .expect("We should be able to deserialize the content we generated")
607                    .message_id;
608
609                result.push((
610                    message_id.unwrap_or("<undefined>").to_owned(),
611                    user_id.to_string(),
612                    device.to_string(),
613                ));
614            }
615        }
616        result
617    }
618
619    /// Get to-device requests to share a room key with users in a room.
620    ///
621    /// # Arguments
622    ///
623    /// `room_id` - The room id of the room where the room key will be used.
624    ///
625    /// `users` - The list of users that should receive the room key.
626    ///
627    /// `encryption_settings` - The settings that should be used for
628    /// the room key.
629    #[instrument(skip(self, users, encryption_settings), fields(session_id))]
630    pub async fn share_room_key(
631        &self,
632        room_id: &RoomId,
633        users: impl Iterator<Item = &UserId>,
634        encryption_settings: impl Into<EncryptionSettings>,
635    ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
636        trace!("Checking if a room key needs to be shared");
637
638        let account = self.store.static_account();
639        let device = self.store.get_device(account.user_id(), account.device_id()).await?;
640
641        let encryption_settings = encryption_settings.into();
642        let mut changes = Changes::default();
643
644        // Try to get an existing session or create a new one.
645        let (outbound, inbound) = self
646            .get_or_create_outbound_session(
647                room_id,
648                encryption_settings.clone(),
649                SenderData::unknown(),
650            )
651            .await?;
652        tracing::Span::current().record("session_id", outbound.session_id());
653
654        // Having an inbound group session here means that we created a new
655        // group session pair, which we then need to store.
656        if let Some(mut inbound) = inbound {
657            // Use our own device info to populate the SenderData that validates the
658            // InboundGroupSession that we create as a pair to the OutboundGroupSession we
659            // are sending out.
660            let own_sender_data = if let Some(device) = &device {
661                SenderDataFinder::find_using_device_data(
662                    &self.store,
663                    device.inner.clone(),
664                    &inbound,
665                )
666                .await?
667            } else {
668                error!("Unable to find our own device!");
669                SenderData::unknown()
670            };
671            inbound.sender_data = own_sender_data;
672
673            changes.outbound_group_sessions.push(outbound.clone());
674            changes.inbound_group_sessions.push(inbound);
675        }
676
677        // Collect the recipient devices and check if either the settings
678        // or the recipient list changed in a way that requires the
679        // session to be rotated.
680        let CollectRecipientsResult { should_rotate, devices, mut withheld_devices } =
681            self.collect_session_recipients(users, &encryption_settings, &outbound).await?;
682
683        let outbound = self
684            .maybe_rotate_group_session(
685                should_rotate,
686                room_id,
687                outbound,
688                encryption_settings,
689                &mut changes,
690                device,
691            )
692            .await?;
693
694        // Filter out the devices that already received this room key or have a
695        // to-device message already queued up.
696        let devices: Vec<_> = devices
697            .into_iter()
698            .flat_map(|(_, d)| {
699                d.into_iter().filter(|d| match outbound.sharing_view().get_share_state(d) {
700                    ShareState::NotShared => true,
701                    ShareState::Shared { message_index: _, olm_wedging_index } => {
702                        // If the recipient device's Olm wedging index is higher
703                        // than the value that we stored with the session, that
704                        // means that they tried to unwedge the session since we
705                        // last shared the room key.  So we re-share it with
706                        // them in case they weren't able to decrypt the room
707                        // key the last time we shared it.
708                        olm_wedging_index < d.olm_wedging_index
709                    }
710                    _ => false,
711                })
712            })
713            .collect();
714
715        // The `encrypt_for_devices()` method adds the to-device requests that will send
716        // out the room key to the `OutboundGroupSession`. It doesn't do that
717        // for the m.room_key_withheld events since we might have more of those
718        // coming from the `collect_session_recipients()` method. Instead they get
719        // returned by the method.
720        let unable_to_encrypt_devices =
721            self.encrypt_for_devices(devices, &outbound, &mut changes).await?;
722
723        // Merge the withheld recipients.
724        withheld_devices.extend(unable_to_encrypt_devices);
725
726        // Now handle and add the withheld recipients to the resulting requests to the
727        // `OutboundGroupSession`.
728        self.handle_withheld_devices(&outbound, withheld_devices)?;
729
730        // The to-device requests get added to the outbound group session, this
731        // way we're making sure that they are persisted and scoped to the
732        // session.
733        let requests = outbound.pending_requests();
734
735        if requests.is_empty() {
736            if !outbound.shared() {
737                debug!("The room key doesn't need to be shared with anyone. Marking as shared.");
738
739                outbound.mark_as_shared();
740                changes.outbound_group_sessions.push(outbound.clone());
741            }
742        } else {
743            Self::log_room_key_sharing_result(&requests)
744        }
745
746        // Persist any changes we might have collected.
747        if !changes.is_empty() {
748            let session_count = changes.sessions.len();
749
750            self.store.save_changes(changes).await?;
751
752            trace!(
753                session_count = session_count,
754                "Stored the changed sessions after encrypting an room key"
755            );
756        }
757
758        Ok(requests)
759    }
760
761    /// Collect the devices belonging to the given user, and send the details of
762    /// a room key bundle to those devices.
763    ///
764    /// Returns a list of to-device requests which must be sent.
765    ///
766    /// For security reasons, only "safe" [`CollectStrategy`]s are supported, in
767    /// which the recipient must have signed their
768    /// devices. [`CollectStrategy::AllDevices`] and
769    /// [`CollectStrategy::ErrorOnVerifiedUserProblem`] are "unsafe" in this
770    /// respect,and are treated the same as
771    /// [`CollectStrategy::IdentityBasedStrategy`].
772    #[instrument(skip(self, bundle_data))]
773    pub async fn share_room_key_bundle_data(
774        &self,
775        user_id: &UserId,
776        collect_strategy: &CollectStrategy,
777        bundle_data: RoomKeyBundleContent,
778    ) -> OlmResult<Vec<ToDeviceRequest>> {
779        // Only allow conservative sharing strategies
780        let collect_strategy = match collect_strategy {
781            CollectStrategy::AllDevices | CollectStrategy::ErrorOnVerifiedUserProblem => {
782                warn!(
783                    "Ignoring request to use unsafe sharing strategy {collect_strategy:?} \
784                     for room key history sharing",
785                );
786                &CollectStrategy::IdentityBasedStrategy
787            }
788            CollectStrategy::IdentityBasedStrategy | CollectStrategy::OnlyTrustedDevices => {
789                collect_strategy
790            }
791        };
792
793        let mut changes = Changes::default();
794
795        let CollectRecipientsResult { devices, .. } =
796            share_strategy::collect_recipients_for_share_strategy(
797                &self.store,
798                iter::once(user_id),
799                collect_strategy,
800                None,
801            )
802            .await?;
803
804        let devices = devices.into_values().flatten().collect();
805        let event_type = bundle_data.event_type().to_owned();
806        let (requests, _) = self
807            .encrypt_content_for_devices(devices, &event_type, bundle_data, &mut changes)
808            .await?;
809
810        // TODO: figure out what to do with withheld devices
811
812        // Persist any changes we might have collected.
813        if !changes.is_empty() {
814            let session_count = changes.sessions.len();
815
816            self.store.save_changes(changes).await?;
817
818            trace!(
819                session_count = session_count,
820                "Stored the changed sessions after encrypting an room key"
821            );
822        }
823
824        Ok(requests)
825    }
826
827    /// Encrypt the given content for the given devices and build to-device
828    /// requests to send the encrypted content to them.
829    ///
830    /// Returns a tuple containing (1) the list of to-device requests, and (2)
831    /// the list of devices that we could not find an olm session for (so
832    /// need a withheld message).
833    pub(crate) async fn encrypt_content_for_devices(
834        &self,
835        recipient_devices: Vec<DeviceData>,
836        event_type: &str,
837        content: impl Serialize + Clone + Send + 'static,
838        changes: &mut Changes,
839    ) -> OlmResult<(Vec<ToDeviceRequest>, Vec<(DeviceData, WithheldCode)>)> {
840        let recipients = recipient_list_to_users_and_devices(&recipient_devices);
841        info!(?recipients, "Encrypting content of type {}", event_type);
842
843        // Chunk the recipients out so each to-device request will contain a
844        // limited amount of to-device messages.
845        //
846        // Create concurrent tasks for each chunk of recipients.
847        let tasks: Vec<_> = recipient_devices
848            .chunks(Self::MAX_TO_DEVICE_MESSAGES)
849            .map(|chunk| {
850                spawn(
851                    encrypt_content_for_devices(
852                        self.store.crypto_store(),
853                        event_type.to_owned(),
854                        content.clone(),
855                        chunk.to_vec(),
856                    )
857                    .in_current_span(),
858                )
859            })
860            .collect();
861
862        let mut no_olm_devices = Vec::new();
863        let mut to_device_requests = Vec::new();
864
865        // Wait for all the tasks to finish up and queue up the Olm session that
866        // was used to encrypt the room key to be persisted again. This is
867        // needed because each encryption step will mutate the Olm session,
868        // ratcheting its state forward.
869        for result in join_all(tasks).await {
870            let result = result.expect("Encryption task panicked")?;
871            if let Some(request) = result.to_device_request {
872                to_device_requests.push(request);
873            }
874            changes.sessions.extend(result.updated_olm_sessions);
875            no_olm_devices.extend(result.no_olm_devices);
876        }
877
878        Ok((to_device_requests, no_olm_devices))
879    }
880}
881
882/// Helper for [`GroupSessionManager::encrypt_content_for_devices`].
883///
884/// Encrypt the given content for the given devices and build a to-device
885/// request to send the encrypted content to them.
886///
887/// See also [`GroupSessionManager::encrypt_session_for`], which is similar
888/// but applies specifically to `m.room_key` messages that hold a megolm
889/// session key.
890async fn encrypt_content_for_devices(
891    store: Arc<CryptoStoreWrapper>,
892    event_type: String,
893    content: impl Serialize + Clone + Send + 'static,
894    devices: Vec<DeviceData>,
895) -> OlmResult<EncryptForDevicesResult> {
896    let mut result_builder = EncryptForDevicesResultBuilder::default();
897
898    async fn encrypt(
899        store: Arc<CryptoStoreWrapper>,
900        device: DeviceData,
901        event_type: String,
902        bundle_data: impl Serialize,
903    ) -> OlmResult<(Session, Raw<ToDeviceEncryptedEventContent>)> {
904        device.encrypt(store.as_ref(), &event_type, bundle_data).await
905    }
906
907    let tasks = devices.iter().map(|device| {
908        spawn(
909            encrypt(store.clone(), device.clone(), event_type.clone(), content.clone())
910                .in_current_span(),
911        )
912    });
913
914    let results = join_all(tasks).await;
915
916    for (device, result) in zip(devices, results) {
917        let encryption_result = result.expect("Encryption task panicked");
918
919        match encryption_result {
920            Ok((used_session, message)) => {
921                result_builder.on_successful_encryption(&device, used_session, message.cast());
922            }
923            Err(OlmError::MissingSession) => {
924                // There is no established Olm session for this device
925                result_builder.on_missing_session(device);
926            }
927            Err(e) => return Err(e),
928        }
929    }
930
931    Ok(result_builder.into_result())
932}
933
934/// Result of [`GroupSessionManager::encrypt_session_for`] and
935/// [`encrypt_content_for_devices`].
936#[derive(Debug)]
937struct EncryptForDevicesResult {
938    /// The request to send the to-device messages containing the encrypted
939    /// payload, if any devices were found.
940    to_device_request: Option<ToDeviceRequest>,
941
942    /// The devices which lack an Olm session and therefore need a withheld code
943    no_olm_devices: Vec<(DeviceData, WithheldCode)>,
944
945    /// The Olm sessions which were used to encrypt the requests and now need
946    /// persisting to the store.
947    updated_olm_sessions: Vec<Session>,
948}
949
950/// A helper for building [`EncryptForDevicesResult`]
951#[derive(Debug, Default)]
952struct EncryptForDevicesResultBuilder {
953    /// The payloads of the to-device messages
954    messages: BTreeMap<OwnedUserId, BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>>>,
955
956    /// The devices which lack an Olm session and therefore need a withheld code
957    no_olm_devices: Vec<(DeviceData, WithheldCode)>,
958
959    /// The Olm sessions which were used to encrypt the requests and now need
960    /// persisting to the store.
961    updated_olm_sessions: Vec<Session>,
962}
963
964impl EncryptForDevicesResultBuilder {
965    /// Record a successful encryption. The encrypted message is added to the
966    /// list to be sent, and the olm session is added to the list of those
967    /// that have been modified.
968    pub fn on_successful_encryption(
969        &mut self,
970        device: &DeviceData,
971        used_session: Session,
972        message: Raw<AnyToDeviceEventContent>,
973    ) {
974        self.updated_olm_sessions.push(used_session);
975
976        self.messages
977            .entry(device.user_id().to_owned())
978            .or_default()
979            .insert(DeviceIdOrAllDevices::DeviceId(device.device_id().to_owned()), message);
980    }
981
982    /// Record a device which didn't have an active Olm session.
983    pub fn on_missing_session(&mut self, device: DeviceData) {
984        self.no_olm_devices.push((device, WithheldCode::NoOlm));
985    }
986
987    /// Transform the accumulated results into an [`EncryptForDevicesResult`],
988    /// wrapping the messages, if any, into a `ToDeviceRequest`.
989    pub fn into_result(self) -> EncryptForDevicesResult {
990        let EncryptForDevicesResultBuilder { updated_olm_sessions, no_olm_devices, messages } =
991            self;
992
993        let mut encrypt_for_devices_result = EncryptForDevicesResult {
994            to_device_request: None,
995            updated_olm_sessions,
996            no_olm_devices,
997        };
998
999        if !messages.is_empty() {
1000            let request = ToDeviceRequest {
1001                event_type: ToDeviceEventType::RoomEncrypted,
1002                txn_id: TransactionId::new(),
1003                messages,
1004            };
1005            trace!(
1006                recipient_count = request.message_count(),
1007                transaction_id = ?request.txn_id,
1008                "Created a to-device request carrying room keys",
1009            );
1010            encrypt_for_devices_result.to_device_request = Some(request);
1011        };
1012
1013        encrypt_for_devices_result
1014    }
1015}
1016
1017fn recipient_list_to_users_and_devices(
1018    recipient_devices: &[DeviceData],
1019) -> BTreeMap<&UserId, BTreeSet<&DeviceId>> {
1020    #[allow(unknown_lints, clippy::unwrap_or_default)] // false positive
1021    recipient_devices.iter().fold(BTreeMap::new(), |mut acc, d| {
1022        acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id());
1023        acc
1024    })
1025}
1026
1027#[cfg(test)]
1028mod tests {
1029    use std::{
1030        collections::{BTreeMap, BTreeSet},
1031        iter,
1032        ops::Deref,
1033        sync::Arc,
1034    };
1035
1036    use assert_matches2::assert_let;
1037    use matrix_sdk_common::deserialized_responses::WithheldCode;
1038    use matrix_sdk_test::{async_test, ruma_response_from_json};
1039    use ruma::{
1040        api::client::{
1041            keys::{claim_keys, get_keys, upload_keys},
1042            to_device::send_event_to_device::v3::Response as ToDeviceResponse,
1043        },
1044        device_id,
1045        events::room::{
1046            history_visibility::HistoryVisibility, EncryptedFileInit, JsonWebKey, JsonWebKeyInit,
1047        },
1048        owned_room_id, room_id,
1049        serde::Base64,
1050        to_device::DeviceIdOrAllDevices,
1051        user_id, DeviceId, OneTimeKeyAlgorithm, OwnedMxcUri, TransactionId, UInt, UserId,
1052    };
1053    use serde_json::{json, Value};
1054
1055    use crate::{
1056        identities::DeviceData,
1057        machine::{
1058            test_helpers::get_machine_pair_with_setup_sessions_test_helper, EncryptionSyncChanges,
1059        },
1060        olm::{Account, SenderData},
1061        session_manager::{group_sessions::CollectRecipientsResult, CollectStrategy},
1062        types::{
1063            events::{
1064                room::encrypted::EncryptedToDeviceEvent,
1065                room_key_bundle::RoomKeyBundleContent,
1066                room_key_withheld::RoomKeyWithheldContent::{self, MegolmV1AesSha2},
1067            },
1068            requests::ToDeviceRequest,
1069            DeviceKeys, EventEncryptionAlgorithm, ProcessedToDeviceEvent,
1070        },
1071        EncryptionSettings, LocalTrust, OlmMachine,
1072    };
1073
1074    fn alice_id() -> &'static UserId {
1075        user_id!("@alice:example.org")
1076    }
1077
1078    fn alice_device_id() -> &'static DeviceId {
1079        device_id!("JLAFKJWSCS")
1080    }
1081
1082    /// Returns a /keys/query response for user "@example:localhost"
1083    fn keys_query_response() -> get_keys::v3::Response {
1084        let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_query.json");
1085        let data: Value = serde_json::from_slice(data).unwrap();
1086        ruma_response_from_json(&data)
1087    }
1088
1089    fn bob_keys_query_response() -> get_keys::v3::Response {
1090        let data = json!({
1091            "device_keys": {
1092                "@bob:localhost": {
1093                    "BOBDEVICE": {
1094                        "user_id": "@bob:localhost",
1095                        "device_id": "BOBDEVICE",
1096                        "algorithms": [
1097                            "m.olm.v1.curve25519-aes-sha2",
1098                            "m.megolm.v1.aes-sha2",
1099                            "m.megolm.v2.aes-sha2"
1100                        ],
1101                        "keys": {
1102                            "curve25519:BOBDEVICE": "QzXDFZj0Pt5xG4r11XGSrqE4mnFOTgRM5pz7n3tzohU",
1103                            "ed25519:BOBDEVICE": "T7QMEXcEo/NfiC/8doVHT+2XnMm0pDpRa27bmE8PlPI"
1104                        },
1105                        "signatures": {
1106                            "@bob:localhost": {
1107                                "ed25519:BOBDEVICE": "1Ee9J02KoVf4DKhT+LkurpZJEygiznqpgkT4lqvMTLtZyzShsVTnwmoMPttuGcJkLp9lMK1egveNYCEaYP80Cw"
1108                            }
1109                        }
1110                    }
1111                }
1112            }
1113        });
1114        ruma_response_from_json(&data)
1115    }
1116
1117    /// Returns a keys claim response for device `BOBDEVICE` of user
1118    /// `@bob:localhost`.
1119    fn bob_one_time_key() -> claim_keys::v3::Response {
1120        let data = json!({
1121            "failures": {},
1122            "one_time_keys":{
1123                "@bob:localhost":{
1124                    "BOBDEVICE":{
1125                      "signed_curve25519:AAAAAAAAAAA": {
1126                          "key":"bm1olfbksjC5SwKxCLLK4XaINCA0FwR/155J85gIpCk",
1127                          "signatures":{
1128                              "@bob:localhost":{
1129                                  "ed25519:BOBDEVICE":"BKyS/+EV76zdZkWgny2D0svZ0ycS3etfyHCrsDgm7MYe166HqQmSoX29HsjGLvE/5F+Sg2zW7RJileUvquPwDA"
1130                              }
1131                          }
1132                      }
1133                    }
1134                }
1135            }
1136        });
1137        ruma_response_from_json(&data)
1138    }
1139
1140    /// Returns a key claim response for device `NMMBNBUSNR` of user
1141    /// `@example2:localhost`
1142    fn keys_claim_response() -> claim_keys::v3::Response {
1143        let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_claim.json");
1144        let data: Value = serde_json::from_slice(data).unwrap();
1145        ruma_response_from_json(&data)
1146    }
1147
1148    async fn machine_with_user_test_helper(user_id: &UserId, device_id: &DeviceId) -> OlmMachine {
1149        let keys_query = keys_query_response();
1150        let txn_id = TransactionId::new();
1151
1152        let machine = OlmMachine::new(user_id, device_id).await;
1153
1154        // complete a /keys/query and /keys/claim for @example:localhost
1155        machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1156        let (txn_id, _keys_claim_request) = machine
1157            .get_missing_sessions(iter::once(user_id!("@example:localhost")))
1158            .await
1159            .unwrap()
1160            .unwrap();
1161        let keys_claim = keys_claim_response();
1162        machine.mark_request_as_sent(&txn_id, &keys_claim).await.unwrap();
1163
1164        // complete a /keys/query and /keys/claim for @bob:localhost
1165        machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1166        let (txn_id, _keys_claim_request) = machine
1167            .get_missing_sessions(iter::once(user_id!("@bob:localhost")))
1168            .await
1169            .unwrap()
1170            .unwrap();
1171        machine.mark_request_as_sent(&txn_id, &bob_one_time_key()).await.unwrap();
1172
1173        machine
1174    }
1175
1176    async fn machine() -> OlmMachine {
1177        machine_with_user_test_helper(alice_id(), alice_device_id()).await
1178    }
1179
1180    async fn machine_with_shared_room_key_test_helper() -> OlmMachine {
1181        let machine = machine().await;
1182        let room_id = room_id!("!test:localhost");
1183        let keys_claim = keys_claim_response();
1184
1185        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1186        let requests =
1187            machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
1188
1189        let outbound =
1190            machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1191
1192        assert!(!outbound.pending_requests().is_empty());
1193        assert!(!outbound.shared());
1194
1195        let response = ToDeviceResponse::new();
1196        for request in requests {
1197            machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1198        }
1199
1200        assert!(outbound.shared());
1201        assert!(outbound.pending_requests().is_empty());
1202
1203        machine
1204    }
1205
1206    #[async_test]
1207    async fn test_sharing() {
1208        let machine = machine().await;
1209        let room_id = room_id!("!test:localhost");
1210        let keys_claim = keys_claim_response();
1211
1212        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1213
1214        let requests =
1215            machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
1216
1217        let event_count: usize = requests
1218            .iter()
1219            .filter(|r| r.event_type == "m.room.encrypted".into())
1220            .map(|r| r.message_count())
1221            .sum();
1222
1223        // The keys claim response has a couple of one-time keys with invalid
1224        // signatures, thus only 148 sessions are actually created, we check
1225        // that all 148 valid sessions get an room key.
1226        assert_eq!(event_count, 148);
1227
1228        let withheld_count: usize = requests
1229            .iter()
1230            .filter(|r| r.event_type == "m.room_key.withheld".into())
1231            .map(|r| r.message_count())
1232            .sum();
1233        assert_eq!(withheld_count, 2);
1234    }
1235
1236    fn count_withheld_from(requests: &[Arc<ToDeviceRequest>], code: WithheldCode) -> usize {
1237        requests
1238            .iter()
1239            .filter(|r| r.event_type == "m.room_key.withheld".into())
1240            .map(|r| {
1241                let mut count = 0;
1242                // count targets
1243                for message in r.messages.values() {
1244                    message.iter().for_each(|(_, content)| {
1245                        let withheld: RoomKeyWithheldContent =
1246                            content.deserialize_as::<RoomKeyWithheldContent>().unwrap();
1247
1248                        if let MegolmV1AesSha2(content) = withheld {
1249                            if content.withheld_code() == code {
1250                                count += 1;
1251                            }
1252                        }
1253                    })
1254                }
1255                count
1256            })
1257            .sum()
1258    }
1259
1260    #[async_test]
1261    async fn test_no_olm_sent_once() {
1262        let machine = machine().await;
1263        let keys_claim = keys_claim_response();
1264
1265        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1266
1267        let first_room_id = room_id!("!test:localhost");
1268
1269        let requests = machine
1270            .share_room_key(first_room_id, users.to_owned(), EncryptionSettings::default())
1271            .await
1272            .unwrap();
1273
1274        // there will be two no_olm
1275        let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1276        assert_eq!(withheld_count, 2);
1277
1278        // Re-sharing same session while request has not been sent should not produces
1279        // withheld
1280        let new_requests = machine
1281            .share_room_key(first_room_id, users, EncryptionSettings::default())
1282            .await
1283            .unwrap();
1284        let withheld_count: usize = count_withheld_from(&new_requests, WithheldCode::NoOlm);
1285        // No additional request was added, still the 2 already pending
1286        assert_eq!(withheld_count, 2);
1287
1288        let response = ToDeviceResponse::new();
1289        for request in requests {
1290            machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1291        }
1292
1293        // The fact that an olm was sent should be remembered even if sharing another
1294        // session in an other room.
1295        let second_room_id = room_id!("!other:localhost");
1296        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1297        let requests = machine
1298            .share_room_key(second_room_id, users, EncryptionSettings::default())
1299            .await
1300            .unwrap();
1301
1302        let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1303        assert_eq!(withheld_count, 0);
1304
1305        // Help how do I simulate the creation of a new session for the device
1306        // with no session now?
1307    }
1308
1309    #[async_test]
1310    async fn test_ratcheted_sharing() {
1311        let machine = machine_with_shared_room_key_test_helper().await;
1312
1313        let room_id = room_id!("!test:localhost");
1314        let late_joiner = user_id!("@bob:localhost");
1315        let keys_claim = keys_claim_response();
1316
1317        let mut users: BTreeSet<_> = keys_claim.one_time_keys.keys().map(Deref::deref).collect();
1318        users.insert(late_joiner);
1319
1320        let requests = machine
1321            .share_room_key(room_id, users.into_iter(), EncryptionSettings::default())
1322            .await
1323            .unwrap();
1324
1325        let event_count: usize = requests
1326            .iter()
1327            .filter(|r| r.event_type == "m.room.encrypted".into())
1328            .map(|r| r.message_count())
1329            .sum();
1330        let outbound =
1331            machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1332
1333        assert_eq!(event_count, 1);
1334        assert!(!outbound.pending_requests().is_empty());
1335    }
1336
1337    #[async_test]
1338    async fn test_changing_encryption_settings() {
1339        let machine = machine_with_shared_room_key_test_helper().await;
1340        let room_id = room_id!("!test:localhost");
1341        let keys_claim = keys_claim_response();
1342
1343        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1344        let outbound =
1345            machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1346
1347        let CollectRecipientsResult { should_rotate, .. } = machine
1348            .inner
1349            .group_session_manager
1350            .collect_session_recipients(users.clone(), &EncryptionSettings::default(), &outbound)
1351            .await
1352            .unwrap();
1353
1354        assert!(!should_rotate);
1355
1356        let settings = EncryptionSettings {
1357            history_visibility: HistoryVisibility::Invited,
1358            ..Default::default()
1359        };
1360
1361        let CollectRecipientsResult { should_rotate, .. } = machine
1362            .inner
1363            .group_session_manager
1364            .collect_session_recipients(users.clone(), &settings, &outbound)
1365            .await
1366            .unwrap();
1367
1368        assert!(should_rotate);
1369
1370        let settings = EncryptionSettings {
1371            algorithm: EventEncryptionAlgorithm::from("m.megolm.v2.aes-sha2"),
1372            ..Default::default()
1373        };
1374
1375        let CollectRecipientsResult { should_rotate, .. } = machine
1376            .inner
1377            .group_session_manager
1378            .collect_session_recipients(users, &settings, &outbound)
1379            .await
1380            .unwrap();
1381
1382        assert!(should_rotate);
1383    }
1384
1385    #[async_test]
1386    async fn test_key_recipient_collecting() {
1387        // The user id comes from the fact that the keys_query.json file uses
1388        // this one.
1389        let user_id = user_id!("@example:localhost");
1390        let device_id = device_id!("TESTDEVICE");
1391        let room_id = room_id!("!test:localhost");
1392
1393        let machine = machine_with_user_test_helper(user_id, device_id).await;
1394
1395        let (outbound, _) = machine
1396            .inner
1397            .group_session_manager
1398            .get_or_create_outbound_session(
1399                room_id,
1400                EncryptionSettings::default(),
1401                SenderData::unknown(),
1402            )
1403            .await
1404            .expect("We should be able to create a new session");
1405        let history_visibility = HistoryVisibility::Joined;
1406        let settings = EncryptionSettings { history_visibility, ..Default::default() };
1407
1408        let users = [user_id].into_iter();
1409
1410        let CollectRecipientsResult { devices: recipients, .. } = machine
1411            .inner
1412            .group_session_manager
1413            .collect_session_recipients(users, &settings, &outbound)
1414            .await
1415            .expect("We should be able to collect the session recipients");
1416
1417        assert!(!recipients[user_id].is_empty());
1418
1419        // Make sure that our own device isn't part of the recipients.
1420        assert!(!recipients[user_id]
1421            .iter()
1422            .any(|d| d.user_id() == user_id && d.device_id() == device_id));
1423
1424        let settings = EncryptionSettings {
1425            sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1426            ..Default::default()
1427        };
1428        let users = [user_id].into_iter();
1429
1430        let CollectRecipientsResult { devices: recipients, .. } = machine
1431            .inner
1432            .group_session_manager
1433            .collect_session_recipients(users, &settings, &outbound)
1434            .await
1435            .expect("We should be able to collect the session recipients");
1436
1437        assert!(recipients[user_id].is_empty());
1438
1439        let device_id = "AFGUOBTZWM".into();
1440        let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1441        device.set_local_trust(LocalTrust::Verified).await.unwrap();
1442        let users = [user_id].into_iter();
1443
1444        let CollectRecipientsResult { devices: recipients, withheld_devices: withheld, .. } =
1445            machine
1446                .inner
1447                .group_session_manager
1448                .collect_session_recipients(users, &settings, &outbound)
1449                .await
1450                .expect("We should be able to collect the session recipients");
1451
1452        assert!(recipients[user_id]
1453            .iter()
1454            .any(|d| d.user_id() == user_id && d.device_id() == device_id));
1455
1456        let devices = machine.get_user_devices(user_id, None).await.unwrap();
1457        devices
1458            .devices()
1459            // Ignore our own device
1460            .filter(|d| d.device_id() != device_id!("TESTDEVICE"))
1461            .for_each(|d| {
1462                if d.is_blacklisted() {
1463                    assert!(withheld.iter().any(|(dev, w)| {
1464                        dev.device_id() == d.device_id() && w == &WithheldCode::Blacklisted
1465                    }));
1466                } else if !d.is_verified() {
1467                    // the device should then be in the list of withhelds
1468                    assert!(withheld.iter().any(|(dev, w)| {
1469                        dev.device_id() == d.device_id() && w == &WithheldCode::Unverified
1470                    }));
1471                }
1472            });
1473
1474        assert_eq!(149, withheld.len());
1475    }
1476
1477    #[async_test]
1478    async fn test_sharing_withheld_only_trusted() {
1479        let machine = machine().await;
1480        let room_id = room_id!("!test:localhost");
1481        let keys_claim = keys_claim_response();
1482
1483        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1484        let settings = EncryptionSettings {
1485            sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1486            ..Default::default()
1487        };
1488
1489        // Trust only one
1490        let user_id = user_id!("@example:localhost");
1491        let device_id = "MWFXPINOAO".into();
1492        let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1493        device.set_local_trust(LocalTrust::Verified).await.unwrap();
1494        machine
1495            .get_device(user_id, "MWVTUXDNNM".into(), None)
1496            .await
1497            .unwrap()
1498            .unwrap()
1499            .set_local_trust(LocalTrust::BlackListed)
1500            .await
1501            .unwrap();
1502
1503        let requests = machine.share_room_key(room_id, users, settings).await.unwrap();
1504
1505        // One room key should be sent
1506        let room_key_count =
1507            requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).count();
1508
1509        assert_eq!(1, room_key_count);
1510
1511        let withheld_count =
1512            requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1513        // Can be send in one batch
1514        assert_eq!(1, withheld_count);
1515
1516        let event_count: usize = requests
1517            .iter()
1518            .filter(|r| r.event_type == "m.room_key.withheld".into())
1519            .map(|r| r.message_count())
1520            .sum();
1521
1522        // withhelds are sent in clear so all device should be counted (even if no OTK)
1523        assert_eq!(event_count, 149);
1524
1525        // One should be blacklisted
1526        let has_blacklist =
1527            requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).any(|r| {
1528                let device_key = DeviceIdOrAllDevices::from(device_id!("MWVTUXDNNM").to_owned());
1529                let content = &r.messages[user_id][&device_key];
1530                let withheld: RoomKeyWithheldContent =
1531                    content.deserialize_as::<RoomKeyWithheldContent>().unwrap();
1532                if let MegolmV1AesSha2(content) = withheld {
1533                    content.withheld_code() == WithheldCode::Blacklisted
1534                } else {
1535                    false
1536                }
1537            });
1538
1539        assert!(has_blacklist);
1540    }
1541
1542    #[async_test]
1543    async fn test_no_olm_withheld_only_sent_once() {
1544        let keys_query = keys_query_response();
1545        let txn_id = TransactionId::new();
1546
1547        let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1548
1549        machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1550        machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1551
1552        let first_room = room_id!("!test:localhost");
1553        let second_room = room_id!("!test2:localhost");
1554        let bob_id = user_id!("@bob:localhost");
1555
1556        let settings = EncryptionSettings::default();
1557        let users = [bob_id];
1558
1559        let requests = machine
1560            .share_room_key(first_room, users.into_iter(), settings.to_owned())
1561            .await
1562            .unwrap();
1563
1564        // One withheld request should be sent.
1565        let withheld_count =
1566            requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1567
1568        assert_eq!(withheld_count, 1);
1569        assert_eq!(requests.len(), 1);
1570
1571        // On the second room key share attempt we're not sending another `m.no_olm`
1572        // code since the first one is taking care of this.
1573        let second_requests =
1574            machine.share_room_key(second_room, users.into_iter(), settings).await.unwrap();
1575
1576        let withheld_count =
1577            second_requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1578
1579        assert_eq!(withheld_count, 0);
1580        assert_eq!(second_requests.len(), 0);
1581
1582        let response = ToDeviceResponse::new();
1583
1584        let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1585
1586        // The device should be marked as having the `m.no_olm` code received only after
1587        // the request has been marked as sent.
1588        assert!(!device.was_withheld_code_sent());
1589
1590        for request in requests {
1591            machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1592        }
1593
1594        let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1595
1596        assert!(device.was_withheld_code_sent());
1597    }
1598
1599    #[async_test]
1600    async fn test_resend_session_after_unwedging() {
1601        let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1602        assert_let!(Ok(Some((txn_id, device_keys_request))) = machine.upload_device_keys().await);
1603        let device_keys_response = upload_keys::v3::Response::new(BTreeMap::from([(
1604            OneTimeKeyAlgorithm::SignedCurve25519,
1605            UInt::new(device_keys_request.one_time_keys.len() as u64).unwrap(),
1606        )]));
1607        machine.mark_request_as_sent(&txn_id, &device_keys_response).await.unwrap();
1608
1609        let room_id = room_id!("!test:localhost");
1610
1611        let bob_id = user_id!("@bob:localhost");
1612        let bob_account = Account::new(bob_id);
1613        let keys_query_data = json!({
1614            "device_keys": {
1615                "@bob:localhost": {
1616                    bob_account.device_id.clone(): bob_account.device_keys()
1617                }
1618            }
1619        });
1620        let keys_query: get_keys::v3::Response = ruma_response_from_json(&keys_query_data);
1621        let txn_id = TransactionId::new();
1622        machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1623
1624        let alice_device_keys =
1625            device_keys_request.device_keys.unwrap().deserialize_as::<DeviceKeys>().unwrap();
1626        let mut alice_otks = device_keys_request.one_time_keys.iter();
1627        let alice_device = DeviceData::new(alice_device_keys, LocalTrust::Unset);
1628
1629        {
1630            // Bob creates an Olm session with Alice and encrypts a message to her
1631            let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1632            let mut session = bob_account
1633                .create_outbound_session(
1634                    &alice_device,
1635                    &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1636                    bob_account.device_keys(),
1637                )
1638                .unwrap();
1639            let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1640
1641            let to_device =
1642                EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1643
1644            // Alice decrypts the message
1645            let sync_changes = EncryptionSyncChanges {
1646                to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1647                changed_devices: &Default::default(),
1648                one_time_keys_counts: &Default::default(),
1649                unused_fallback_keys: None,
1650                next_batch_token: None,
1651            };
1652            let (decrypted, _) = machine.receive_sync_changes(sync_changes).await.unwrap();
1653
1654            assert_eq!(1, decrypted.len());
1655        }
1656
1657        // Alice shares the room key with Bob
1658        {
1659            let requests = machine
1660                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1661                .await
1662                .unwrap();
1663
1664            // We should have had one to-device event
1665            let event_count: usize = requests
1666                .iter()
1667                .filter(|r| r.event_type == "m.room.encrypted".into())
1668                .map(|r| r.message_count())
1669                .sum();
1670            assert_eq!(event_count, 1);
1671
1672            let response = ToDeviceResponse::new();
1673            for request in requests {
1674                machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1675            }
1676        }
1677
1678        // When Alice shares the room key again, there shouldn't be any
1679        // to-device events, since we already shared with Bob
1680        {
1681            let requests = machine
1682                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1683                .await
1684                .unwrap();
1685
1686            let event_count: usize = requests
1687                .iter()
1688                .filter(|r| r.event_type == "m.room.encrypted".into())
1689                .map(|r| r.message_count())
1690                .sum();
1691            assert_eq!(event_count, 0);
1692        }
1693
1694        // Pretend that Bob wasn't able to decrypt, so he tries to unwedge
1695        {
1696            let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1697            let mut session = bob_account
1698                .create_outbound_session(
1699                    &alice_device,
1700                    &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1701                    bob_account.device_keys(),
1702                )
1703                .unwrap();
1704            let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1705
1706            let to_device =
1707                EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1708
1709            // Alice decrypts the unwedge message
1710            let sync_changes = EncryptionSyncChanges {
1711                to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1712                changed_devices: &Default::default(),
1713                one_time_keys_counts: &Default::default(),
1714                unused_fallback_keys: None,
1715                next_batch_token: None,
1716            };
1717            let (decrypted, _) = machine.receive_sync_changes(sync_changes).await.unwrap();
1718
1719            assert_eq!(1, decrypted.len());
1720        }
1721
1722        // When Alice shares the room key again, it should be re-shared with Bob
1723        {
1724            let requests = machine
1725                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1726                .await
1727                .unwrap();
1728
1729            let event_count: usize = requests
1730                .iter()
1731                .filter(|r| r.event_type == "m.room.encrypted".into())
1732                .map(|r| r.message_count())
1733                .sum();
1734            assert_eq!(event_count, 1);
1735
1736            let response = ToDeviceResponse::new();
1737            for request in requests {
1738                machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1739            }
1740        }
1741
1742        // When Alice shares the room key yet again, there shouldn't be any
1743        // to-device events
1744        {
1745            let requests = machine
1746                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1747                .await
1748                .unwrap();
1749
1750            let event_count: usize = requests
1751                .iter()
1752                .filter(|r| r.event_type == "m.room.encrypted".into())
1753                .map(|r| r.message_count())
1754                .sum();
1755            assert_eq!(event_count, 0);
1756        }
1757    }
1758
1759    #[async_test]
1760    async fn test_room_key_bundle_sharing() {
1761        let (alice, bob) = get_machine_pair_with_setup_sessions_test_helper(
1762            user_id!("@alice:localhost"),
1763            user_id!("@bob:localhost"),
1764            false,
1765        )
1766        .await;
1767
1768        // Alice trusts Bob's device
1769        let device = alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap();
1770        device.set_local_trust(LocalTrust::Verified).await.unwrap();
1771
1772        let content = RoomKeyBundleContent {
1773            room_id: owned_room_id!("!room:id"),
1774            file: (EncryptedFileInit {
1775                url: OwnedMxcUri::from("test"),
1776                key: JsonWebKey::from(JsonWebKeyInit {
1777                    kty: "oct".to_owned(),
1778                    key_ops: vec!["encrypt".to_owned(), "decrypt".to_owned()],
1779                    alg: "A256CTR".to_owned(),
1780                    #[allow(clippy::unnecessary_to_owned)]
1781                    k: Base64::new(vec![0u8; 0]),
1782                    ext: true,
1783                }),
1784                iv: Base64::new(vec![0u8; 0]),
1785                hashes: Default::default(),
1786                v: "".to_owned(),
1787            })
1788            .into(),
1789        };
1790
1791        let requests = alice
1792            .share_room_key_bundle_data(
1793                bob.user_id(),
1794                &CollectStrategy::OnlyTrustedDevices,
1795                content,
1796            )
1797            .await
1798            .unwrap();
1799
1800        // There should be exactly one message
1801        let requests: Vec<_> =
1802            requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).collect();
1803        let message_count: usize = requests.iter().map(|r| r.message_count()).sum();
1804        assert_eq!(message_count, 1);
1805
1806        // Bob decrypts the message
1807        let bob_message = requests[0]
1808            .messages
1809            .get(bob.user_id())
1810            .unwrap()
1811            .get(&(bob.device_id().to_owned().into()))
1812            .unwrap();
1813        let to_device = EncryptedToDeviceEvent::new(
1814            alice.user_id().to_owned(),
1815            bob_message.cast_ref().deserialize().unwrap(),
1816        );
1817
1818        let sync_changes = EncryptionSyncChanges {
1819            to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1820            changed_devices: &Default::default(),
1821            one_time_keys_counts: &Default::default(),
1822            unused_fallback_keys: None,
1823            next_batch_token: None,
1824        };
1825        let (decrypted, _) = bob.receive_sync_changes(sync_changes).await.unwrap();
1826        assert_eq!(1, decrypted.len());
1827        use crate::types::events::EventType;
1828        assert_let!(
1829            ProcessedToDeviceEvent::Decrypted { raw, .. } = decrypted.first().unwrap().clone()
1830        );
1831        assert_eq!(
1832            raw.get_field::<String>("type").unwrap().unwrap(),
1833            RoomKeyBundleContent::EVENT_TYPE,
1834        );
1835    }
1836}