1mod share_strategy;
16
17use std::{
18 collections::{BTreeMap, BTreeSet},
19 fmt::Debug,
20 iter,
21 iter::zip,
22 sync::Arc,
23};
24
25use futures_util::future::join_all;
26use itertools::Itertools;
27use matrix_sdk_common::{
28 deserialized_responses::WithheldCode, executor::spawn, locks::RwLock as StdRwLock,
29};
30use ruma::{
31 events::{AnyMessageLikeEventContent, AnyToDeviceEventContent, ToDeviceEventType},
32 serde::Raw,
33 to_device::DeviceIdOrAllDevices,
34 DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId,
35 UserId,
36};
37use serde::Serialize;
38pub(crate) use share_strategy::CollectRecipientsResult;
39pub use share_strategy::CollectStrategy;
40use tracing::{debug, error, info, instrument, trace, warn, Instrument};
41
42use crate::{
43 error::{EventError, MegolmResult, OlmResult},
44 identities::device::MaybeEncryptedRoomKey,
45 olm::{
46 InboundGroupSession, OutboundGroupSession, SenderData, SenderDataFinder, Session,
47 ShareInfo, ShareState,
48 },
49 store::{Changes, CryptoStoreWrapper, Result as StoreResult, Store},
50 types::{
51 events::{
52 room::encrypted::{RoomEncryptedEventContent, ToDeviceEncryptedEventContent},
53 room_key_bundle::RoomKeyBundleContent,
54 EventType,
55 },
56 requests::ToDeviceRequest,
57 },
58 Device, DeviceData, EncryptionSettings, OlmError,
59};
60
61#[derive(Clone, Debug)]
62pub(crate) struct GroupSessionCache {
63 store: Store,
64 sessions: Arc<StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>>,
65 sessions_being_shared: Arc<StdRwLock<BTreeMap<OwnedTransactionId, OutboundGroupSession>>>,
68}
69
70impl GroupSessionCache {
71 pub(crate) fn new(store: Store) -> Self {
72 Self { store, sessions: Default::default(), sessions_being_shared: Default::default() }
73 }
74
75 pub(crate) fn insert(&self, session: OutboundGroupSession) {
76 self.sessions.write().insert(session.room_id().to_owned(), session);
77 }
78
79 pub async fn get_or_load(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
86 if let Some(s) = self.sessions.read().get(room_id) {
89 return Some(s.clone());
90 }
91
92 match self.store.get_outbound_group_session(room_id).await {
93 Ok(Some(s)) => {
94 {
95 let mut sessions_being_shared = self.sessions_being_shared.write();
96 for request_id in s.pending_request_ids() {
97 sessions_being_shared.insert(request_id, s.clone());
98 }
99 }
100
101 self.sessions.write().insert(room_id.to_owned(), s.clone());
102
103 Some(s)
104 }
105 Ok(None) => None,
106 Err(e) => {
107 error!("Couldn't restore an outbound group session: {e:?}");
108 None
109 }
110 }
111 }
112
113 fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
120 self.sessions.read().get(room_id).cloned()
121 }
122
123 fn has_session_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
125 self.sessions.read().values().any(|s| s.sharing_view().is_withheld_to(device, code))
126 }
127
128 fn remove_from_being_shared(&self, id: &TransactionId) -> Option<OutboundGroupSession> {
129 self.sessions_being_shared.write().remove(id)
130 }
131
132 fn mark_as_being_shared(&self, id: OwnedTransactionId, session: OutboundGroupSession) {
133 self.sessions_being_shared.write().insert(id, session);
134 }
135}
136
137#[derive(Debug, Clone)]
138pub(crate) struct GroupSessionManager {
139 store: Store,
143 sessions: GroupSessionCache,
145}
146
147impl GroupSessionManager {
148 const MAX_TO_DEVICE_MESSAGES: usize = 250;
149
150 pub fn new(store: Store) -> Self {
151 Self { store: store.clone(), sessions: GroupSessionCache::new(store) }
152 }
153
154 pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
155 if let Some(s) = self.sessions.get(room_id) {
156 s.invalidate_session();
157
158 let mut changes = Changes::default();
159 changes.outbound_group_sessions.push(s.clone());
160 self.store.save_changes(changes).await?;
161
162 Ok(true)
163 } else {
164 Ok(false)
165 }
166 }
167
168 pub async fn mark_request_as_sent(&self, request_id: &TransactionId) -> StoreResult<()> {
169 let Some(session) = self.sessions.remove_from_being_shared(request_id) else {
170 return Ok(());
171 };
172
173 let no_olm = session.mark_request_as_sent(request_id);
174
175 let mut changes = Changes::default();
176
177 for (user_id, devices) in &no_olm {
178 for device_id in devices {
179 let device = self.store.get_device(user_id, device_id).await;
180
181 if let Ok(Some(device)) = device {
182 device.mark_withheld_code_as_sent();
183 changes.devices.changed.push(device.inner.clone());
184 } else {
185 error!(
186 ?request_id,
187 "Marking to-device no olm as sent but device not found, might \
188 have been deleted?"
189 );
190 }
191 }
192 }
193
194 changes.outbound_group_sessions.push(session.clone());
195 self.store.save_changes(changes).await
196 }
197
198 #[cfg(test)]
199 pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
200 self.sessions.get(room_id)
201 }
202
203 pub async fn encrypt(
204 &self,
205 room_id: &RoomId,
206 event_type: &str,
207 content: &Raw<AnyMessageLikeEventContent>,
208 ) -> MegolmResult<Raw<RoomEncryptedEventContent>> {
209 let session =
210 self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
211
212 assert!(!session.expired(), "Session expired");
213
214 let content = session.encrypt(event_type, content).await;
215
216 let mut changes = Changes::default();
217 changes.outbound_group_sessions.push(session);
218 self.store.save_changes(changes).await?;
219
220 Ok(content)
221 }
222
223 pub async fn create_outbound_group_session(
227 &self,
228 room_id: &RoomId,
229 settings: EncryptionSettings,
230 own_sender_data: SenderData,
231 ) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> {
232 let (outbound, inbound) = self
233 .store
234 .static_account()
235 .create_group_session_pair(room_id, settings, own_sender_data)
236 .await
237 .map_err(|_| EventError::UnsupportedAlgorithm)?;
238
239 self.sessions.insert(outbound.clone());
240 Ok((outbound, inbound))
241 }
242
243 pub async fn get_or_create_outbound_session(
244 &self,
245 room_id: &RoomId,
246 settings: EncryptionSettings,
247 own_sender_data: SenderData,
248 ) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> {
249 let outbound_session = self.sessions.get_or_load(room_id).await;
250
251 if let Some(s) = outbound_session {
254 if s.expired() || s.invalidated() {
255 self.create_outbound_group_session(room_id, settings, own_sender_data)
256 .await
257 .map(|(o, i)| (o, i.into()))
258 } else {
259 Ok((s, None))
260 }
261 } else {
262 self.create_outbound_group_session(room_id, settings, own_sender_data)
263 .await
264 .map(|(o, i)| (o, i.into()))
265 }
266 }
267
268 async fn encrypt_session_for(
275 store: Arc<CryptoStoreWrapper>,
276 group_session: OutboundGroupSession,
277 devices: Vec<DeviceData>,
278 ) -> OlmResult<(
279 EncryptForDevicesResult,
280 BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>,
281 )> {
282 pub struct DeviceResult {
284 device: DeviceData,
285 maybe_encrypted_room_key: MaybeEncryptedRoomKey,
286 }
287
288 let mut result_builder = EncryptForDevicesResultBuilder::default();
289 let mut share_infos = BTreeMap::new();
290
291 let encrypt = |store: Arc<CryptoStoreWrapper>,
294 device: DeviceData,
295 session: OutboundGroupSession| async move {
296 let encryption_result = device.maybe_encrypt_room_key(store.as_ref(), session).await?;
297
298 Ok::<_, OlmError>(DeviceResult { device, maybe_encrypted_room_key: encryption_result })
299 };
300
301 let tasks: Vec<_> = devices
302 .iter()
303 .map(|d| spawn(encrypt(store.clone(), d.clone(), group_session.clone())))
304 .collect();
305
306 let results = join_all(tasks).await;
307
308 for result in results {
309 let result = result.expect("Encryption task panicked")?;
310
311 match result.maybe_encrypted_room_key {
312 MaybeEncryptedRoomKey::Encrypted { used_session, share_info, message } => {
313 result_builder.on_successful_encryption(&result.device, *used_session, message);
314
315 let user_id = result.device.user_id().to_owned();
316 let device_id = result.device.device_id().to_owned();
317 share_infos
318 .entry(user_id)
319 .or_insert_with(BTreeMap::new)
320 .insert(device_id, *share_info);
321 }
322 MaybeEncryptedRoomKey::MissingSession => {
323 result_builder.on_missing_session(result.device);
324 }
325 }
326 }
327
328 Ok((result_builder.into_result(), share_infos))
329 }
330
331 #[instrument(skip_all)]
338 pub async fn collect_session_recipients(
339 &self,
340 users: impl Iterator<Item = &UserId>,
341 settings: &EncryptionSettings,
342 outbound: &OutboundGroupSession,
343 ) -> OlmResult<CollectRecipientsResult> {
344 share_strategy::collect_session_recipients(&self.store, users, settings, outbound).await
345 }
346
347 async fn encrypt_request(
348 store: Arc<CryptoStoreWrapper>,
349 chunk: Vec<DeviceData>,
350 outbound: OutboundGroupSession,
351 sessions: GroupSessionCache,
352 ) -> OlmResult<(Vec<Session>, Vec<(DeviceData, WithheldCode)>)> {
353 let (result, share_infos) =
354 Self::encrypt_session_for(store, outbound.clone(), chunk).await?;
355
356 if let Some(request) = result.to_device_request {
357 let id = request.txn_id.clone();
358 outbound.add_request(id.clone(), request.into(), share_infos);
359 sessions.mark_as_being_shared(id, outbound.clone());
360 }
361
362 Ok((result.updated_olm_sessions, result.no_olm_devices))
363 }
364
365 pub(crate) fn session_cache(&self) -> GroupSessionCache {
366 self.sessions.clone()
367 }
368
369 async fn maybe_rotate_group_session(
370 &self,
371 should_rotate: bool,
372 room_id: &RoomId,
373 outbound: OutboundGroupSession,
374 encryption_settings: EncryptionSettings,
375 changes: &mut Changes,
376 own_device: Option<Device>,
377 ) -> OlmResult<OutboundGroupSession> {
378 Ok(if should_rotate {
379 let old_session_id = outbound.session_id();
380
381 let (outbound, mut inbound) = self
382 .create_outbound_group_session(room_id, encryption_settings, SenderData::unknown())
383 .await?;
384
385 let own_sender_data = if let Some(device) = own_device {
389 SenderDataFinder::find_using_device_data(
390 &self.store,
391 device.inner.clone(),
392 &inbound,
393 )
394 .await?
395 } else {
396 error!("Unable to find our own device!");
397 SenderData::unknown()
398 };
399 inbound.sender_data = own_sender_data;
400
401 changes.outbound_group_sessions.push(outbound.clone());
402 changes.inbound_group_sessions.push(inbound);
403
404 debug!(
405 old_session_id = old_session_id,
406 session_id = outbound.session_id(),
407 "A user or device has left the room since we last sent a \
408 message, or the encryption settings have changed. Rotating the \
409 room key.",
410 );
411
412 outbound
413 } else {
414 outbound
415 })
416 }
417
418 async fn encrypt_for_devices(
419 &self,
420 recipient_devices: Vec<DeviceData>,
421 group_session: &OutboundGroupSession,
422 changes: &mut Changes,
423 ) -> OlmResult<Vec<(DeviceData, WithheldCode)>> {
424 if !recipient_devices.is_empty() {
426 let recipients = recipient_list_to_users_and_devices(&recipient_devices);
427
428 changes.outbound_group_sessions = vec![group_session.clone()];
431
432 let message_index = group_session.message_index().await;
433
434 info!(
435 ?recipients,
436 message_index,
437 room_id = ?group_session.room_id(),
438 session_id = group_session.session_id(),
439 "Trying to encrypt a room key",
440 );
441 }
442
443 let tasks: Vec<_> = recipient_devices
448 .chunks(Self::MAX_TO_DEVICE_MESSAGES)
449 .map(|chunk| {
450 spawn(Self::encrypt_request(
451 self.store.crypto_store(),
452 chunk.to_vec(),
453 group_session.clone(),
454 self.sessions.clone(),
455 ))
456 })
457 .collect();
458
459 let mut withheld_devices = Vec::new();
460
461 for result in join_all(tasks).await {
466 let result = result.expect("Encryption task panicked");
467
468 let (used_sessions, failed_no_olm) = result?;
469
470 changes.sessions.extend(used_sessions);
471 withheld_devices.extend(failed_no_olm);
472 }
473
474 Ok(withheld_devices)
475 }
476
477 fn is_withheld_to(
478 &self,
479 group_session: &OutboundGroupSession,
480 device: &DeviceData,
481 code: &WithheldCode,
482 ) -> bool {
483 if code == &WithheldCode::NoOlm {
501 device.was_withheld_code_sent() || self.sessions.has_session_withheld_to(device, code)
502 } else {
503 group_session.sharing_view().is_withheld_to(device, code)
504 }
505 }
506
507 fn handle_withheld_devices(
508 &self,
509 group_session: &OutboundGroupSession,
510 withheld_devices: Vec<(DeviceData, WithheldCode)>,
511 ) -> OlmResult<()> {
512 let to_content = |code| {
514 let content = group_session.withheld_code(code);
515 Raw::new(&content).expect("We can always serialize a withheld content info").cast()
516 };
517
518 let chunk_to_request = |chunk| {
521 let mut messages = BTreeMap::new();
522 let mut share_infos = BTreeMap::new();
523
524 for (device, code) in chunk {
525 let device: DeviceData = device;
526 let code: WithheldCode = code;
527
528 let user_id = device.user_id().to_owned();
529 let device_id = device.device_id().to_owned();
530
531 let share_info = ShareInfo::new_withheld(code.to_owned());
532 let content = to_content(code);
533
534 messages
535 .entry(user_id.to_owned())
536 .or_insert_with(BTreeMap::new)
537 .insert(DeviceIdOrAllDevices::DeviceId(device_id.to_owned()), content);
538
539 share_infos
540 .entry(user_id)
541 .or_insert_with(BTreeMap::new)
542 .insert(device_id, share_info);
543 }
544
545 let txn_id = TransactionId::new();
546
547 let request = ToDeviceRequest {
548 event_type: ToDeviceEventType::from("m.room_key.withheld"),
549 txn_id,
550 messages,
551 };
552
553 (request, share_infos)
554 };
555
556 let result: Vec<_> = withheld_devices
557 .into_iter()
558 .filter(|(device, code)| !self.is_withheld_to(group_session, device, code))
559 .chunks(Self::MAX_TO_DEVICE_MESSAGES)
560 .into_iter()
561 .map(chunk_to_request)
562 .collect();
563
564 for (request, share_info) in result {
565 if !request.messages.is_empty() {
566 let txn_id = request.txn_id.to_owned();
567 group_session.add_request(txn_id.to_owned(), request.into(), share_info);
568
569 self.sessions.mark_as_being_shared(txn_id, group_session.clone());
570 }
571 }
572
573 Ok(())
574 }
575
576 fn log_room_key_sharing_result(requests: &[Arc<ToDeviceRequest>]) {
577 for request in requests {
578 let message_list = Self::to_device_request_to_log_list(request);
579 info!(
580 request_id = ?request.txn_id,
581 ?message_list,
582 "Created batch of to-device messages of type {}",
583 request.event_type
584 );
585 }
586 }
587
588 fn to_device_request_to_log_list(
592 request: &Arc<ToDeviceRequest>,
593 ) -> Vec<(String, String, String)> {
594 #[derive(serde::Deserialize)]
595 struct ContentStub<'a> {
596 #[serde(borrow, default, rename = "org.matrix.msgid")]
597 message_id: Option<&'a str>,
598 }
599
600 let mut result: Vec<(String, String, String)> = Vec::new();
601
602 for (user_id, device_map) in &request.messages {
603 for (device, content) in device_map {
604 let message_id: Option<&str> = content
605 .deserialize_as::<ContentStub<'_>>()
606 .expect("We should be able to deserialize the content we generated")
607 .message_id;
608
609 result.push((
610 message_id.unwrap_or("<undefined>").to_owned(),
611 user_id.to_string(),
612 device.to_string(),
613 ));
614 }
615 }
616 result
617 }
618
619 #[instrument(skip(self, users, encryption_settings), fields(session_id))]
630 pub async fn share_room_key(
631 &self,
632 room_id: &RoomId,
633 users: impl Iterator<Item = &UserId>,
634 encryption_settings: impl Into<EncryptionSettings>,
635 ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
636 trace!("Checking if a room key needs to be shared");
637
638 let account = self.store.static_account();
639 let device = self.store.get_device(account.user_id(), account.device_id()).await?;
640
641 let encryption_settings = encryption_settings.into();
642 let mut changes = Changes::default();
643
644 let (outbound, inbound) = self
646 .get_or_create_outbound_session(
647 room_id,
648 encryption_settings.clone(),
649 SenderData::unknown(),
650 )
651 .await?;
652 tracing::Span::current().record("session_id", outbound.session_id());
653
654 if let Some(mut inbound) = inbound {
657 let own_sender_data = if let Some(device) = &device {
661 SenderDataFinder::find_using_device_data(
662 &self.store,
663 device.inner.clone(),
664 &inbound,
665 )
666 .await?
667 } else {
668 error!("Unable to find our own device!");
669 SenderData::unknown()
670 };
671 inbound.sender_data = own_sender_data;
672
673 changes.outbound_group_sessions.push(outbound.clone());
674 changes.inbound_group_sessions.push(inbound);
675 }
676
677 let CollectRecipientsResult { should_rotate, devices, mut withheld_devices } =
681 self.collect_session_recipients(users, &encryption_settings, &outbound).await?;
682
683 let outbound = self
684 .maybe_rotate_group_session(
685 should_rotate,
686 room_id,
687 outbound,
688 encryption_settings,
689 &mut changes,
690 device,
691 )
692 .await?;
693
694 let devices: Vec<_> = devices
697 .into_iter()
698 .flat_map(|(_, d)| {
699 d.into_iter().filter(|d| match outbound.sharing_view().get_share_state(d) {
700 ShareState::NotShared => true,
701 ShareState::Shared { message_index: _, olm_wedging_index } => {
702 olm_wedging_index < d.olm_wedging_index
709 }
710 _ => false,
711 })
712 })
713 .collect();
714
715 let unable_to_encrypt_devices =
721 self.encrypt_for_devices(devices, &outbound, &mut changes).await?;
722
723 withheld_devices.extend(unable_to_encrypt_devices);
725
726 self.handle_withheld_devices(&outbound, withheld_devices)?;
729
730 let requests = outbound.pending_requests();
734
735 if requests.is_empty() {
736 if !outbound.shared() {
737 debug!("The room key doesn't need to be shared with anyone. Marking as shared.");
738
739 outbound.mark_as_shared();
740 changes.outbound_group_sessions.push(outbound.clone());
741 }
742 } else {
743 Self::log_room_key_sharing_result(&requests)
744 }
745
746 if !changes.is_empty() {
748 let session_count = changes.sessions.len();
749
750 self.store.save_changes(changes).await?;
751
752 trace!(
753 session_count = session_count,
754 "Stored the changed sessions after encrypting an room key"
755 );
756 }
757
758 Ok(requests)
759 }
760
761 #[instrument(skip(self, bundle_data))]
773 pub async fn share_room_key_bundle_data(
774 &self,
775 user_id: &UserId,
776 collect_strategy: &CollectStrategy,
777 bundle_data: RoomKeyBundleContent,
778 ) -> OlmResult<Vec<ToDeviceRequest>> {
779 let collect_strategy = match collect_strategy {
781 CollectStrategy::AllDevices | CollectStrategy::ErrorOnVerifiedUserProblem => {
782 warn!(
783 "Ignoring request to use unsafe sharing strategy {collect_strategy:?} \
784 for room key history sharing",
785 );
786 &CollectStrategy::IdentityBasedStrategy
787 }
788 CollectStrategy::IdentityBasedStrategy | CollectStrategy::OnlyTrustedDevices => {
789 collect_strategy
790 }
791 };
792
793 let mut changes = Changes::default();
794
795 let CollectRecipientsResult { devices, .. } =
796 share_strategy::collect_recipients_for_share_strategy(
797 &self.store,
798 iter::once(user_id),
799 collect_strategy,
800 None,
801 )
802 .await?;
803
804 let devices = devices.into_values().flatten().collect();
805 let event_type = bundle_data.event_type().to_owned();
806 let (requests, _) = self
807 .encrypt_content_for_devices(devices, &event_type, bundle_data, &mut changes)
808 .await?;
809
810 if !changes.is_empty() {
814 let session_count = changes.sessions.len();
815
816 self.store.save_changes(changes).await?;
817
818 trace!(
819 session_count = session_count,
820 "Stored the changed sessions after encrypting an room key"
821 );
822 }
823
824 Ok(requests)
825 }
826
827 pub(crate) async fn encrypt_content_for_devices(
834 &self,
835 recipient_devices: Vec<DeviceData>,
836 event_type: &str,
837 content: impl Serialize + Clone + Send + 'static,
838 changes: &mut Changes,
839 ) -> OlmResult<(Vec<ToDeviceRequest>, Vec<(DeviceData, WithheldCode)>)> {
840 let recipients = recipient_list_to_users_and_devices(&recipient_devices);
841 info!(?recipients, "Encrypting content of type {}", event_type);
842
843 let tasks: Vec<_> = recipient_devices
848 .chunks(Self::MAX_TO_DEVICE_MESSAGES)
849 .map(|chunk| {
850 spawn(
851 encrypt_content_for_devices(
852 self.store.crypto_store(),
853 event_type.to_owned(),
854 content.clone(),
855 chunk.to_vec(),
856 )
857 .in_current_span(),
858 )
859 })
860 .collect();
861
862 let mut no_olm_devices = Vec::new();
863 let mut to_device_requests = Vec::new();
864
865 for result in join_all(tasks).await {
870 let result = result.expect("Encryption task panicked")?;
871 if let Some(request) = result.to_device_request {
872 to_device_requests.push(request);
873 }
874 changes.sessions.extend(result.updated_olm_sessions);
875 no_olm_devices.extend(result.no_olm_devices);
876 }
877
878 Ok((to_device_requests, no_olm_devices))
879 }
880}
881
882async fn encrypt_content_for_devices(
891 store: Arc<CryptoStoreWrapper>,
892 event_type: String,
893 content: impl Serialize + Clone + Send + 'static,
894 devices: Vec<DeviceData>,
895) -> OlmResult<EncryptForDevicesResult> {
896 let mut result_builder = EncryptForDevicesResultBuilder::default();
897
898 async fn encrypt(
899 store: Arc<CryptoStoreWrapper>,
900 device: DeviceData,
901 event_type: String,
902 bundle_data: impl Serialize,
903 ) -> OlmResult<(Session, Raw<ToDeviceEncryptedEventContent>)> {
904 device.encrypt(store.as_ref(), &event_type, bundle_data).await
905 }
906
907 let tasks = devices.iter().map(|device| {
908 spawn(
909 encrypt(store.clone(), device.clone(), event_type.clone(), content.clone())
910 .in_current_span(),
911 )
912 });
913
914 let results = join_all(tasks).await;
915
916 for (device, result) in zip(devices, results) {
917 let encryption_result = result.expect("Encryption task panicked");
918
919 match encryption_result {
920 Ok((used_session, message)) => {
921 result_builder.on_successful_encryption(&device, used_session, message.cast());
922 }
923 Err(OlmError::MissingSession) => {
924 result_builder.on_missing_session(device);
926 }
927 Err(e) => return Err(e),
928 }
929 }
930
931 Ok(result_builder.into_result())
932}
933
934#[derive(Debug)]
937struct EncryptForDevicesResult {
938 to_device_request: Option<ToDeviceRequest>,
941
942 no_olm_devices: Vec<(DeviceData, WithheldCode)>,
944
945 updated_olm_sessions: Vec<Session>,
948}
949
950#[derive(Debug, Default)]
952struct EncryptForDevicesResultBuilder {
953 messages: BTreeMap<OwnedUserId, BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>>>,
955
956 no_olm_devices: Vec<(DeviceData, WithheldCode)>,
958
959 updated_olm_sessions: Vec<Session>,
962}
963
964impl EncryptForDevicesResultBuilder {
965 pub fn on_successful_encryption(
969 &mut self,
970 device: &DeviceData,
971 used_session: Session,
972 message: Raw<AnyToDeviceEventContent>,
973 ) {
974 self.updated_olm_sessions.push(used_session);
975
976 self.messages
977 .entry(device.user_id().to_owned())
978 .or_default()
979 .insert(DeviceIdOrAllDevices::DeviceId(device.device_id().to_owned()), message);
980 }
981
982 pub fn on_missing_session(&mut self, device: DeviceData) {
984 self.no_olm_devices.push((device, WithheldCode::NoOlm));
985 }
986
987 pub fn into_result(self) -> EncryptForDevicesResult {
990 let EncryptForDevicesResultBuilder { updated_olm_sessions, no_olm_devices, messages } =
991 self;
992
993 let mut encrypt_for_devices_result = EncryptForDevicesResult {
994 to_device_request: None,
995 updated_olm_sessions,
996 no_olm_devices,
997 };
998
999 if !messages.is_empty() {
1000 let request = ToDeviceRequest {
1001 event_type: ToDeviceEventType::RoomEncrypted,
1002 txn_id: TransactionId::new(),
1003 messages,
1004 };
1005 trace!(
1006 recipient_count = request.message_count(),
1007 transaction_id = ?request.txn_id,
1008 "Created a to-device request carrying room keys",
1009 );
1010 encrypt_for_devices_result.to_device_request = Some(request);
1011 };
1012
1013 encrypt_for_devices_result
1014 }
1015}
1016
1017fn recipient_list_to_users_and_devices(
1018 recipient_devices: &[DeviceData],
1019) -> BTreeMap<&UserId, BTreeSet<&DeviceId>> {
1020 #[allow(unknown_lints, clippy::unwrap_or_default)] recipient_devices.iter().fold(BTreeMap::new(), |mut acc, d| {
1022 acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id());
1023 acc
1024 })
1025}
1026
1027#[cfg(test)]
1028mod tests {
1029 use std::{
1030 collections::{BTreeMap, BTreeSet},
1031 iter,
1032 ops::Deref,
1033 sync::Arc,
1034 };
1035
1036 use assert_matches2::assert_let;
1037 use matrix_sdk_common::deserialized_responses::WithheldCode;
1038 use matrix_sdk_test::{async_test, ruma_response_from_json};
1039 use ruma::{
1040 api::client::{
1041 keys::{claim_keys, get_keys, upload_keys},
1042 to_device::send_event_to_device::v3::Response as ToDeviceResponse,
1043 },
1044 device_id,
1045 events::room::{
1046 history_visibility::HistoryVisibility, EncryptedFileInit, JsonWebKey, JsonWebKeyInit,
1047 },
1048 owned_room_id, room_id,
1049 serde::Base64,
1050 to_device::DeviceIdOrAllDevices,
1051 user_id, DeviceId, OneTimeKeyAlgorithm, OwnedMxcUri, TransactionId, UInt, UserId,
1052 };
1053 use serde_json::{json, Value};
1054
1055 use crate::{
1056 identities::DeviceData,
1057 machine::{
1058 test_helpers::get_machine_pair_with_setup_sessions_test_helper, EncryptionSyncChanges,
1059 },
1060 olm::{Account, SenderData},
1061 session_manager::{group_sessions::CollectRecipientsResult, CollectStrategy},
1062 types::{
1063 events::{
1064 room::encrypted::EncryptedToDeviceEvent,
1065 room_key_bundle::RoomKeyBundleContent,
1066 room_key_withheld::RoomKeyWithheldContent::{self, MegolmV1AesSha2},
1067 },
1068 requests::ToDeviceRequest,
1069 DeviceKeys, EventEncryptionAlgorithm, ProcessedToDeviceEvent,
1070 },
1071 EncryptionSettings, LocalTrust, OlmMachine,
1072 };
1073
1074 fn alice_id() -> &'static UserId {
1075 user_id!("@alice:example.org")
1076 }
1077
1078 fn alice_device_id() -> &'static DeviceId {
1079 device_id!("JLAFKJWSCS")
1080 }
1081
1082 fn keys_query_response() -> get_keys::v3::Response {
1084 let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_query.json");
1085 let data: Value = serde_json::from_slice(data).unwrap();
1086 ruma_response_from_json(&data)
1087 }
1088
1089 fn bob_keys_query_response() -> get_keys::v3::Response {
1090 let data = json!({
1091 "device_keys": {
1092 "@bob:localhost": {
1093 "BOBDEVICE": {
1094 "user_id": "@bob:localhost",
1095 "device_id": "BOBDEVICE",
1096 "algorithms": [
1097 "m.olm.v1.curve25519-aes-sha2",
1098 "m.megolm.v1.aes-sha2",
1099 "m.megolm.v2.aes-sha2"
1100 ],
1101 "keys": {
1102 "curve25519:BOBDEVICE": "QzXDFZj0Pt5xG4r11XGSrqE4mnFOTgRM5pz7n3tzohU",
1103 "ed25519:BOBDEVICE": "T7QMEXcEo/NfiC/8doVHT+2XnMm0pDpRa27bmE8PlPI"
1104 },
1105 "signatures": {
1106 "@bob:localhost": {
1107 "ed25519:BOBDEVICE": "1Ee9J02KoVf4DKhT+LkurpZJEygiznqpgkT4lqvMTLtZyzShsVTnwmoMPttuGcJkLp9lMK1egveNYCEaYP80Cw"
1108 }
1109 }
1110 }
1111 }
1112 }
1113 });
1114 ruma_response_from_json(&data)
1115 }
1116
1117 fn bob_one_time_key() -> claim_keys::v3::Response {
1120 let data = json!({
1121 "failures": {},
1122 "one_time_keys":{
1123 "@bob:localhost":{
1124 "BOBDEVICE":{
1125 "signed_curve25519:AAAAAAAAAAA": {
1126 "key":"bm1olfbksjC5SwKxCLLK4XaINCA0FwR/155J85gIpCk",
1127 "signatures":{
1128 "@bob:localhost":{
1129 "ed25519:BOBDEVICE":"BKyS/+EV76zdZkWgny2D0svZ0ycS3etfyHCrsDgm7MYe166HqQmSoX29HsjGLvE/5F+Sg2zW7RJileUvquPwDA"
1130 }
1131 }
1132 }
1133 }
1134 }
1135 }
1136 });
1137 ruma_response_from_json(&data)
1138 }
1139
1140 fn keys_claim_response() -> claim_keys::v3::Response {
1143 let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_claim.json");
1144 let data: Value = serde_json::from_slice(data).unwrap();
1145 ruma_response_from_json(&data)
1146 }
1147
1148 async fn machine_with_user_test_helper(user_id: &UserId, device_id: &DeviceId) -> OlmMachine {
1149 let keys_query = keys_query_response();
1150 let txn_id = TransactionId::new();
1151
1152 let machine = OlmMachine::new(user_id, device_id).await;
1153
1154 machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1156 let (txn_id, _keys_claim_request) = machine
1157 .get_missing_sessions(iter::once(user_id!("@example:localhost")))
1158 .await
1159 .unwrap()
1160 .unwrap();
1161 let keys_claim = keys_claim_response();
1162 machine.mark_request_as_sent(&txn_id, &keys_claim).await.unwrap();
1163
1164 machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1166 let (txn_id, _keys_claim_request) = machine
1167 .get_missing_sessions(iter::once(user_id!("@bob:localhost")))
1168 .await
1169 .unwrap()
1170 .unwrap();
1171 machine.mark_request_as_sent(&txn_id, &bob_one_time_key()).await.unwrap();
1172
1173 machine
1174 }
1175
1176 async fn machine() -> OlmMachine {
1177 machine_with_user_test_helper(alice_id(), alice_device_id()).await
1178 }
1179
1180 async fn machine_with_shared_room_key_test_helper() -> OlmMachine {
1181 let machine = machine().await;
1182 let room_id = room_id!("!test:localhost");
1183 let keys_claim = keys_claim_response();
1184
1185 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1186 let requests =
1187 machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
1188
1189 let outbound =
1190 machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1191
1192 assert!(!outbound.pending_requests().is_empty());
1193 assert!(!outbound.shared());
1194
1195 let response = ToDeviceResponse::new();
1196 for request in requests {
1197 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1198 }
1199
1200 assert!(outbound.shared());
1201 assert!(outbound.pending_requests().is_empty());
1202
1203 machine
1204 }
1205
1206 #[async_test]
1207 async fn test_sharing() {
1208 let machine = machine().await;
1209 let room_id = room_id!("!test:localhost");
1210 let keys_claim = keys_claim_response();
1211
1212 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1213
1214 let requests =
1215 machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
1216
1217 let event_count: usize = requests
1218 .iter()
1219 .filter(|r| r.event_type == "m.room.encrypted".into())
1220 .map(|r| r.message_count())
1221 .sum();
1222
1223 assert_eq!(event_count, 148);
1227
1228 let withheld_count: usize = requests
1229 .iter()
1230 .filter(|r| r.event_type == "m.room_key.withheld".into())
1231 .map(|r| r.message_count())
1232 .sum();
1233 assert_eq!(withheld_count, 2);
1234 }
1235
1236 fn count_withheld_from(requests: &[Arc<ToDeviceRequest>], code: WithheldCode) -> usize {
1237 requests
1238 .iter()
1239 .filter(|r| r.event_type == "m.room_key.withheld".into())
1240 .map(|r| {
1241 let mut count = 0;
1242 for message in r.messages.values() {
1244 message.iter().for_each(|(_, content)| {
1245 let withheld: RoomKeyWithheldContent =
1246 content.deserialize_as::<RoomKeyWithheldContent>().unwrap();
1247
1248 if let MegolmV1AesSha2(content) = withheld {
1249 if content.withheld_code() == code {
1250 count += 1;
1251 }
1252 }
1253 })
1254 }
1255 count
1256 })
1257 .sum()
1258 }
1259
1260 #[async_test]
1261 async fn test_no_olm_sent_once() {
1262 let machine = machine().await;
1263 let keys_claim = keys_claim_response();
1264
1265 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1266
1267 let first_room_id = room_id!("!test:localhost");
1268
1269 let requests = machine
1270 .share_room_key(first_room_id, users.to_owned(), EncryptionSettings::default())
1271 .await
1272 .unwrap();
1273
1274 let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1276 assert_eq!(withheld_count, 2);
1277
1278 let new_requests = machine
1281 .share_room_key(first_room_id, users, EncryptionSettings::default())
1282 .await
1283 .unwrap();
1284 let withheld_count: usize = count_withheld_from(&new_requests, WithheldCode::NoOlm);
1285 assert_eq!(withheld_count, 2);
1287
1288 let response = ToDeviceResponse::new();
1289 for request in requests {
1290 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1291 }
1292
1293 let second_room_id = room_id!("!other:localhost");
1296 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1297 let requests = machine
1298 .share_room_key(second_room_id, users, EncryptionSettings::default())
1299 .await
1300 .unwrap();
1301
1302 let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1303 assert_eq!(withheld_count, 0);
1304
1305 }
1308
1309 #[async_test]
1310 async fn test_ratcheted_sharing() {
1311 let machine = machine_with_shared_room_key_test_helper().await;
1312
1313 let room_id = room_id!("!test:localhost");
1314 let late_joiner = user_id!("@bob:localhost");
1315 let keys_claim = keys_claim_response();
1316
1317 let mut users: BTreeSet<_> = keys_claim.one_time_keys.keys().map(Deref::deref).collect();
1318 users.insert(late_joiner);
1319
1320 let requests = machine
1321 .share_room_key(room_id, users.into_iter(), EncryptionSettings::default())
1322 .await
1323 .unwrap();
1324
1325 let event_count: usize = requests
1326 .iter()
1327 .filter(|r| r.event_type == "m.room.encrypted".into())
1328 .map(|r| r.message_count())
1329 .sum();
1330 let outbound =
1331 machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1332
1333 assert_eq!(event_count, 1);
1334 assert!(!outbound.pending_requests().is_empty());
1335 }
1336
1337 #[async_test]
1338 async fn test_changing_encryption_settings() {
1339 let machine = machine_with_shared_room_key_test_helper().await;
1340 let room_id = room_id!("!test:localhost");
1341 let keys_claim = keys_claim_response();
1342
1343 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1344 let outbound =
1345 machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1346
1347 let CollectRecipientsResult { should_rotate, .. } = machine
1348 .inner
1349 .group_session_manager
1350 .collect_session_recipients(users.clone(), &EncryptionSettings::default(), &outbound)
1351 .await
1352 .unwrap();
1353
1354 assert!(!should_rotate);
1355
1356 let settings = EncryptionSettings {
1357 history_visibility: HistoryVisibility::Invited,
1358 ..Default::default()
1359 };
1360
1361 let CollectRecipientsResult { should_rotate, .. } = machine
1362 .inner
1363 .group_session_manager
1364 .collect_session_recipients(users.clone(), &settings, &outbound)
1365 .await
1366 .unwrap();
1367
1368 assert!(should_rotate);
1369
1370 let settings = EncryptionSettings {
1371 algorithm: EventEncryptionAlgorithm::from("m.megolm.v2.aes-sha2"),
1372 ..Default::default()
1373 };
1374
1375 let CollectRecipientsResult { should_rotate, .. } = machine
1376 .inner
1377 .group_session_manager
1378 .collect_session_recipients(users, &settings, &outbound)
1379 .await
1380 .unwrap();
1381
1382 assert!(should_rotate);
1383 }
1384
1385 #[async_test]
1386 async fn test_key_recipient_collecting() {
1387 let user_id = user_id!("@example:localhost");
1390 let device_id = device_id!("TESTDEVICE");
1391 let room_id = room_id!("!test:localhost");
1392
1393 let machine = machine_with_user_test_helper(user_id, device_id).await;
1394
1395 let (outbound, _) = machine
1396 .inner
1397 .group_session_manager
1398 .get_or_create_outbound_session(
1399 room_id,
1400 EncryptionSettings::default(),
1401 SenderData::unknown(),
1402 )
1403 .await
1404 .expect("We should be able to create a new session");
1405 let history_visibility = HistoryVisibility::Joined;
1406 let settings = EncryptionSettings { history_visibility, ..Default::default() };
1407
1408 let users = [user_id].into_iter();
1409
1410 let CollectRecipientsResult { devices: recipients, .. } = machine
1411 .inner
1412 .group_session_manager
1413 .collect_session_recipients(users, &settings, &outbound)
1414 .await
1415 .expect("We should be able to collect the session recipients");
1416
1417 assert!(!recipients[user_id].is_empty());
1418
1419 assert!(!recipients[user_id]
1421 .iter()
1422 .any(|d| d.user_id() == user_id && d.device_id() == device_id));
1423
1424 let settings = EncryptionSettings {
1425 sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1426 ..Default::default()
1427 };
1428 let users = [user_id].into_iter();
1429
1430 let CollectRecipientsResult { devices: recipients, .. } = machine
1431 .inner
1432 .group_session_manager
1433 .collect_session_recipients(users, &settings, &outbound)
1434 .await
1435 .expect("We should be able to collect the session recipients");
1436
1437 assert!(recipients[user_id].is_empty());
1438
1439 let device_id = "AFGUOBTZWM".into();
1440 let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1441 device.set_local_trust(LocalTrust::Verified).await.unwrap();
1442 let users = [user_id].into_iter();
1443
1444 let CollectRecipientsResult { devices: recipients, withheld_devices: withheld, .. } =
1445 machine
1446 .inner
1447 .group_session_manager
1448 .collect_session_recipients(users, &settings, &outbound)
1449 .await
1450 .expect("We should be able to collect the session recipients");
1451
1452 assert!(recipients[user_id]
1453 .iter()
1454 .any(|d| d.user_id() == user_id && d.device_id() == device_id));
1455
1456 let devices = machine.get_user_devices(user_id, None).await.unwrap();
1457 devices
1458 .devices()
1459 .filter(|d| d.device_id() != device_id!("TESTDEVICE"))
1461 .for_each(|d| {
1462 if d.is_blacklisted() {
1463 assert!(withheld.iter().any(|(dev, w)| {
1464 dev.device_id() == d.device_id() && w == &WithheldCode::Blacklisted
1465 }));
1466 } else if !d.is_verified() {
1467 assert!(withheld.iter().any(|(dev, w)| {
1469 dev.device_id() == d.device_id() && w == &WithheldCode::Unverified
1470 }));
1471 }
1472 });
1473
1474 assert_eq!(149, withheld.len());
1475 }
1476
1477 #[async_test]
1478 async fn test_sharing_withheld_only_trusted() {
1479 let machine = machine().await;
1480 let room_id = room_id!("!test:localhost");
1481 let keys_claim = keys_claim_response();
1482
1483 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1484 let settings = EncryptionSettings {
1485 sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1486 ..Default::default()
1487 };
1488
1489 let user_id = user_id!("@example:localhost");
1491 let device_id = "MWFXPINOAO".into();
1492 let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1493 device.set_local_trust(LocalTrust::Verified).await.unwrap();
1494 machine
1495 .get_device(user_id, "MWVTUXDNNM".into(), None)
1496 .await
1497 .unwrap()
1498 .unwrap()
1499 .set_local_trust(LocalTrust::BlackListed)
1500 .await
1501 .unwrap();
1502
1503 let requests = machine.share_room_key(room_id, users, settings).await.unwrap();
1504
1505 let room_key_count =
1507 requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).count();
1508
1509 assert_eq!(1, room_key_count);
1510
1511 let withheld_count =
1512 requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1513 assert_eq!(1, withheld_count);
1515
1516 let event_count: usize = requests
1517 .iter()
1518 .filter(|r| r.event_type == "m.room_key.withheld".into())
1519 .map(|r| r.message_count())
1520 .sum();
1521
1522 assert_eq!(event_count, 149);
1524
1525 let has_blacklist =
1527 requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).any(|r| {
1528 let device_key = DeviceIdOrAllDevices::from(device_id!("MWVTUXDNNM").to_owned());
1529 let content = &r.messages[user_id][&device_key];
1530 let withheld: RoomKeyWithheldContent =
1531 content.deserialize_as::<RoomKeyWithheldContent>().unwrap();
1532 if let MegolmV1AesSha2(content) = withheld {
1533 content.withheld_code() == WithheldCode::Blacklisted
1534 } else {
1535 false
1536 }
1537 });
1538
1539 assert!(has_blacklist);
1540 }
1541
1542 #[async_test]
1543 async fn test_no_olm_withheld_only_sent_once() {
1544 let keys_query = keys_query_response();
1545 let txn_id = TransactionId::new();
1546
1547 let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1548
1549 machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1550 machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1551
1552 let first_room = room_id!("!test:localhost");
1553 let second_room = room_id!("!test2:localhost");
1554 let bob_id = user_id!("@bob:localhost");
1555
1556 let settings = EncryptionSettings::default();
1557 let users = [bob_id];
1558
1559 let requests = machine
1560 .share_room_key(first_room, users.into_iter(), settings.to_owned())
1561 .await
1562 .unwrap();
1563
1564 let withheld_count =
1566 requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1567
1568 assert_eq!(withheld_count, 1);
1569 assert_eq!(requests.len(), 1);
1570
1571 let second_requests =
1574 machine.share_room_key(second_room, users.into_iter(), settings).await.unwrap();
1575
1576 let withheld_count =
1577 second_requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1578
1579 assert_eq!(withheld_count, 0);
1580 assert_eq!(second_requests.len(), 0);
1581
1582 let response = ToDeviceResponse::new();
1583
1584 let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1585
1586 assert!(!device.was_withheld_code_sent());
1589
1590 for request in requests {
1591 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1592 }
1593
1594 let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1595
1596 assert!(device.was_withheld_code_sent());
1597 }
1598
1599 #[async_test]
1600 async fn test_resend_session_after_unwedging() {
1601 let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1602 assert_let!(Ok(Some((txn_id, device_keys_request))) = machine.upload_device_keys().await);
1603 let device_keys_response = upload_keys::v3::Response::new(BTreeMap::from([(
1604 OneTimeKeyAlgorithm::SignedCurve25519,
1605 UInt::new(device_keys_request.one_time_keys.len() as u64).unwrap(),
1606 )]));
1607 machine.mark_request_as_sent(&txn_id, &device_keys_response).await.unwrap();
1608
1609 let room_id = room_id!("!test:localhost");
1610
1611 let bob_id = user_id!("@bob:localhost");
1612 let bob_account = Account::new(bob_id);
1613 let keys_query_data = json!({
1614 "device_keys": {
1615 "@bob:localhost": {
1616 bob_account.device_id.clone(): bob_account.device_keys()
1617 }
1618 }
1619 });
1620 let keys_query: get_keys::v3::Response = ruma_response_from_json(&keys_query_data);
1621 let txn_id = TransactionId::new();
1622 machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1623
1624 let alice_device_keys =
1625 device_keys_request.device_keys.unwrap().deserialize_as::<DeviceKeys>().unwrap();
1626 let mut alice_otks = device_keys_request.one_time_keys.iter();
1627 let alice_device = DeviceData::new(alice_device_keys, LocalTrust::Unset);
1628
1629 {
1630 let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1632 let mut session = bob_account
1633 .create_outbound_session(
1634 &alice_device,
1635 &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1636 bob_account.device_keys(),
1637 )
1638 .unwrap();
1639 let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1640
1641 let to_device =
1642 EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1643
1644 let sync_changes = EncryptionSyncChanges {
1646 to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1647 changed_devices: &Default::default(),
1648 one_time_keys_counts: &Default::default(),
1649 unused_fallback_keys: None,
1650 next_batch_token: None,
1651 };
1652 let (decrypted, _) = machine.receive_sync_changes(sync_changes).await.unwrap();
1653
1654 assert_eq!(1, decrypted.len());
1655 }
1656
1657 {
1659 let requests = machine
1660 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1661 .await
1662 .unwrap();
1663
1664 let event_count: usize = requests
1666 .iter()
1667 .filter(|r| r.event_type == "m.room.encrypted".into())
1668 .map(|r| r.message_count())
1669 .sum();
1670 assert_eq!(event_count, 1);
1671
1672 let response = ToDeviceResponse::new();
1673 for request in requests {
1674 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1675 }
1676 }
1677
1678 {
1681 let requests = machine
1682 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1683 .await
1684 .unwrap();
1685
1686 let event_count: usize = requests
1687 .iter()
1688 .filter(|r| r.event_type == "m.room.encrypted".into())
1689 .map(|r| r.message_count())
1690 .sum();
1691 assert_eq!(event_count, 0);
1692 }
1693
1694 {
1696 let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1697 let mut session = bob_account
1698 .create_outbound_session(
1699 &alice_device,
1700 &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1701 bob_account.device_keys(),
1702 )
1703 .unwrap();
1704 let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1705
1706 let to_device =
1707 EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1708
1709 let sync_changes = EncryptionSyncChanges {
1711 to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1712 changed_devices: &Default::default(),
1713 one_time_keys_counts: &Default::default(),
1714 unused_fallback_keys: None,
1715 next_batch_token: None,
1716 };
1717 let (decrypted, _) = machine.receive_sync_changes(sync_changes).await.unwrap();
1718
1719 assert_eq!(1, decrypted.len());
1720 }
1721
1722 {
1724 let requests = machine
1725 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1726 .await
1727 .unwrap();
1728
1729 let event_count: usize = requests
1730 .iter()
1731 .filter(|r| r.event_type == "m.room.encrypted".into())
1732 .map(|r| r.message_count())
1733 .sum();
1734 assert_eq!(event_count, 1);
1735
1736 let response = ToDeviceResponse::new();
1737 for request in requests {
1738 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1739 }
1740 }
1741
1742 {
1745 let requests = machine
1746 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1747 .await
1748 .unwrap();
1749
1750 let event_count: usize = requests
1751 .iter()
1752 .filter(|r| r.event_type == "m.room.encrypted".into())
1753 .map(|r| r.message_count())
1754 .sum();
1755 assert_eq!(event_count, 0);
1756 }
1757 }
1758
1759 #[async_test]
1760 async fn test_room_key_bundle_sharing() {
1761 let (alice, bob) = get_machine_pair_with_setup_sessions_test_helper(
1762 user_id!("@alice:localhost"),
1763 user_id!("@bob:localhost"),
1764 false,
1765 )
1766 .await;
1767
1768 let device = alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap();
1770 device.set_local_trust(LocalTrust::Verified).await.unwrap();
1771
1772 let content = RoomKeyBundleContent {
1773 room_id: owned_room_id!("!room:id"),
1774 file: (EncryptedFileInit {
1775 url: OwnedMxcUri::from("test"),
1776 key: JsonWebKey::from(JsonWebKeyInit {
1777 kty: "oct".to_owned(),
1778 key_ops: vec!["encrypt".to_owned(), "decrypt".to_owned()],
1779 alg: "A256CTR".to_owned(),
1780 #[allow(clippy::unnecessary_to_owned)]
1781 k: Base64::new(vec![0u8; 0]),
1782 ext: true,
1783 }),
1784 iv: Base64::new(vec![0u8; 0]),
1785 hashes: Default::default(),
1786 v: "".to_owned(),
1787 })
1788 .into(),
1789 };
1790
1791 let requests = alice
1792 .share_room_key_bundle_data(
1793 bob.user_id(),
1794 &CollectStrategy::OnlyTrustedDevices,
1795 content,
1796 )
1797 .await
1798 .unwrap();
1799
1800 let requests: Vec<_> =
1802 requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).collect();
1803 let message_count: usize = requests.iter().map(|r| r.message_count()).sum();
1804 assert_eq!(message_count, 1);
1805
1806 let bob_message = requests[0]
1808 .messages
1809 .get(bob.user_id())
1810 .unwrap()
1811 .get(&(bob.device_id().to_owned().into()))
1812 .unwrap();
1813 let to_device = EncryptedToDeviceEvent::new(
1814 alice.user_id().to_owned(),
1815 bob_message.cast_ref().deserialize().unwrap(),
1816 );
1817
1818 let sync_changes = EncryptionSyncChanges {
1819 to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1820 changed_devices: &Default::default(),
1821 one_time_keys_counts: &Default::default(),
1822 unused_fallback_keys: None,
1823 next_batch_token: None,
1824 };
1825 let (decrypted, _) = bob.receive_sync_changes(sync_changes).await.unwrap();
1826 assert_eq!(1, decrypted.len());
1827 use crate::types::events::EventType;
1828 assert_let!(
1829 ProcessedToDeviceEvent::Decrypted { raw, .. } = decrypted.first().unwrap().clone()
1830 );
1831 assert_eq!(
1832 raw.get_field::<String>("type").unwrap().unwrap(),
1833 RoomKeyBundleContent::EVENT_TYPE,
1834 );
1835 }
1836}