matrix_sdk_crypto/session_manager/group_sessions/
mod.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15mod share_strategy;
16
17use std::{
18    collections::{BTreeMap, BTreeSet},
19    fmt::Debug,
20    iter,
21    iter::zip,
22    sync::Arc,
23};
24
25use futures_util::future::join_all;
26use itertools::Itertools;
27use matrix_sdk_common::{
28    deserialized_responses::WithheldCode, executor::spawn, locks::RwLock as StdRwLock,
29};
30#[cfg(feature = "experimental-encrypted-state-events")]
31use ruma::events::AnyStateEventContent;
32use ruma::{
33    events::{AnyMessageLikeEventContent, AnyToDeviceEventContent, ToDeviceEventType},
34    serde::Raw,
35    to_device::DeviceIdOrAllDevices,
36    DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId,
37    UserId,
38};
39use serde::Serialize;
40#[cfg(feature = "experimental-send-custom-to-device")]
41pub(crate) use share_strategy::split_devices_for_share_strategy;
42pub use share_strategy::CollectStrategy;
43pub(crate) use share_strategy::{
44    withheld_code_for_device_for_share_strategy, CollectRecipientsResult,
45};
46use tracing::{debug, error, info, instrument, trace, warn, Instrument};
47
48use crate::{
49    error::{EventError, MegolmResult, OlmResult},
50    identities::device::MaybeEncryptedRoomKey,
51    olm::{
52        InboundGroupSession, OutboundGroupSession, SenderData, SenderDataFinder, Session,
53        ShareInfo, ShareState,
54    },
55    store::{types::Changes, CryptoStoreWrapper, Result as StoreResult, Store},
56    types::{
57        events::{
58            room::encrypted::{RoomEncryptedEventContent, ToDeviceEncryptedEventContent},
59            room_key_bundle::RoomKeyBundleContent,
60            EventType,
61        },
62        requests::ToDeviceRequest,
63    },
64    Device, DeviceData, EncryptionSettings, OlmError,
65};
66
67#[derive(Clone, Debug)]
68pub(crate) struct GroupSessionCache {
69    store: Store,
70    sessions: Arc<StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>>,
71    /// A map from the request id to the group session that the request belongs
72    /// to. Used to mark requests belonging to the session as shared.
73    sessions_being_shared: Arc<StdRwLock<BTreeMap<OwnedTransactionId, OutboundGroupSession>>>,
74}
75
76impl GroupSessionCache {
77    pub(crate) fn new(store: Store) -> Self {
78        Self { store, sessions: Default::default(), sessions_being_shared: Default::default() }
79    }
80
81    pub(crate) fn insert(&self, session: OutboundGroupSession) {
82        self.sessions.write().insert(session.room_id().to_owned(), session);
83    }
84
85    /// Either get a session for the given room from the cache or load it from
86    /// the store.
87    ///
88    /// # Arguments
89    ///
90    /// * `room_id` - The id of the room this session is used for.
91    pub async fn get_or_load(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
92        // Get the cached session, if there isn't one load one from the store
93        // and put it in the cache.
94        if let Some(s) = self.sessions.read().get(room_id) {
95            return Some(s.clone());
96        }
97
98        match self.store.get_outbound_group_session(room_id).await {
99            Ok(Some(s)) => {
100                {
101                    let mut sessions_being_shared = self.sessions_being_shared.write();
102                    for request_id in s.pending_request_ids() {
103                        sessions_being_shared.insert(request_id, s.clone());
104                    }
105                }
106
107                self.sessions.write().insert(room_id.to_owned(), s.clone());
108
109                Some(s)
110            }
111            Ok(None) => None,
112            Err(e) => {
113                error!("Couldn't restore an outbound group session: {e:?}");
114                None
115            }
116        }
117    }
118
119    /// Get an outbound group session for a room, if one exists.
120    ///
121    /// # Arguments
122    ///
123    /// * `room_id` - The id of the room for which we should get the outbound
124    ///   group session.
125    fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
126        self.sessions.read().get(room_id).cloned()
127    }
128
129    /// Returns whether any session is withheld with the given device and code.
130    fn has_session_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
131        self.sessions.read().values().any(|s| s.sharing_view().is_withheld_to(device, code))
132    }
133
134    fn remove_from_being_shared(&self, id: &TransactionId) -> Option<OutboundGroupSession> {
135        self.sessions_being_shared.write().remove(id)
136    }
137
138    fn mark_as_being_shared(&self, id: OwnedTransactionId, session: OutboundGroupSession) {
139        self.sessions_being_shared.write().insert(id, session);
140    }
141}
142
143#[derive(Debug, Clone)]
144pub(crate) struct GroupSessionManager {
145    /// Store for the encryption keys.
146    /// Persists all the encryption keys so a client can resume the session
147    /// without the need to create new keys.
148    store: Store,
149    /// The currently active outbound group sessions.
150    sessions: GroupSessionCache,
151}
152
153impl GroupSessionManager {
154    const MAX_TO_DEVICE_MESSAGES: usize = 250;
155
156    pub fn new(store: Store) -> Self {
157        Self { store: store.clone(), sessions: GroupSessionCache::new(store) }
158    }
159
160    pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
161        if let Some(s) = self.sessions.get(room_id) {
162            s.invalidate_session();
163
164            let mut changes = Changes::default();
165            changes.outbound_group_sessions.push(s.clone());
166            self.store.save_changes(changes).await?;
167
168            Ok(true)
169        } else {
170            Ok(false)
171        }
172    }
173
174    pub async fn mark_request_as_sent(&self, request_id: &TransactionId) -> StoreResult<()> {
175        let Some(session) = self.sessions.remove_from_being_shared(request_id) else {
176            return Ok(());
177        };
178
179        let no_olm = session.mark_request_as_sent(request_id);
180
181        let mut changes = Changes::default();
182
183        for (user_id, devices) in &no_olm {
184            for device_id in devices {
185                let device = self.store.get_device(user_id, device_id).await;
186
187                if let Ok(Some(device)) = device {
188                    device.mark_withheld_code_as_sent();
189                    changes.devices.changed.push(device.inner.clone());
190                } else {
191                    error!(
192                        ?request_id,
193                        "Marking to-device no olm as sent but device not found, might \
194                            have been deleted?"
195                    );
196                }
197            }
198        }
199
200        changes.outbound_group_sessions.push(session.clone());
201        self.store.save_changes(changes).await
202    }
203
204    #[cfg(test)]
205    pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
206        self.sessions.get(room_id)
207    }
208
209    pub async fn encrypt(
210        &self,
211        room_id: &RoomId,
212        event_type: &str,
213        content: &Raw<AnyMessageLikeEventContent>,
214    ) -> MegolmResult<Raw<RoomEncryptedEventContent>> {
215        let session =
216            self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
217
218        assert!(!session.expired(), "Session expired");
219
220        let content = session.encrypt(event_type, content).await;
221
222        let mut changes = Changes::default();
223        changes.outbound_group_sessions.push(session);
224        self.store.save_changes(changes).await?;
225
226        Ok(content)
227    }
228
229    /// Encrypts a state event for the given room using its outbound group
230    /// session.
231    ///
232    /// # Arguments
233    ///
234    /// * `room_id` - The ID of the room where the state event will be sent.
235    /// * `event_type` - The type of the state event to encrypt.
236    /// * `state_key` - The state key associated with the event.
237    /// * `content` - The raw content of the state event to encrypt.
238    ///
239    /// # Returns
240    ///
241    /// Returns the raw encrypted state event content.
242    ///
243    /// # Errors
244    ///
245    /// Returns an error if saving changes to the store fails.
246    ///
247    /// # Panics
248    ///
249    /// Panics if no session exists for the given room ID, or the session
250    /// has expired.
251    #[cfg(feature = "experimental-encrypted-state-events")]
252    pub async fn encrypt_state(
253        &self,
254        room_id: &RoomId,
255        event_type: &str,
256        state_key: &str,
257        content: &Raw<AnyStateEventContent>,
258    ) -> MegolmResult<Raw<RoomEncryptedEventContent>> {
259        let session =
260            self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
261
262        assert!(!session.expired(), "Session expired");
263
264        let content = session.encrypt_state(event_type, state_key, content).await;
265
266        let mut changes = Changes::default();
267        changes.outbound_group_sessions.push(session);
268        self.store.save_changes(changes).await?;
269
270        Ok(content)
271    }
272
273    /// Create a new outbound group session.
274    ///
275    /// This also creates a matching inbound group session.
276    pub async fn create_outbound_group_session(
277        &self,
278        room_id: &RoomId,
279        settings: EncryptionSettings,
280        own_sender_data: SenderData,
281    ) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> {
282        let (outbound, inbound) = self
283            .store
284            .static_account()
285            .create_group_session_pair(room_id, settings, own_sender_data)
286            .await
287            .map_err(|_| EventError::UnsupportedAlgorithm)?;
288
289        self.sessions.insert(outbound.clone());
290        Ok((outbound, inbound))
291    }
292
293    pub async fn get_or_create_outbound_session(
294        &self,
295        room_id: &RoomId,
296        settings: EncryptionSettings,
297        own_sender_data: SenderData,
298    ) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> {
299        let outbound_session = self.sessions.get_or_load(room_id).await;
300
301        // If there is no session or the session has expired or is invalid,
302        // create a new one.
303        if let Some(s) = outbound_session {
304            if s.expired() || s.invalidated() {
305                self.create_outbound_group_session(room_id, settings, own_sender_data)
306                    .await
307                    .map(|(o, i)| (o, i.into()))
308            } else {
309                Ok((s, None))
310            }
311        } else {
312            self.create_outbound_group_session(room_id, settings, own_sender_data)
313                .await
314                .map(|(o, i)| (o, i.into()))
315        }
316    }
317
318    /// Encrypt the given group session key for the given devices and create
319    /// to-device requests that sends the encrypted content to them.
320    ///
321    /// See also [`encrypt_content_for_devices`] which is similar
322    /// but is not specific to group sessions, and does not return the
323    /// [`ShareInfo`] data.
324    async fn encrypt_session_for(
325        store: Arc<CryptoStoreWrapper>,
326        group_session: OutboundGroupSession,
327        devices: Vec<DeviceData>,
328    ) -> OlmResult<(
329        EncryptForDevicesResult,
330        BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>,
331    )> {
332        // Use a named type instead of a tuple with rather long type name
333        pub struct DeviceResult {
334            device: DeviceData,
335            maybe_encrypted_room_key: MaybeEncryptedRoomKey,
336        }
337
338        let mut result_builder = EncryptForDevicesResultBuilder::default();
339        let mut share_infos = BTreeMap::new();
340
341        // XXX is there a way to do this that doesn't involve cloning the
342        // `Arc<CryptoStoreWrapper>` for each device?
343        let encrypt = |store: Arc<CryptoStoreWrapper>,
344                       device: DeviceData,
345                       session: OutboundGroupSession| async move {
346            let encryption_result = device.maybe_encrypt_room_key(store.as_ref(), session).await?;
347
348            Ok::<_, OlmError>(DeviceResult { device, maybe_encrypted_room_key: encryption_result })
349        };
350
351        let tasks: Vec<_> = devices
352            .iter()
353            .map(|d| spawn(encrypt(store.clone(), d.clone(), group_session.clone())))
354            .collect();
355
356        let results = join_all(tasks).await;
357
358        for result in results {
359            let result = result.expect("Encryption task panicked")?;
360
361            match result.maybe_encrypted_room_key {
362                MaybeEncryptedRoomKey::Encrypted { used_session, share_info, message } => {
363                    result_builder.on_successful_encryption(&result.device, *used_session, message);
364
365                    let user_id = result.device.user_id().to_owned();
366                    let device_id = result.device.device_id().to_owned();
367                    share_infos
368                        .entry(user_id)
369                        .or_insert_with(BTreeMap::new)
370                        .insert(device_id, *share_info);
371                }
372                MaybeEncryptedRoomKey::MissingSession => {
373                    result_builder.on_missing_session(result.device);
374                }
375            }
376        }
377
378        Ok((result_builder.into_result(), share_infos))
379    }
380
381    /// Given a list of user and an outbound session, return the list of users
382    /// and their devices that this session should be shared with.
383    ///
384    /// Returns information indicating whether the session needs to be rotated
385    /// and the list of users/devices that should receive or not the session
386    /// (with withheld reason).
387    #[instrument(skip_all)]
388    pub async fn collect_session_recipients(
389        &self,
390        users: impl Iterator<Item = &UserId>,
391        settings: &EncryptionSettings,
392        outbound: &OutboundGroupSession,
393    ) -> OlmResult<CollectRecipientsResult> {
394        share_strategy::collect_session_recipients(&self.store, users, settings, outbound).await
395    }
396
397    async fn encrypt_request(
398        store: Arc<CryptoStoreWrapper>,
399        chunk: Vec<DeviceData>,
400        outbound: OutboundGroupSession,
401        sessions: GroupSessionCache,
402    ) -> OlmResult<(Vec<Session>, Vec<(DeviceData, WithheldCode)>)> {
403        let (result, share_infos) =
404            Self::encrypt_session_for(store, outbound.clone(), chunk).await?;
405
406        if let Some(request) = result.to_device_request {
407            let id = request.txn_id.clone();
408            outbound.add_request(id.clone(), request.into(), share_infos);
409            sessions.mark_as_being_shared(id, outbound.clone());
410        }
411
412        Ok((result.updated_olm_sessions, result.no_olm_devices))
413    }
414
415    pub(crate) fn session_cache(&self) -> GroupSessionCache {
416        self.sessions.clone()
417    }
418
419    async fn maybe_rotate_group_session(
420        &self,
421        should_rotate: bool,
422        room_id: &RoomId,
423        outbound: OutboundGroupSession,
424        encryption_settings: EncryptionSettings,
425        changes: &mut Changes,
426        own_device: Option<Device>,
427    ) -> OlmResult<OutboundGroupSession> {
428        Ok(if should_rotate {
429            let old_session_id = outbound.session_id();
430
431            let (outbound, mut inbound) = self
432                .create_outbound_group_session(room_id, encryption_settings, SenderData::unknown())
433                .await?;
434
435            // Use our own device info to populate the SenderData that validates the
436            // InboundGroupSession that we create as a pair to the OutboundGroupSession we
437            // are sending out.
438            let own_sender_data = if let Some(device) = own_device {
439                SenderDataFinder::find_using_device_data(
440                    &self.store,
441                    device.inner.clone(),
442                    &inbound,
443                )
444                .await?
445            } else {
446                error!("Unable to find our own device!");
447                SenderData::unknown()
448            };
449            inbound.sender_data = own_sender_data;
450
451            changes.outbound_group_sessions.push(outbound.clone());
452            changes.inbound_group_sessions.push(inbound);
453
454            debug!(
455                old_session_id = old_session_id,
456                session_id = outbound.session_id(),
457                "A user or device has left the room since we last sent a \
458                message, or the encryption settings have changed. Rotating the \
459                room key.",
460            );
461
462            outbound
463        } else {
464            outbound
465        })
466    }
467
468    async fn encrypt_for_devices(
469        &self,
470        recipient_devices: Vec<DeviceData>,
471        group_session: &OutboundGroupSession,
472        changes: &mut Changes,
473    ) -> OlmResult<Vec<(DeviceData, WithheldCode)>> {
474        // If we have some recipients, log them here.
475        if !recipient_devices.is_empty() {
476            let recipients = recipient_list_to_users_and_devices(&recipient_devices);
477
478            // If there are new recipients we need to persist the outbound group
479            // session as the to-device requests are persisted with the session.
480            changes.outbound_group_sessions = vec![group_session.clone()];
481
482            let message_index = group_session.message_index().await;
483
484            info!(
485                ?recipients,
486                message_index,
487                room_id = ?group_session.room_id(),
488                session_id = group_session.session_id(),
489                "Trying to encrypt a room key",
490            );
491        }
492
493        // Chunk the recipients out so each to-device request will contain a
494        // limited amount of to-device messages.
495        //
496        // Create concurrent tasks for each chunk of recipients.
497        let tasks: Vec<_> = recipient_devices
498            .chunks(Self::MAX_TO_DEVICE_MESSAGES)
499            .map(|chunk| {
500                spawn(Self::encrypt_request(
501                    self.store.crypto_store(),
502                    chunk.to_vec(),
503                    group_session.clone(),
504                    self.sessions.clone(),
505                ))
506            })
507            .collect();
508
509        let mut withheld_devices = Vec::new();
510
511        // Wait for all the tasks to finish up and queue up the Olm session that
512        // was used to encrypt the room key to be persisted again. This is
513        // needed because each encryption step will mutate the Olm session,
514        // ratcheting its state forward.
515        for result in join_all(tasks).await {
516            let result = result.expect("Encryption task panicked");
517
518            let (used_sessions, failed_no_olm) = result?;
519
520            changes.sessions.extend(used_sessions);
521            withheld_devices.extend(failed_no_olm);
522        }
523
524        Ok(withheld_devices)
525    }
526
527    fn is_withheld_to(
528        &self,
529        group_session: &OutboundGroupSession,
530        device: &DeviceData,
531        code: &WithheldCode,
532    ) -> bool {
533        // The `m.no_olm` withheld code is special because it is supposed to be sent
534        // only once for a given device. The `Device` remembers the flag if we
535        // already sent a `m.no_olm` to this particular device so let's check
536        // that first.
537        //
538        // Keep in mind that any outbound group session might want to send this code to
539        // the device. So we need to check if any of our outbound group sessions
540        // is attempting to send the code to the device.
541        //
542        // This still has a slight race where some other thread might remove the
543        // outbound group session while a third is marking the device as having
544        // received the code.
545        //
546        // Since nothing terrible happens if we do end up sending the withheld code
547        // twice, and removing the race requires us to lock the store because the
548        // `OutboundGroupSession` and the `Device` both interact with the flag we'll
549        // leave it be.
550        if code == &WithheldCode::NoOlm {
551            device.was_withheld_code_sent() || self.sessions.has_session_withheld_to(device, code)
552        } else {
553            group_session.sharing_view().is_withheld_to(device, code)
554        }
555    }
556
557    fn handle_withheld_devices(
558        &self,
559        group_session: &OutboundGroupSession,
560        withheld_devices: Vec<(DeviceData, WithheldCode)>,
561    ) -> OlmResult<()> {
562        // Convert a withheld code for the group session into a to-device event content.
563        let to_content = |code| {
564            let content = group_session.withheld_code(code);
565            Raw::new(&content).expect("We can always serialize a withheld content info").cast()
566        };
567
568        // Helper to convert a chunk of device and withheld code pairs into a to-device
569        // request and it's accompanying share info.
570        let chunk_to_request = |chunk| {
571            let mut messages = BTreeMap::new();
572            let mut share_infos = BTreeMap::new();
573
574            for (device, code) in chunk {
575                let device: DeviceData = device;
576                let code: WithheldCode = code;
577
578                let user_id = device.user_id().to_owned();
579                let device_id = device.device_id().to_owned();
580
581                let share_info = ShareInfo::new_withheld(code.to_owned());
582                let content = to_content(code);
583
584                messages
585                    .entry(user_id.to_owned())
586                    .or_insert_with(BTreeMap::new)
587                    .insert(DeviceIdOrAllDevices::DeviceId(device_id.to_owned()), content);
588
589                share_infos
590                    .entry(user_id)
591                    .or_insert_with(BTreeMap::new)
592                    .insert(device_id, share_info);
593            }
594
595            let txn_id = TransactionId::new();
596
597            let request = ToDeviceRequest {
598                event_type: ToDeviceEventType::from("m.room_key.withheld"),
599                txn_id,
600                messages,
601            };
602
603            (request, share_infos)
604        };
605
606        let result: Vec<_> = withheld_devices
607            .into_iter()
608            .filter(|(device, code)| !self.is_withheld_to(group_session, device, code))
609            .chunks(Self::MAX_TO_DEVICE_MESSAGES)
610            .into_iter()
611            .map(chunk_to_request)
612            .collect();
613
614        for (request, share_info) in result {
615            if !request.messages.is_empty() {
616                let txn_id = request.txn_id.to_owned();
617                group_session.add_request(txn_id.to_owned(), request.into(), share_info);
618
619                self.sessions.mark_as_being_shared(txn_id, group_session.clone());
620            }
621        }
622
623        Ok(())
624    }
625
626    fn log_room_key_sharing_result(requests: &[Arc<ToDeviceRequest>]) {
627        for request in requests {
628            let message_list = Self::to_device_request_to_log_list(request);
629            info!(
630                request_id = ?request.txn_id,
631                ?message_list,
632                "Created batch of to-device messages of type {}",
633                request.event_type
634            );
635        }
636    }
637
638    /// Given a to-device request, build a recipient map suitable for logging.
639    ///
640    /// Returns a list of triples of (message_id, user id, device_id).
641    fn to_device_request_to_log_list(
642        request: &Arc<ToDeviceRequest>,
643    ) -> Vec<(String, String, String)> {
644        #[derive(serde::Deserialize)]
645        struct ContentStub<'a> {
646            #[serde(borrow, default, rename = "org.matrix.msgid")]
647            message_id: Option<&'a str>,
648        }
649
650        let mut result: Vec<(String, String, String)> = Vec::new();
651
652        for (user_id, device_map) in &request.messages {
653            for (device, content) in device_map {
654                let message_id: Option<&str> = content
655                    .deserialize_as_unchecked::<ContentStub<'_>>()
656                    .expect("We should be able to deserialize the content we generated")
657                    .message_id;
658
659                result.push((
660                    message_id.unwrap_or("<undefined>").to_owned(),
661                    user_id.to_string(),
662                    device.to_string(),
663                ));
664            }
665        }
666        result
667    }
668
669    /// Get to-device requests to share a room key with users in a room.
670    ///
671    /// # Arguments
672    ///
673    /// `room_id` - The room id of the room where the room key will be used.
674    ///
675    /// `users` - The list of users that should receive the room key.
676    ///
677    /// `encryption_settings` - The settings that should be used for
678    /// the room key.
679    #[instrument(skip(self, users, encryption_settings), fields(session_id))]
680    pub async fn share_room_key(
681        &self,
682        room_id: &RoomId,
683        users: impl Iterator<Item = &UserId>,
684        encryption_settings: impl Into<EncryptionSettings>,
685    ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
686        trace!("Checking if a room key needs to be shared");
687
688        let account = self.store.static_account();
689        let device = self.store.get_device(account.user_id(), account.device_id()).await?;
690
691        let encryption_settings = encryption_settings.into();
692        let mut changes = Changes::default();
693
694        // Try to get an existing session or create a new one.
695        let (outbound, inbound) = self
696            .get_or_create_outbound_session(
697                room_id,
698                encryption_settings.clone(),
699                SenderData::unknown(),
700            )
701            .await?;
702        tracing::Span::current().record("session_id", outbound.session_id());
703
704        // Having an inbound group session here means that we created a new
705        // group session pair, which we then need to store.
706        if let Some(mut inbound) = inbound {
707            // Use our own device info to populate the SenderData that validates the
708            // InboundGroupSession that we create as a pair to the OutboundGroupSession we
709            // are sending out.
710            let own_sender_data = if let Some(device) = &device {
711                SenderDataFinder::find_using_device_data(
712                    &self.store,
713                    device.inner.clone(),
714                    &inbound,
715                )
716                .await?
717            } else {
718                error!("Unable to find our own device!");
719                SenderData::unknown()
720            };
721            inbound.sender_data = own_sender_data;
722
723            changes.outbound_group_sessions.push(outbound.clone());
724            changes.inbound_group_sessions.push(inbound);
725        }
726
727        // Collect the recipient devices and check if either the settings
728        // or the recipient list changed in a way that requires the
729        // session to be rotated.
730        let CollectRecipientsResult { should_rotate, devices, mut withheld_devices } =
731            self.collect_session_recipients(users, &encryption_settings, &outbound).await?;
732
733        let outbound = self
734            .maybe_rotate_group_session(
735                should_rotate,
736                room_id,
737                outbound,
738                encryption_settings,
739                &mut changes,
740                device,
741            )
742            .await?;
743
744        // Filter out the devices that already received this room key or have a
745        // to-device message already queued up.
746        let devices: Vec<_> = devices
747            .into_iter()
748            .flat_map(|(_, d)| {
749                d.into_iter().filter(|d| match outbound.sharing_view().get_share_state(d) {
750                    ShareState::NotShared => true,
751                    ShareState::Shared { message_index: _, olm_wedging_index } => {
752                        // If the recipient device's Olm wedging index is higher
753                        // than the value that we stored with the session, that
754                        // means that they tried to unwedge the session since we
755                        // last shared the room key.  So we re-share it with
756                        // them in case they weren't able to decrypt the room
757                        // key the last time we shared it.
758                        olm_wedging_index < d.olm_wedging_index
759                    }
760                    _ => false,
761                })
762            })
763            .collect();
764
765        // The `encrypt_for_devices()` method adds the to-device requests that will send
766        // out the room key to the `OutboundGroupSession`. It doesn't do that
767        // for the m.room_key_withheld events since we might have more of those
768        // coming from the `collect_session_recipients()` method. Instead they get
769        // returned by the method.
770        let unable_to_encrypt_devices =
771            self.encrypt_for_devices(devices, &outbound, &mut changes).await?;
772
773        // Merge the withheld recipients.
774        withheld_devices.extend(unable_to_encrypt_devices);
775
776        // Now handle and add the withheld recipients to the resulting requests to the
777        // `OutboundGroupSession`.
778        self.handle_withheld_devices(&outbound, withheld_devices)?;
779
780        // The to-device requests get added to the outbound group session, this
781        // way we're making sure that they are persisted and scoped to the
782        // session.
783        let requests = outbound.pending_requests();
784
785        if requests.is_empty() {
786            if !outbound.shared() {
787                debug!("The room key doesn't need to be shared with anyone. Marking as shared.");
788
789                outbound.mark_as_shared();
790                changes.outbound_group_sessions.push(outbound.clone());
791            }
792        } else {
793            Self::log_room_key_sharing_result(&requests)
794        }
795
796        // Persist any changes we might have collected.
797        if !changes.is_empty() {
798            let session_count = changes.sessions.len();
799
800            self.store.save_changes(changes).await?;
801
802            trace!(
803                session_count = session_count,
804                "Stored the changed sessions after encrypting an room key"
805            );
806        }
807
808        Ok(requests)
809    }
810
811    /// Collect the devices belonging to the given user, and send the details of
812    /// a room key bundle to those devices.
813    ///
814    /// Returns a list of to-device requests which must be sent.
815    ///
816    /// For security reasons, only "safe" [`CollectStrategy`]s are supported, in
817    /// which the recipient must have signed their
818    /// devices. [`CollectStrategy::AllDevices`] and
819    /// [`CollectStrategy::ErrorOnVerifiedUserProblem`] are "unsafe" in this
820    /// respect,and are treated the same as
821    /// [`CollectStrategy::IdentityBasedStrategy`].
822    #[instrument(skip(self, bundle_data))]
823    pub async fn share_room_key_bundle_data(
824        &self,
825        user_id: &UserId,
826        collect_strategy: &CollectStrategy,
827        bundle_data: RoomKeyBundleContent,
828    ) -> OlmResult<Vec<ToDeviceRequest>> {
829        // Only allow conservative sharing strategies
830        let collect_strategy = match collect_strategy {
831            CollectStrategy::AllDevices | CollectStrategy::ErrorOnVerifiedUserProblem => {
832                warn!(
833                    "Ignoring request to use unsafe sharing strategy {collect_strategy:?} \
834                     for room key history sharing",
835                );
836                &CollectStrategy::IdentityBasedStrategy
837            }
838            CollectStrategy::IdentityBasedStrategy | CollectStrategy::OnlyTrustedDevices => {
839                collect_strategy
840            }
841        };
842
843        let mut changes = Changes::default();
844
845        let CollectRecipientsResult { devices, .. } =
846            share_strategy::collect_recipients_for_share_strategy(
847                &self.store,
848                iter::once(user_id),
849                collect_strategy,
850                None,
851            )
852            .await?;
853
854        let devices = devices.into_values().flatten().collect();
855        let event_type = bundle_data.event_type().to_owned();
856        let (requests, _) = self
857            .encrypt_content_for_devices(devices, &event_type, bundle_data, &mut changes)
858            .await?;
859
860        // TODO: figure out what to do with withheld devices
861
862        // Persist any changes we might have collected.
863        if !changes.is_empty() {
864            let session_count = changes.sessions.len();
865
866            self.store.save_changes(changes).await?;
867
868            trace!(
869                session_count = session_count,
870                "Stored the changed sessions after encrypting an room key"
871            );
872        }
873
874        Ok(requests)
875    }
876
877    /// Encrypt the given content for the given devices and build to-device
878    /// requests to send the encrypted content to them.
879    ///
880    /// Returns a tuple containing (1) the list of to-device requests, and (2)
881    /// the list of devices that we could not find an olm session for (so
882    /// need a withheld message).
883    pub(crate) async fn encrypt_content_for_devices(
884        &self,
885        recipient_devices: Vec<DeviceData>,
886        event_type: &str,
887        content: impl Serialize + Clone + Send + 'static,
888        changes: &mut Changes,
889    ) -> OlmResult<(Vec<ToDeviceRequest>, Vec<(DeviceData, WithheldCode)>)> {
890        let recipients = recipient_list_to_users_and_devices(&recipient_devices);
891        info!(?recipients, "Encrypting content of type {}", event_type);
892
893        // Chunk the recipients out so each to-device request will contain a
894        // limited amount of to-device messages.
895        //
896        // Create concurrent tasks for each chunk of recipients.
897        let tasks: Vec<_> = recipient_devices
898            .chunks(Self::MAX_TO_DEVICE_MESSAGES)
899            .map(|chunk| {
900                spawn(
901                    encrypt_content_for_devices(
902                        self.store.crypto_store(),
903                        event_type.to_owned(),
904                        content.clone(),
905                        chunk.to_vec(),
906                    )
907                    .in_current_span(),
908                )
909            })
910            .collect();
911
912        let mut no_olm_devices = Vec::new();
913        let mut to_device_requests = Vec::new();
914
915        // Wait for all the tasks to finish up and queue up the Olm session that
916        // was used to encrypt the room key to be persisted again. This is
917        // needed because each encryption step will mutate the Olm session,
918        // ratcheting its state forward.
919        for result in join_all(tasks).await {
920            let result = result.expect("Encryption task panicked")?;
921            if let Some(request) = result.to_device_request {
922                to_device_requests.push(request);
923            }
924            changes.sessions.extend(result.updated_olm_sessions);
925            no_olm_devices.extend(result.no_olm_devices);
926        }
927
928        Ok((to_device_requests, no_olm_devices))
929    }
930}
931
932/// Helper for [`GroupSessionManager::encrypt_content_for_devices`].
933///
934/// Encrypt the given content for the given devices and build a to-device
935/// request to send the encrypted content to them.
936///
937/// See also [`GroupSessionManager::encrypt_session_for`], which is similar
938/// but applies specifically to `m.room_key` messages that hold a megolm
939/// session key.
940async fn encrypt_content_for_devices(
941    store: Arc<CryptoStoreWrapper>,
942    event_type: String,
943    content: impl Serialize + Clone + Send + 'static,
944    devices: Vec<DeviceData>,
945) -> OlmResult<EncryptForDevicesResult> {
946    let mut result_builder = EncryptForDevicesResultBuilder::default();
947
948    async fn encrypt(
949        store: Arc<CryptoStoreWrapper>,
950        device: DeviceData,
951        event_type: String,
952        bundle_data: impl Serialize,
953    ) -> OlmResult<(Session, Raw<ToDeviceEncryptedEventContent>)> {
954        device.encrypt(store.as_ref(), &event_type, bundle_data).await
955    }
956
957    let tasks = devices.iter().map(|device| {
958        spawn(
959            encrypt(store.clone(), device.clone(), event_type.clone(), content.clone())
960                .in_current_span(),
961        )
962    });
963
964    let results = join_all(tasks).await;
965
966    for (device, result) in zip(devices, results) {
967        let encryption_result = result.expect("Encryption task panicked");
968
969        match encryption_result {
970            Ok((used_session, message)) => {
971                result_builder.on_successful_encryption(&device, used_session, message.cast());
972            }
973            Err(OlmError::MissingSession) => {
974                // There is no established Olm session for this device
975                result_builder.on_missing_session(device);
976            }
977            Err(e) => return Err(e),
978        }
979    }
980
981    Ok(result_builder.into_result())
982}
983
984/// Result of [`GroupSessionManager::encrypt_session_for`] and
985/// [`encrypt_content_for_devices`].
986#[derive(Debug)]
987struct EncryptForDevicesResult {
988    /// The request to send the to-device messages containing the encrypted
989    /// payload, if any devices were found.
990    to_device_request: Option<ToDeviceRequest>,
991
992    /// The devices which lack an Olm session and therefore need a withheld code
993    no_olm_devices: Vec<(DeviceData, WithheldCode)>,
994
995    /// The Olm sessions which were used to encrypt the requests and now need
996    /// persisting to the store.
997    updated_olm_sessions: Vec<Session>,
998}
999
1000/// A helper for building [`EncryptForDevicesResult`]
1001#[derive(Debug, Default)]
1002struct EncryptForDevicesResultBuilder {
1003    /// The payloads of the to-device messages
1004    messages: BTreeMap<OwnedUserId, BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>>>,
1005
1006    /// The devices which lack an Olm session and therefore need a withheld code
1007    no_olm_devices: Vec<(DeviceData, WithheldCode)>,
1008
1009    /// The Olm sessions which were used to encrypt the requests and now need
1010    /// persisting to the store.
1011    updated_olm_sessions: Vec<Session>,
1012}
1013
1014impl EncryptForDevicesResultBuilder {
1015    /// Record a successful encryption. The encrypted message is added to the
1016    /// list to be sent, and the olm session is added to the list of those
1017    /// that have been modified.
1018    pub fn on_successful_encryption(
1019        &mut self,
1020        device: &DeviceData,
1021        used_session: Session,
1022        message: Raw<AnyToDeviceEventContent>,
1023    ) {
1024        self.updated_olm_sessions.push(used_session);
1025
1026        self.messages
1027            .entry(device.user_id().to_owned())
1028            .or_default()
1029            .insert(DeviceIdOrAllDevices::DeviceId(device.device_id().to_owned()), message);
1030    }
1031
1032    /// Record a device which didn't have an active Olm session.
1033    pub fn on_missing_session(&mut self, device: DeviceData) {
1034        self.no_olm_devices.push((device, WithheldCode::NoOlm));
1035    }
1036
1037    /// Transform the accumulated results into an [`EncryptForDevicesResult`],
1038    /// wrapping the messages, if any, into a `ToDeviceRequest`.
1039    pub fn into_result(self) -> EncryptForDevicesResult {
1040        let EncryptForDevicesResultBuilder { updated_olm_sessions, no_olm_devices, messages } =
1041            self;
1042
1043        let mut encrypt_for_devices_result = EncryptForDevicesResult {
1044            to_device_request: None,
1045            updated_olm_sessions,
1046            no_olm_devices,
1047        };
1048
1049        if !messages.is_empty() {
1050            let request = ToDeviceRequest {
1051                event_type: ToDeviceEventType::RoomEncrypted,
1052                txn_id: TransactionId::new(),
1053                messages,
1054            };
1055            trace!(
1056                recipient_count = request.message_count(),
1057                transaction_id = ?request.txn_id,
1058                "Created a to-device request carrying room keys",
1059            );
1060            encrypt_for_devices_result.to_device_request = Some(request);
1061        }
1062
1063        encrypt_for_devices_result
1064    }
1065}
1066
1067fn recipient_list_to_users_and_devices(
1068    recipient_devices: &[DeviceData],
1069) -> BTreeMap<&UserId, BTreeSet<&DeviceId>> {
1070    #[allow(unknown_lints, clippy::unwrap_or_default)] // false positive
1071    recipient_devices.iter().fold(BTreeMap::new(), |mut acc, d| {
1072        acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id());
1073        acc
1074    })
1075}
1076
1077#[cfg(test)]
1078mod tests {
1079    use std::{
1080        collections::{BTreeMap, BTreeSet},
1081        iter,
1082        ops::Deref,
1083        sync::Arc,
1084    };
1085
1086    use assert_matches2::assert_let;
1087    use matrix_sdk_common::deserialized_responses::{ProcessedToDeviceEvent, WithheldCode};
1088    use matrix_sdk_test::{async_test, ruma_response_from_json};
1089    use ruma::{
1090        api::client::{
1091            keys::{claim_keys, get_keys, upload_keys},
1092            to_device::send_event_to_device::v3::Response as ToDeviceResponse,
1093        },
1094        device_id,
1095        events::room::{
1096            history_visibility::HistoryVisibility, EncryptedFileInit, JsonWebKey, JsonWebKeyInit,
1097        },
1098        owned_room_id, room_id,
1099        serde::Base64,
1100        to_device::DeviceIdOrAllDevices,
1101        user_id, DeviceId, OneTimeKeyAlgorithm, OwnedMxcUri, TransactionId, UInt, UserId,
1102    };
1103    use serde_json::{json, Value};
1104
1105    use crate::{
1106        identities::DeviceData,
1107        machine::{
1108            test_helpers::get_machine_pair_with_setup_sessions_test_helper, EncryptionSyncChanges,
1109        },
1110        olm::{Account, SenderData},
1111        session_manager::{group_sessions::CollectRecipientsResult, CollectStrategy},
1112        types::{
1113            events::{
1114                room::encrypted::EncryptedToDeviceEvent,
1115                room_key_bundle::RoomKeyBundleContent,
1116                room_key_withheld::RoomKeyWithheldContent::{self, MegolmV1AesSha2},
1117            },
1118            requests::ToDeviceRequest,
1119            DeviceKeys, EventEncryptionAlgorithm,
1120        },
1121        DecryptionSettings, EncryptionSettings, LocalTrust, OlmMachine, TrustRequirement,
1122    };
1123
1124    fn alice_id() -> &'static UserId {
1125        user_id!("@alice:example.org")
1126    }
1127
1128    fn alice_device_id() -> &'static DeviceId {
1129        device_id!("JLAFKJWSCS")
1130    }
1131
1132    /// Returns a /keys/query response for user "@example:localhost"
1133    fn keys_query_response() -> get_keys::v3::Response {
1134        let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_query.json");
1135        let data: Value = serde_json::from_slice(data).unwrap();
1136        ruma_response_from_json(&data)
1137    }
1138
1139    fn bob_keys_query_response() -> get_keys::v3::Response {
1140        let data = json!({
1141            "device_keys": {
1142                "@bob:localhost": {
1143                    "BOBDEVICE": {
1144                        "user_id": "@bob:localhost",
1145                        "device_id": "BOBDEVICE",
1146                        "algorithms": [
1147                            "m.olm.v1.curve25519-aes-sha2",
1148                            "m.megolm.v1.aes-sha2",
1149                            "m.megolm.v2.aes-sha2"
1150                        ],
1151                        "keys": {
1152                            "curve25519:BOBDEVICE": "QzXDFZj0Pt5xG4r11XGSrqE4mnFOTgRM5pz7n3tzohU",
1153                            "ed25519:BOBDEVICE": "T7QMEXcEo/NfiC/8doVHT+2XnMm0pDpRa27bmE8PlPI"
1154                        },
1155                        "signatures": {
1156                            "@bob:localhost": {
1157                                "ed25519:BOBDEVICE": "1Ee9J02KoVf4DKhT+LkurpZJEygiznqpgkT4lqvMTLtZyzShsVTnwmoMPttuGcJkLp9lMK1egveNYCEaYP80Cw"
1158                            }
1159                        }
1160                    }
1161                }
1162            }
1163        });
1164        ruma_response_from_json(&data)
1165    }
1166
1167    /// Returns a keys claim response for device `BOBDEVICE` of user
1168    /// `@bob:localhost`.
1169    fn bob_one_time_key() -> claim_keys::v3::Response {
1170        let data = json!({
1171            "failures": {},
1172            "one_time_keys":{
1173                "@bob:localhost":{
1174                    "BOBDEVICE":{
1175                      "signed_curve25519:AAAAAAAAAAA": {
1176                          "key":"bm1olfbksjC5SwKxCLLK4XaINCA0FwR/155J85gIpCk",
1177                          "signatures":{
1178                              "@bob:localhost":{
1179                                  "ed25519:BOBDEVICE":"BKyS/+EV76zdZkWgny2D0svZ0ycS3etfyHCrsDgm7MYe166HqQmSoX29HsjGLvE/5F+Sg2zW7RJileUvquPwDA"
1180                              }
1181                          }
1182                      }
1183                    }
1184                }
1185            }
1186        });
1187        ruma_response_from_json(&data)
1188    }
1189
1190    /// Returns a key claim response for device `NMMBNBUSNR` of user
1191    /// `@example2:localhost`
1192    fn keys_claim_response() -> claim_keys::v3::Response {
1193        let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_claim.json");
1194        let data: Value = serde_json::from_slice(data).unwrap();
1195        ruma_response_from_json(&data)
1196    }
1197
1198    async fn machine_with_user_test_helper(user_id: &UserId, device_id: &DeviceId) -> OlmMachine {
1199        let keys_query = keys_query_response();
1200        let txn_id = TransactionId::new();
1201
1202        let machine = OlmMachine::new(user_id, device_id).await;
1203
1204        // complete a /keys/query and /keys/claim for @example:localhost
1205        machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1206        let (txn_id, _keys_claim_request) = machine
1207            .get_missing_sessions(iter::once(user_id!("@example:localhost")))
1208            .await
1209            .unwrap()
1210            .unwrap();
1211        let keys_claim = keys_claim_response();
1212        machine.mark_request_as_sent(&txn_id, &keys_claim).await.unwrap();
1213
1214        // complete a /keys/query and /keys/claim for @bob:localhost
1215        machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1216        let (txn_id, _keys_claim_request) = machine
1217            .get_missing_sessions(iter::once(user_id!("@bob:localhost")))
1218            .await
1219            .unwrap()
1220            .unwrap();
1221        machine.mark_request_as_sent(&txn_id, &bob_one_time_key()).await.unwrap();
1222
1223        machine
1224    }
1225
1226    async fn machine() -> OlmMachine {
1227        machine_with_user_test_helper(alice_id(), alice_device_id()).await
1228    }
1229
1230    async fn machine_with_shared_room_key_test_helper() -> OlmMachine {
1231        let machine = machine().await;
1232        let room_id = room_id!("!test:localhost");
1233        let keys_claim = keys_claim_response();
1234
1235        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1236        let requests =
1237            machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
1238
1239        let outbound =
1240            machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1241
1242        assert!(!outbound.pending_requests().is_empty());
1243        assert!(!outbound.shared());
1244
1245        let response = ToDeviceResponse::new();
1246        for request in requests {
1247            machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1248        }
1249
1250        assert!(outbound.shared());
1251        assert!(outbound.pending_requests().is_empty());
1252
1253        machine
1254    }
1255
1256    #[async_test]
1257    async fn test_sharing() {
1258        let machine = machine().await;
1259        let room_id = room_id!("!test:localhost");
1260        let keys_claim = keys_claim_response();
1261
1262        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1263
1264        let requests =
1265            machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
1266
1267        let event_count: usize = requests
1268            .iter()
1269            .filter(|r| r.event_type == "m.room.encrypted".into())
1270            .map(|r| r.message_count())
1271            .sum();
1272
1273        // The keys claim response has a couple of one-time keys with invalid
1274        // signatures, thus only 148 sessions are actually created, we check
1275        // that all 148 valid sessions get an room key.
1276        assert_eq!(event_count, 148);
1277
1278        let withheld_count: usize = requests
1279            .iter()
1280            .filter(|r| r.event_type == "m.room_key.withheld".into())
1281            .map(|r| r.message_count())
1282            .sum();
1283        assert_eq!(withheld_count, 2);
1284    }
1285
1286    fn count_withheld_from(requests: &[Arc<ToDeviceRequest>], code: WithheldCode) -> usize {
1287        requests
1288            .iter()
1289            .filter(|r| r.event_type == "m.room_key.withheld".into())
1290            .map(|r| {
1291                let mut count = 0;
1292                // count targets
1293                for message in r.messages.values() {
1294                    message.iter().for_each(|(_, content)| {
1295                        let withheld: RoomKeyWithheldContent =
1296                            content.deserialize_as_unchecked::<RoomKeyWithheldContent>().unwrap();
1297
1298                        if let MegolmV1AesSha2(content) = withheld {
1299                            if content.withheld_code() == code {
1300                                count += 1;
1301                            }
1302                        }
1303                    })
1304                }
1305                count
1306            })
1307            .sum()
1308    }
1309
1310    #[async_test]
1311    async fn test_no_olm_sent_once() {
1312        let machine = machine().await;
1313        let keys_claim = keys_claim_response();
1314
1315        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1316
1317        let first_room_id = room_id!("!test:localhost");
1318
1319        let requests = machine
1320            .share_room_key(first_room_id, users.to_owned(), EncryptionSettings::default())
1321            .await
1322            .unwrap();
1323
1324        // there will be two no_olm
1325        let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1326        assert_eq!(withheld_count, 2);
1327
1328        // Re-sharing same session while request has not been sent should not produces
1329        // withheld
1330        let new_requests = machine
1331            .share_room_key(first_room_id, users, EncryptionSettings::default())
1332            .await
1333            .unwrap();
1334        let withheld_count: usize = count_withheld_from(&new_requests, WithheldCode::NoOlm);
1335        // No additional request was added, still the 2 already pending
1336        assert_eq!(withheld_count, 2);
1337
1338        let response = ToDeviceResponse::new();
1339        for request in requests {
1340            machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1341        }
1342
1343        // The fact that an olm was sent should be remembered even if sharing another
1344        // session in an other room.
1345        let second_room_id = room_id!("!other:localhost");
1346        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1347        let requests = machine
1348            .share_room_key(second_room_id, users, EncryptionSettings::default())
1349            .await
1350            .unwrap();
1351
1352        let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1353        assert_eq!(withheld_count, 0);
1354
1355        // Help how do I simulate the creation of a new session for the device
1356        // with no session now?
1357    }
1358
1359    #[async_test]
1360    async fn test_ratcheted_sharing() {
1361        let machine = machine_with_shared_room_key_test_helper().await;
1362
1363        let room_id = room_id!("!test:localhost");
1364        let late_joiner = user_id!("@bob:localhost");
1365        let keys_claim = keys_claim_response();
1366
1367        let mut users: BTreeSet<_> = keys_claim.one_time_keys.keys().map(Deref::deref).collect();
1368        users.insert(late_joiner);
1369
1370        let requests = machine
1371            .share_room_key(room_id, users.into_iter(), EncryptionSettings::default())
1372            .await
1373            .unwrap();
1374
1375        let event_count: usize = requests
1376            .iter()
1377            .filter(|r| r.event_type == "m.room.encrypted".into())
1378            .map(|r| r.message_count())
1379            .sum();
1380        let outbound =
1381            machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1382
1383        assert_eq!(event_count, 1);
1384        assert!(!outbound.pending_requests().is_empty());
1385    }
1386
1387    #[async_test]
1388    async fn test_changing_encryption_settings() {
1389        let machine = machine_with_shared_room_key_test_helper().await;
1390        let room_id = room_id!("!test:localhost");
1391        let keys_claim = keys_claim_response();
1392
1393        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1394        let outbound =
1395            machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1396
1397        let CollectRecipientsResult { should_rotate, .. } = machine
1398            .inner
1399            .group_session_manager
1400            .collect_session_recipients(users.clone(), &EncryptionSettings::default(), &outbound)
1401            .await
1402            .unwrap();
1403
1404        assert!(!should_rotate);
1405
1406        let settings = EncryptionSettings {
1407            history_visibility: HistoryVisibility::Invited,
1408            ..Default::default()
1409        };
1410
1411        let CollectRecipientsResult { should_rotate, .. } = machine
1412            .inner
1413            .group_session_manager
1414            .collect_session_recipients(users.clone(), &settings, &outbound)
1415            .await
1416            .unwrap();
1417
1418        assert!(should_rotate);
1419
1420        let settings = EncryptionSettings {
1421            algorithm: EventEncryptionAlgorithm::from("m.megolm.v2.aes-sha2"),
1422            ..Default::default()
1423        };
1424
1425        let CollectRecipientsResult { should_rotate, .. } = machine
1426            .inner
1427            .group_session_manager
1428            .collect_session_recipients(users, &settings, &outbound)
1429            .await
1430            .unwrap();
1431
1432        assert!(should_rotate);
1433    }
1434
1435    #[async_test]
1436    async fn test_key_recipient_collecting() {
1437        // The user id comes from the fact that the keys_query.json file uses
1438        // this one.
1439        let user_id = user_id!("@example:localhost");
1440        let device_id = device_id!("TESTDEVICE");
1441        let room_id = room_id!("!test:localhost");
1442
1443        let machine = machine_with_user_test_helper(user_id, device_id).await;
1444
1445        let (outbound, _) = machine
1446            .inner
1447            .group_session_manager
1448            .get_or_create_outbound_session(
1449                room_id,
1450                EncryptionSettings::default(),
1451                SenderData::unknown(),
1452            )
1453            .await
1454            .expect("We should be able to create a new session");
1455        let history_visibility = HistoryVisibility::Joined;
1456        let settings = EncryptionSettings { history_visibility, ..Default::default() };
1457
1458        let users = [user_id].into_iter();
1459
1460        let CollectRecipientsResult { devices: recipients, .. } = machine
1461            .inner
1462            .group_session_manager
1463            .collect_session_recipients(users, &settings, &outbound)
1464            .await
1465            .expect("We should be able to collect the session recipients");
1466
1467        assert!(!recipients[user_id].is_empty());
1468
1469        // Make sure that our own device isn't part of the recipients.
1470        assert!(!recipients[user_id]
1471            .iter()
1472            .any(|d| d.user_id() == user_id && d.device_id() == device_id));
1473
1474        let settings = EncryptionSettings {
1475            sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1476            ..Default::default()
1477        };
1478        let users = [user_id].into_iter();
1479
1480        let CollectRecipientsResult { devices: recipients, .. } = machine
1481            .inner
1482            .group_session_manager
1483            .collect_session_recipients(users, &settings, &outbound)
1484            .await
1485            .expect("We should be able to collect the session recipients");
1486
1487        assert!(recipients[user_id].is_empty());
1488
1489        let device_id = "AFGUOBTZWM".into();
1490        let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1491        device.set_local_trust(LocalTrust::Verified).await.unwrap();
1492        let users = [user_id].into_iter();
1493
1494        let CollectRecipientsResult { devices: recipients, withheld_devices: withheld, .. } =
1495            machine
1496                .inner
1497                .group_session_manager
1498                .collect_session_recipients(users, &settings, &outbound)
1499                .await
1500                .expect("We should be able to collect the session recipients");
1501
1502        assert!(recipients[user_id]
1503            .iter()
1504            .any(|d| d.user_id() == user_id && d.device_id() == device_id));
1505
1506        let devices = machine.get_user_devices(user_id, None).await.unwrap();
1507        devices
1508            .devices()
1509            // Ignore our own device
1510            .filter(|d| d.device_id() != device_id!("TESTDEVICE"))
1511            .for_each(|d| {
1512                if d.is_blacklisted() {
1513                    assert!(withheld.iter().any(|(dev, w)| {
1514                        dev.device_id() == d.device_id() && w == &WithheldCode::Blacklisted
1515                    }));
1516                } else if !d.is_verified() {
1517                    // the device should then be in the list of withhelds
1518                    assert!(withheld.iter().any(|(dev, w)| {
1519                        dev.device_id() == d.device_id() && w == &WithheldCode::Unverified
1520                    }));
1521                }
1522            });
1523
1524        assert_eq!(149, withheld.len());
1525    }
1526
1527    #[async_test]
1528    async fn test_sharing_withheld_only_trusted() {
1529        let machine = machine().await;
1530        let room_id = room_id!("!test:localhost");
1531        let keys_claim = keys_claim_response();
1532
1533        let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1534        let settings = EncryptionSettings {
1535            sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1536            ..Default::default()
1537        };
1538
1539        // Trust only one
1540        let user_id = user_id!("@example:localhost");
1541        let device_id = "MWFXPINOAO".into();
1542        let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1543        device.set_local_trust(LocalTrust::Verified).await.unwrap();
1544        machine
1545            .get_device(user_id, "MWVTUXDNNM".into(), None)
1546            .await
1547            .unwrap()
1548            .unwrap()
1549            .set_local_trust(LocalTrust::BlackListed)
1550            .await
1551            .unwrap();
1552
1553        let requests = machine.share_room_key(room_id, users, settings).await.unwrap();
1554
1555        // One room key should be sent
1556        let room_key_count =
1557            requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).count();
1558
1559        assert_eq!(1, room_key_count);
1560
1561        let withheld_count =
1562            requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1563        // Can be send in one batch
1564        assert_eq!(1, withheld_count);
1565
1566        let event_count: usize = requests
1567            .iter()
1568            .filter(|r| r.event_type == "m.room_key.withheld".into())
1569            .map(|r| r.message_count())
1570            .sum();
1571
1572        // withhelds are sent in clear so all device should be counted (even if no OTK)
1573        assert_eq!(event_count, 149);
1574
1575        // One should be blacklisted
1576        let has_blacklist =
1577            requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).any(|r| {
1578                let device_key = DeviceIdOrAllDevices::from(device_id!("MWVTUXDNNM").to_owned());
1579                let content = &r.messages[user_id][&device_key];
1580                let withheld: RoomKeyWithheldContent =
1581                    content.deserialize_as_unchecked::<RoomKeyWithheldContent>().unwrap();
1582                if let MegolmV1AesSha2(content) = withheld {
1583                    content.withheld_code() == WithheldCode::Blacklisted
1584                } else {
1585                    false
1586                }
1587            });
1588
1589        assert!(has_blacklist);
1590    }
1591
1592    #[async_test]
1593    async fn test_no_olm_withheld_only_sent_once() {
1594        let keys_query = keys_query_response();
1595        let txn_id = TransactionId::new();
1596
1597        let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1598
1599        machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1600        machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1601
1602        let first_room = room_id!("!test:localhost");
1603        let second_room = room_id!("!test2:localhost");
1604        let bob_id = user_id!("@bob:localhost");
1605
1606        let settings = EncryptionSettings::default();
1607        let users = [bob_id];
1608
1609        let requests = machine
1610            .share_room_key(first_room, users.into_iter(), settings.to_owned())
1611            .await
1612            .unwrap();
1613
1614        // One withheld request should be sent.
1615        let withheld_count =
1616            requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1617
1618        assert_eq!(withheld_count, 1);
1619        assert_eq!(requests.len(), 1);
1620
1621        // On the second room key share attempt we're not sending another `m.no_olm`
1622        // code since the first one is taking care of this.
1623        let second_requests =
1624            machine.share_room_key(second_room, users.into_iter(), settings).await.unwrap();
1625
1626        let withheld_count =
1627            second_requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1628
1629        assert_eq!(withheld_count, 0);
1630        assert_eq!(second_requests.len(), 0);
1631
1632        let response = ToDeviceResponse::new();
1633
1634        let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1635
1636        // The device should be marked as having the `m.no_olm` code received only after
1637        // the request has been marked as sent.
1638        assert!(!device.was_withheld_code_sent());
1639
1640        for request in requests {
1641            machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1642        }
1643
1644        let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1645
1646        assert!(device.was_withheld_code_sent());
1647    }
1648
1649    #[async_test]
1650    async fn test_resend_session_after_unwedging() {
1651        let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1652        assert_let!(Ok(Some((txn_id, device_keys_request))) = machine.upload_device_keys().await);
1653        let device_keys_response = upload_keys::v3::Response::new(BTreeMap::from([(
1654            OneTimeKeyAlgorithm::SignedCurve25519,
1655            UInt::new(device_keys_request.one_time_keys.len() as u64).unwrap(),
1656        )]));
1657        machine.mark_request_as_sent(&txn_id, &device_keys_response).await.unwrap();
1658
1659        let room_id = room_id!("!test:localhost");
1660
1661        let bob_id = user_id!("@bob:localhost");
1662        let bob_account = Account::new(bob_id);
1663        let keys_query_data = json!({
1664            "device_keys": {
1665                "@bob:localhost": {
1666                    bob_account.device_id.clone(): bob_account.device_keys()
1667                }
1668            }
1669        });
1670        let keys_query: get_keys::v3::Response = ruma_response_from_json(&keys_query_data);
1671        let txn_id = TransactionId::new();
1672        machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1673
1674        let alice_device_keys =
1675            device_keys_request.device_keys.unwrap().deserialize_as::<DeviceKeys>().unwrap();
1676        let mut alice_otks = device_keys_request.one_time_keys.iter();
1677        let alice_device = DeviceData::new(alice_device_keys, LocalTrust::Unset);
1678
1679        {
1680            // Bob creates an Olm session with Alice and encrypts a message to her
1681            let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1682            let mut session = bob_account
1683                .create_outbound_session(
1684                    &alice_device,
1685                    &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1686                    bob_account.device_keys(),
1687                )
1688                .unwrap();
1689            let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1690
1691            let to_device =
1692                EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1693
1694            // Alice decrypts the message
1695            let sync_changes = EncryptionSyncChanges {
1696                to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1697                changed_devices: &Default::default(),
1698                one_time_keys_counts: &Default::default(),
1699                unused_fallback_keys: None,
1700                next_batch_token: None,
1701            };
1702
1703            let decryption_settings =
1704                DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1705
1706            let (decrypted, _) =
1707                machine.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1708
1709            assert_eq!(1, decrypted.len());
1710        }
1711
1712        // Alice shares the room key with Bob
1713        {
1714            let requests = machine
1715                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1716                .await
1717                .unwrap();
1718
1719            // We should have had one to-device event
1720            let event_count: usize = requests
1721                .iter()
1722                .filter(|r| r.event_type == "m.room.encrypted".into())
1723                .map(|r| r.message_count())
1724                .sum();
1725            assert_eq!(event_count, 1);
1726
1727            let response = ToDeviceResponse::new();
1728            for request in requests {
1729                machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1730            }
1731        }
1732
1733        // When Alice shares the room key again, there shouldn't be any
1734        // to-device events, since we already shared with Bob
1735        {
1736            let requests = machine
1737                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1738                .await
1739                .unwrap();
1740
1741            let event_count: usize = requests
1742                .iter()
1743                .filter(|r| r.event_type == "m.room.encrypted".into())
1744                .map(|r| r.message_count())
1745                .sum();
1746            assert_eq!(event_count, 0);
1747        }
1748
1749        // Pretend that Bob wasn't able to decrypt, so he tries to unwedge
1750        {
1751            let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1752            let mut session = bob_account
1753                .create_outbound_session(
1754                    &alice_device,
1755                    &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1756                    bob_account.device_keys(),
1757                )
1758                .unwrap();
1759            let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1760
1761            let to_device =
1762                EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1763
1764            // Alice decrypts the unwedge message
1765            let sync_changes = EncryptionSyncChanges {
1766                to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1767                changed_devices: &Default::default(),
1768                one_time_keys_counts: &Default::default(),
1769                unused_fallback_keys: None,
1770                next_batch_token: None,
1771            };
1772
1773            let decryption_settings =
1774                DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1775
1776            let (decrypted, _) =
1777                machine.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1778
1779            assert_eq!(1, decrypted.len());
1780        }
1781
1782        // When Alice shares the room key again, it should be re-shared with Bob
1783        {
1784            let requests = machine
1785                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1786                .await
1787                .unwrap();
1788
1789            let event_count: usize = requests
1790                .iter()
1791                .filter(|r| r.event_type == "m.room.encrypted".into())
1792                .map(|r| r.message_count())
1793                .sum();
1794            assert_eq!(event_count, 1);
1795
1796            let response = ToDeviceResponse::new();
1797            for request in requests {
1798                machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1799            }
1800        }
1801
1802        // When Alice shares the room key yet again, there shouldn't be any
1803        // to-device events
1804        {
1805            let requests = machine
1806                .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1807                .await
1808                .unwrap();
1809
1810            let event_count: usize = requests
1811                .iter()
1812                .filter(|r| r.event_type == "m.room.encrypted".into())
1813                .map(|r| r.message_count())
1814                .sum();
1815            assert_eq!(event_count, 0);
1816        }
1817    }
1818
1819    #[async_test]
1820    async fn test_room_key_bundle_sharing() {
1821        let (alice, bob) = get_machine_pair_with_setup_sessions_test_helper(
1822            user_id!("@alice:localhost"),
1823            user_id!("@bob:localhost"),
1824            false,
1825        )
1826        .await;
1827
1828        // Alice trusts Bob's device
1829        let device = alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap();
1830        device.set_local_trust(LocalTrust::Verified).await.unwrap();
1831
1832        let content = RoomKeyBundleContent {
1833            room_id: owned_room_id!("!room:id"),
1834            file: (EncryptedFileInit {
1835                url: OwnedMxcUri::from("test"),
1836                key: JsonWebKey::from(JsonWebKeyInit {
1837                    kty: "oct".to_owned(),
1838                    key_ops: vec!["encrypt".to_owned(), "decrypt".to_owned()],
1839                    alg: "A256CTR".to_owned(),
1840                    #[allow(clippy::unnecessary_to_owned)]
1841                    k: Base64::new(vec![0u8; 0]),
1842                    ext: true,
1843                }),
1844                iv: Base64::new(vec![0u8; 0]),
1845                hashes: Default::default(),
1846                v: "".to_owned(),
1847            })
1848            .into(),
1849        };
1850
1851        let requests = alice
1852            .share_room_key_bundle_data(
1853                bob.user_id(),
1854                &CollectStrategy::OnlyTrustedDevices,
1855                content,
1856            )
1857            .await
1858            .unwrap();
1859
1860        // There should be exactly one message
1861        let requests: Vec<_> =
1862            requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).collect();
1863        let message_count: usize = requests.iter().map(|r| r.message_count()).sum();
1864        assert_eq!(message_count, 1);
1865
1866        // Bob decrypts the message
1867        let bob_message = requests[0]
1868            .messages
1869            .get(bob.user_id())
1870            .unwrap()
1871            .get(&(bob.device_id().to_owned().into()))
1872            .unwrap();
1873        let to_device = EncryptedToDeviceEvent::new(
1874            alice.user_id().to_owned(),
1875            bob_message.deserialize_as_unchecked().unwrap(),
1876        );
1877
1878        let sync_changes = EncryptionSyncChanges {
1879            to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1880            changed_devices: &Default::default(),
1881            one_time_keys_counts: &Default::default(),
1882            unused_fallback_keys: None,
1883            next_batch_token: None,
1884        };
1885
1886        let decryption_settings =
1887            DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1888
1889        let (decrypted, _) =
1890            bob.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1891        assert_eq!(1, decrypted.len());
1892        use crate::types::events::EventType;
1893        assert_let!(
1894            ProcessedToDeviceEvent::Decrypted { raw, .. } = decrypted.first().unwrap().clone()
1895        );
1896        assert_eq!(
1897            raw.get_field::<String>("type").unwrap().unwrap(),
1898            RoomKeyBundleContent::EVENT_TYPE,
1899        );
1900    }
1901}