Skip to main content

matrix_sdk_crypto/session_manager/group_sessions/
mod.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod share_strategy;
16
17use std::{
18    collections::{BTreeMap, BTreeSet},
19    fmt::Debug,
20    iter,
21    iter::zip,
22    sync::Arc,
23};
24
25use futures_util::future::join_all;
26use itertools::Itertools;
27use matrix_sdk_common::{
28    deserialized_responses::WithheldCode, executor::spawn, locks::RwLock as StdRwLock,
29};
30#[cfg(feature = "experimental-encrypted-state-events")]
31use ruma::events::AnyStateEventContent;
32use ruma::{
33    DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId,
34    UserId,
35    events::{AnyMessageLikeEventContent, AnyToDeviceEventContent, ToDeviceEventType},
36    serde::Raw,
37    to_device::DeviceIdOrAllDevices,
38};
39use serde::Serialize;
40pub use share_strategy::CollectStrategy;
41#[cfg(feature = "experimental-send-custom-to-device")]
42pub(crate) use share_strategy::split_devices_for_share_strategy;
43pub(crate) use share_strategy::{
44    CollectRecipientsResult, withheld_code_for_device_for_share_strategy,
45};
46use tracing::{Instrument, debug, error, info, instrument, trace, warn};
47
48#[cfg(feature = "experimental-encrypted-state-events")]
49use crate::types::events::room::encrypted::RoomEncryptedEventContent;
50use crate::{
51    Device, DeviceData, EncryptionSettings, OlmError,
52    error::{EventError, MegolmResult, OlmResult},
53    identities::device::MaybeEncryptedRoomKey,
54    olm::{
55        InboundGroupSession, OutboundGroupSession, OutboundGroupSessionEncryptionResult,
56        SenderData, SenderDataFinder, Session, ShareInfo, ShareState,
57    },
58    store::{CryptoStoreWrapper, Result as StoreResult, Store, types::Changes},
59    types::{
60        events::{
61            EventType, room::encrypted::ToDeviceEncryptedEventContent,
62            room_key_bundle::RoomKeyBundleContent,
63        },
64        requests::ToDeviceRequest,
65    },
66};
67
68#[derive(Clone, Debug)]
69pub(crate) struct GroupSessionCache {
70    store: Store,
71    sessions: Arc<StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>>,
72    /// A map from the request id to the group session that the request belongs
73    /// to. Used to mark requests belonging to the session as shared.
74    sessions_being_shared: Arc<StdRwLock<BTreeMap<OwnedTransactionId, OutboundGroupSession>>>,
75}
76
77impl GroupSessionCache {
78    pub(crate) fn new(store: Store) -> Self {
79        Self { store, sessions: Default::default(), sessions_being_shared: Default::default() }
80    }
81
82    pub(crate) fn insert(&self, session: OutboundGroupSession) {
83        self.sessions.write().insert(session.room_id().to_owned(), session);
84    }
85
86    /// Either get a session for the given room from the cache or load it from
87    /// the store.
88    ///
89    /// # Arguments
90    ///
91    /// * `room_id` - The id of the room this session is used for.
92    pub async fn get_or_load(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
93        // Get the cached session, if there isn't one load one from the store
94        // and put it in the cache.
95        if let Some(s) = self.sessions.read().get(room_id) {
96            return Some(s.clone());
97        }
98
99        match self.store.get_outbound_group_session(room_id).await {
100            Ok(Some(s)) => {
101                {
102                    let mut sessions_being_shared = self.sessions_being_shared.write();
103                    for request_id in s.pending_request_ids() {
104                        sessions_being_shared.insert(request_id, s.clone());
105                    }
106                }
107
108                self.sessions.write().insert(room_id.to_owned(), s.clone());
109
110                Some(s)
111            }
112            Ok(None) => None,
113            Err(e) => {
114                error!("Couldn't restore an outbound group session: {e:?}");
115                None
116            }
117        }
118    }
119
120    /// Get an outbound group session for a room, if one exists.
121    ///
122    /// # Arguments
123    ///
124    /// * `room_id` - The id of the room for which we should get the outbound
125    ///   group session.
126    #[cfg(test)]
127    fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
128        self.sessions.read().get(room_id).cloned()
129    }
130
131    /// Returns whether any session is withheld with the given device and code.
132    fn has_session_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
133        self.sessions.read().values().any(|s| s.sharing_view().is_withheld_to(device, code))
134    }
135
136    fn remove_from_being_shared(&self, id: &TransactionId) -> Option<OutboundGroupSession> {
137        self.sessions_being_shared.write().remove(id)
138    }
139
140    fn mark_as_being_shared(&self, id: OwnedTransactionId, session: OutboundGroupSession) {
141        self.sessions_being_shared.write().insert(id, session);
142    }
143}
144
145#[derive(Debug, Clone)]
146pub(crate) struct GroupSessionManager {
147    /// Store for the encryption keys.
148    /// Persists all the encryption keys so a client can resume the session
149    /// without the need to create new keys.
150    store: Store,
151    /// The currently active outbound group sessions.
152    sessions: GroupSessionCache,
153}
154
155impl GroupSessionManager {
156    const MAX_TO_DEVICE_MESSAGES: usize = 250;
157
158    pub fn new(store: Store) -> Self {
159        Self { store: store.clone(), sessions: GroupSessionCache::new(store) }
160    }
161
162    pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
163        if let Some(s) = self.sessions.get_or_load(room_id).await {
164            s.invalidate_session();
165
166            let mut changes = Changes::default();
167            changes.outbound_group_sessions.push(s.clone());
168            self.store.save_changes(changes).await?;
169
170            Ok(true)
171        } else {
172            Ok(false)
173        }
174    }
175
176    pub async fn mark_request_as_sent(&self, request_id: &TransactionId) -> StoreResult<()> {
177        let Some(session) = self.sessions.remove_from_being_shared(request_id) else {
178            return Ok(());
179        };
180
181        let no_olm = session.mark_request_as_sent(request_id);
182
183        let mut changes = Changes::default();
184
185        for (user_id, devices) in &no_olm {
186            for device_id in devices {
187                let device = self.store.get_device(user_id, device_id).await;
188
189                if let Ok(Some(device)) = device {
190                    device.mark_withheld_code_as_sent();
191                    changes.devices.changed.push(device.inner.clone());
192                } else {
193                    error!(
194                        ?request_id,
195                        "Marking to-device no olm as sent but device not found, might \
196                            have been deleted?"
197                    );
198                }
199            }
200        }
201
202        changes.outbound_group_sessions.push(session.clone());
203        self.store.save_changes(changes).await
204    }
205
206    #[cfg(test)]
207    pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
208        self.sessions.get(room_id)
209    }
210
211    pub async fn encrypt(
212        &self,
213        room_id: &RoomId,
214        event_type: &str,
215        content: &Raw<AnyMessageLikeEventContent>,
216    ) -> MegolmResult<OutboundGroupSessionEncryptionResult> {
217        let session =
218            self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
219
220        assert!(!session.expired(), "Session expired");
221
222        let result = session.encrypt(event_type, content).await;
223
224        let mut changes = Changes::default();
225        changes.outbound_group_sessions.push(session);
226        self.store.save_changes(changes).await?;
227
228        Ok(result)
229    }
230
231    /// Encrypts a state event for the given room using its outbound group
232    /// session.
233    ///
234    /// # Arguments
235    ///
236    /// * `room_id` - The ID of the room where the state event will be sent.
237    /// * `event_type` - The type of the state event to encrypt.
238    /// * `state_key` - The state key associated with the event.
239    /// * `content` - The raw content of the state event to encrypt.
240    ///
241    /// # Returns
242    ///
243    /// Returns the raw encrypted state event content.
244    ///
245    /// # Errors
246    ///
247    /// Returns an error if saving changes to the store fails.
248    ///
249    /// # Panics
250    ///
251    /// Panics if no session exists for the given room ID, or the session
252    /// has expired.
253    #[cfg(feature = "experimental-encrypted-state-events")]
254    pub async fn encrypt_state(
255        &self,
256        room_id: &RoomId,
257        event_type: &str,
258        state_key: &str,
259        content: &Raw<AnyStateEventContent>,
260    ) -> MegolmResult<Raw<RoomEncryptedEventContent>> {
261        let session =
262            self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
263
264        assert!(!session.expired(), "Session expired");
265
266        let content = session.encrypt_state(event_type, state_key, content).await;
267
268        let mut changes = Changes::default();
269        changes.outbound_group_sessions.push(session);
270        self.store.save_changes(changes).await?;
271
272        Ok(content)
273    }
274
275    /// Create a new outbound group session.
276    ///
277    /// This also creates a matching inbound group session.
278    pub async fn create_outbound_group_session(
279        &self,
280        room_id: &RoomId,
281        settings: EncryptionSettings,
282        own_sender_data: SenderData,
283    ) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> {
284        let (outbound, inbound) = self
285            .store
286            .static_account()
287            .create_group_session_pair(room_id, settings, own_sender_data)
288            .await
289            .map_err(|_| EventError::UnsupportedAlgorithm)?;
290
291        self.sessions.insert(outbound.clone());
292        Ok((outbound, inbound))
293    }
294
295    pub async fn get_or_create_outbound_session(
296        &self,
297        room_id: &RoomId,
298        settings: EncryptionSettings,
299        own_sender_data: SenderData,
300    ) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> {
301        let outbound_session = self.sessions.get_or_load(room_id).await;
302
303        // If there is no session or the session has expired or is invalid,
304        // create a new one.
305        if let Some(s) = outbound_session {
306            if s.expired() || s.invalidated() {
307                self.create_outbound_group_session(room_id, settings, own_sender_data)
308                    .await
309                    .map(|(o, i)| (o, i.into()))
310            } else {
311                Ok((s, None))
312            }
313        } else {
314            self.create_outbound_group_session(room_id, settings, own_sender_data)
315                .await
316                .map(|(o, i)| (o, i.into()))
317        }
318    }
319
320    /// Encrypt the given group session key for the given devices and create
321    /// to-device requests that sends the encrypted content to them.
322    ///
323    /// See also [`encrypt_content_for_devices`] which is similar
324    /// but is not specific to group sessions, and does not return the
325    /// [`ShareInfo`] data.
326    async fn encrypt_session_for(
327        store: Arc<CryptoStoreWrapper>,
328        group_session: OutboundGroupSession,
329        devices: Vec<DeviceData>,
330    ) -> OlmResult<(
331        EncryptForDevicesResult,
332        BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>,
333    )> {
334        // Use a named type instead of a tuple with rather long type name
335        pub struct DeviceResult {
336            device: DeviceData,
337            maybe_encrypted_room_key: MaybeEncryptedRoomKey,
338        }
339
340        let mut result_builder = EncryptForDevicesResultBuilder::default();
341        let mut share_infos = BTreeMap::new();
342
343        // XXX is there a way to do this that doesn't involve cloning the
344        // `Arc<CryptoStoreWrapper>` for each device?
345        let encrypt = |store: Arc<CryptoStoreWrapper>,
346                       device: DeviceData,
347                       session: OutboundGroupSession| async move {
348            let encryption_result = device.maybe_encrypt_room_key(store.as_ref(), session).await?;
349
350            Ok::<_, OlmError>(DeviceResult { device, maybe_encrypted_room_key: encryption_result })
351        };
352
353        let tasks: Vec<_> = devices
354            .iter()
355            .map(|d| spawn(encrypt(store.clone(), d.clone(), group_session.clone())))
356            .collect();
357
358        let results = join_all(tasks).await;
359
360        for result in results {
361            let result = result.expect("Encryption task panicked")?;
362
363            match result.maybe_encrypted_room_key {
364                MaybeEncryptedRoomKey::Encrypted { used_session, share_info, message } => {
365                    result_builder.on_successful_encryption(&result.device, *used_session, message);
366
367                    let user_id = result.device.user_id().to_owned();
368                    let device_id = result.device.device_id().to_owned();
369                    share_infos
370                        .entry(user_id)
371                        .or_insert_with(BTreeMap::new)
372                        .insert(device_id, *share_info);
373                }
374                MaybeEncryptedRoomKey::MissingSession => {
375                    result_builder.on_missing_session(result.device);
376                }
377            }
378        }
379
380        Ok((result_builder.into_result(), share_infos))
381    }
382
383    /// Given a list of user and an outbound session, return the list of users
384    /// and their devices that this session should be shared with.
385    ///
386    /// Returns information indicating whether the session needs to be rotated
387    /// and the list of users/devices that should receive or not the session
388    /// (with withheld reason).
389    #[instrument(skip_all)]
390    pub async fn collect_session_recipients(
391        &self,
392        users: impl Iterator<Item = &UserId>,
393        settings: &EncryptionSettings,
394        outbound: &OutboundGroupSession,
395    ) -> OlmResult<CollectRecipientsResult> {
396        share_strategy::collect_session_recipients(&self.store, users, settings, outbound).await
397    }
398
399    async fn encrypt_request(
400        store: Arc<CryptoStoreWrapper>,
401        chunk: Vec<DeviceData>,
402        outbound: OutboundGroupSession,
403        sessions: GroupSessionCache,
404    ) -> OlmResult<(Vec<Session>, Vec<(DeviceData, WithheldCode)>)> {
405        let (result, share_infos) =
406            Self::encrypt_session_for(store, outbound.clone(), chunk).await?;
407
408        if let Some(request) = result.to_device_request {
409            let id = request.txn_id.clone();
410            outbound.add_request(id.clone(), request.into(), share_infos);
411            sessions.mark_as_being_shared(id, outbound.clone());
412        }
413
414        Ok((result.updated_olm_sessions, result.no_olm_devices))
415    }
416
417    pub(crate) fn session_cache(&self) -> GroupSessionCache {
418        self.sessions.clone()
419    }
420
421    async fn maybe_rotate_group_session(
422        &self,
423        should_rotate: bool,
424        room_id: &RoomId,
425        outbound: OutboundGroupSession,
426        encryption_settings: EncryptionSettings,
427        changes: &mut Changes,
428        own_device: Option<Device>,
429    ) -> OlmResult<OutboundGroupSession> {
430        Ok(if should_rotate {
431            let old_session_id = outbound.session_id();
432
433            let (outbound, mut inbound) = self
434                .create_outbound_group_session(room_id, encryption_settings, SenderData::unknown())
435                .await?;
436
437            // Use our own device info to populate the SenderData that validates the
438            // InboundGroupSession that we create as a pair to the OutboundGroupSession we
439            // are sending out.
440            let own_sender_data = if let Some(device) = own_device {
441                SenderDataFinder::find_using_device_data(
442                    &self.store,
443                    device.inner.clone(),
444                    &inbound,
445                )
446                .await?
447            } else {
448                error!("Unable to find our own device!");
449                SenderData::unknown()
450            };
451            inbound.sender_data = own_sender_data;
452
453            changes.outbound_group_sessions.push(outbound.clone());
454            changes.inbound_group_sessions.push(inbound);
455
456            debug!(
457                old_session_id = old_session_id,
458                session_id = outbound.session_id(),
459                "A user or device has left the room since we last sent a \
460                message, or the encryption settings have changed. Rotating the \
461                room key.",
462            );
463
464            outbound
465        } else {
466            outbound
467        })
468    }
469
470    async fn encrypt_for_devices(
471        &self,
472        recipient_devices: Vec<DeviceData>,
473        group_session: &OutboundGroupSession,
474        changes: &mut Changes,
475    ) -> OlmResult<Vec<(DeviceData, WithheldCode)>> {
476        // If we have some recipients, log them here.
477        if !recipient_devices.is_empty() {
478            let recipients = recipient_list_to_users_and_devices(&recipient_devices);
479
480            // If there are new recipients we need to persist the outbound group
481            // session as the to-device requests are persisted with the session.
482            changes.outbound_group_sessions = vec![group_session.clone()];
483
484            let message_index = group_session.message_index().await;
485
486            info!(
487                ?recipients,
488                message_index,
489                room_id = ?group_session.room_id(),
490                session_id = group_session.session_id(),
491                "Trying to encrypt a room key",
492            );
493        }
494
495        // Chunk the recipients out so each to-device request will contain a
496        // limited amount of to-device messages.
497        //
498        // Create concurrent tasks for each chunk of recipients.
499        let tasks: Vec<_> = recipient_devices
500            .chunks(Self::MAX_TO_DEVICE_MESSAGES)
501            .map(|chunk| {
502                spawn(Self::encrypt_request(
503                    self.store.crypto_store(),
504                    chunk.to_vec(),
505                    group_session.clone(),
506                    self.sessions.clone(),
507                ))
508            })
509            .collect();
510
511        let mut withheld_devices = Vec::new();
512
513        // Wait for all the tasks to finish up and queue up the Olm session that
514        // was used to encrypt the room key to be persisted again. This is
515        // needed because each encryption step will mutate the Olm session,
516        // ratcheting its state forward.
517        for result in join_all(tasks).await {
518            let result = result.expect("Encryption task panicked");
519
520            let (used_sessions, failed_no_olm) = result?;
521
522            changes.sessions.extend(used_sessions);
523            withheld_devices.extend(failed_no_olm);
524        }
525
526        Ok(withheld_devices)
527    }
528
529    fn is_withheld_to(
530        &self,
531        group_session: &OutboundGroupSession,
532        device: &DeviceData,
533        code: &WithheldCode,
534    ) -> bool {
535        // The `m.no_olm` withheld code is special because it is supposed to be sent
536        // only once for a given device. The `Device` remembers the flag if we
537        // already sent a `m.no_olm` to this particular device so let's check
538        // that first.
539        //
540        // Keep in mind that any outbound group session might want to send this code to
541        // the device. So we need to check if any of our outbound group sessions
542        // is attempting to send the code to the device.
543        //
544        // This still has a slight race where some other thread might remove the
545        // outbound group session while a third is marking the device as having
546        // received the code.
547        //
548        // Since nothing terrible happens if we do end up sending the withheld code
549        // twice, and removing the race requires us to lock the store because the
550        // `OutboundGroupSession` and the `Device` both interact with the flag we'll
551        // leave it be.
552        if code == &WithheldCode::NoOlm {
553            device.was_withheld_code_sent() || self.sessions.has_session_withheld_to(device, code)
554        } else {
555            group_session.sharing_view().is_withheld_to(device, code)
556        }
557    }
558
559    fn handle_withheld_devices(
560        &self,
561        group_session: &OutboundGroupSession,
562        withheld_devices: Vec<(DeviceData, WithheldCode)>,
563    ) -> OlmResult<()> {
564        // Convert a withheld code for the group session into a to-device event content.
565        let to_content = |code| {
566            let content = group_session.withheld_code(code);
567            Raw::new(&content).expect("We can always serialize a withheld content info").cast()
568        };
569
570        // Helper to convert a chunk of device and withheld code pairs into a to-device
571        // request and it's accompanying share info.
572        let chunk_to_request = |chunk| {
573            let mut messages = BTreeMap::new();
574            let mut share_infos = BTreeMap::new();
575
576            for (device, code) in chunk {
577                let device: DeviceData = device;
578                let code: WithheldCode = code;
579
580                let user_id = device.user_id().to_owned();
581                let device_id = device.device_id().to_owned();
582
583                let share_info = ShareInfo::new_withheld(code.to_owned());
584                let content = to_content(code);
585
586                messages
587                    .entry(user_id.to_owned())
588                    .or_insert_with(BTreeMap::new)
589                    .insert(DeviceIdOrAllDevices::DeviceId(device_id.to_owned()), content);
590
591                share_infos
592                    .entry(user_id)
593                    .or_insert_with(BTreeMap::new)
594                    .insert(device_id, share_info);
595            }
596
597            let txn_id = TransactionId::new();
598
599            let request = ToDeviceRequest {
600                event_type: ToDeviceEventType::from("m.room_key.withheld"),
601                txn_id,
602                messages,
603            };
604
605            (request, share_infos)
606        };
607
608        let result: Vec<_> = withheld_devices
609            .into_iter()
610            .filter(|(device, code)| !self.is_withheld_to(group_session, device, code))
611            .chunks(Self::MAX_TO_DEVICE_MESSAGES)
612            .into_iter()
613            .map(chunk_to_request)
614            .collect();
615
616        for (request, share_info) in result {
617            if !request.messages.is_empty() {
618                let txn_id = request.txn_id.to_owned();
619                group_session.add_request(txn_id.to_owned(), request.into(), share_info);
620
621                self.sessions.mark_as_being_shared(txn_id, group_session.clone());
622            }
623        }
624
625        Ok(())
626    }
627
628    fn log_room_key_sharing_result(requests: &[Arc<ToDeviceRequest>]) {
629        for request in requests {
630            let message_list = Self::to_device_request_to_log_list(request);
631            info!(
632                request_id = ?request.txn_id,
633                ?message_list,
634                "Created batch of to-device messages of type {}",
635                request.event_type
636            );
637        }
638    }
639
640    /// Given a to-device request, build a recipient map suitable for logging.
641    ///
642    /// Returns a list of triples of (message_id, user id, device_id).
643    fn to_device_request_to_log_list(
644        request: &Arc<ToDeviceRequest>,
645    ) -> Vec<(String, String, String)> {
646        #[derive(serde::Deserialize)]
647        struct ContentStub<'a> {
648            #[serde(borrow, default, rename = "org.matrix.msgid")]
649            message_id: Option<&'a str>,
650        }
651
652        let mut result: Vec<(String, String, String)> = Vec::new();
653
654        for (user_id, device_map) in &request.messages {
655            for (device, content) in device_map {
656                let message_id: Option<&str> = content
657                    .deserialize_as_unchecked::<ContentStub<'_>>()
658                    .expect("We should be able to deserialize the content we generated")
659                    .message_id;
660
661                result.push((
662                    message_id.unwrap_or("<undefined>").to_owned(),
663                    user_id.to_string(),
664                    device.to_string(),
665                ));
666            }
667        }
668        result
669    }
670
671    /// Get to-device requests to share a room key with users in a room.
672    ///
673    /// # Arguments
674    ///
675    /// `room_id` - The room id of the room where the room key will be used.
676    ///
677    /// `users` - The list of users that should receive the room key.
678    ///
679    /// `encryption_settings` - The settings that should be used for
680    /// the room key.
681    #[instrument(skip(self, users, encryption_settings), fields(session_id))]
682    pub async fn share_room_key(
683        &self,
684        room_id: &RoomId,
685        users: impl Iterator<Item = &UserId>,
686        encryption_settings: impl Into<EncryptionSettings>,
687    ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
688        trace!("Checking if a room key needs to be shared");
689
690        let account = self.store.static_account();
691        let device = self.store.get_device(account.user_id(), account.device_id()).await?;
692
693        let encryption_settings = encryption_settings.into();
694        let mut changes = Changes::default();
695
696        // Try to get an existing session or create a new one.
697        let (outbound, inbound) = self
698            .get_or_create_outbound_session(
699                room_id,
700                encryption_settings.clone(),
701                SenderData::unknown(),
702            )
703            .await?;
704        tracing::Span::current().record("session_id", outbound.session_id());
705
706        // Having an inbound group session here means that we created a new
707        // group session pair, which we then need to store.
708        if let Some(mut inbound) = inbound {
709            // Use our own device info to populate the SenderData that validates the
710            // InboundGroupSession that we create as a pair to the OutboundGroupSession we
711            // are sending out.
712            let own_sender_data = if let Some(device) = &device {
713                SenderDataFinder::find_using_device_data(
714                    &self.store,
715                    device.inner.clone(),
716                    &inbound,
717                )
718                .await?
719            } else {
720                error!("Unable to find our own device!");
721                SenderData::unknown()
722            };
723            inbound.sender_data = own_sender_data;
724
725            changes.outbound_group_sessions.push(outbound.clone());
726            changes.inbound_group_sessions.push(inbound);
727        }
728
729        // Collect the recipient devices and check if either the settings
730        // or the recipient list changed in a way that requires the
731        // session to be rotated.
732        let CollectRecipientsResult { should_rotate, devices, mut withheld_devices } =
733            self.collect_session_recipients(users, &encryption_settings, &outbound).await?;
734
735        let outbound = self
736            .maybe_rotate_group_session(
737                should_rotate,
738                room_id,
739                outbound,
740                encryption_settings,
741                &mut changes,
742                device,
743            )
744            .await?;
745
746        // Filter out the devices that already received this room key or have a
747        // to-device message already queued up.
748        let devices: Vec<_> = devices
749            .into_values()
750            .flat_map(|d| {
751                d.into_iter().filter(|d| match outbound.sharing_view().get_share_state(d) {
752                    ShareState::NotShared => true,
753                    ShareState::Shared { message_index: _, olm_wedging_index } => {
754                        // If the recipient device's Olm wedging index is higher
755                        // than the value that we stored with the session, that
756                        // means that they tried to unwedge the session since we
757                        // last shared the room key.  So we re-share it with
758                        // them in case they weren't able to decrypt the room
759                        // key the last time we shared it.
760                        olm_wedging_index < d.olm_wedging_index
761                    }
762                    _ => false,
763                })
764            })
765            .collect();
766
767        // The `encrypt_for_devices()` method adds the to-device requests that will send
768        // out the room key to the `OutboundGroupSession`. It doesn't do that
769        // for the m.room_key_withheld events since we might have more of those
770        // coming from the `collect_session_recipients()` method. Instead they get
771        // returned by the method.
772        let unable_to_encrypt_devices =
773            self.encrypt_for_devices(devices, &outbound, &mut changes).await?;
774
775        // Merge the withheld recipients.
776        withheld_devices.extend(unable_to_encrypt_devices);
777
778        // Now handle and add the withheld recipients to the resulting requests to the
779        // `OutboundGroupSession`.
780        self.handle_withheld_devices(&outbound, withheld_devices)?;
781
782        // The to-device requests get added to the outbound group session, this
783        // way we're making sure that they are persisted and scoped to the
784        // session.
785        let requests = outbound.pending_requests();
786
787        if requests.is_empty() {
788            if !outbound.shared() {
789                debug!("The room key doesn't need to be shared with anyone. Marking as shared.");
790
791                outbound.mark_as_shared();
792                changes.outbound_group_sessions.push(outbound.clone());
793            }
794        } else {
795            Self::log_room_key_sharing_result(&requests)
796        }
797
798        // Persist any changes we might have collected.
799        if !changes.is_empty() {
800            let session_count = changes.sessions.len();
801
802            self.store.save_changes(changes).await?;
803
804            trace!(
805                session_count = session_count,
806                "Stored the changed sessions after encrypting an room key"
807            );
808        }
809
810        Ok(requests)
811    }
812
813    /// Collect the devices belonging to the given user, and send the details of
814    /// a room key bundle to those devices.
815    ///
816    /// Returns a list of to-device requests which must be sent.
817    ///
818    /// For security reasons, only "safe" [`CollectStrategy`]s are supported, in
819    /// which the recipient must have signed their
820    /// devices. [`CollectStrategy::AllDevices`] and
821    /// [`CollectStrategy::ErrorOnVerifiedUserProblem`] are "unsafe" in this
822    /// respect,and are treated the same as
823    /// [`CollectStrategy::IdentityBasedStrategy`].
824    #[instrument(skip(self, bundle_data))]
825    pub async fn share_room_key_bundle_data(
826        &self,
827        user_id: &UserId,
828        collect_strategy: &CollectStrategy,
829        bundle_data: RoomKeyBundleContent,
830    ) -> OlmResult<Vec<ToDeviceRequest>> {
831        // Only allow conservative sharing strategies
832        let collect_strategy = match collect_strategy {
833            CollectStrategy::AllDevices | CollectStrategy::ErrorOnVerifiedUserProblem => {
834                warn!(
835                    "Ignoring request to use unsafe sharing strategy {collect_strategy:?} \
836                     for room key history sharing",
837                );
838                &CollectStrategy::IdentityBasedStrategy
839            }
840            CollectStrategy::IdentityBasedStrategy | CollectStrategy::OnlyTrustedDevices => {
841                collect_strategy
842            }
843        };
844
845        let mut changes = Changes::default();
846
847        let CollectRecipientsResult { devices, .. } =
848            share_strategy::collect_recipients_for_share_strategy(
849                &self.store,
850                iter::once(user_id),
851                collect_strategy,
852                None,
853            )
854            .await?;
855
856        let devices = devices.into_values().flatten().collect();
857        let event_type = bundle_data.event_type().to_owned();
858        let (requests, _) = self
859            .encrypt_content_for_devices(devices, &event_type, bundle_data, &mut changes)
860            .await?;
861
862        // TODO: figure out what to do with withheld devices
863
864        // Persist any changes we might have collected.
865        if !changes.is_empty() {
866            let session_count = changes.sessions.len();
867
868            self.store.save_changes(changes).await?;
869
870            trace!(
871                session_count = session_count,
872                "Stored the changed sessions after encrypting an room key"
873            );
874        }
875
876        Ok(requests)
877    }
878
879    /// Encrypt the given content for the given devices and build to-device
880    /// requests to send the encrypted content to them.
881    ///
882    /// Returns a tuple containing (1) the list of to-device requests, and (2)
883    /// the list of devices that we could not find an olm session for (so
884    /// need a withheld message).
885    pub(crate) async fn encrypt_content_for_devices(
886        &self,
887        recipient_devices: Vec<DeviceData>,
888        event_type: &str,
889        content: impl Serialize + Clone + Send + 'static,
890        changes: &mut Changes,
891    ) -> OlmResult<(Vec<ToDeviceRequest>, Vec<(DeviceData, WithheldCode)>)> {
892        let recipients = recipient_list_to_users_and_devices(&recipient_devices);
893        info!(?recipients, "Encrypting content of type {}", event_type);
894
895        // Chunk the recipients out so each to-device request will contain a
896        // limited amount of to-device messages.
897        //
898        // Create concurrent tasks for each chunk of recipients.
899        let tasks: Vec<_> = recipient_devices
900            .chunks(Self::MAX_TO_DEVICE_MESSAGES)
901            .map(|chunk| {
902                spawn(
903                    encrypt_content_for_devices(
904                        self.store.crypto_store(),
905                        event_type.to_owned(),
906                        content.clone(),
907                        chunk.to_vec(),
908                    )
909                    .in_current_span(),
910                )
911            })
912            .collect();
913
914        let mut no_olm_devices = Vec::new();
915        let mut to_device_requests = Vec::new();
916
917        // Wait for all the tasks to finish up and queue up the Olm session that
918        // was used to encrypt the room key to be persisted again. This is
919        // needed because each encryption step will mutate the Olm session,
920        // ratcheting its state forward.
921        for result in join_all(tasks).await {
922            let result = result.expect("Encryption task panicked")?;
923            if let Some(request) = result.to_device_request {
924                to_device_requests.push(request);
925            }
926            changes.sessions.extend(result.updated_olm_sessions);
927            no_olm_devices.extend(result.no_olm_devices);
928        }
929
930        Ok((to_device_requests, no_olm_devices))
931    }
932}
933
934/// Helper for [`GroupSessionManager::encrypt_content_for_devices`].
935///
936/// Encrypt the given content for the given devices and build a to-device
937/// request to send the encrypted content to them.
938///
939/// See also [`GroupSessionManager::encrypt_session_for`], which is similar
940/// but applies specifically to `m.room_key` messages that hold a megolm
941/// session key.
942async fn encrypt_content_for_devices(
943    store: Arc<CryptoStoreWrapper>,
944    event_type: String,
945    content: impl Serialize + Clone + Send + 'static,
946    devices: Vec<DeviceData>,
947) -> OlmResult<EncryptForDevicesResult> {
948    let mut result_builder = EncryptForDevicesResultBuilder::default();
949
950    async fn encrypt(
951        store: Arc<CryptoStoreWrapper>,
952        device: DeviceData,
953        event_type: String,
954        bundle_data: impl Serialize,
955    ) -> OlmResult<(Session, Raw<ToDeviceEncryptedEventContent>)> {
956        device
957            .encrypt(store.as_ref(), &event_type, bundle_data)
958            .await
959            .map(|(session, message, _message_id)| (session, message))
960    }
961
962    let tasks = devices.iter().map(|device| {
963        spawn(
964            encrypt(store.clone(), device.clone(), event_type.clone(), content.clone())
965                .in_current_span(),
966        )
967    });
968
969    let results = join_all(tasks).await;
970
971    for (device, result) in zip(devices, results) {
972        let encryption_result = result.expect("Encryption task panicked");
973
974        match encryption_result {
975            Ok((used_session, message)) => {
976                result_builder.on_successful_encryption(&device, used_session, message.cast());
977            }
978            Err(OlmError::MissingSession) => {
979                // There is no established Olm session for this device
980                result_builder.on_missing_session(device);
981            }
982            Err(e) => return Err(e),
983        }
984    }
985
986    Ok(result_builder.into_result())
987}
988
989/// Result of [`GroupSessionManager::encrypt_session_for`] and
990/// [`encrypt_content_for_devices`].
991#[derive(Debug)]
992struct EncryptForDevicesResult {
993    /// The request to send the to-device messages containing the encrypted
994    /// payload, if any devices were found.
995    to_device_request: Option<ToDeviceRequest>,
996
997    /// The devices which lack an Olm session and therefore need a withheld code
998    no_olm_devices: Vec<(DeviceData, WithheldCode)>,
999
1000    /// The Olm sessions which were used to encrypt the requests and now need
1001    /// persisting to the store.
1002    updated_olm_sessions: Vec<Session>,
1003}
1004
1005/// A helper for building [`EncryptForDevicesResult`]
1006#[derive(Debug, Default)]
1007struct EncryptForDevicesResultBuilder {
1008    /// The payloads of the to-device messages
1009    messages: BTreeMap<OwnedUserId, BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>>>,
1010
1011    /// The devices which lack an Olm session and therefore need a withheld code
1012    no_olm_devices: Vec<(DeviceData, WithheldCode)>,
1013
1014    /// The Olm sessions which were used to encrypt the requests and now need
1015    /// persisting to the store.
1016    updated_olm_sessions: Vec<Session>,
1017}
1018
1019impl EncryptForDevicesResultBuilder {
1020    /// Record a successful encryption. The encrypted message is added to the
1021    /// list to be sent, and the olm session is added to the list of those
1022    /// that have been modified.
1023    pub fn on_successful_encryption(
1024        &mut self,
1025        device: &DeviceData,
1026        used_session: Session,
1027        message: Raw<AnyToDeviceEventContent>,
1028    ) {
1029        self.updated_olm_sessions.push(used_session);
1030
1031        self.messages
1032            .entry(device.user_id().to_owned())
1033            .or_default()
1034            .insert(DeviceIdOrAllDevices::DeviceId(device.device_id().to_owned()), message);
1035    }
1036
1037    /// Record a device which didn't have an active Olm session.
1038    pub fn on_missing_session(&mut self, device: DeviceData) {
1039        self.no_olm_devices.push((device, WithheldCode::NoOlm));
1040    }
1041
1042    /// Transform the accumulated results into an [`EncryptForDevicesResult`],
1043    /// wrapping the messages, if any, into a `ToDeviceRequest`.
1044    pub fn into_result(self) -> EncryptForDevicesResult {
1045        let EncryptForDevicesResultBuilder { updated_olm_sessions, no_olm_devices, messages } =
1046            self;
1047
1048        let mut encrypt_for_devices_result = EncryptForDevicesResult {
1049            to_device_request: None,
1050            updated_olm_sessions,
1051            no_olm_devices,
1052        };
1053
1054        if !messages.is_empty() {
1055            let request = ToDeviceRequest {
1056                event_type: ToDeviceEventType::RoomEncrypted,
1057                txn_id: TransactionId::new(),
1058                messages,
1059            };
1060            trace!(
1061                recipient_count = request.message_count(),
1062                transaction_id = ?request.txn_id,
1063                "Created a to-device request carrying room keys",
1064            );
1065            encrypt_for_devices_result.to_device_request = Some(request);
1066        }
1067
1068        encrypt_for_devices_result
1069    }
1070}
1071
1072fn recipient_list_to_users_and_devices(
1073    recipient_devices: &[DeviceData],
1074) -> BTreeMap<&UserId, BTreeSet<&DeviceId>> {
1075    #[allow(unknown_lints, clippy::unwrap_or_default)] // false positive
1076    recipient_devices.iter().fold(BTreeMap::new(), |mut acc, d| {
1077        acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id());
1078        acc
1079    })
1080}
1081
1082#[cfg(test)]
1083mod tests {
1084    use std::{
1085        collections::{BTreeMap, BTreeSet},
1086        iter,
1087        ops::Deref,
1088        sync::Arc,
1089    };
1090
1091    use assert_matches2::assert_let;
1092    use matrix_sdk_common::deserialized_responses::{ProcessedToDeviceEvent, WithheldCode};
1093    use matrix_sdk_test::{async_test, ruma_response_from_json};
1094    use ruma::{
1095        DeviceId, OneTimeKeyAlgorithm, OwnedMxcUri, TransactionId, UInt, UserId,
1096        api::client::{
1097            keys::{claim_keys, get_keys, upload_keys},
1098            to_device::send_event_to_device::v3::Response as ToDeviceResponse,
1099        },
1100        device_id,
1101        events::room::{EncryptedFile, V2EncryptedFileInfo, history_visibility::HistoryVisibility},
1102        owned_device_id, owned_room_id, room_id,
1103        to_device::DeviceIdOrAllDevices,
1104        user_id,
1105    };
1106    use serde_json::{Value, json};
1107
1108    use crate::{
1109        DecryptionSettings, EncryptionSettings, LocalTrust, OlmMachine, TrustRequirement,
1110        identities::DeviceData,
1111        machine::{
1112            EncryptionSyncChanges, test_helpers::get_machine_pair_with_setup_sessions_test_helper,
1113        },
1114        olm::{Account, SenderData},
1115        session_manager::{CollectStrategy, group_sessions::CollectRecipientsResult},
1116        types::{
1117            DeviceKeys, EventEncryptionAlgorithm,
1118            events::{
1119                room::encrypted::EncryptedToDeviceEvent,
1120                room_key_bundle::RoomKeyBundleContent,
1121                room_key_withheld::RoomKeyWithheldContent::{self, MegolmV1AesSha2},
1122            },
1123            requests::ToDeviceRequest,
1124        },
1125    };
1126
1127    fn alice_id() -> &'static UserId {
1128        user_id!("@alice:example.org")
1129    }
1130
1131    fn alice_device_id() -> &'static DeviceId {
1132        device_id!("JLAFKJWSCS")
1133    }
1134
1135    /// Returns a /keys/query response for user "@example:localhost"
1136    fn keys_query_response() -> get_keys::v3::Response {
1137        let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_query.json");
1138        let data: Value = serde_json::from_slice(data).unwrap();
1139        ruma_response_from_json(&data)
1140    }
1141
1142    fn bob_keys_query_response() -> get_keys::v3::Response {
1143        let data = json!({
1144            "device_keys": {
1145                "@bob:localhost": {
1146                    "BOBDEVICE": {
1147                        "user_id": "@bob:localhost",
1148                        "device_id": "BOBDEVICE",
1149                        "algorithms": [
1150                            "m.olm.v1.curve25519-aes-sha2",
1151                            "m.megolm.v1.aes-sha2",
1152                            "m.megolm.v2.aes-sha2"
1153                        ],
1154                        "keys": {
1155                            "curve25519:BOBDEVICE": "QzXDFZj0Pt5xG4r11XGSrqE4mnFOTgRM5pz7n3tzohU",
1156                            "ed25519:BOBDEVICE": "T7QMEXcEo/NfiC/8doVHT+2XnMm0pDpRa27bmE8PlPI"
1157                        },
1158                        "signatures": {
1159                            "@bob:localhost": {
1160                                "ed25519:BOBDEVICE": "1Ee9J02KoVf4DKhT+LkurpZJEygiznqpgkT4lqvMTLtZyzShsVTnwmoMPttuGcJkLp9lMK1egveNYCEaYP80Cw"
1161                            }
1162                        }
1163                    }
1164                }
1165            }
1166        });
1167        ruma_response_from_json(&data)
1168    }
1169
1170    /// Returns a keys claim response for device `BOBDEVICE` of user
1171    /// `@bob:localhost`.
1172    fn bob_one_time_key() -> claim_keys::v3::Response {
1173        let data = json!({
1174            "failures": {},
1175            "one_time_keys":{
1176                "@bob:localhost":{
1177                    "BOBDEVICE":{
1178                      "signed_curve25519:AAAAAAAAAAA": {
1179                          "key":"bm1olfbksjC5SwKxCLLK4XaINCA0FwR/155J85gIpCk",
1180                          "signatures":{
1181                              "@bob:localhost":{
1182                                  "ed25519:BOBDEVICE":"BKyS/+EV76zdZkWgny2D0svZ0ycS3etfyHCrsDgm7MYe166HqQmSoX29HsjGLvE/5F+Sg2zW7RJileUvquPwDA"
1183                              }
1184                          }
1185                      }
1186                    }
1187                }
1188            }
1189        });
1190        ruma_response_from_json(&data)
1191    }
1192
1193    /// Returns a key claim response for device `NMMBNBUSNR` of user
1194    /// `@example2:localhost`
1195    fn keys_claim_response() -> claim_keys::v3::Response {
1196        let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_claim.json");
1197        let data: Value = serde_json::from_slice(data).unwrap();
1198        ruma_response_from_json(&data)
1199    }
1200
1201    async fn machine_with_user_test_helper(user_id: &UserId, device_id: &DeviceId) -> OlmMachine {
1202        let keys_query = keys_query_response();
1203        let txn_id = TransactionId::new();
1204
1205        let machine = OlmMachine::new(user_id, device_id).await;
1206
1207        // complete a /keys/query and /keys/claim for @example:localhost
1208        machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1209        let (txn_id, _keys_claim_request) = machine
1210            .get_missing_sessions(iter::once(user_id!("@example:localhost")))
1211            .await
1212            .unwrap()
1213            .unwrap();
1214        let keys_claim = keys_claim_response();
1215        machine.mark_request_as_sent(&txn_id, &keys_claim).await.unwrap();
1216
1217        // complete a /keys/query and /keys/claim for @bob:localhost
1218        machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1219        let (txn_id, _keys_claim_request) = machine
1220            .get_missing_sessions(iter::once(user_id!("@bob:localhost")))
1221            .await
1222            .unwrap()
1223            .unwrap();
1224        machine.mark_request_as_sent(&txn_id, &bob_one_time_key()).await.unwrap();
1225
1226        machine
1227    }
1228
1229    async fn machine() -> OlmMachine {
1230        machine_with_user_test_helper(alice_id(), alice_device_id()).await
1231    }
1232
1233    async fn machine_with_shared_room_key_test_helper() -> OlmMachine {
1234        let machine = machine().await;
1235        let room_id = room_id!("!test:localhost");
1236        let keys_claim = keys_claim_response();
1237
1238        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1239        let requests =
1240            machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
1241
1242        let outbound =
1243            machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1244
1245        assert!(!outbound.pending_requests().is_empty());
1246        assert!(!outbound.shared());
1247
1248        let response = ToDeviceResponse::new();
1249        for request in requests {
1250            machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1251        }
1252
1253        assert!(outbound.shared());
1254        assert!(outbound.pending_requests().is_empty());
1255
1256        machine
1257    }
1258
1259    #[async_test]
1260    async fn test_sharing() {
1261        let machine = machine().await;
1262        let room_id = room_id!("!test:localhost");
1263        let keys_claim = keys_claim_response();
1264
1265        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1266
1267        let requests =
1268            machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
1269
1270        let event_count: usize = requests
1271            .iter()
1272            .filter(|r| r.event_type == "m.room.encrypted".into())
1273            .map(|r| r.message_count())
1274            .sum();
1275
1276        // The keys claim response has a couple of one-time keys with invalid
1277        // signatures, thus only 148 sessions are actually created, we check
1278        // that all 148 valid sessions get an room key.
1279        assert_eq!(event_count, 148);
1280
1281        let withheld_count: usize = requests
1282            .iter()
1283            .filter(|r| r.event_type == "m.room_key.withheld".into())
1284            .map(|r| r.message_count())
1285            .sum();
1286        assert_eq!(withheld_count, 2);
1287    }
1288
1289    fn count_withheld_from(requests: &[Arc<ToDeviceRequest>], code: WithheldCode) -> usize {
1290        requests
1291            .iter()
1292            .filter(|r| r.event_type == "m.room_key.withheld".into())
1293            .map(|r| {
1294                let mut count = 0;
1295                // count targets
1296                for message in r.messages.values() {
1297                    message.iter().for_each(|(_, content)| {
1298                        let withheld: RoomKeyWithheldContent =
1299                            content.deserialize_as_unchecked::<RoomKeyWithheldContent>().unwrap();
1300
1301                        if let MegolmV1AesSha2(content) = withheld
1302                            && content.withheld_code() == code
1303                        {
1304                            count += 1;
1305                        }
1306                    })
1307                }
1308                count
1309            })
1310            .sum()
1311    }
1312
1313    #[async_test]
1314    async fn test_no_olm_sent_once() {
1315        let machine = machine().await;
1316        let keys_claim = keys_claim_response();
1317
1318        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1319
1320        let first_room_id = room_id!("!test:localhost");
1321
1322        let requests = machine
1323            .share_room_key(first_room_id, users.to_owned(), EncryptionSettings::default())
1324            .await
1325            .unwrap();
1326
1327        // there will be two no_olm
1328        let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1329        assert_eq!(withheld_count, 2);
1330
1331        // Re-sharing same session while request has not been sent should not produces
1332        // withheld
1333        let new_requests = machine
1334            .share_room_key(first_room_id, users, EncryptionSettings::default())
1335            .await
1336            .unwrap();
1337        let withheld_count: usize = count_withheld_from(&new_requests, WithheldCode::NoOlm);
1338        // No additional request was added, still the 2 already pending
1339        assert_eq!(withheld_count, 2);
1340
1341        let response = ToDeviceResponse::new();
1342        for request in requests {
1343            machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1344        }
1345
1346        // The fact that an olm was sent should be remembered even if sharing another
1347        // session in an other room.
1348        let second_room_id = room_id!("!other:localhost");
1349        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1350        let requests = machine
1351            .share_room_key(second_room_id, users, EncryptionSettings::default())
1352            .await
1353            .unwrap();
1354
1355        let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1356        assert_eq!(withheld_count, 0);
1357
1358        // Help how do I simulate the creation of a new session for the device
1359        // with no session now?
1360    }
1361
1362    #[async_test]
1363    async fn test_ratcheted_sharing() {
1364        let machine = machine_with_shared_room_key_test_helper().await;
1365
1366        let room_id = room_id!("!test:localhost");
1367        let late_joiner = user_id!("@bob:localhost");
1368        let keys_claim = keys_claim_response();
1369
1370        let mut users: BTreeSet<_> = keys_claim.one_time_keys.keys().map(Deref::deref).collect();
1371        users.insert(late_joiner);
1372
1373        let requests = machine
1374            .share_room_key(room_id, users.into_iter(), EncryptionSettings::default())
1375            .await
1376            .unwrap();
1377
1378        let event_count: usize = requests
1379            .iter()
1380            .filter(|r| r.event_type == "m.room.encrypted".into())
1381            .map(|r| r.message_count())
1382            .sum();
1383        let outbound =
1384            machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1385
1386        assert_eq!(event_count, 1);
1387        assert!(!outbound.pending_requests().is_empty());
1388    }
1389
1390    #[async_test]
1391    async fn test_changing_encryption_settings() {
1392        let machine = machine_with_shared_room_key_test_helper().await;
1393        let room_id = room_id!("!test:localhost");
1394        let keys_claim = keys_claim_response();
1395
1396        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1397        let outbound =
1398            machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1399
1400        let CollectRecipientsResult { should_rotate, .. } = machine
1401            .inner
1402            .group_session_manager
1403            .collect_session_recipients(users.clone(), &EncryptionSettings::default(), &outbound)
1404            .await
1405            .unwrap();
1406
1407        assert!(!should_rotate);
1408
1409        let settings = EncryptionSettings {
1410            history_visibility: HistoryVisibility::Invited,
1411            ..Default::default()
1412        };
1413
1414        let CollectRecipientsResult { should_rotate, .. } = machine
1415            .inner
1416            .group_session_manager
1417            .collect_session_recipients(users.clone(), &settings, &outbound)
1418            .await
1419            .unwrap();
1420
1421        assert!(should_rotate);
1422
1423        let settings = EncryptionSettings {
1424            algorithm: EventEncryptionAlgorithm::from("m.megolm.v2.aes-sha2"),
1425            ..Default::default()
1426        };
1427
1428        let CollectRecipientsResult { should_rotate, .. } = machine
1429            .inner
1430            .group_session_manager
1431            .collect_session_recipients(users, &settings, &outbound)
1432            .await
1433            .unwrap();
1434
1435        assert!(should_rotate);
1436    }
1437
1438    #[async_test]
1439    async fn test_key_recipient_collecting() {
1440        // The user id comes from the fact that the keys_query.json file uses
1441        // this one.
1442        let user_id = user_id!("@example:localhost");
1443        let device_id = device_id!("TESTDEVICE");
1444        let room_id = room_id!("!test:localhost");
1445
1446        let machine = machine_with_user_test_helper(user_id, device_id).await;
1447
1448        let (outbound, _) = machine
1449            .inner
1450            .group_session_manager
1451            .get_or_create_outbound_session(
1452                room_id,
1453                EncryptionSettings::default(),
1454                SenderData::unknown(),
1455            )
1456            .await
1457            .expect("We should be able to create a new session");
1458        let history_visibility = HistoryVisibility::Joined;
1459        let settings = EncryptionSettings { history_visibility, ..Default::default() };
1460
1461        let users = [user_id].into_iter();
1462
1463        let CollectRecipientsResult { devices: recipients, .. } = machine
1464            .inner
1465            .group_session_manager
1466            .collect_session_recipients(users, &settings, &outbound)
1467            .await
1468            .expect("We should be able to collect the session recipients");
1469
1470        assert!(!recipients[user_id].is_empty());
1471
1472        // Make sure that our own device isn't part of the recipients.
1473        assert!(
1474            !recipients[user_id]
1475                .iter()
1476                .any(|d| d.user_id() == user_id && d.device_id() == device_id)
1477        );
1478
1479        let settings = EncryptionSettings {
1480            sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1481            ..Default::default()
1482        };
1483        let users = [user_id].into_iter();
1484
1485        let CollectRecipientsResult { devices: recipients, .. } = machine
1486            .inner
1487            .group_session_manager
1488            .collect_session_recipients(users, &settings, &outbound)
1489            .await
1490            .expect("We should be able to collect the session recipients");
1491
1492        assert!(recipients[user_id].is_empty());
1493
1494        let device_id = "AFGUOBTZWM".into();
1495        let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1496        device.set_local_trust(LocalTrust::Verified).await.unwrap();
1497        let users = [user_id].into_iter();
1498
1499        let CollectRecipientsResult { devices: recipients, withheld_devices: withheld, .. } =
1500            machine
1501                .inner
1502                .group_session_manager
1503                .collect_session_recipients(users, &settings, &outbound)
1504                .await
1505                .expect("We should be able to collect the session recipients");
1506
1507        assert!(
1508            recipients[user_id]
1509                .iter()
1510                .any(|d| d.user_id() == user_id && d.device_id() == device_id)
1511        );
1512
1513        let devices = machine.get_user_devices(user_id, None).await.unwrap();
1514        devices
1515            .devices()
1516            // Ignore our own device
1517            .filter(|d| d.device_id() != device_id!("TESTDEVICE"))
1518            .for_each(|d| {
1519                if d.is_blacklisted() {
1520                    assert!(withheld.iter().any(|(dev, w)| {
1521                        dev.device_id() == d.device_id() && w == &WithheldCode::Blacklisted
1522                    }));
1523                } else if !d.is_verified() {
1524                    // the device should then be in the list of withhelds
1525                    assert!(withheld.iter().any(|(dev, w)| {
1526                        dev.device_id() == d.device_id() && w == &WithheldCode::Unverified
1527                    }));
1528                }
1529            });
1530
1531        assert_eq!(149, withheld.len());
1532    }
1533
1534    #[async_test]
1535    async fn test_sharing_withheld_only_trusted() {
1536        let machine = machine().await;
1537        let room_id = room_id!("!test:localhost");
1538        let keys_claim = keys_claim_response();
1539
1540        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1541        let settings = EncryptionSettings {
1542            sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1543            ..Default::default()
1544        };
1545
1546        // Trust only one
1547        let user_id = user_id!("@example:localhost");
1548        let device_id = "MWFXPINOAO".into();
1549        let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1550        device.set_local_trust(LocalTrust::Verified).await.unwrap();
1551        machine
1552            .get_device(user_id, "MWVTUXDNNM".into(), None)
1553            .await
1554            .unwrap()
1555            .unwrap()
1556            .set_local_trust(LocalTrust::BlackListed)
1557            .await
1558            .unwrap();
1559
1560        let requests = machine.share_room_key(room_id, users, settings).await.unwrap();
1561
1562        // One room key should be sent
1563        let room_key_count =
1564            requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).count();
1565
1566        assert_eq!(1, room_key_count);
1567
1568        let withheld_count =
1569            requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1570        // Can be send in one batch
1571        assert_eq!(1, withheld_count);
1572
1573        let event_count: usize = requests
1574            .iter()
1575            .filter(|r| r.event_type == "m.room_key.withheld".into())
1576            .map(|r| r.message_count())
1577            .sum();
1578
1579        // withhelds are sent in clear so all device should be counted (even if no OTK)
1580        assert_eq!(event_count, 149);
1581
1582        // One should be blacklisted
1583        let has_blacklist =
1584            requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).any(|r| {
1585                let device_key = DeviceIdOrAllDevices::from(owned_device_id!("MWVTUXDNNM"));
1586                let content = &r.messages[user_id][&device_key];
1587                let withheld: RoomKeyWithheldContent =
1588                    content.deserialize_as_unchecked::<RoomKeyWithheldContent>().unwrap();
1589                if let MegolmV1AesSha2(content) = withheld {
1590                    content.withheld_code() == WithheldCode::Blacklisted
1591                } else {
1592                    false
1593                }
1594            });
1595
1596        assert!(has_blacklist);
1597    }
1598
1599    #[async_test]
1600    async fn test_no_olm_withheld_only_sent_once() {
1601        let keys_query = keys_query_response();
1602        let txn_id = TransactionId::new();
1603
1604        let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1605
1606        machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1607        machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1608
1609        let first_room = room_id!("!test:localhost");
1610        let second_room = room_id!("!test2:localhost");
1611        let bob_id = user_id!("@bob:localhost");
1612
1613        let settings = EncryptionSettings::default();
1614        let users = [bob_id];
1615
1616        let requests = machine
1617            .share_room_key(first_room, users.into_iter(), settings.to_owned())
1618            .await
1619            .unwrap();
1620
1621        // One withheld request should be sent.
1622        let withheld_count =
1623            requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1624
1625        assert_eq!(withheld_count, 1);
1626        assert_eq!(requests.len(), 1);
1627
1628        // On the second room key share attempt we're not sending another `m.no_olm`
1629        // code since the first one is taking care of this.
1630        let second_requests =
1631            machine.share_room_key(second_room, users.into_iter(), settings).await.unwrap();
1632
1633        let withheld_count =
1634            second_requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1635
1636        assert_eq!(withheld_count, 0);
1637        assert_eq!(second_requests.len(), 0);
1638
1639        let response = ToDeviceResponse::new();
1640
1641        let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1642
1643        // The device should be marked as having the `m.no_olm` code received only after
1644        // the request has been marked as sent.
1645        assert!(!device.was_withheld_code_sent());
1646
1647        for request in requests {
1648            machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1649        }
1650
1651        let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1652
1653        assert!(device.was_withheld_code_sent());
1654    }
1655
1656    #[async_test]
1657    async fn test_resend_session_after_unwedging() {
1658        let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1659        assert_let!(Ok(Some((txn_id, device_keys_request))) = machine.upload_device_keys().await);
1660        let device_keys_response = upload_keys::v3::Response::new(BTreeMap::from([(
1661            OneTimeKeyAlgorithm::SignedCurve25519,
1662            UInt::new(device_keys_request.one_time_keys.len() as u64).unwrap(),
1663        )]));
1664        machine.mark_request_as_sent(&txn_id, &device_keys_response).await.unwrap();
1665
1666        let room_id = room_id!("!test:localhost");
1667
1668        let bob_id = user_id!("@bob:localhost");
1669        let bob_account = Account::new(bob_id);
1670        let keys_query_data = json!({
1671            "device_keys": {
1672                "@bob:localhost": {
1673                    bob_account.device_id.clone(): bob_account.device_keys()
1674                }
1675            }
1676        });
1677        let keys_query: get_keys::v3::Response = ruma_response_from_json(&keys_query_data);
1678        let txn_id = TransactionId::new();
1679        machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1680
1681        let alice_device_keys =
1682            device_keys_request.device_keys.unwrap().deserialize_as::<DeviceKeys>().unwrap();
1683        let mut alice_otks = device_keys_request.one_time_keys.iter();
1684        let alice_device = DeviceData::new(alice_device_keys, LocalTrust::Unset);
1685
1686        {
1687            // Bob creates an Olm session with Alice and encrypts a message to her
1688            let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1689            let mut session = bob_account
1690                .create_outbound_session(
1691                    &alice_device,
1692                    &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1693                    bob_account.device_keys(),
1694                )
1695                .unwrap();
1696            let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1697
1698            let to_device =
1699                EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1700
1701            // Alice decrypts the message
1702            let sync_changes = EncryptionSyncChanges {
1703                to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1704                changed_devices: &Default::default(),
1705                one_time_keys_counts: &Default::default(),
1706                unused_fallback_keys: None,
1707                next_batch_token: None,
1708            };
1709
1710            let decryption_settings =
1711                DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1712
1713            let (decrypted, _) =
1714                machine.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1715
1716            assert_eq!(1, decrypted.len());
1717        }
1718
1719        // Alice shares the room key with Bob
1720        {
1721            let requests = machine
1722                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1723                .await
1724                .unwrap();
1725
1726            // We should have had one to-device event
1727            let event_count: usize = requests
1728                .iter()
1729                .filter(|r| r.event_type == "m.room.encrypted".into())
1730                .map(|r| r.message_count())
1731                .sum();
1732            assert_eq!(event_count, 1);
1733
1734            let response = ToDeviceResponse::new();
1735            for request in requests {
1736                machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1737            }
1738        }
1739
1740        // When Alice shares the room key again, there shouldn't be any
1741        // to-device events, since we already shared with Bob
1742        {
1743            let requests = machine
1744                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1745                .await
1746                .unwrap();
1747
1748            let event_count: usize = requests
1749                .iter()
1750                .filter(|r| r.event_type == "m.room.encrypted".into())
1751                .map(|r| r.message_count())
1752                .sum();
1753            assert_eq!(event_count, 0);
1754        }
1755
1756        // Pretend that Bob wasn't able to decrypt, so he tries to unwedge
1757        {
1758            let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1759            let mut session = bob_account
1760                .create_outbound_session(
1761                    &alice_device,
1762                    &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1763                    bob_account.device_keys(),
1764                )
1765                .unwrap();
1766            let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1767
1768            let to_device =
1769                EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1770
1771            // Alice decrypts the unwedge message
1772            let sync_changes = EncryptionSyncChanges {
1773                to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1774                changed_devices: &Default::default(),
1775                one_time_keys_counts: &Default::default(),
1776                unused_fallback_keys: None,
1777                next_batch_token: None,
1778            };
1779
1780            let decryption_settings =
1781                DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1782
1783            let (decrypted, _) =
1784                machine.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1785
1786            assert_eq!(1, decrypted.len());
1787        }
1788
1789        // When Alice shares the room key again, it should be re-shared with Bob
1790        {
1791            let requests = machine
1792                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1793                .await
1794                .unwrap();
1795
1796            let event_count: usize = requests
1797                .iter()
1798                .filter(|r| r.event_type == "m.room.encrypted".into())
1799                .map(|r| r.message_count())
1800                .sum();
1801            assert_eq!(event_count, 1);
1802
1803            let response = ToDeviceResponse::new();
1804            for request in requests {
1805                machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1806            }
1807        }
1808
1809        // When Alice shares the room key yet again, there shouldn't be any
1810        // to-device events
1811        {
1812            let requests = machine
1813                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1814                .await
1815                .unwrap();
1816
1817            let event_count: usize = requests
1818                .iter()
1819                .filter(|r| r.event_type == "m.room.encrypted".into())
1820                .map(|r| r.message_count())
1821                .sum();
1822            assert_eq!(event_count, 0);
1823        }
1824    }
1825
1826    #[async_test]
1827    async fn test_room_key_bundle_sharing() {
1828        let (alice, bob) = get_machine_pair_with_setup_sessions_test_helper(
1829            user_id!("@alice:localhost"),
1830            user_id!("@bob:localhost"),
1831            false,
1832        )
1833        .await;
1834
1835        // Alice trusts Bob's device
1836        let device = alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap();
1837        device.set_local_trust(LocalTrust::Verified).await.unwrap();
1838
1839        let content = RoomKeyBundleContent {
1840            room_id: owned_room_id!("!room:id"),
1841            file: EncryptedFile::new(
1842                OwnedMxcUri::from("test"),
1843                V2EncryptedFileInfo::encode([0; 32], [0; 16]).into(),
1844                Default::default(),
1845            ),
1846        };
1847
1848        let requests = alice
1849            .share_room_key_bundle_data(
1850                bob.user_id(),
1851                &CollectStrategy::OnlyTrustedDevices,
1852                content,
1853            )
1854            .await
1855            .unwrap();
1856
1857        // There should be exactly one message
1858        let requests: Vec<_> =
1859            requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).collect();
1860        let message_count: usize = requests.iter().map(|r| r.message_count()).sum();
1861        assert_eq!(message_count, 1);
1862
1863        // Bob decrypts the message
1864        let bob_message = requests[0]
1865            .messages
1866            .get(bob.user_id())
1867            .unwrap()
1868            .get(&(bob.device_id().to_owned().into()))
1869            .unwrap();
1870        let to_device = EncryptedToDeviceEvent::new(
1871            alice.user_id().to_owned(),
1872            bob_message.deserialize_as_unchecked().unwrap(),
1873        );
1874
1875        let sync_changes = EncryptionSyncChanges {
1876            to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1877            changed_devices: &Default::default(),
1878            one_time_keys_counts: &Default::default(),
1879            unused_fallback_keys: None,
1880            next_batch_token: None,
1881        };
1882
1883        let decryption_settings =
1884            DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1885
1886        let (decrypted, _) =
1887            bob.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1888        assert_eq!(1, decrypted.len());
1889        use crate::types::events::EventType;
1890        assert_let!(
1891            ProcessedToDeviceEvent::Decrypted { raw, .. } = decrypted.first().unwrap().clone()
1892        );
1893        assert_eq!(
1894            raw.get_field::<String>("type").unwrap().unwrap(),
1895            RoomKeyBundleContent::EVENT_TYPE,
1896        );
1897    }
1898}