1mod share_strategy;
16
17use std::{
18 collections::{BTreeMap, BTreeSet},
19 fmt::Debug,
20 iter,
21 iter::zip,
22 sync::Arc,
23};
24
25use futures_util::future::join_all;
26use itertools::Itertools;
27use matrix_sdk_common::{
28 deserialized_responses::WithheldCode, executor::spawn, locks::RwLock as StdRwLock,
29};
30#[cfg(feature = "experimental-encrypted-state-events")]
31use ruma::events::AnyStateEventContent;
32use ruma::{
33 events::{AnyMessageLikeEventContent, AnyToDeviceEventContent, ToDeviceEventType},
34 serde::Raw,
35 to_device::DeviceIdOrAllDevices,
36 DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId,
37 UserId,
38};
39use serde::Serialize;
40#[cfg(feature = "experimental-send-custom-to-device")]
41pub(crate) use share_strategy::split_devices_for_share_strategy;
42pub use share_strategy::CollectStrategy;
43pub(crate) use share_strategy::{
44 withheld_code_for_device_for_share_strategy, CollectRecipientsResult,
45};
46use tracing::{debug, error, info, instrument, trace, warn, Instrument};
47
48use crate::{
49 error::{EventError, MegolmResult, OlmResult},
50 identities::device::MaybeEncryptedRoomKey,
51 olm::{
52 InboundGroupSession, OutboundGroupSession, SenderData, SenderDataFinder, Session,
53 ShareInfo, ShareState,
54 },
55 store::{types::Changes, CryptoStoreWrapper, Result as StoreResult, Store},
56 types::{
57 events::{
58 room::encrypted::{RoomEncryptedEventContent, ToDeviceEncryptedEventContent},
59 room_key_bundle::RoomKeyBundleContent,
60 EventType,
61 },
62 requests::ToDeviceRequest,
63 },
64 Device, DeviceData, EncryptionSettings, OlmError,
65};
66
67#[derive(Clone, Debug)]
68pub(crate) struct GroupSessionCache {
69 store: Store,
70 sessions: Arc<StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>>,
71 sessions_being_shared: Arc<StdRwLock<BTreeMap<OwnedTransactionId, OutboundGroupSession>>>,
74}
75
76impl GroupSessionCache {
77 pub(crate) fn new(store: Store) -> Self {
78 Self { store, sessions: Default::default(), sessions_being_shared: Default::default() }
79 }
80
81 pub(crate) fn insert(&self, session: OutboundGroupSession) {
82 self.sessions.write().insert(session.room_id().to_owned(), session);
83 }
84
85 pub async fn get_or_load(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
92 if let Some(s) = self.sessions.read().get(room_id) {
95 return Some(s.clone());
96 }
97
98 match self.store.get_outbound_group_session(room_id).await {
99 Ok(Some(s)) => {
100 {
101 let mut sessions_being_shared = self.sessions_being_shared.write();
102 for request_id in s.pending_request_ids() {
103 sessions_being_shared.insert(request_id, s.clone());
104 }
105 }
106
107 self.sessions.write().insert(room_id.to_owned(), s.clone());
108
109 Some(s)
110 }
111 Ok(None) => None,
112 Err(e) => {
113 error!("Couldn't restore an outbound group session: {e:?}");
114 None
115 }
116 }
117 }
118
119 fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
126 self.sessions.read().get(room_id).cloned()
127 }
128
129 fn has_session_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
131 self.sessions.read().values().any(|s| s.sharing_view().is_withheld_to(device, code))
132 }
133
134 fn remove_from_being_shared(&self, id: &TransactionId) -> Option<OutboundGroupSession> {
135 self.sessions_being_shared.write().remove(id)
136 }
137
138 fn mark_as_being_shared(&self, id: OwnedTransactionId, session: OutboundGroupSession) {
139 self.sessions_being_shared.write().insert(id, session);
140 }
141}
142
143#[derive(Debug, Clone)]
144pub(crate) struct GroupSessionManager {
145 store: Store,
149 sessions: GroupSessionCache,
151}
152
153impl GroupSessionManager {
154 const MAX_TO_DEVICE_MESSAGES: usize = 250;
155
156 pub fn new(store: Store) -> Self {
157 Self { store: store.clone(), sessions: GroupSessionCache::new(store) }
158 }
159
160 pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
161 if let Some(s) = self.sessions.get(room_id) {
162 s.invalidate_session();
163
164 let mut changes = Changes::default();
165 changes.outbound_group_sessions.push(s.clone());
166 self.store.save_changes(changes).await?;
167
168 Ok(true)
169 } else {
170 Ok(false)
171 }
172 }
173
174 pub async fn mark_request_as_sent(&self, request_id: &TransactionId) -> StoreResult<()> {
175 let Some(session) = self.sessions.remove_from_being_shared(request_id) else {
176 return Ok(());
177 };
178
179 let no_olm = session.mark_request_as_sent(request_id);
180
181 let mut changes = Changes::default();
182
183 for (user_id, devices) in &no_olm {
184 for device_id in devices {
185 let device = self.store.get_device(user_id, device_id).await;
186
187 if let Ok(Some(device)) = device {
188 device.mark_withheld_code_as_sent();
189 changes.devices.changed.push(device.inner.clone());
190 } else {
191 error!(
192 ?request_id,
193 "Marking to-device no olm as sent but device not found, might \
194 have been deleted?"
195 );
196 }
197 }
198 }
199
200 changes.outbound_group_sessions.push(session.clone());
201 self.store.save_changes(changes).await
202 }
203
204 #[cfg(test)]
205 pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
206 self.sessions.get(room_id)
207 }
208
209 pub async fn encrypt(
210 &self,
211 room_id: &RoomId,
212 event_type: &str,
213 content: &Raw<AnyMessageLikeEventContent>,
214 ) -> MegolmResult<Raw<RoomEncryptedEventContent>> {
215 let session =
216 self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
217
218 assert!(!session.expired(), "Session expired");
219
220 let content = session.encrypt(event_type, content).await;
221
222 let mut changes = Changes::default();
223 changes.outbound_group_sessions.push(session);
224 self.store.save_changes(changes).await?;
225
226 Ok(content)
227 }
228
229 #[cfg(feature = "experimental-encrypted-state-events")]
252 pub async fn encrypt_state(
253 &self,
254 room_id: &RoomId,
255 event_type: &str,
256 state_key: &str,
257 content: &Raw<AnyStateEventContent>,
258 ) -> MegolmResult<Raw<RoomEncryptedEventContent>> {
259 let session =
260 self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
261
262 assert!(!session.expired(), "Session expired");
263
264 let content = session.encrypt_state(event_type, state_key, content).await;
265
266 let mut changes = Changes::default();
267 changes.outbound_group_sessions.push(session);
268 self.store.save_changes(changes).await?;
269
270 Ok(content)
271 }
272
273 pub async fn create_outbound_group_session(
277 &self,
278 room_id: &RoomId,
279 settings: EncryptionSettings,
280 own_sender_data: SenderData,
281 ) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> {
282 let (outbound, inbound) = self
283 .store
284 .static_account()
285 .create_group_session_pair(room_id, settings, own_sender_data)
286 .await
287 .map_err(|_| EventError::UnsupportedAlgorithm)?;
288
289 self.sessions.insert(outbound.clone());
290 Ok((outbound, inbound))
291 }
292
293 pub async fn get_or_create_outbound_session(
294 &self,
295 room_id: &RoomId,
296 settings: EncryptionSettings,
297 own_sender_data: SenderData,
298 ) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> {
299 let outbound_session = self.sessions.get_or_load(room_id).await;
300
301 if let Some(s) = outbound_session {
304 if s.expired() || s.invalidated() {
305 self.create_outbound_group_session(room_id, settings, own_sender_data)
306 .await
307 .map(|(o, i)| (o, i.into()))
308 } else {
309 Ok((s, None))
310 }
311 } else {
312 self.create_outbound_group_session(room_id, settings, own_sender_data)
313 .await
314 .map(|(o, i)| (o, i.into()))
315 }
316 }
317
318 async fn encrypt_session_for(
325 store: Arc<CryptoStoreWrapper>,
326 group_session: OutboundGroupSession,
327 devices: Vec<DeviceData>,
328 ) -> OlmResult<(
329 EncryptForDevicesResult,
330 BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>,
331 )> {
332 pub struct DeviceResult {
334 device: DeviceData,
335 maybe_encrypted_room_key: MaybeEncryptedRoomKey,
336 }
337
338 let mut result_builder = EncryptForDevicesResultBuilder::default();
339 let mut share_infos = BTreeMap::new();
340
341 let encrypt = |store: Arc<CryptoStoreWrapper>,
344 device: DeviceData,
345 session: OutboundGroupSession| async move {
346 let encryption_result = device.maybe_encrypt_room_key(store.as_ref(), session).await?;
347
348 Ok::<_, OlmError>(DeviceResult { device, maybe_encrypted_room_key: encryption_result })
349 };
350
351 let tasks: Vec<_> = devices
352 .iter()
353 .map(|d| spawn(encrypt(store.clone(), d.clone(), group_session.clone())))
354 .collect();
355
356 let results = join_all(tasks).await;
357
358 for result in results {
359 let result = result.expect("Encryption task panicked")?;
360
361 match result.maybe_encrypted_room_key {
362 MaybeEncryptedRoomKey::Encrypted { used_session, share_info, message } => {
363 result_builder.on_successful_encryption(&result.device, *used_session, message);
364
365 let user_id = result.device.user_id().to_owned();
366 let device_id = result.device.device_id().to_owned();
367 share_infos
368 .entry(user_id)
369 .or_insert_with(BTreeMap::new)
370 .insert(device_id, *share_info);
371 }
372 MaybeEncryptedRoomKey::MissingSession => {
373 result_builder.on_missing_session(result.device);
374 }
375 }
376 }
377
378 Ok((result_builder.into_result(), share_infos))
379 }
380
381 #[instrument(skip_all)]
388 pub async fn collect_session_recipients(
389 &self,
390 users: impl Iterator<Item = &UserId>,
391 settings: &EncryptionSettings,
392 outbound: &OutboundGroupSession,
393 ) -> OlmResult<CollectRecipientsResult> {
394 share_strategy::collect_session_recipients(&self.store, users, settings, outbound).await
395 }
396
397 async fn encrypt_request(
398 store: Arc<CryptoStoreWrapper>,
399 chunk: Vec<DeviceData>,
400 outbound: OutboundGroupSession,
401 sessions: GroupSessionCache,
402 ) -> OlmResult<(Vec<Session>, Vec<(DeviceData, WithheldCode)>)> {
403 let (result, share_infos) =
404 Self::encrypt_session_for(store, outbound.clone(), chunk).await?;
405
406 if let Some(request) = result.to_device_request {
407 let id = request.txn_id.clone();
408 outbound.add_request(id.clone(), request.into(), share_infos);
409 sessions.mark_as_being_shared(id, outbound.clone());
410 }
411
412 Ok((result.updated_olm_sessions, result.no_olm_devices))
413 }
414
415 pub(crate) fn session_cache(&self) -> GroupSessionCache {
416 self.sessions.clone()
417 }
418
419 async fn maybe_rotate_group_session(
420 &self,
421 should_rotate: bool,
422 room_id: &RoomId,
423 outbound: OutboundGroupSession,
424 encryption_settings: EncryptionSettings,
425 changes: &mut Changes,
426 own_device: Option<Device>,
427 ) -> OlmResult<OutboundGroupSession> {
428 Ok(if should_rotate {
429 let old_session_id = outbound.session_id();
430
431 let (outbound, mut inbound) = self
432 .create_outbound_group_session(room_id, encryption_settings, SenderData::unknown())
433 .await?;
434
435 let own_sender_data = if let Some(device) = own_device {
439 SenderDataFinder::find_using_device_data(
440 &self.store,
441 device.inner.clone(),
442 &inbound,
443 )
444 .await?
445 } else {
446 error!("Unable to find our own device!");
447 SenderData::unknown()
448 };
449 inbound.sender_data = own_sender_data;
450
451 changes.outbound_group_sessions.push(outbound.clone());
452 changes.inbound_group_sessions.push(inbound);
453
454 debug!(
455 old_session_id = old_session_id,
456 session_id = outbound.session_id(),
457 "A user or device has left the room since we last sent a \
458 message, or the encryption settings have changed. Rotating the \
459 room key.",
460 );
461
462 outbound
463 } else {
464 outbound
465 })
466 }
467
468 async fn encrypt_for_devices(
469 &self,
470 recipient_devices: Vec<DeviceData>,
471 group_session: &OutboundGroupSession,
472 changes: &mut Changes,
473 ) -> OlmResult<Vec<(DeviceData, WithheldCode)>> {
474 if !recipient_devices.is_empty() {
476 let recipients = recipient_list_to_users_and_devices(&recipient_devices);
477
478 changes.outbound_group_sessions = vec![group_session.clone()];
481
482 let message_index = group_session.message_index().await;
483
484 info!(
485 ?recipients,
486 message_index,
487 room_id = ?group_session.room_id(),
488 session_id = group_session.session_id(),
489 "Trying to encrypt a room key",
490 );
491 }
492
493 let tasks: Vec<_> = recipient_devices
498 .chunks(Self::MAX_TO_DEVICE_MESSAGES)
499 .map(|chunk| {
500 spawn(Self::encrypt_request(
501 self.store.crypto_store(),
502 chunk.to_vec(),
503 group_session.clone(),
504 self.sessions.clone(),
505 ))
506 })
507 .collect();
508
509 let mut withheld_devices = Vec::new();
510
511 for result in join_all(tasks).await {
516 let result = result.expect("Encryption task panicked");
517
518 let (used_sessions, failed_no_olm) = result?;
519
520 changes.sessions.extend(used_sessions);
521 withheld_devices.extend(failed_no_olm);
522 }
523
524 Ok(withheld_devices)
525 }
526
527 fn is_withheld_to(
528 &self,
529 group_session: &OutboundGroupSession,
530 device: &DeviceData,
531 code: &WithheldCode,
532 ) -> bool {
533 if code == &WithheldCode::NoOlm {
551 device.was_withheld_code_sent() || self.sessions.has_session_withheld_to(device, code)
552 } else {
553 group_session.sharing_view().is_withheld_to(device, code)
554 }
555 }
556
557 fn handle_withheld_devices(
558 &self,
559 group_session: &OutboundGroupSession,
560 withheld_devices: Vec<(DeviceData, WithheldCode)>,
561 ) -> OlmResult<()> {
562 let to_content = |code| {
564 let content = group_session.withheld_code(code);
565 Raw::new(&content).expect("We can always serialize a withheld content info").cast()
566 };
567
568 let chunk_to_request = |chunk| {
571 let mut messages = BTreeMap::new();
572 let mut share_infos = BTreeMap::new();
573
574 for (device, code) in chunk {
575 let device: DeviceData = device;
576 let code: WithheldCode = code;
577
578 let user_id = device.user_id().to_owned();
579 let device_id = device.device_id().to_owned();
580
581 let share_info = ShareInfo::new_withheld(code.to_owned());
582 let content = to_content(code);
583
584 messages
585 .entry(user_id.to_owned())
586 .or_insert_with(BTreeMap::new)
587 .insert(DeviceIdOrAllDevices::DeviceId(device_id.to_owned()), content);
588
589 share_infos
590 .entry(user_id)
591 .or_insert_with(BTreeMap::new)
592 .insert(device_id, share_info);
593 }
594
595 let txn_id = TransactionId::new();
596
597 let request = ToDeviceRequest {
598 event_type: ToDeviceEventType::from("m.room_key.withheld"),
599 txn_id,
600 messages,
601 };
602
603 (request, share_infos)
604 };
605
606 let result: Vec<_> = withheld_devices
607 .into_iter()
608 .filter(|(device, code)| !self.is_withheld_to(group_session, device, code))
609 .chunks(Self::MAX_TO_DEVICE_MESSAGES)
610 .into_iter()
611 .map(chunk_to_request)
612 .collect();
613
614 for (request, share_info) in result {
615 if !request.messages.is_empty() {
616 let txn_id = request.txn_id.to_owned();
617 group_session.add_request(txn_id.to_owned(), request.into(), share_info);
618
619 self.sessions.mark_as_being_shared(txn_id, group_session.clone());
620 }
621 }
622
623 Ok(())
624 }
625
626 fn log_room_key_sharing_result(requests: &[Arc<ToDeviceRequest>]) {
627 for request in requests {
628 let message_list = Self::to_device_request_to_log_list(request);
629 info!(
630 request_id = ?request.txn_id,
631 ?message_list,
632 "Created batch of to-device messages of type {}",
633 request.event_type
634 );
635 }
636 }
637
638 fn to_device_request_to_log_list(
642 request: &Arc<ToDeviceRequest>,
643 ) -> Vec<(String, String, String)> {
644 #[derive(serde::Deserialize)]
645 struct ContentStub<'a> {
646 #[serde(borrow, default, rename = "org.matrix.msgid")]
647 message_id: Option<&'a str>,
648 }
649
650 let mut result: Vec<(String, String, String)> = Vec::new();
651
652 for (user_id, device_map) in &request.messages {
653 for (device, content) in device_map {
654 let message_id: Option<&str> = content
655 .deserialize_as_unchecked::<ContentStub<'_>>()
656 .expect("We should be able to deserialize the content we generated")
657 .message_id;
658
659 result.push((
660 message_id.unwrap_or("<undefined>").to_owned(),
661 user_id.to_string(),
662 device.to_string(),
663 ));
664 }
665 }
666 result
667 }
668
669 #[instrument(skip(self, users, encryption_settings), fields(session_id))]
680 pub async fn share_room_key(
681 &self,
682 room_id: &RoomId,
683 users: impl Iterator<Item = &UserId>,
684 encryption_settings: impl Into<EncryptionSettings>,
685 ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
686 trace!("Checking if a room key needs to be shared");
687
688 let account = self.store.static_account();
689 let device = self.store.get_device(account.user_id(), account.device_id()).await?;
690
691 let encryption_settings = encryption_settings.into();
692 let mut changes = Changes::default();
693
694 let (outbound, inbound) = self
696 .get_or_create_outbound_session(
697 room_id,
698 encryption_settings.clone(),
699 SenderData::unknown(),
700 )
701 .await?;
702 tracing::Span::current().record("session_id", outbound.session_id());
703
704 if let Some(mut inbound) = inbound {
707 let own_sender_data = if let Some(device) = &device {
711 SenderDataFinder::find_using_device_data(
712 &self.store,
713 device.inner.clone(),
714 &inbound,
715 )
716 .await?
717 } else {
718 error!("Unable to find our own device!");
719 SenderData::unknown()
720 };
721 inbound.sender_data = own_sender_data;
722
723 changes.outbound_group_sessions.push(outbound.clone());
724 changes.inbound_group_sessions.push(inbound);
725 }
726
727 let CollectRecipientsResult { should_rotate, devices, mut withheld_devices } =
731 self.collect_session_recipients(users, &encryption_settings, &outbound).await?;
732
733 let outbound = self
734 .maybe_rotate_group_session(
735 should_rotate,
736 room_id,
737 outbound,
738 encryption_settings,
739 &mut changes,
740 device,
741 )
742 .await?;
743
744 let devices: Vec<_> = devices
747 .into_iter()
748 .flat_map(|(_, d)| {
749 d.into_iter().filter(|d| match outbound.sharing_view().get_share_state(d) {
750 ShareState::NotShared => true,
751 ShareState::Shared { message_index: _, olm_wedging_index } => {
752 olm_wedging_index < d.olm_wedging_index
759 }
760 _ => false,
761 })
762 })
763 .collect();
764
765 let unable_to_encrypt_devices =
771 self.encrypt_for_devices(devices, &outbound, &mut changes).await?;
772
773 withheld_devices.extend(unable_to_encrypt_devices);
775
776 self.handle_withheld_devices(&outbound, withheld_devices)?;
779
780 let requests = outbound.pending_requests();
784
785 if requests.is_empty() {
786 if !outbound.shared() {
787 debug!("The room key doesn't need to be shared with anyone. Marking as shared.");
788
789 outbound.mark_as_shared();
790 changes.outbound_group_sessions.push(outbound.clone());
791 }
792 } else {
793 Self::log_room_key_sharing_result(&requests)
794 }
795
796 if !changes.is_empty() {
798 let session_count = changes.sessions.len();
799
800 self.store.save_changes(changes).await?;
801
802 trace!(
803 session_count = session_count,
804 "Stored the changed sessions after encrypting an room key"
805 );
806 }
807
808 Ok(requests)
809 }
810
811 #[instrument(skip(self, bundle_data))]
823 pub async fn share_room_key_bundle_data(
824 &self,
825 user_id: &UserId,
826 collect_strategy: &CollectStrategy,
827 bundle_data: RoomKeyBundleContent,
828 ) -> OlmResult<Vec<ToDeviceRequest>> {
829 let collect_strategy = match collect_strategy {
831 CollectStrategy::AllDevices | CollectStrategy::ErrorOnVerifiedUserProblem => {
832 warn!(
833 "Ignoring request to use unsafe sharing strategy {collect_strategy:?} \
834 for room key history sharing",
835 );
836 &CollectStrategy::IdentityBasedStrategy
837 }
838 CollectStrategy::IdentityBasedStrategy | CollectStrategy::OnlyTrustedDevices => {
839 collect_strategy
840 }
841 };
842
843 let mut changes = Changes::default();
844
845 let CollectRecipientsResult { devices, .. } =
846 share_strategy::collect_recipients_for_share_strategy(
847 &self.store,
848 iter::once(user_id),
849 collect_strategy,
850 None,
851 )
852 .await?;
853
854 let devices = devices.into_values().flatten().collect();
855 let event_type = bundle_data.event_type().to_owned();
856 let (requests, _) = self
857 .encrypt_content_for_devices(devices, &event_type, bundle_data, &mut changes)
858 .await?;
859
860 if !changes.is_empty() {
864 let session_count = changes.sessions.len();
865
866 self.store.save_changes(changes).await?;
867
868 trace!(
869 session_count = session_count,
870 "Stored the changed sessions after encrypting an room key"
871 );
872 }
873
874 Ok(requests)
875 }
876
877 pub(crate) async fn encrypt_content_for_devices(
884 &self,
885 recipient_devices: Vec<DeviceData>,
886 event_type: &str,
887 content: impl Serialize + Clone + Send + 'static,
888 changes: &mut Changes,
889 ) -> OlmResult<(Vec<ToDeviceRequest>, Vec<(DeviceData, WithheldCode)>)> {
890 let recipients = recipient_list_to_users_and_devices(&recipient_devices);
891 info!(?recipients, "Encrypting content of type {}", event_type);
892
893 let tasks: Vec<_> = recipient_devices
898 .chunks(Self::MAX_TO_DEVICE_MESSAGES)
899 .map(|chunk| {
900 spawn(
901 encrypt_content_for_devices(
902 self.store.crypto_store(),
903 event_type.to_owned(),
904 content.clone(),
905 chunk.to_vec(),
906 )
907 .in_current_span(),
908 )
909 })
910 .collect();
911
912 let mut no_olm_devices = Vec::new();
913 let mut to_device_requests = Vec::new();
914
915 for result in join_all(tasks).await {
920 let result = result.expect("Encryption task panicked")?;
921 if let Some(request) = result.to_device_request {
922 to_device_requests.push(request);
923 }
924 changes.sessions.extend(result.updated_olm_sessions);
925 no_olm_devices.extend(result.no_olm_devices);
926 }
927
928 Ok((to_device_requests, no_olm_devices))
929 }
930}
931
932async fn encrypt_content_for_devices(
941 store: Arc<CryptoStoreWrapper>,
942 event_type: String,
943 content: impl Serialize + Clone + Send + 'static,
944 devices: Vec<DeviceData>,
945) -> OlmResult<EncryptForDevicesResult> {
946 let mut result_builder = EncryptForDevicesResultBuilder::default();
947
948 async fn encrypt(
949 store: Arc<CryptoStoreWrapper>,
950 device: DeviceData,
951 event_type: String,
952 bundle_data: impl Serialize,
953 ) -> OlmResult<(Session, Raw<ToDeviceEncryptedEventContent>)> {
954 device.encrypt(store.as_ref(), &event_type, bundle_data).await
955 }
956
957 let tasks = devices.iter().map(|device| {
958 spawn(
959 encrypt(store.clone(), device.clone(), event_type.clone(), content.clone())
960 .in_current_span(),
961 )
962 });
963
964 let results = join_all(tasks).await;
965
966 for (device, result) in zip(devices, results) {
967 let encryption_result = result.expect("Encryption task panicked");
968
969 match encryption_result {
970 Ok((used_session, message)) => {
971 result_builder.on_successful_encryption(&device, used_session, message.cast());
972 }
973 Err(OlmError::MissingSession) => {
974 result_builder.on_missing_session(device);
976 }
977 Err(e) => return Err(e),
978 }
979 }
980
981 Ok(result_builder.into_result())
982}
983
984#[derive(Debug)]
987struct EncryptForDevicesResult {
988 to_device_request: Option<ToDeviceRequest>,
991
992 no_olm_devices: Vec<(DeviceData, WithheldCode)>,
994
995 updated_olm_sessions: Vec<Session>,
998}
999
1000#[derive(Debug, Default)]
1002struct EncryptForDevicesResultBuilder {
1003 messages: BTreeMap<OwnedUserId, BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>>>,
1005
1006 no_olm_devices: Vec<(DeviceData, WithheldCode)>,
1008
1009 updated_olm_sessions: Vec<Session>,
1012}
1013
1014impl EncryptForDevicesResultBuilder {
1015 pub fn on_successful_encryption(
1019 &mut self,
1020 device: &DeviceData,
1021 used_session: Session,
1022 message: Raw<AnyToDeviceEventContent>,
1023 ) {
1024 self.updated_olm_sessions.push(used_session);
1025
1026 self.messages
1027 .entry(device.user_id().to_owned())
1028 .or_default()
1029 .insert(DeviceIdOrAllDevices::DeviceId(device.device_id().to_owned()), message);
1030 }
1031
1032 pub fn on_missing_session(&mut self, device: DeviceData) {
1034 self.no_olm_devices.push((device, WithheldCode::NoOlm));
1035 }
1036
1037 pub fn into_result(self) -> EncryptForDevicesResult {
1040 let EncryptForDevicesResultBuilder { updated_olm_sessions, no_olm_devices, messages } =
1041 self;
1042
1043 let mut encrypt_for_devices_result = EncryptForDevicesResult {
1044 to_device_request: None,
1045 updated_olm_sessions,
1046 no_olm_devices,
1047 };
1048
1049 if !messages.is_empty() {
1050 let request = ToDeviceRequest {
1051 event_type: ToDeviceEventType::RoomEncrypted,
1052 txn_id: TransactionId::new(),
1053 messages,
1054 };
1055 trace!(
1056 recipient_count = request.message_count(),
1057 transaction_id = ?request.txn_id,
1058 "Created a to-device request carrying room keys",
1059 );
1060 encrypt_for_devices_result.to_device_request = Some(request);
1061 }
1062
1063 encrypt_for_devices_result
1064 }
1065}
1066
1067fn recipient_list_to_users_and_devices(
1068 recipient_devices: &[DeviceData],
1069) -> BTreeMap<&UserId, BTreeSet<&DeviceId>> {
1070 #[allow(unknown_lints, clippy::unwrap_or_default)] recipient_devices.iter().fold(BTreeMap::new(), |mut acc, d| {
1072 acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id());
1073 acc
1074 })
1075}
1076
1077#[cfg(test)]
1078mod tests {
1079 use std::{
1080 collections::{BTreeMap, BTreeSet},
1081 iter,
1082 ops::Deref,
1083 sync::Arc,
1084 };
1085
1086 use assert_matches2::assert_let;
1087 use matrix_sdk_common::deserialized_responses::{ProcessedToDeviceEvent, WithheldCode};
1088 use matrix_sdk_test::{async_test, ruma_response_from_json};
1089 use ruma::{
1090 api::client::{
1091 keys::{claim_keys, get_keys, upload_keys},
1092 to_device::send_event_to_device::v3::Response as ToDeviceResponse,
1093 },
1094 device_id,
1095 events::room::{
1096 history_visibility::HistoryVisibility, EncryptedFileInit, JsonWebKey, JsonWebKeyInit,
1097 },
1098 owned_room_id, room_id,
1099 serde::Base64,
1100 to_device::DeviceIdOrAllDevices,
1101 user_id, DeviceId, OneTimeKeyAlgorithm, OwnedMxcUri, TransactionId, UInt, UserId,
1102 };
1103 use serde_json::{json, Value};
1104
1105 use crate::{
1106 identities::DeviceData,
1107 machine::{
1108 test_helpers::get_machine_pair_with_setup_sessions_test_helper, EncryptionSyncChanges,
1109 },
1110 olm::{Account, SenderData},
1111 session_manager::{group_sessions::CollectRecipientsResult, CollectStrategy},
1112 types::{
1113 events::{
1114 room::encrypted::EncryptedToDeviceEvent,
1115 room_key_bundle::RoomKeyBundleContent,
1116 room_key_withheld::RoomKeyWithheldContent::{self, MegolmV1AesSha2},
1117 },
1118 requests::ToDeviceRequest,
1119 DeviceKeys, EventEncryptionAlgorithm,
1120 },
1121 DecryptionSettings, EncryptionSettings, LocalTrust, OlmMachine, TrustRequirement,
1122 };
1123
1124 fn alice_id() -> &'static UserId {
1125 user_id!("@alice:example.org")
1126 }
1127
1128 fn alice_device_id() -> &'static DeviceId {
1129 device_id!("JLAFKJWSCS")
1130 }
1131
1132 fn keys_query_response() -> get_keys::v3::Response {
1134 let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_query.json");
1135 let data: Value = serde_json::from_slice(data).unwrap();
1136 ruma_response_from_json(&data)
1137 }
1138
1139 fn bob_keys_query_response() -> get_keys::v3::Response {
1140 let data = json!({
1141 "device_keys": {
1142 "@bob:localhost": {
1143 "BOBDEVICE": {
1144 "user_id": "@bob:localhost",
1145 "device_id": "BOBDEVICE",
1146 "algorithms": [
1147 "m.olm.v1.curve25519-aes-sha2",
1148 "m.megolm.v1.aes-sha2",
1149 "m.megolm.v2.aes-sha2"
1150 ],
1151 "keys": {
1152 "curve25519:BOBDEVICE": "QzXDFZj0Pt5xG4r11XGSrqE4mnFOTgRM5pz7n3tzohU",
1153 "ed25519:BOBDEVICE": "T7QMEXcEo/NfiC/8doVHT+2XnMm0pDpRa27bmE8PlPI"
1154 },
1155 "signatures": {
1156 "@bob:localhost": {
1157 "ed25519:BOBDEVICE": "1Ee9J02KoVf4DKhT+LkurpZJEygiznqpgkT4lqvMTLtZyzShsVTnwmoMPttuGcJkLp9lMK1egveNYCEaYP80Cw"
1158 }
1159 }
1160 }
1161 }
1162 }
1163 });
1164 ruma_response_from_json(&data)
1165 }
1166
1167 fn bob_one_time_key() -> claim_keys::v3::Response {
1170 let data = json!({
1171 "failures": {},
1172 "one_time_keys":{
1173 "@bob:localhost":{
1174 "BOBDEVICE":{
1175 "signed_curve25519:AAAAAAAAAAA": {
1176 "key":"bm1olfbksjC5SwKxCLLK4XaINCA0FwR/155J85gIpCk",
1177 "signatures":{
1178 "@bob:localhost":{
1179 "ed25519:BOBDEVICE":"BKyS/+EV76zdZkWgny2D0svZ0ycS3etfyHCrsDgm7MYe166HqQmSoX29HsjGLvE/5F+Sg2zW7RJileUvquPwDA"
1180 }
1181 }
1182 }
1183 }
1184 }
1185 }
1186 });
1187 ruma_response_from_json(&data)
1188 }
1189
1190 fn keys_claim_response() -> claim_keys::v3::Response {
1193 let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_claim.json");
1194 let data: Value = serde_json::from_slice(data).unwrap();
1195 ruma_response_from_json(&data)
1196 }
1197
1198 async fn machine_with_user_test_helper(user_id: &UserId, device_id: &DeviceId) -> OlmMachine {
1199 let keys_query = keys_query_response();
1200 let txn_id = TransactionId::new();
1201
1202 let machine = OlmMachine::new(user_id, device_id).await;
1203
1204 machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1206 let (txn_id, _keys_claim_request) = machine
1207 .get_missing_sessions(iter::once(user_id!("@example:localhost")))
1208 .await
1209 .unwrap()
1210 .unwrap();
1211 let keys_claim = keys_claim_response();
1212 machine.mark_request_as_sent(&txn_id, &keys_claim).await.unwrap();
1213
1214 machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1216 let (txn_id, _keys_claim_request) = machine
1217 .get_missing_sessions(iter::once(user_id!("@bob:localhost")))
1218 .await
1219 .unwrap()
1220 .unwrap();
1221 machine.mark_request_as_sent(&txn_id, &bob_one_time_key()).await.unwrap();
1222
1223 machine
1224 }
1225
1226 async fn machine() -> OlmMachine {
1227 machine_with_user_test_helper(alice_id(), alice_device_id()).await
1228 }
1229
1230 async fn machine_with_shared_room_key_test_helper() -> OlmMachine {
1231 let machine = machine().await;
1232 let room_id = room_id!("!test:localhost");
1233 let keys_claim = keys_claim_response();
1234
1235 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1236 let requests =
1237 machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
1238
1239 let outbound =
1240 machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1241
1242 assert!(!outbound.pending_requests().is_empty());
1243 assert!(!outbound.shared());
1244
1245 let response = ToDeviceResponse::new();
1246 for request in requests {
1247 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1248 }
1249
1250 assert!(outbound.shared());
1251 assert!(outbound.pending_requests().is_empty());
1252
1253 machine
1254 }
1255
1256 #[async_test]
1257 async fn test_sharing() {
1258 let machine = machine().await;
1259 let room_id = room_id!("!test:localhost");
1260 let keys_claim = keys_claim_response();
1261
1262 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1263
1264 let requests =
1265 machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
1266
1267 let event_count: usize = requests
1268 .iter()
1269 .filter(|r| r.event_type == "m.room.encrypted".into())
1270 .map(|r| r.message_count())
1271 .sum();
1272
1273 assert_eq!(event_count, 148);
1277
1278 let withheld_count: usize = requests
1279 .iter()
1280 .filter(|r| r.event_type == "m.room_key.withheld".into())
1281 .map(|r| r.message_count())
1282 .sum();
1283 assert_eq!(withheld_count, 2);
1284 }
1285
1286 fn count_withheld_from(requests: &[Arc<ToDeviceRequest>], code: WithheldCode) -> usize {
1287 requests
1288 .iter()
1289 .filter(|r| r.event_type == "m.room_key.withheld".into())
1290 .map(|r| {
1291 let mut count = 0;
1292 for message in r.messages.values() {
1294 message.iter().for_each(|(_, content)| {
1295 let withheld: RoomKeyWithheldContent =
1296 content.deserialize_as_unchecked::<RoomKeyWithheldContent>().unwrap();
1297
1298 if let MegolmV1AesSha2(content) = withheld {
1299 if content.withheld_code() == code {
1300 count += 1;
1301 }
1302 }
1303 })
1304 }
1305 count
1306 })
1307 .sum()
1308 }
1309
1310 #[async_test]
1311 async fn test_no_olm_sent_once() {
1312 let machine = machine().await;
1313 let keys_claim = keys_claim_response();
1314
1315 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1316
1317 let first_room_id = room_id!("!test:localhost");
1318
1319 let requests = machine
1320 .share_room_key(first_room_id, users.to_owned(), EncryptionSettings::default())
1321 .await
1322 .unwrap();
1323
1324 let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1326 assert_eq!(withheld_count, 2);
1327
1328 let new_requests = machine
1331 .share_room_key(first_room_id, users, EncryptionSettings::default())
1332 .await
1333 .unwrap();
1334 let withheld_count: usize = count_withheld_from(&new_requests, WithheldCode::NoOlm);
1335 assert_eq!(withheld_count, 2);
1337
1338 let response = ToDeviceResponse::new();
1339 for request in requests {
1340 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1341 }
1342
1343 let second_room_id = room_id!("!other:localhost");
1346 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1347 let requests = machine
1348 .share_room_key(second_room_id, users, EncryptionSettings::default())
1349 .await
1350 .unwrap();
1351
1352 let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1353 assert_eq!(withheld_count, 0);
1354
1355 }
1358
1359 #[async_test]
1360 async fn test_ratcheted_sharing() {
1361 let machine = machine_with_shared_room_key_test_helper().await;
1362
1363 let room_id = room_id!("!test:localhost");
1364 let late_joiner = user_id!("@bob:localhost");
1365 let keys_claim = keys_claim_response();
1366
1367 let mut users: BTreeSet<_> = keys_claim.one_time_keys.keys().map(Deref::deref).collect();
1368 users.insert(late_joiner);
1369
1370 let requests = machine
1371 .share_room_key(room_id, users.into_iter(), EncryptionSettings::default())
1372 .await
1373 .unwrap();
1374
1375 let event_count: usize = requests
1376 .iter()
1377 .filter(|r| r.event_type == "m.room.encrypted".into())
1378 .map(|r| r.message_count())
1379 .sum();
1380 let outbound =
1381 machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1382
1383 assert_eq!(event_count, 1);
1384 assert!(!outbound.pending_requests().is_empty());
1385 }
1386
1387 #[async_test]
1388 async fn test_changing_encryption_settings() {
1389 let machine = machine_with_shared_room_key_test_helper().await;
1390 let room_id = room_id!("!test:localhost");
1391 let keys_claim = keys_claim_response();
1392
1393 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1394 let outbound =
1395 machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1396
1397 let CollectRecipientsResult { should_rotate, .. } = machine
1398 .inner
1399 .group_session_manager
1400 .collect_session_recipients(users.clone(), &EncryptionSettings::default(), &outbound)
1401 .await
1402 .unwrap();
1403
1404 assert!(!should_rotate);
1405
1406 let settings = EncryptionSettings {
1407 history_visibility: HistoryVisibility::Invited,
1408 ..Default::default()
1409 };
1410
1411 let CollectRecipientsResult { should_rotate, .. } = machine
1412 .inner
1413 .group_session_manager
1414 .collect_session_recipients(users.clone(), &settings, &outbound)
1415 .await
1416 .unwrap();
1417
1418 assert!(should_rotate);
1419
1420 let settings = EncryptionSettings {
1421 algorithm: EventEncryptionAlgorithm::from("m.megolm.v2.aes-sha2"),
1422 ..Default::default()
1423 };
1424
1425 let CollectRecipientsResult { should_rotate, .. } = machine
1426 .inner
1427 .group_session_manager
1428 .collect_session_recipients(users, &settings, &outbound)
1429 .await
1430 .unwrap();
1431
1432 assert!(should_rotate);
1433 }
1434
1435 #[async_test]
1436 async fn test_key_recipient_collecting() {
1437 let user_id = user_id!("@example:localhost");
1440 let device_id = device_id!("TESTDEVICE");
1441 let room_id = room_id!("!test:localhost");
1442
1443 let machine = machine_with_user_test_helper(user_id, device_id).await;
1444
1445 let (outbound, _) = machine
1446 .inner
1447 .group_session_manager
1448 .get_or_create_outbound_session(
1449 room_id,
1450 EncryptionSettings::default(),
1451 SenderData::unknown(),
1452 )
1453 .await
1454 .expect("We should be able to create a new session");
1455 let history_visibility = HistoryVisibility::Joined;
1456 let settings = EncryptionSettings { history_visibility, ..Default::default() };
1457
1458 let users = [user_id].into_iter();
1459
1460 let CollectRecipientsResult { devices: recipients, .. } = machine
1461 .inner
1462 .group_session_manager
1463 .collect_session_recipients(users, &settings, &outbound)
1464 .await
1465 .expect("We should be able to collect the session recipients");
1466
1467 assert!(!recipients[user_id].is_empty());
1468
1469 assert!(!recipients[user_id]
1471 .iter()
1472 .any(|d| d.user_id() == user_id && d.device_id() == device_id));
1473
1474 let settings = EncryptionSettings {
1475 sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1476 ..Default::default()
1477 };
1478 let users = [user_id].into_iter();
1479
1480 let CollectRecipientsResult { devices: recipients, .. } = machine
1481 .inner
1482 .group_session_manager
1483 .collect_session_recipients(users, &settings, &outbound)
1484 .await
1485 .expect("We should be able to collect the session recipients");
1486
1487 assert!(recipients[user_id].is_empty());
1488
1489 let device_id = "AFGUOBTZWM".into();
1490 let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1491 device.set_local_trust(LocalTrust::Verified).await.unwrap();
1492 let users = [user_id].into_iter();
1493
1494 let CollectRecipientsResult { devices: recipients, withheld_devices: withheld, .. } =
1495 machine
1496 .inner
1497 .group_session_manager
1498 .collect_session_recipients(users, &settings, &outbound)
1499 .await
1500 .expect("We should be able to collect the session recipients");
1501
1502 assert!(recipients[user_id]
1503 .iter()
1504 .any(|d| d.user_id() == user_id && d.device_id() == device_id));
1505
1506 let devices = machine.get_user_devices(user_id, None).await.unwrap();
1507 devices
1508 .devices()
1509 .filter(|d| d.device_id() != device_id!("TESTDEVICE"))
1511 .for_each(|d| {
1512 if d.is_blacklisted() {
1513 assert!(withheld.iter().any(|(dev, w)| {
1514 dev.device_id() == d.device_id() && w == &WithheldCode::Blacklisted
1515 }));
1516 } else if !d.is_verified() {
1517 assert!(withheld.iter().any(|(dev, w)| {
1519 dev.device_id() == d.device_id() && w == &WithheldCode::Unverified
1520 }));
1521 }
1522 });
1523
1524 assert_eq!(149, withheld.len());
1525 }
1526
1527 #[async_test]
1528 async fn test_sharing_withheld_only_trusted() {
1529 let machine = machine().await;
1530 let room_id = room_id!("!test:localhost");
1531 let keys_claim = keys_claim_response();
1532
1533 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1534 let settings = EncryptionSettings {
1535 sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1536 ..Default::default()
1537 };
1538
1539 let user_id = user_id!("@example:localhost");
1541 let device_id = "MWFXPINOAO".into();
1542 let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1543 device.set_local_trust(LocalTrust::Verified).await.unwrap();
1544 machine
1545 .get_device(user_id, "MWVTUXDNNM".into(), None)
1546 .await
1547 .unwrap()
1548 .unwrap()
1549 .set_local_trust(LocalTrust::BlackListed)
1550 .await
1551 .unwrap();
1552
1553 let requests = machine.share_room_key(room_id, users, settings).await.unwrap();
1554
1555 let room_key_count =
1557 requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).count();
1558
1559 assert_eq!(1, room_key_count);
1560
1561 let withheld_count =
1562 requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1563 assert_eq!(1, withheld_count);
1565
1566 let event_count: usize = requests
1567 .iter()
1568 .filter(|r| r.event_type == "m.room_key.withheld".into())
1569 .map(|r| r.message_count())
1570 .sum();
1571
1572 assert_eq!(event_count, 149);
1574
1575 let has_blacklist =
1577 requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).any(|r| {
1578 let device_key = DeviceIdOrAllDevices::from(device_id!("MWVTUXDNNM").to_owned());
1579 let content = &r.messages[user_id][&device_key];
1580 let withheld: RoomKeyWithheldContent =
1581 content.deserialize_as_unchecked::<RoomKeyWithheldContent>().unwrap();
1582 if let MegolmV1AesSha2(content) = withheld {
1583 content.withheld_code() == WithheldCode::Blacklisted
1584 } else {
1585 false
1586 }
1587 });
1588
1589 assert!(has_blacklist);
1590 }
1591
1592 #[async_test]
1593 async fn test_no_olm_withheld_only_sent_once() {
1594 let keys_query = keys_query_response();
1595 let txn_id = TransactionId::new();
1596
1597 let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1598
1599 machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1600 machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1601
1602 let first_room = room_id!("!test:localhost");
1603 let second_room = room_id!("!test2:localhost");
1604 let bob_id = user_id!("@bob:localhost");
1605
1606 let settings = EncryptionSettings::default();
1607 let users = [bob_id];
1608
1609 let requests = machine
1610 .share_room_key(first_room, users.into_iter(), settings.to_owned())
1611 .await
1612 .unwrap();
1613
1614 let withheld_count =
1616 requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1617
1618 assert_eq!(withheld_count, 1);
1619 assert_eq!(requests.len(), 1);
1620
1621 let second_requests =
1624 machine.share_room_key(second_room, users.into_iter(), settings).await.unwrap();
1625
1626 let withheld_count =
1627 second_requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1628
1629 assert_eq!(withheld_count, 0);
1630 assert_eq!(second_requests.len(), 0);
1631
1632 let response = ToDeviceResponse::new();
1633
1634 let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1635
1636 assert!(!device.was_withheld_code_sent());
1639
1640 for request in requests {
1641 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1642 }
1643
1644 let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1645
1646 assert!(device.was_withheld_code_sent());
1647 }
1648
1649 #[async_test]
1650 async fn test_resend_session_after_unwedging() {
1651 let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1652 assert_let!(Ok(Some((txn_id, device_keys_request))) = machine.upload_device_keys().await);
1653 let device_keys_response = upload_keys::v3::Response::new(BTreeMap::from([(
1654 OneTimeKeyAlgorithm::SignedCurve25519,
1655 UInt::new(device_keys_request.one_time_keys.len() as u64).unwrap(),
1656 )]));
1657 machine.mark_request_as_sent(&txn_id, &device_keys_response).await.unwrap();
1658
1659 let room_id = room_id!("!test:localhost");
1660
1661 let bob_id = user_id!("@bob:localhost");
1662 let bob_account = Account::new(bob_id);
1663 let keys_query_data = json!({
1664 "device_keys": {
1665 "@bob:localhost": {
1666 bob_account.device_id.clone(): bob_account.device_keys()
1667 }
1668 }
1669 });
1670 let keys_query: get_keys::v3::Response = ruma_response_from_json(&keys_query_data);
1671 let txn_id = TransactionId::new();
1672 machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1673
1674 let alice_device_keys =
1675 device_keys_request.device_keys.unwrap().deserialize_as::<DeviceKeys>().unwrap();
1676 let mut alice_otks = device_keys_request.one_time_keys.iter();
1677 let alice_device = DeviceData::new(alice_device_keys, LocalTrust::Unset);
1678
1679 {
1680 let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1682 let mut session = bob_account
1683 .create_outbound_session(
1684 &alice_device,
1685 &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1686 bob_account.device_keys(),
1687 )
1688 .unwrap();
1689 let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1690
1691 let to_device =
1692 EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1693
1694 let sync_changes = EncryptionSyncChanges {
1696 to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1697 changed_devices: &Default::default(),
1698 one_time_keys_counts: &Default::default(),
1699 unused_fallback_keys: None,
1700 next_batch_token: None,
1701 };
1702
1703 let decryption_settings =
1704 DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1705
1706 let (decrypted, _) =
1707 machine.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1708
1709 assert_eq!(1, decrypted.len());
1710 }
1711
1712 {
1714 let requests = machine
1715 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1716 .await
1717 .unwrap();
1718
1719 let event_count: usize = requests
1721 .iter()
1722 .filter(|r| r.event_type == "m.room.encrypted".into())
1723 .map(|r| r.message_count())
1724 .sum();
1725 assert_eq!(event_count, 1);
1726
1727 let response = ToDeviceResponse::new();
1728 for request in requests {
1729 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1730 }
1731 }
1732
1733 {
1736 let requests = machine
1737 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1738 .await
1739 .unwrap();
1740
1741 let event_count: usize = requests
1742 .iter()
1743 .filter(|r| r.event_type == "m.room.encrypted".into())
1744 .map(|r| r.message_count())
1745 .sum();
1746 assert_eq!(event_count, 0);
1747 }
1748
1749 {
1751 let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1752 let mut session = bob_account
1753 .create_outbound_session(
1754 &alice_device,
1755 &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1756 bob_account.device_keys(),
1757 )
1758 .unwrap();
1759 let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1760
1761 let to_device =
1762 EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1763
1764 let sync_changes = EncryptionSyncChanges {
1766 to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1767 changed_devices: &Default::default(),
1768 one_time_keys_counts: &Default::default(),
1769 unused_fallback_keys: None,
1770 next_batch_token: None,
1771 };
1772
1773 let decryption_settings =
1774 DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1775
1776 let (decrypted, _) =
1777 machine.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1778
1779 assert_eq!(1, decrypted.len());
1780 }
1781
1782 {
1784 let requests = machine
1785 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1786 .await
1787 .unwrap();
1788
1789 let event_count: usize = requests
1790 .iter()
1791 .filter(|r| r.event_type == "m.room.encrypted".into())
1792 .map(|r| r.message_count())
1793 .sum();
1794 assert_eq!(event_count, 1);
1795
1796 let response = ToDeviceResponse::new();
1797 for request in requests {
1798 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1799 }
1800 }
1801
1802 {
1805 let requests = machine
1806 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1807 .await
1808 .unwrap();
1809
1810 let event_count: usize = requests
1811 .iter()
1812 .filter(|r| r.event_type == "m.room.encrypted".into())
1813 .map(|r| r.message_count())
1814 .sum();
1815 assert_eq!(event_count, 0);
1816 }
1817 }
1818
1819 #[async_test]
1820 async fn test_room_key_bundle_sharing() {
1821 let (alice, bob) = get_machine_pair_with_setup_sessions_test_helper(
1822 user_id!("@alice:localhost"),
1823 user_id!("@bob:localhost"),
1824 false,
1825 )
1826 .await;
1827
1828 let device = alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap();
1830 device.set_local_trust(LocalTrust::Verified).await.unwrap();
1831
1832 let content = RoomKeyBundleContent {
1833 room_id: owned_room_id!("!room:id"),
1834 file: (EncryptedFileInit {
1835 url: OwnedMxcUri::from("test"),
1836 key: JsonWebKey::from(JsonWebKeyInit {
1837 kty: "oct".to_owned(),
1838 key_ops: vec!["encrypt".to_owned(), "decrypt".to_owned()],
1839 alg: "A256CTR".to_owned(),
1840 #[allow(clippy::unnecessary_to_owned)]
1841 k: Base64::new(vec![0u8; 0]),
1842 ext: true,
1843 }),
1844 iv: Base64::new(vec![0u8; 0]),
1845 hashes: Default::default(),
1846 v: "".to_owned(),
1847 })
1848 .into(),
1849 };
1850
1851 let requests = alice
1852 .share_room_key_bundle_data(
1853 bob.user_id(),
1854 &CollectStrategy::OnlyTrustedDevices,
1855 content,
1856 )
1857 .await
1858 .unwrap();
1859
1860 let requests: Vec<_> =
1862 requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).collect();
1863 let message_count: usize = requests.iter().map(|r| r.message_count()).sum();
1864 assert_eq!(message_count, 1);
1865
1866 let bob_message = requests[0]
1868 .messages
1869 .get(bob.user_id())
1870 .unwrap()
1871 .get(&(bob.device_id().to_owned().into()))
1872 .unwrap();
1873 let to_device = EncryptedToDeviceEvent::new(
1874 alice.user_id().to_owned(),
1875 bob_message.deserialize_as_unchecked().unwrap(),
1876 );
1877
1878 let sync_changes = EncryptionSyncChanges {
1879 to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1880 changed_devices: &Default::default(),
1881 one_time_keys_counts: &Default::default(),
1882 unused_fallback_keys: None,
1883 next_batch_token: None,
1884 };
1885
1886 let decryption_settings =
1887 DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1888
1889 let (decrypted, _) =
1890 bob.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1891 assert_eq!(1, decrypted.len());
1892 use crate::types::events::EventType;
1893 assert_let!(
1894 ProcessedToDeviceEvent::Decrypted { raw, .. } = decrypted.first().unwrap().clone()
1895 );
1896 assert_eq!(
1897 raw.get_field::<String>("type").unwrap().unwrap(),
1898 RoomKeyBundleContent::EVENT_TYPE,
1899 );
1900 }
1901}