1mod share_strategy;
16
17use std::{
18 collections::{BTreeMap, BTreeSet},
19 fmt::Debug,
20 iter,
21 iter::zip,
22 sync::Arc,
23};
24
25use futures_util::future::join_all;
26use itertools::Itertools;
27use matrix_sdk_common::{
28 deserialized_responses::WithheldCode, executor::spawn, locks::RwLock as StdRwLock,
29};
30#[cfg(feature = "experimental-encrypted-state-events")]
31use ruma::events::AnyStateEventContent;
32use ruma::{
33 DeviceId, OwnedDeviceId, OwnedRoomId, OwnedTransactionId, OwnedUserId, RoomId, TransactionId,
34 UserId,
35 events::{AnyMessageLikeEventContent, AnyToDeviceEventContent, ToDeviceEventType},
36 serde::Raw,
37 to_device::DeviceIdOrAllDevices,
38};
39use serde::Serialize;
40pub use share_strategy::CollectStrategy;
41#[cfg(feature = "experimental-send-custom-to-device")]
42pub(crate) use share_strategy::split_devices_for_share_strategy;
43pub(crate) use share_strategy::{
44 CollectRecipientsResult, withheld_code_for_device_for_share_strategy,
45};
46use tracing::{Instrument, debug, error, info, instrument, trace, warn};
47
48#[cfg(feature = "experimental-encrypted-state-events")]
49use crate::types::events::room::encrypted::RoomEncryptedEventContent;
50use crate::{
51 Device, DeviceData, EncryptionSettings, OlmError,
52 error::{EventError, MegolmResult, OlmResult},
53 identities::device::MaybeEncryptedRoomKey,
54 olm::{
55 InboundGroupSession, OutboundGroupSession, OutboundGroupSessionEncryptionResult,
56 SenderData, SenderDataFinder, Session, ShareInfo, ShareState,
57 },
58 store::{CryptoStoreWrapper, Result as StoreResult, Store, types::Changes},
59 types::{
60 events::{
61 EventType, room::encrypted::ToDeviceEncryptedEventContent,
62 room_key_bundle::RoomKeyBundleContent,
63 },
64 requests::ToDeviceRequest,
65 },
66};
67
68#[derive(Clone, Debug)]
69pub(crate) struct GroupSessionCache {
70 store: Store,
71 sessions: Arc<StdRwLock<BTreeMap<OwnedRoomId, OutboundGroupSession>>>,
72 sessions_being_shared: Arc<StdRwLock<BTreeMap<OwnedTransactionId, OutboundGroupSession>>>,
75}
76
77impl GroupSessionCache {
78 pub(crate) fn new(store: Store) -> Self {
79 Self { store, sessions: Default::default(), sessions_being_shared: Default::default() }
80 }
81
82 pub(crate) fn insert(&self, session: OutboundGroupSession) {
83 self.sessions.write().insert(session.room_id().to_owned(), session);
84 }
85
86 pub async fn get_or_load(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
93 if let Some(s) = self.sessions.read().get(room_id) {
96 return Some(s.clone());
97 }
98
99 match self.store.get_outbound_group_session(room_id).await {
100 Ok(Some(s)) => {
101 {
102 let mut sessions_being_shared = self.sessions_being_shared.write();
103 for request_id in s.pending_request_ids() {
104 sessions_being_shared.insert(request_id, s.clone());
105 }
106 }
107
108 self.sessions.write().insert(room_id.to_owned(), s.clone());
109
110 Some(s)
111 }
112 Ok(None) => None,
113 Err(e) => {
114 error!("Couldn't restore an outbound group session: {e:?}");
115 None
116 }
117 }
118 }
119
120 #[cfg(test)]
127 fn get(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
128 self.sessions.read().get(room_id).cloned()
129 }
130
131 fn has_session_withheld_to(&self, device: &DeviceData, code: &WithheldCode) -> bool {
133 self.sessions.read().values().any(|s| s.sharing_view().is_withheld_to(device, code))
134 }
135
136 fn remove_from_being_shared(&self, id: &TransactionId) -> Option<OutboundGroupSession> {
137 self.sessions_being_shared.write().remove(id)
138 }
139
140 fn mark_as_being_shared(&self, id: OwnedTransactionId, session: OutboundGroupSession) {
141 self.sessions_being_shared.write().insert(id, session);
142 }
143}
144
145#[derive(Debug, Clone)]
146pub(crate) struct GroupSessionManager {
147 store: Store,
151 sessions: GroupSessionCache,
153}
154
155impl GroupSessionManager {
156 const MAX_TO_DEVICE_MESSAGES: usize = 250;
157
158 pub fn new(store: Store) -> Self {
159 Self { store: store.clone(), sessions: GroupSessionCache::new(store) }
160 }
161
162 pub async fn invalidate_group_session(&self, room_id: &RoomId) -> StoreResult<bool> {
163 if let Some(s) = self.sessions.get_or_load(room_id).await {
164 s.invalidate_session();
165
166 let mut changes = Changes::default();
167 changes.outbound_group_sessions.push(s.clone());
168 self.store.save_changes(changes).await?;
169
170 Ok(true)
171 } else {
172 Ok(false)
173 }
174 }
175
176 pub async fn mark_request_as_sent(&self, request_id: &TransactionId) -> StoreResult<()> {
177 let Some(session) = self.sessions.remove_from_being_shared(request_id) else {
178 return Ok(());
179 };
180
181 let no_olm = session.mark_request_as_sent(request_id);
182
183 let mut changes = Changes::default();
184
185 for (user_id, devices) in &no_olm {
186 for device_id in devices {
187 let device = self.store.get_device(user_id, device_id).await;
188
189 if let Ok(Some(device)) = device {
190 device.mark_withheld_code_as_sent();
191 changes.devices.changed.push(device.inner.clone());
192 } else {
193 error!(
194 ?request_id,
195 "Marking to-device no olm as sent but device not found, might \
196 have been deleted?"
197 );
198 }
199 }
200 }
201
202 changes.outbound_group_sessions.push(session.clone());
203 self.store.save_changes(changes).await
204 }
205
206 #[cfg(test)]
207 pub fn get_outbound_group_session(&self, room_id: &RoomId) -> Option<OutboundGroupSession> {
208 self.sessions.get(room_id)
209 }
210
211 pub async fn encrypt(
212 &self,
213 room_id: &RoomId,
214 event_type: &str,
215 content: &Raw<AnyMessageLikeEventContent>,
216 ) -> MegolmResult<OutboundGroupSessionEncryptionResult> {
217 let session =
218 self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
219
220 assert!(!session.expired(), "Session expired");
221
222 let result = session.encrypt(event_type, content).await;
223
224 let mut changes = Changes::default();
225 changes.outbound_group_sessions.push(session);
226 self.store.save_changes(changes).await?;
227
228 Ok(result)
229 }
230
231 #[cfg(feature = "experimental-encrypted-state-events")]
254 pub async fn encrypt_state(
255 &self,
256 room_id: &RoomId,
257 event_type: &str,
258 state_key: &str,
259 content: &Raw<AnyStateEventContent>,
260 ) -> MegolmResult<Raw<RoomEncryptedEventContent>> {
261 let session =
262 self.sessions.get_or_load(room_id).await.expect("Session wasn't created nor shared");
263
264 assert!(!session.expired(), "Session expired");
265
266 let content = session.encrypt_state(event_type, state_key, content).await;
267
268 let mut changes = Changes::default();
269 changes.outbound_group_sessions.push(session);
270 self.store.save_changes(changes).await?;
271
272 Ok(content)
273 }
274
275 pub async fn create_outbound_group_session(
279 &self,
280 room_id: &RoomId,
281 settings: EncryptionSettings,
282 own_sender_data: SenderData,
283 ) -> OlmResult<(OutboundGroupSession, InboundGroupSession)> {
284 let (outbound, inbound) = self
285 .store
286 .static_account()
287 .create_group_session_pair(room_id, settings, own_sender_data)
288 .await
289 .map_err(|_| EventError::UnsupportedAlgorithm)?;
290
291 self.sessions.insert(outbound.clone());
292 Ok((outbound, inbound))
293 }
294
295 pub async fn get_or_create_outbound_session(
296 &self,
297 room_id: &RoomId,
298 settings: EncryptionSettings,
299 own_sender_data: SenderData,
300 ) -> OlmResult<(OutboundGroupSession, Option<InboundGroupSession>)> {
301 let outbound_session = self.sessions.get_or_load(room_id).await;
302
303 if let Some(s) = outbound_session {
306 if s.expired() || s.invalidated() {
307 self.create_outbound_group_session(room_id, settings, own_sender_data)
308 .await
309 .map(|(o, i)| (o, i.into()))
310 } else {
311 Ok((s, None))
312 }
313 } else {
314 self.create_outbound_group_session(room_id, settings, own_sender_data)
315 .await
316 .map(|(o, i)| (o, i.into()))
317 }
318 }
319
320 async fn encrypt_session_for(
327 store: Arc<CryptoStoreWrapper>,
328 group_session: OutboundGroupSession,
329 devices: Vec<DeviceData>,
330 ) -> OlmResult<(
331 EncryptForDevicesResult,
332 BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, ShareInfo>>,
333 )> {
334 pub struct DeviceResult {
336 device: DeviceData,
337 maybe_encrypted_room_key: MaybeEncryptedRoomKey,
338 }
339
340 let mut result_builder = EncryptForDevicesResultBuilder::default();
341 let mut share_infos = BTreeMap::new();
342
343 let encrypt = |store: Arc<CryptoStoreWrapper>,
346 device: DeviceData,
347 session: OutboundGroupSession| async move {
348 let encryption_result = device.maybe_encrypt_room_key(store.as_ref(), session).await?;
349
350 Ok::<_, OlmError>(DeviceResult { device, maybe_encrypted_room_key: encryption_result })
351 };
352
353 let tasks: Vec<_> = devices
354 .iter()
355 .map(|d| spawn(encrypt(store.clone(), d.clone(), group_session.clone())))
356 .collect();
357
358 let results = join_all(tasks).await;
359
360 for result in results {
361 let result = result.expect("Encryption task panicked")?;
362
363 match result.maybe_encrypted_room_key {
364 MaybeEncryptedRoomKey::Encrypted { used_session, share_info, message } => {
365 result_builder.on_successful_encryption(&result.device, *used_session, message);
366
367 let user_id = result.device.user_id().to_owned();
368 let device_id = result.device.device_id().to_owned();
369 share_infos
370 .entry(user_id)
371 .or_insert_with(BTreeMap::new)
372 .insert(device_id, *share_info);
373 }
374 MaybeEncryptedRoomKey::MissingSession => {
375 result_builder.on_missing_session(result.device);
376 }
377 }
378 }
379
380 Ok((result_builder.into_result(), share_infos))
381 }
382
383 #[instrument(skip_all)]
390 pub async fn collect_session_recipients(
391 &self,
392 users: impl Iterator<Item = &UserId>,
393 settings: &EncryptionSettings,
394 outbound: &OutboundGroupSession,
395 ) -> OlmResult<CollectRecipientsResult> {
396 share_strategy::collect_session_recipients(&self.store, users, settings, outbound).await
397 }
398
399 async fn encrypt_request(
400 store: Arc<CryptoStoreWrapper>,
401 chunk: Vec<DeviceData>,
402 outbound: OutboundGroupSession,
403 sessions: GroupSessionCache,
404 ) -> OlmResult<(Vec<Session>, Vec<(DeviceData, WithheldCode)>)> {
405 let (result, share_infos) =
406 Self::encrypt_session_for(store, outbound.clone(), chunk).await?;
407
408 if let Some(request) = result.to_device_request {
409 let id = request.txn_id.clone();
410 outbound.add_request(id.clone(), request.into(), share_infos);
411 sessions.mark_as_being_shared(id, outbound.clone());
412 }
413
414 Ok((result.updated_olm_sessions, result.no_olm_devices))
415 }
416
417 pub(crate) fn session_cache(&self) -> GroupSessionCache {
418 self.sessions.clone()
419 }
420
421 async fn maybe_rotate_group_session(
422 &self,
423 should_rotate: bool,
424 room_id: &RoomId,
425 outbound: OutboundGroupSession,
426 encryption_settings: EncryptionSettings,
427 changes: &mut Changes,
428 own_device: Option<Device>,
429 ) -> OlmResult<OutboundGroupSession> {
430 Ok(if should_rotate {
431 let old_session_id = outbound.session_id();
432
433 let (outbound, mut inbound) = self
434 .create_outbound_group_session(room_id, encryption_settings, SenderData::unknown())
435 .await?;
436
437 let own_sender_data = if let Some(device) = own_device {
441 SenderDataFinder::find_using_device_data(
442 &self.store,
443 device.inner.clone(),
444 &inbound,
445 )
446 .await?
447 } else {
448 error!("Unable to find our own device!");
449 SenderData::unknown()
450 };
451 inbound.sender_data = own_sender_data;
452
453 changes.outbound_group_sessions.push(outbound.clone());
454 changes.inbound_group_sessions.push(inbound);
455
456 debug!(
457 old_session_id = old_session_id,
458 session_id = outbound.session_id(),
459 "A user or device has left the room since we last sent a \
460 message, or the encryption settings have changed. Rotating the \
461 room key.",
462 );
463
464 outbound
465 } else {
466 outbound
467 })
468 }
469
470 async fn encrypt_for_devices(
471 &self,
472 recipient_devices: Vec<DeviceData>,
473 group_session: &OutboundGroupSession,
474 changes: &mut Changes,
475 ) -> OlmResult<Vec<(DeviceData, WithheldCode)>> {
476 if !recipient_devices.is_empty() {
478 let recipients = recipient_list_to_users_and_devices(&recipient_devices);
479
480 changes.outbound_group_sessions = vec![group_session.clone()];
483
484 let message_index = group_session.message_index().await;
485
486 info!(
487 ?recipients,
488 message_index,
489 room_id = ?group_session.room_id(),
490 session_id = group_session.session_id(),
491 "Trying to encrypt a room key",
492 );
493 }
494
495 let tasks: Vec<_> = recipient_devices
500 .chunks(Self::MAX_TO_DEVICE_MESSAGES)
501 .map(|chunk| {
502 spawn(Self::encrypt_request(
503 self.store.crypto_store(),
504 chunk.to_vec(),
505 group_session.clone(),
506 self.sessions.clone(),
507 ))
508 })
509 .collect();
510
511 let mut withheld_devices = Vec::new();
512
513 for result in join_all(tasks).await {
518 let result = result.expect("Encryption task panicked");
519
520 let (used_sessions, failed_no_olm) = result?;
521
522 changes.sessions.extend(used_sessions);
523 withheld_devices.extend(failed_no_olm);
524 }
525
526 Ok(withheld_devices)
527 }
528
529 fn is_withheld_to(
530 &self,
531 group_session: &OutboundGroupSession,
532 device: &DeviceData,
533 code: &WithheldCode,
534 ) -> bool {
535 if code == &WithheldCode::NoOlm {
553 device.was_withheld_code_sent() || self.sessions.has_session_withheld_to(device, code)
554 } else {
555 group_session.sharing_view().is_withheld_to(device, code)
556 }
557 }
558
559 fn handle_withheld_devices(
560 &self,
561 group_session: &OutboundGroupSession,
562 withheld_devices: Vec<(DeviceData, WithheldCode)>,
563 ) -> OlmResult<()> {
564 let to_content = |code| {
566 let content = group_session.withheld_code(code);
567 Raw::new(&content).expect("We can always serialize a withheld content info").cast()
568 };
569
570 let chunk_to_request = |chunk| {
573 let mut messages = BTreeMap::new();
574 let mut share_infos = BTreeMap::new();
575
576 for (device, code) in chunk {
577 let device: DeviceData = device;
578 let code: WithheldCode = code;
579
580 let user_id = device.user_id().to_owned();
581 let device_id = device.device_id().to_owned();
582
583 let share_info = ShareInfo::new_withheld(code.to_owned());
584 let content = to_content(code);
585
586 messages
587 .entry(user_id.to_owned())
588 .or_insert_with(BTreeMap::new)
589 .insert(DeviceIdOrAllDevices::DeviceId(device_id.to_owned()), content);
590
591 share_infos
592 .entry(user_id)
593 .or_insert_with(BTreeMap::new)
594 .insert(device_id, share_info);
595 }
596
597 let txn_id = TransactionId::new();
598
599 let request = ToDeviceRequest {
600 event_type: ToDeviceEventType::from("m.room_key.withheld"),
601 txn_id,
602 messages,
603 };
604
605 (request, share_infos)
606 };
607
608 let result: Vec<_> = withheld_devices
609 .into_iter()
610 .filter(|(device, code)| !self.is_withheld_to(group_session, device, code))
611 .chunks(Self::MAX_TO_DEVICE_MESSAGES)
612 .into_iter()
613 .map(chunk_to_request)
614 .collect();
615
616 for (request, share_info) in result {
617 if !request.messages.is_empty() {
618 let txn_id = request.txn_id.to_owned();
619 group_session.add_request(txn_id.to_owned(), request.into(), share_info);
620
621 self.sessions.mark_as_being_shared(txn_id, group_session.clone());
622 }
623 }
624
625 Ok(())
626 }
627
628 fn log_room_key_sharing_result(requests: &[Arc<ToDeviceRequest>]) {
629 for request in requests {
630 let message_list = Self::to_device_request_to_log_list(request);
631 info!(
632 request_id = ?request.txn_id,
633 ?message_list,
634 "Created batch of to-device messages of type {}",
635 request.event_type
636 );
637 }
638 }
639
640 fn to_device_request_to_log_list(
644 request: &Arc<ToDeviceRequest>,
645 ) -> Vec<(String, String, String)> {
646 #[derive(serde::Deserialize)]
647 struct ContentStub<'a> {
648 #[serde(borrow, default, rename = "org.matrix.msgid")]
649 message_id: Option<&'a str>,
650 }
651
652 let mut result: Vec<(String, String, String)> = Vec::new();
653
654 for (user_id, device_map) in &request.messages {
655 for (device, content) in device_map {
656 let message_id: Option<&str> = content
657 .deserialize_as_unchecked::<ContentStub<'_>>()
658 .expect("We should be able to deserialize the content we generated")
659 .message_id;
660
661 result.push((
662 message_id.unwrap_or("<undefined>").to_owned(),
663 user_id.to_string(),
664 device.to_string(),
665 ));
666 }
667 }
668 result
669 }
670
671 #[instrument(skip(self, users, encryption_settings), fields(session_id))]
682 pub async fn share_room_key(
683 &self,
684 room_id: &RoomId,
685 users: impl Iterator<Item = &UserId>,
686 encryption_settings: impl Into<EncryptionSettings>,
687 ) -> OlmResult<Vec<Arc<ToDeviceRequest>>> {
688 trace!("Checking if a room key needs to be shared");
689
690 let account = self.store.static_account();
691 let device = self.store.get_device(account.user_id(), account.device_id()).await?;
692
693 let encryption_settings = encryption_settings.into();
694 let mut changes = Changes::default();
695
696 let (outbound, inbound) = self
698 .get_or_create_outbound_session(
699 room_id,
700 encryption_settings.clone(),
701 SenderData::unknown(),
702 )
703 .await?;
704 tracing::Span::current().record("session_id", outbound.session_id());
705
706 if let Some(mut inbound) = inbound {
709 let own_sender_data = if let Some(device) = &device {
713 SenderDataFinder::find_using_device_data(
714 &self.store,
715 device.inner.clone(),
716 &inbound,
717 )
718 .await?
719 } else {
720 error!("Unable to find our own device!");
721 SenderData::unknown()
722 };
723 inbound.sender_data = own_sender_data;
724
725 changes.outbound_group_sessions.push(outbound.clone());
726 changes.inbound_group_sessions.push(inbound);
727 }
728
729 let CollectRecipientsResult { should_rotate, devices, mut withheld_devices } =
733 self.collect_session_recipients(users, &encryption_settings, &outbound).await?;
734
735 let outbound = self
736 .maybe_rotate_group_session(
737 should_rotate,
738 room_id,
739 outbound,
740 encryption_settings,
741 &mut changes,
742 device,
743 )
744 .await?;
745
746 let devices: Vec<_> = devices
749 .into_values()
750 .flat_map(|d| {
751 d.into_iter().filter(|d| match outbound.sharing_view().get_share_state(d) {
752 ShareState::NotShared => true,
753 ShareState::Shared { message_index: _, olm_wedging_index } => {
754 olm_wedging_index < d.olm_wedging_index
761 }
762 _ => false,
763 })
764 })
765 .collect();
766
767 let unable_to_encrypt_devices =
773 self.encrypt_for_devices(devices, &outbound, &mut changes).await?;
774
775 withheld_devices.extend(unable_to_encrypt_devices);
777
778 self.handle_withheld_devices(&outbound, withheld_devices)?;
781
782 let requests = outbound.pending_requests();
786
787 if requests.is_empty() {
788 if !outbound.shared() {
789 debug!("The room key doesn't need to be shared with anyone. Marking as shared.");
790
791 outbound.mark_as_shared();
792 changes.outbound_group_sessions.push(outbound.clone());
793 }
794 } else {
795 Self::log_room_key_sharing_result(&requests)
796 }
797
798 if !changes.is_empty() {
800 let session_count = changes.sessions.len();
801
802 self.store.save_changes(changes).await?;
803
804 trace!(
805 session_count = session_count,
806 "Stored the changed sessions after encrypting an room key"
807 );
808 }
809
810 Ok(requests)
811 }
812
813 #[instrument(skip(self, bundle_data))]
825 pub async fn share_room_key_bundle_data(
826 &self,
827 user_id: &UserId,
828 collect_strategy: &CollectStrategy,
829 bundle_data: RoomKeyBundleContent,
830 ) -> OlmResult<Vec<ToDeviceRequest>> {
831 let collect_strategy = match collect_strategy {
833 CollectStrategy::AllDevices | CollectStrategy::ErrorOnVerifiedUserProblem => {
834 warn!(
835 "Ignoring request to use unsafe sharing strategy {collect_strategy:?} \
836 for room key history sharing",
837 );
838 &CollectStrategy::IdentityBasedStrategy
839 }
840 CollectStrategy::IdentityBasedStrategy | CollectStrategy::OnlyTrustedDevices => {
841 collect_strategy
842 }
843 };
844
845 let mut changes = Changes::default();
846
847 let CollectRecipientsResult { devices, .. } =
848 share_strategy::collect_recipients_for_share_strategy(
849 &self.store,
850 iter::once(user_id),
851 collect_strategy,
852 None,
853 )
854 .await?;
855
856 let devices = devices.into_values().flatten().collect();
857 let event_type = bundle_data.event_type().to_owned();
858 let (requests, _) = self
859 .encrypt_content_for_devices(devices, &event_type, bundle_data, &mut changes)
860 .await?;
861
862 if !changes.is_empty() {
866 let session_count = changes.sessions.len();
867
868 self.store.save_changes(changes).await?;
869
870 trace!(
871 session_count = session_count,
872 "Stored the changed sessions after encrypting an room key"
873 );
874 }
875
876 Ok(requests)
877 }
878
879 pub(crate) async fn encrypt_content_for_devices(
886 &self,
887 recipient_devices: Vec<DeviceData>,
888 event_type: &str,
889 content: impl Serialize + Clone + Send + 'static,
890 changes: &mut Changes,
891 ) -> OlmResult<(Vec<ToDeviceRequest>, Vec<(DeviceData, WithheldCode)>)> {
892 let recipients = recipient_list_to_users_and_devices(&recipient_devices);
893 info!(?recipients, "Encrypting content of type {}", event_type);
894
895 let tasks: Vec<_> = recipient_devices
900 .chunks(Self::MAX_TO_DEVICE_MESSAGES)
901 .map(|chunk| {
902 spawn(
903 encrypt_content_for_devices(
904 self.store.crypto_store(),
905 event_type.to_owned(),
906 content.clone(),
907 chunk.to_vec(),
908 )
909 .in_current_span(),
910 )
911 })
912 .collect();
913
914 let mut no_olm_devices = Vec::new();
915 let mut to_device_requests = Vec::new();
916
917 for result in join_all(tasks).await {
922 let result = result.expect("Encryption task panicked")?;
923 if let Some(request) = result.to_device_request {
924 to_device_requests.push(request);
925 }
926 changes.sessions.extend(result.updated_olm_sessions);
927 no_olm_devices.extend(result.no_olm_devices);
928 }
929
930 Ok((to_device_requests, no_olm_devices))
931 }
932}
933
934async fn encrypt_content_for_devices(
943 store: Arc<CryptoStoreWrapper>,
944 event_type: String,
945 content: impl Serialize + Clone + Send + 'static,
946 devices: Vec<DeviceData>,
947) -> OlmResult<EncryptForDevicesResult> {
948 let mut result_builder = EncryptForDevicesResultBuilder::default();
949
950 async fn encrypt(
951 store: Arc<CryptoStoreWrapper>,
952 device: DeviceData,
953 event_type: String,
954 bundle_data: impl Serialize,
955 ) -> OlmResult<(Session, Raw<ToDeviceEncryptedEventContent>)> {
956 device
957 .encrypt(store.as_ref(), &event_type, bundle_data)
958 .await
959 .map(|(session, message, _message_id)| (session, message))
960 }
961
962 let tasks = devices.iter().map(|device| {
963 spawn(
964 encrypt(store.clone(), device.clone(), event_type.clone(), content.clone())
965 .in_current_span(),
966 )
967 });
968
969 let results = join_all(tasks).await;
970
971 for (device, result) in zip(devices, results) {
972 let encryption_result = result.expect("Encryption task panicked");
973
974 match encryption_result {
975 Ok((used_session, message)) => {
976 result_builder.on_successful_encryption(&device, used_session, message.cast());
977 }
978 Err(OlmError::MissingSession) => {
979 result_builder.on_missing_session(device);
981 }
982 Err(e) => return Err(e),
983 }
984 }
985
986 Ok(result_builder.into_result())
987}
988
989#[derive(Debug)]
992struct EncryptForDevicesResult {
993 to_device_request: Option<ToDeviceRequest>,
996
997 no_olm_devices: Vec<(DeviceData, WithheldCode)>,
999
1000 updated_olm_sessions: Vec<Session>,
1003}
1004
1005#[derive(Debug, Default)]
1007struct EncryptForDevicesResultBuilder {
1008 messages: BTreeMap<OwnedUserId, BTreeMap<DeviceIdOrAllDevices, Raw<AnyToDeviceEventContent>>>,
1010
1011 no_olm_devices: Vec<(DeviceData, WithheldCode)>,
1013
1014 updated_olm_sessions: Vec<Session>,
1017}
1018
1019impl EncryptForDevicesResultBuilder {
1020 pub fn on_successful_encryption(
1024 &mut self,
1025 device: &DeviceData,
1026 used_session: Session,
1027 message: Raw<AnyToDeviceEventContent>,
1028 ) {
1029 self.updated_olm_sessions.push(used_session);
1030
1031 self.messages
1032 .entry(device.user_id().to_owned())
1033 .or_default()
1034 .insert(DeviceIdOrAllDevices::DeviceId(device.device_id().to_owned()), message);
1035 }
1036
1037 pub fn on_missing_session(&mut self, device: DeviceData) {
1039 self.no_olm_devices.push((device, WithheldCode::NoOlm));
1040 }
1041
1042 pub fn into_result(self) -> EncryptForDevicesResult {
1045 let EncryptForDevicesResultBuilder { updated_olm_sessions, no_olm_devices, messages } =
1046 self;
1047
1048 let mut encrypt_for_devices_result = EncryptForDevicesResult {
1049 to_device_request: None,
1050 updated_olm_sessions,
1051 no_olm_devices,
1052 };
1053
1054 if !messages.is_empty() {
1055 let request = ToDeviceRequest {
1056 event_type: ToDeviceEventType::RoomEncrypted,
1057 txn_id: TransactionId::new(),
1058 messages,
1059 };
1060 trace!(
1061 recipient_count = request.message_count(),
1062 transaction_id = ?request.txn_id,
1063 "Created a to-device request carrying room keys",
1064 );
1065 encrypt_for_devices_result.to_device_request = Some(request);
1066 }
1067
1068 encrypt_for_devices_result
1069 }
1070}
1071
1072fn recipient_list_to_users_and_devices(
1073 recipient_devices: &[DeviceData],
1074) -> BTreeMap<&UserId, BTreeSet<&DeviceId>> {
1075 #[allow(unknown_lints, clippy::unwrap_or_default)] recipient_devices.iter().fold(BTreeMap::new(), |mut acc, d| {
1077 acc.entry(d.user_id()).or_insert_with(BTreeSet::new).insert(d.device_id());
1078 acc
1079 })
1080}
1081
1082#[cfg(test)]
1083mod tests {
1084 use std::{
1085 collections::{BTreeMap, BTreeSet},
1086 iter,
1087 ops::Deref,
1088 sync::Arc,
1089 };
1090
1091 use assert_matches2::assert_let;
1092 use matrix_sdk_common::deserialized_responses::{ProcessedToDeviceEvent, WithheldCode};
1093 use matrix_sdk_test::{async_test, ruma_response_from_json};
1094 use ruma::{
1095 DeviceId, OneTimeKeyAlgorithm, OwnedMxcUri, TransactionId, UInt, UserId,
1096 api::client::{
1097 keys::{claim_keys, get_keys, upload_keys},
1098 to_device::send_event_to_device::v3::Response as ToDeviceResponse,
1099 },
1100 device_id,
1101 events::room::{EncryptedFile, V2EncryptedFileInfo, history_visibility::HistoryVisibility},
1102 owned_device_id, owned_room_id, room_id,
1103 to_device::DeviceIdOrAllDevices,
1104 user_id,
1105 };
1106 use serde_json::{Value, json};
1107
1108 use crate::{
1109 DecryptionSettings, EncryptionSettings, LocalTrust, OlmMachine, TrustRequirement,
1110 identities::DeviceData,
1111 machine::{
1112 EncryptionSyncChanges, test_helpers::get_machine_pair_with_setup_sessions_test_helper,
1113 },
1114 olm::{Account, SenderData},
1115 session_manager::{CollectStrategy, group_sessions::CollectRecipientsResult},
1116 types::{
1117 DeviceKeys, EventEncryptionAlgorithm,
1118 events::{
1119 room::encrypted::EncryptedToDeviceEvent,
1120 room_key_bundle::RoomKeyBundleContent,
1121 room_key_withheld::RoomKeyWithheldContent::{self, MegolmV1AesSha2},
1122 },
1123 requests::ToDeviceRequest,
1124 },
1125 };
1126
1127 fn alice_id() -> &'static UserId {
1128 user_id!("@alice:example.org")
1129 }
1130
1131 fn alice_device_id() -> &'static DeviceId {
1132 device_id!("JLAFKJWSCS")
1133 }
1134
1135 fn keys_query_response() -> get_keys::v3::Response {
1137 let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_query.json");
1138 let data: Value = serde_json::from_slice(data).unwrap();
1139 ruma_response_from_json(&data)
1140 }
1141
1142 fn bob_keys_query_response() -> get_keys::v3::Response {
1143 let data = json!({
1144 "device_keys": {
1145 "@bob:localhost": {
1146 "BOBDEVICE": {
1147 "user_id": "@bob:localhost",
1148 "device_id": "BOBDEVICE",
1149 "algorithms": [
1150 "m.olm.v1.curve25519-aes-sha2",
1151 "m.megolm.v1.aes-sha2",
1152 "m.megolm.v2.aes-sha2"
1153 ],
1154 "keys": {
1155 "curve25519:BOBDEVICE": "QzXDFZj0Pt5xG4r11XGSrqE4mnFOTgRM5pz7n3tzohU",
1156 "ed25519:BOBDEVICE": "T7QMEXcEo/NfiC/8doVHT+2XnMm0pDpRa27bmE8PlPI"
1157 },
1158 "signatures": {
1159 "@bob:localhost": {
1160 "ed25519:BOBDEVICE": "1Ee9J02KoVf4DKhT+LkurpZJEygiznqpgkT4lqvMTLtZyzShsVTnwmoMPttuGcJkLp9lMK1egveNYCEaYP80Cw"
1161 }
1162 }
1163 }
1164 }
1165 }
1166 });
1167 ruma_response_from_json(&data)
1168 }
1169
1170 fn bob_one_time_key() -> claim_keys::v3::Response {
1173 let data = json!({
1174 "failures": {},
1175 "one_time_keys":{
1176 "@bob:localhost":{
1177 "BOBDEVICE":{
1178 "signed_curve25519:AAAAAAAAAAA": {
1179 "key":"bm1olfbksjC5SwKxCLLK4XaINCA0FwR/155J85gIpCk",
1180 "signatures":{
1181 "@bob:localhost":{
1182 "ed25519:BOBDEVICE":"BKyS/+EV76zdZkWgny2D0svZ0ycS3etfyHCrsDgm7MYe166HqQmSoX29HsjGLvE/5F+Sg2zW7RJileUvquPwDA"
1183 }
1184 }
1185 }
1186 }
1187 }
1188 }
1189 });
1190 ruma_response_from_json(&data)
1191 }
1192
1193 fn keys_claim_response() -> claim_keys::v3::Response {
1196 let data = include_bytes!("../../../../../benchmarks/benches/crypto_bench/keys_claim.json");
1197 let data: Value = serde_json::from_slice(data).unwrap();
1198 ruma_response_from_json(&data)
1199 }
1200
1201 async fn machine_with_user_test_helper(user_id: &UserId, device_id: &DeviceId) -> OlmMachine {
1202 let keys_query = keys_query_response();
1203 let txn_id = TransactionId::new();
1204
1205 let machine = OlmMachine::new(user_id, device_id).await;
1206
1207 machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1209 let (txn_id, _keys_claim_request) = machine
1210 .get_missing_sessions(iter::once(user_id!("@example:localhost")))
1211 .await
1212 .unwrap()
1213 .unwrap();
1214 let keys_claim = keys_claim_response();
1215 machine.mark_request_as_sent(&txn_id, &keys_claim).await.unwrap();
1216
1217 machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1219 let (txn_id, _keys_claim_request) = machine
1220 .get_missing_sessions(iter::once(user_id!("@bob:localhost")))
1221 .await
1222 .unwrap()
1223 .unwrap();
1224 machine.mark_request_as_sent(&txn_id, &bob_one_time_key()).await.unwrap();
1225
1226 machine
1227 }
1228
1229 async fn machine() -> OlmMachine {
1230 machine_with_user_test_helper(alice_id(), alice_device_id()).await
1231 }
1232
1233 async fn machine_with_shared_room_key_test_helper() -> OlmMachine {
1234 let machine = machine().await;
1235 let room_id = room_id!("!test:localhost");
1236 let keys_claim = keys_claim_response();
1237
1238 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1239 let requests =
1240 machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
1241
1242 let outbound =
1243 machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1244
1245 assert!(!outbound.pending_requests().is_empty());
1246 assert!(!outbound.shared());
1247
1248 let response = ToDeviceResponse::new();
1249 for request in requests {
1250 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1251 }
1252
1253 assert!(outbound.shared());
1254 assert!(outbound.pending_requests().is_empty());
1255
1256 machine
1257 }
1258
1259 #[async_test]
1260 async fn test_sharing() {
1261 let machine = machine().await;
1262 let room_id = room_id!("!test:localhost");
1263 let keys_claim = keys_claim_response();
1264
1265 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1266
1267 let requests =
1268 machine.share_room_key(room_id, users, EncryptionSettings::default()).await.unwrap();
1269
1270 let event_count: usize = requests
1271 .iter()
1272 .filter(|r| r.event_type == "m.room.encrypted".into())
1273 .map(|r| r.message_count())
1274 .sum();
1275
1276 assert_eq!(event_count, 148);
1280
1281 let withheld_count: usize = requests
1282 .iter()
1283 .filter(|r| r.event_type == "m.room_key.withheld".into())
1284 .map(|r| r.message_count())
1285 .sum();
1286 assert_eq!(withheld_count, 2);
1287 }
1288
1289 fn count_withheld_from(requests: &[Arc<ToDeviceRequest>], code: WithheldCode) -> usize {
1290 requests
1291 .iter()
1292 .filter(|r| r.event_type == "m.room_key.withheld".into())
1293 .map(|r| {
1294 let mut count = 0;
1295 for message in r.messages.values() {
1297 message.iter().for_each(|(_, content)| {
1298 let withheld: RoomKeyWithheldContent =
1299 content.deserialize_as_unchecked::<RoomKeyWithheldContent>().unwrap();
1300
1301 if let MegolmV1AesSha2(content) = withheld
1302 && content.withheld_code() == code
1303 {
1304 count += 1;
1305 }
1306 })
1307 }
1308 count
1309 })
1310 .sum()
1311 }
1312
1313 #[async_test]
1314 async fn test_no_olm_sent_once() {
1315 let machine = machine().await;
1316 let keys_claim = keys_claim_response();
1317
1318 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1319
1320 let first_room_id = room_id!("!test:localhost");
1321
1322 let requests = machine
1323 .share_room_key(first_room_id, users.to_owned(), EncryptionSettings::default())
1324 .await
1325 .unwrap();
1326
1327 let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1329 assert_eq!(withheld_count, 2);
1330
1331 let new_requests = machine
1334 .share_room_key(first_room_id, users, EncryptionSettings::default())
1335 .await
1336 .unwrap();
1337 let withheld_count: usize = count_withheld_from(&new_requests, WithheldCode::NoOlm);
1338 assert_eq!(withheld_count, 2);
1340
1341 let response = ToDeviceResponse::new();
1342 for request in requests {
1343 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1344 }
1345
1346 let second_room_id = room_id!("!other:localhost");
1349 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1350 let requests = machine
1351 .share_room_key(second_room_id, users, EncryptionSettings::default())
1352 .await
1353 .unwrap();
1354
1355 let withheld_count: usize = count_withheld_from(&requests, WithheldCode::NoOlm);
1356 assert_eq!(withheld_count, 0);
1357
1358 }
1361
1362 #[async_test]
1363 async fn test_ratcheted_sharing() {
1364 let machine = machine_with_shared_room_key_test_helper().await;
1365
1366 let room_id = room_id!("!test:localhost");
1367 let late_joiner = user_id!("@bob:localhost");
1368 let keys_claim = keys_claim_response();
1369
1370 let mut users: BTreeSet<_> = keys_claim.one_time_keys.keys().map(Deref::deref).collect();
1371 users.insert(late_joiner);
1372
1373 let requests = machine
1374 .share_room_key(room_id, users.into_iter(), EncryptionSettings::default())
1375 .await
1376 .unwrap();
1377
1378 let event_count: usize = requests
1379 .iter()
1380 .filter(|r| r.event_type == "m.room.encrypted".into())
1381 .map(|r| r.message_count())
1382 .sum();
1383 let outbound =
1384 machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1385
1386 assert_eq!(event_count, 1);
1387 assert!(!outbound.pending_requests().is_empty());
1388 }
1389
1390 #[async_test]
1391 async fn test_changing_encryption_settings() {
1392 let machine = machine_with_shared_room_key_test_helper().await;
1393 let room_id = room_id!("!test:localhost");
1394 let keys_claim = keys_claim_response();
1395
1396 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1397 let outbound =
1398 machine.inner.group_session_manager.get_outbound_group_session(room_id).unwrap();
1399
1400 let CollectRecipientsResult { should_rotate, .. } = machine
1401 .inner
1402 .group_session_manager
1403 .collect_session_recipients(users.clone(), &EncryptionSettings::default(), &outbound)
1404 .await
1405 .unwrap();
1406
1407 assert!(!should_rotate);
1408
1409 let settings = EncryptionSettings {
1410 history_visibility: HistoryVisibility::Invited,
1411 ..Default::default()
1412 };
1413
1414 let CollectRecipientsResult { should_rotate, .. } = machine
1415 .inner
1416 .group_session_manager
1417 .collect_session_recipients(users.clone(), &settings, &outbound)
1418 .await
1419 .unwrap();
1420
1421 assert!(should_rotate);
1422
1423 let settings = EncryptionSettings {
1424 algorithm: EventEncryptionAlgorithm::from("m.megolm.v2.aes-sha2"),
1425 ..Default::default()
1426 };
1427
1428 let CollectRecipientsResult { should_rotate, .. } = machine
1429 .inner
1430 .group_session_manager
1431 .collect_session_recipients(users, &settings, &outbound)
1432 .await
1433 .unwrap();
1434
1435 assert!(should_rotate);
1436 }
1437
1438 #[async_test]
1439 async fn test_key_recipient_collecting() {
1440 let user_id = user_id!("@example:localhost");
1443 let device_id = device_id!("TESTDEVICE");
1444 let room_id = room_id!("!test:localhost");
1445
1446 let machine = machine_with_user_test_helper(user_id, device_id).await;
1447
1448 let (outbound, _) = machine
1449 .inner
1450 .group_session_manager
1451 .get_or_create_outbound_session(
1452 room_id,
1453 EncryptionSettings::default(),
1454 SenderData::unknown(),
1455 )
1456 .await
1457 .expect("We should be able to create a new session");
1458 let history_visibility = HistoryVisibility::Joined;
1459 let settings = EncryptionSettings { history_visibility, ..Default::default() };
1460
1461 let users = [user_id].into_iter();
1462
1463 let CollectRecipientsResult { devices: recipients, .. } = machine
1464 .inner
1465 .group_session_manager
1466 .collect_session_recipients(users, &settings, &outbound)
1467 .await
1468 .expect("We should be able to collect the session recipients");
1469
1470 assert!(!recipients[user_id].is_empty());
1471
1472 assert!(
1474 !recipients[user_id]
1475 .iter()
1476 .any(|d| d.user_id() == user_id && d.device_id() == device_id)
1477 );
1478
1479 let settings = EncryptionSettings {
1480 sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1481 ..Default::default()
1482 };
1483 let users = [user_id].into_iter();
1484
1485 let CollectRecipientsResult { devices: recipients, .. } = machine
1486 .inner
1487 .group_session_manager
1488 .collect_session_recipients(users, &settings, &outbound)
1489 .await
1490 .expect("We should be able to collect the session recipients");
1491
1492 assert!(recipients[user_id].is_empty());
1493
1494 let device_id = "AFGUOBTZWM".into();
1495 let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1496 device.set_local_trust(LocalTrust::Verified).await.unwrap();
1497 let users = [user_id].into_iter();
1498
1499 let CollectRecipientsResult { devices: recipients, withheld_devices: withheld, .. } =
1500 machine
1501 .inner
1502 .group_session_manager
1503 .collect_session_recipients(users, &settings, &outbound)
1504 .await
1505 .expect("We should be able to collect the session recipients");
1506
1507 assert!(
1508 recipients[user_id]
1509 .iter()
1510 .any(|d| d.user_id() == user_id && d.device_id() == device_id)
1511 );
1512
1513 let devices = machine.get_user_devices(user_id, None).await.unwrap();
1514 devices
1515 .devices()
1516 .filter(|d| d.device_id() != device_id!("TESTDEVICE"))
1518 .for_each(|d| {
1519 if d.is_blacklisted() {
1520 assert!(withheld.iter().any(|(dev, w)| {
1521 dev.device_id() == d.device_id() && w == &WithheldCode::Blacklisted
1522 }));
1523 } else if !d.is_verified() {
1524 assert!(withheld.iter().any(|(dev, w)| {
1526 dev.device_id() == d.device_id() && w == &WithheldCode::Unverified
1527 }));
1528 }
1529 });
1530
1531 assert_eq!(149, withheld.len());
1532 }
1533
1534 #[async_test]
1535 async fn test_sharing_withheld_only_trusted() {
1536 let machine = machine().await;
1537 let room_id = room_id!("!test:localhost");
1538 let keys_claim = keys_claim_response();
1539
1540 let users = keys_claim.one_time_keys.keys().map(Deref::deref);
1541 let settings = EncryptionSettings {
1542 sharing_strategy: CollectStrategy::OnlyTrustedDevices,
1543 ..Default::default()
1544 };
1545
1546 let user_id = user_id!("@example:localhost");
1548 let device_id = "MWFXPINOAO".into();
1549 let device = machine.get_device(user_id, device_id, None).await.unwrap().unwrap();
1550 device.set_local_trust(LocalTrust::Verified).await.unwrap();
1551 machine
1552 .get_device(user_id, "MWVTUXDNNM".into(), None)
1553 .await
1554 .unwrap()
1555 .unwrap()
1556 .set_local_trust(LocalTrust::BlackListed)
1557 .await
1558 .unwrap();
1559
1560 let requests = machine.share_room_key(room_id, users, settings).await.unwrap();
1561
1562 let room_key_count =
1564 requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).count();
1565
1566 assert_eq!(1, room_key_count);
1567
1568 let withheld_count =
1569 requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1570 assert_eq!(1, withheld_count);
1572
1573 let event_count: usize = requests
1574 .iter()
1575 .filter(|r| r.event_type == "m.room_key.withheld".into())
1576 .map(|r| r.message_count())
1577 .sum();
1578
1579 assert_eq!(event_count, 149);
1581
1582 let has_blacklist =
1584 requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).any(|r| {
1585 let device_key = DeviceIdOrAllDevices::from(owned_device_id!("MWVTUXDNNM"));
1586 let content = &r.messages[user_id][&device_key];
1587 let withheld: RoomKeyWithheldContent =
1588 content.deserialize_as_unchecked::<RoomKeyWithheldContent>().unwrap();
1589 if let MegolmV1AesSha2(content) = withheld {
1590 content.withheld_code() == WithheldCode::Blacklisted
1591 } else {
1592 false
1593 }
1594 });
1595
1596 assert!(has_blacklist);
1597 }
1598
1599 #[async_test]
1600 async fn test_no_olm_withheld_only_sent_once() {
1601 let keys_query = keys_query_response();
1602 let txn_id = TransactionId::new();
1603
1604 let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1605
1606 machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1607 machine.mark_request_as_sent(&txn_id, &bob_keys_query_response()).await.unwrap();
1608
1609 let first_room = room_id!("!test:localhost");
1610 let second_room = room_id!("!test2:localhost");
1611 let bob_id = user_id!("@bob:localhost");
1612
1613 let settings = EncryptionSettings::default();
1614 let users = [bob_id];
1615
1616 let requests = machine
1617 .share_room_key(first_room, users.into_iter(), settings.to_owned())
1618 .await
1619 .unwrap();
1620
1621 let withheld_count =
1623 requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1624
1625 assert_eq!(withheld_count, 1);
1626 assert_eq!(requests.len(), 1);
1627
1628 let second_requests =
1631 machine.share_room_key(second_room, users.into_iter(), settings).await.unwrap();
1632
1633 let withheld_count =
1634 second_requests.iter().filter(|r| r.event_type == "m.room_key.withheld".into()).count();
1635
1636 assert_eq!(withheld_count, 0);
1637 assert_eq!(second_requests.len(), 0);
1638
1639 let response = ToDeviceResponse::new();
1640
1641 let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1642
1643 assert!(!device.was_withheld_code_sent());
1646
1647 for request in requests {
1648 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1649 }
1650
1651 let device = machine.get_device(bob_id, "BOBDEVICE".into(), None).await.unwrap().unwrap();
1652
1653 assert!(device.was_withheld_code_sent());
1654 }
1655
1656 #[async_test]
1657 async fn test_resend_session_after_unwedging() {
1658 let machine = OlmMachine::new(alice_id(), alice_device_id()).await;
1659 assert_let!(Ok(Some((txn_id, device_keys_request))) = machine.upload_device_keys().await);
1660 let device_keys_response = upload_keys::v3::Response::new(BTreeMap::from([(
1661 OneTimeKeyAlgorithm::SignedCurve25519,
1662 UInt::new(device_keys_request.one_time_keys.len() as u64).unwrap(),
1663 )]));
1664 machine.mark_request_as_sent(&txn_id, &device_keys_response).await.unwrap();
1665
1666 let room_id = room_id!("!test:localhost");
1667
1668 let bob_id = user_id!("@bob:localhost");
1669 let bob_account = Account::new(bob_id);
1670 let keys_query_data = json!({
1671 "device_keys": {
1672 "@bob:localhost": {
1673 bob_account.device_id.clone(): bob_account.device_keys()
1674 }
1675 }
1676 });
1677 let keys_query: get_keys::v3::Response = ruma_response_from_json(&keys_query_data);
1678 let txn_id = TransactionId::new();
1679 machine.mark_request_as_sent(&txn_id, &keys_query).await.unwrap();
1680
1681 let alice_device_keys =
1682 device_keys_request.device_keys.unwrap().deserialize_as::<DeviceKeys>().unwrap();
1683 let mut alice_otks = device_keys_request.one_time_keys.iter();
1684 let alice_device = DeviceData::new(alice_device_keys, LocalTrust::Unset);
1685
1686 {
1687 let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1689 let mut session = bob_account
1690 .create_outbound_session(
1691 &alice_device,
1692 &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1693 bob_account.device_keys(),
1694 )
1695 .unwrap();
1696 let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1697
1698 let to_device =
1699 EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1700
1701 let sync_changes = EncryptionSyncChanges {
1703 to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1704 changed_devices: &Default::default(),
1705 one_time_keys_counts: &Default::default(),
1706 unused_fallback_keys: None,
1707 next_batch_token: None,
1708 };
1709
1710 let decryption_settings =
1711 DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1712
1713 let (decrypted, _) =
1714 machine.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1715
1716 assert_eq!(1, decrypted.len());
1717 }
1718
1719 {
1721 let requests = machine
1722 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1723 .await
1724 .unwrap();
1725
1726 let event_count: usize = requests
1728 .iter()
1729 .filter(|r| r.event_type == "m.room.encrypted".into())
1730 .map(|r| r.message_count())
1731 .sum();
1732 assert_eq!(event_count, 1);
1733
1734 let response = ToDeviceResponse::new();
1735 for request in requests {
1736 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1737 }
1738 }
1739
1740 {
1743 let requests = machine
1744 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1745 .await
1746 .unwrap();
1747
1748 let event_count: usize = requests
1749 .iter()
1750 .filter(|r| r.event_type == "m.room.encrypted".into())
1751 .map(|r| r.message_count())
1752 .sum();
1753 assert_eq!(event_count, 0);
1754 }
1755
1756 {
1758 let (alice_otk_id, alice_otk) = alice_otks.next().unwrap();
1759 let mut session = bob_account
1760 .create_outbound_session(
1761 &alice_device,
1762 &BTreeMap::from([(alice_otk_id.clone(), alice_otk.clone())]),
1763 bob_account.device_keys(),
1764 )
1765 .unwrap();
1766 let content = session.encrypt(&alice_device, "m.dummy", json!({}), None).await.unwrap();
1767
1768 let to_device =
1769 EncryptedToDeviceEvent::new(bob_id.to_owned(), content.deserialize().unwrap());
1770
1771 let sync_changes = EncryptionSyncChanges {
1773 to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1774 changed_devices: &Default::default(),
1775 one_time_keys_counts: &Default::default(),
1776 unused_fallback_keys: None,
1777 next_batch_token: None,
1778 };
1779
1780 let decryption_settings =
1781 DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1782
1783 let (decrypted, _) =
1784 machine.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1785
1786 assert_eq!(1, decrypted.len());
1787 }
1788
1789 {
1791 let requests = machine
1792 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1793 .await
1794 .unwrap();
1795
1796 let event_count: usize = requests
1797 .iter()
1798 .filter(|r| r.event_type == "m.room.encrypted".into())
1799 .map(|r| r.message_count())
1800 .sum();
1801 assert_eq!(event_count, 1);
1802
1803 let response = ToDeviceResponse::new();
1804 for request in requests {
1805 machine.mark_request_as_sent(&request.txn_id, &response).await.unwrap();
1806 }
1807 }
1808
1809 {
1812 let requests = machine
1813 .share_room_key(room_id, [bob_id].into_iter(), EncryptionSettings::default())
1814 .await
1815 .unwrap();
1816
1817 let event_count: usize = requests
1818 .iter()
1819 .filter(|r| r.event_type == "m.room.encrypted".into())
1820 .map(|r| r.message_count())
1821 .sum();
1822 assert_eq!(event_count, 0);
1823 }
1824 }
1825
1826 #[async_test]
1827 async fn test_room_key_bundle_sharing() {
1828 let (alice, bob) = get_machine_pair_with_setup_sessions_test_helper(
1829 user_id!("@alice:localhost"),
1830 user_id!("@bob:localhost"),
1831 false,
1832 )
1833 .await;
1834
1835 let device = alice.get_device(bob.user_id(), bob.device_id(), None).await.unwrap().unwrap();
1837 device.set_local_trust(LocalTrust::Verified).await.unwrap();
1838
1839 let content = RoomKeyBundleContent {
1840 room_id: owned_room_id!("!room:id"),
1841 file: EncryptedFile::new(
1842 OwnedMxcUri::from("test"),
1843 V2EncryptedFileInfo::encode([0; 32], [0; 16]).into(),
1844 Default::default(),
1845 ),
1846 };
1847
1848 let requests = alice
1849 .share_room_key_bundle_data(
1850 bob.user_id(),
1851 &CollectStrategy::OnlyTrustedDevices,
1852 content,
1853 )
1854 .await
1855 .unwrap();
1856
1857 let requests: Vec<_> =
1859 requests.iter().filter(|r| r.event_type == "m.room.encrypted".into()).collect();
1860 let message_count: usize = requests.iter().map(|r| r.message_count()).sum();
1861 assert_eq!(message_count, 1);
1862
1863 let bob_message = requests[0]
1865 .messages
1866 .get(bob.user_id())
1867 .unwrap()
1868 .get(&(bob.device_id().to_owned().into()))
1869 .unwrap();
1870 let to_device = EncryptedToDeviceEvent::new(
1871 alice.user_id().to_owned(),
1872 bob_message.deserialize_as_unchecked().unwrap(),
1873 );
1874
1875 let sync_changes = EncryptionSyncChanges {
1876 to_device_events: vec![crate::utilities::json_convert(&to_device).unwrap()],
1877 changed_devices: &Default::default(),
1878 one_time_keys_counts: &Default::default(),
1879 unused_fallback_keys: None,
1880 next_batch_token: None,
1881 };
1882
1883 let decryption_settings =
1884 DecryptionSettings { sender_device_trust_requirement: TrustRequirement::Untrusted };
1885
1886 let (decrypted, _) =
1887 bob.receive_sync_changes(sync_changes, &decryption_settings).await.unwrap();
1888 assert_eq!(1, decrypted.len());
1889 use crate::types::events::EventType;
1890 assert_let!(
1891 ProcessedToDeviceEvent::Decrypted { raw, .. } = decrypted.first().unwrap().clone()
1892 );
1893 assert_eq!(
1894 raw.get_field::<String>("type").unwrap().unwrap(),
1895 RoomKeyBundleContent::EVENT_TYPE,
1896 );
1897 }
1898}