1use std::{
42 collections::{BTreeMap, BTreeSet, HashMap, HashSet},
43 fmt::Debug,
44 ops::Deref,
45 pin::pin,
46 sync::{atomic::Ordering, Arc},
47 time::Duration,
48};
49
50use as_variant::as_variant;
51use futures_core::Stream;
52use futures_util::StreamExt;
53use itertools::{Either, Itertools};
54use matrix_sdk_common::locks::RwLock as StdRwLock;
55use ruma::{
56 encryption::KeyUsage, events::secret::request::SecretName, DeviceId, OwnedDeviceId,
57 OwnedRoomId, OwnedUserId, RoomId, UserId,
58};
59use serde::{de::DeserializeOwned, Deserialize, Serialize};
60use thiserror::Error;
61use tokio::sync::{Mutex, MutexGuard, Notify, OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
62use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
63use tracing::{error, info, instrument, trace, warn};
64use vodozemac::{base64_encode, megolm::SessionOrdering, Curve25519PublicKey};
65use zeroize::{Zeroize, ZeroizeOnDrop};
66
67#[cfg(doc)]
68use crate::{backups::BackupMachine, identities::OwnUserIdentity};
69use crate::{
70 gossiping::GossippedSecret,
71 identities::{user::UserIdentity, Device, DeviceData, UserDevices, UserIdentityData},
72 olm::{
73 Account, ExportedRoomKey, InboundGroupSession, OlmMessageHash, OutboundGroupSession,
74 PrivateCrossSigningIdentity, SenderData, Session, StaticAccountData,
75 },
76 types::{
77 events::room_key_withheld::RoomKeyWithheldEvent, BackupSecrets, CrossSigningSecrets,
78 EventEncryptionAlgorithm, MegolmBackupV1Curve25519AesSha2Secrets, RoomKeyExport,
79 SecretsBundle,
80 },
81 verification::VerificationMachine,
82 CrossSigningStatus, OwnUserIdentityData, RoomKeyImportResult,
83};
84
85pub mod caches;
86mod crypto_store_wrapper;
87mod error;
88mod memorystore;
89mod traits;
90
91#[cfg(any(test, feature = "testing"))]
92#[macro_use]
93#[allow(missing_docs)]
94pub mod integration_tests;
95
96use caches::{SequenceNumber, UsersForKeyQuery};
97pub(crate) use crypto_store_wrapper::CryptoStoreWrapper;
98pub use error::{CryptoStoreError, Result};
99use matrix_sdk_common::{
100 deserialized_responses::WithheldCode, store_locks::CrossProcessStoreLock, timeout::timeout,
101};
102pub use memorystore::MemoryStore;
103pub use traits::{CryptoStore, DynCryptoStore, IntoCryptoStore};
104
105use crate::types::{
106 events::{room_key_bundle::RoomKeyBundleContent, room_key_withheld::RoomKeyWithheldContent},
107 room_history::RoomKeyBundle,
108};
109pub use crate::{
110 dehydrated_devices::DehydrationError,
111 gossiping::{GossipRequest, SecretInfo},
112};
113
114#[derive(Debug, Clone)]
121pub struct Store {
122 inner: Arc<StoreInner>,
123}
124
125#[derive(Debug, Default)]
126pub(crate) struct KeyQueryManager {
127 users_for_key_query: Mutex<UsersForKeyQuery>,
129
130 users_for_key_query_notify: Notify,
132}
133
134impl KeyQueryManager {
135 pub async fn synced<'a>(&'a self, cache: &'a StoreCache) -> Result<SyncedKeyQueryManager<'a>> {
136 self.ensure_sync_tracked_users(cache).await?;
137 Ok(SyncedKeyQueryManager { cache, manager: self })
138 }
139
140 async fn ensure_sync_tracked_users(&self, cache: &StoreCache) -> Result<()> {
147 let loaded = cache.loaded_tracked_users.read().await;
149 if *loaded {
150 return Ok(());
151 }
152
153 drop(loaded);
155 let mut loaded = cache.loaded_tracked_users.write().await;
156
157 if *loaded {
161 return Ok(());
162 }
163
164 let tracked_users = cache.store.load_tracked_users().await?;
165
166 let mut query_users_lock = self.users_for_key_query.lock().await;
167 let mut tracked_users_cache = cache.tracked_users.write();
168 for user in tracked_users {
169 tracked_users_cache.insert(user.user_id.to_owned());
170
171 if user.dirty {
172 query_users_lock.insert_user(&user.user_id);
173 }
174 }
175
176 *loaded = true;
177
178 Ok(())
179 }
180
181 pub async fn wait_if_user_key_query_pending(
191 &self,
192 cache: StoreCacheGuard,
193 timeout_duration: Duration,
194 user: &UserId,
195 ) -> Result<UserKeyQueryResult> {
196 {
197 self.ensure_sync_tracked_users(&cache).await?;
200 drop(cache);
201 }
202
203 let mut users_for_key_query = self.users_for_key_query.lock().await;
204 let Some(waiter) = users_for_key_query.maybe_register_waiting_task(user) else {
205 return Ok(UserKeyQueryResult::WasNotPending);
206 };
207
208 let wait_for_completion = async {
209 while !waiter.completed.load(Ordering::Relaxed) {
210 let mut notified = pin!(self.users_for_key_query_notify.notified());
214 notified.as_mut().enable();
215 drop(users_for_key_query);
216
217 notified.await;
219
220 users_for_key_query = self.users_for_key_query.lock().await;
224 }
225 };
226
227 match timeout(Box::pin(wait_for_completion), timeout_duration).await {
228 Err(_) => {
229 warn!(
230 user_id = ?user,
231 "The user has a pending `/keys/query` request which did \
232 not finish yet, some devices might be missing."
233 );
234
235 Ok(UserKeyQueryResult::TimeoutExpired)
236 }
237 _ => Ok(UserKeyQueryResult::WasPending),
238 }
239 }
240}
241
242pub(crate) struct SyncedKeyQueryManager<'a> {
243 cache: &'a StoreCache,
244 manager: &'a KeyQueryManager,
245}
246
247impl SyncedKeyQueryManager<'_> {
248 pub async fn update_tracked_users(&self, users: impl Iterator<Item = &UserId>) -> Result<()> {
253 let mut store_updates = Vec::new();
254 let mut key_query_lock = self.manager.users_for_key_query.lock().await;
255
256 {
257 let mut tracked_users = self.cache.tracked_users.write();
258 for user_id in users {
259 if tracked_users.insert(user_id.to_owned()) {
260 key_query_lock.insert_user(user_id);
261 store_updates.push((user_id, true))
262 }
263 }
264 }
265
266 self.cache.store.save_tracked_users(&store_updates).await
267 }
268
269 pub async fn mark_tracked_users_as_changed(
276 &self,
277 users: impl Iterator<Item = &UserId>,
278 ) -> Result<()> {
279 let mut store_updates: Vec<(&UserId, bool)> = Vec::new();
280 let mut key_query_lock = self.manager.users_for_key_query.lock().await;
281
282 {
283 let tracked_users = &self.cache.tracked_users.read();
284 for user_id in users {
285 if tracked_users.contains(user_id) {
286 key_query_lock.insert_user(user_id);
287 store_updates.push((user_id, true));
288 }
289 }
290 }
291
292 self.cache.store.save_tracked_users(&store_updates).await
293 }
294
295 pub async fn mark_tracked_users_as_up_to_date(
301 &self,
302 users: impl Iterator<Item = &UserId>,
303 sequence_number: SequenceNumber,
304 ) -> Result<()> {
305 let mut store_updates: Vec<(&UserId, bool)> = Vec::new();
306 let mut key_query_lock = self.manager.users_for_key_query.lock().await;
307
308 {
309 let tracked_users = self.cache.tracked_users.read();
310 for user_id in users {
311 if tracked_users.contains(user_id) {
312 let clean = key_query_lock.maybe_remove_user(user_id, sequence_number);
313 store_updates.push((user_id, !clean));
314 }
315 }
316 }
317
318 self.cache.store.save_tracked_users(&store_updates).await?;
319 self.manager.users_for_key_query_notify.notify_waiters();
321
322 Ok(())
323 }
324
325 pub async fn users_for_key_query(&self) -> (HashSet<OwnedUserId>, SequenceNumber) {
337 self.manager.users_for_key_query.lock().await.users_for_key_query()
338 }
339
340 pub fn tracked_users(&self) -> HashSet<OwnedUserId> {
342 self.cache.tracked_users.read().iter().cloned().collect()
343 }
344
345 pub async fn mark_user_as_changed(&self, user: &UserId) -> Result<()> {
351 self.manager.users_for_key_query.lock().await.insert_user(user);
352 self.cache.tracked_users.write().insert(user.to_owned());
353
354 self.cache.store.save_tracked_users(&[(user, true)]).await
355 }
356}
357
358#[derive(Debug)]
359pub(crate) struct StoreCache {
360 store: Arc<CryptoStoreWrapper>,
361 tracked_users: StdRwLock<BTreeSet<OwnedUserId>>,
362 loaded_tracked_users: RwLock<bool>,
363 account: Mutex<Option<Account>>,
364}
365
366impl StoreCache {
367 pub(crate) fn store_wrapper(&self) -> &CryptoStoreWrapper {
368 self.store.as_ref()
369 }
370
371 async fn account(&self) -> Result<impl Deref<Target = Account> + '_> {
383 let mut guard = self.account.lock().await;
384 if guard.is_some() {
385 Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
386 } else {
387 match self.store.load_account().await? {
388 Some(account) => {
389 *guard = Some(account);
390 Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
391 }
392 None => Err(CryptoStoreError::AccountUnset),
393 }
394 }
395 }
396}
397
398pub(crate) struct StoreCacheGuard {
404 cache: OwnedRwLockReadGuard<StoreCache>,
405 }
407
408impl StoreCacheGuard {
409 pub async fn account(&self) -> Result<impl Deref<Target = Account> + '_> {
417 self.cache.account().await
418 }
419}
420
421impl Deref for StoreCacheGuard {
422 type Target = StoreCache;
423
424 fn deref(&self) -> &Self::Target {
425 &self.cache
426 }
427}
428
429#[allow(missing_debug_implementations)]
431pub struct StoreTransaction {
432 store: Store,
433 changes: PendingChanges,
434 cache: OwnedRwLockWriteGuard<StoreCache>,
436}
437
438impl StoreTransaction {
439 async fn new(store: Store) -> Self {
441 let cache = store.inner.cache.clone();
442
443 Self { store, changes: PendingChanges::default(), cache: cache.clone().write_owned().await }
444 }
445
446 pub(crate) fn cache(&self) -> &StoreCache {
447 &self.cache
448 }
449
450 pub fn store(&self) -> &Store {
452 &self.store
453 }
454
455 pub async fn account(&mut self) -> Result<&mut Account> {
462 if self.changes.account.is_none() {
463 let _ = self.cache.account().await?;
465 self.changes.account = self.cache.account.lock().await.take();
466 }
467 Ok(self.changes.account.as_mut().unwrap())
468 }
469
470 pub async fn commit(self) -> Result<()> {
473 if self.changes.is_empty() {
474 return Ok(());
475 }
476
477 let account = self.changes.account.as_ref().map(|acc| acc.deep_clone());
479
480 self.store.save_pending_changes(self.changes).await?;
481
482 if let Some(account) = account {
484 *self.cache.account.lock().await = Some(account);
485 }
486
487 Ok(())
488 }
489}
490
491#[derive(Debug)]
492struct StoreInner {
493 identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
494 store: Arc<CryptoStoreWrapper>,
495
496 cache: Arc<RwLock<StoreCache>>,
500
501 verification_machine: VerificationMachine,
502
503 static_account: StaticAccountData,
506}
507
508#[derive(Default, Debug)]
514#[allow(missing_docs)]
515pub struct PendingChanges {
516 pub account: Option<Account>,
517}
518
519impl PendingChanges {
520 pub fn is_empty(&self) -> bool {
522 self.account.is_none()
523 }
524}
525
526#[derive(Default, Debug)]
529#[allow(missing_docs)]
530pub struct Changes {
531 pub private_identity: Option<PrivateCrossSigningIdentity>,
532 pub backup_version: Option<String>,
533 pub backup_decryption_key: Option<BackupDecryptionKey>,
534 pub dehydrated_device_pickle_key: Option<DehydratedDeviceKey>,
535 pub sessions: Vec<Session>,
536 pub message_hashes: Vec<OlmMessageHash>,
537 pub inbound_group_sessions: Vec<InboundGroupSession>,
538 pub outbound_group_sessions: Vec<OutboundGroupSession>,
539 pub key_requests: Vec<GossipRequest>,
540 pub identities: IdentityChanges,
541 pub devices: DeviceChanges,
542 pub withheld_session_info: BTreeMap<OwnedRoomId, BTreeMap<String, RoomKeyWithheldEvent>>,
544 pub room_settings: HashMap<OwnedRoomId, RoomSettings>,
545 pub secrets: Vec<GossippedSecret>,
546 pub next_batch_token: Option<String>,
547
548 pub received_room_key_bundles: Vec<StoredRoomKeyBundleData>,
551}
552
553#[derive(Clone, Debug, Serialize, Deserialize)]
557pub struct StoredRoomKeyBundleData {
558 pub sender_user: OwnedUserId,
560
561 pub sender_data: SenderData,
564
565 pub bundle_data: RoomKeyBundleContent,
567}
568
569#[derive(Clone, Debug, Serialize, Deserialize)]
571pub struct TrackedUser {
572 pub user_id: OwnedUserId,
574 pub dirty: bool,
579}
580
581impl Changes {
582 pub fn is_empty(&self) -> bool {
584 self.private_identity.is_none()
585 && self.backup_version.is_none()
586 && self.backup_decryption_key.is_none()
587 && self.dehydrated_device_pickle_key.is_none()
588 && self.sessions.is_empty()
589 && self.message_hashes.is_empty()
590 && self.inbound_group_sessions.is_empty()
591 && self.outbound_group_sessions.is_empty()
592 && self.key_requests.is_empty()
593 && self.identities.is_empty()
594 && self.devices.is_empty()
595 && self.withheld_session_info.is_empty()
596 && self.room_settings.is_empty()
597 && self.secrets.is_empty()
598 && self.next_batch_token.is_none()
599 && self.received_room_key_bundles.is_empty()
600 }
601}
602
603#[derive(Debug, Clone, Default)]
614#[allow(missing_docs)]
615pub struct IdentityChanges {
616 pub new: Vec<UserIdentityData>,
617 pub changed: Vec<UserIdentityData>,
618 pub unchanged: Vec<UserIdentityData>,
619}
620
621impl IdentityChanges {
622 fn is_empty(&self) -> bool {
623 self.new.is_empty() && self.changed.is_empty()
624 }
625
626 fn into_maps(
629 self,
630 ) -> (
631 BTreeMap<OwnedUserId, UserIdentityData>,
632 BTreeMap<OwnedUserId, UserIdentityData>,
633 BTreeMap<OwnedUserId, UserIdentityData>,
634 ) {
635 let new: BTreeMap<_, _> = self
636 .new
637 .into_iter()
638 .map(|identity| (identity.user_id().to_owned(), identity))
639 .collect();
640
641 let changed: BTreeMap<_, _> = self
642 .changed
643 .into_iter()
644 .map(|identity| (identity.user_id().to_owned(), identity))
645 .collect();
646
647 let unchanged: BTreeMap<_, _> = self
648 .unchanged
649 .into_iter()
650 .map(|identity| (identity.user_id().to_owned(), identity))
651 .collect();
652
653 (new, changed, unchanged)
654 }
655}
656
657#[derive(Debug, Clone, Default)]
658#[allow(missing_docs)]
659pub struct DeviceChanges {
660 pub new: Vec<DeviceData>,
661 pub changed: Vec<DeviceData>,
662 pub deleted: Vec<DeviceData>,
663}
664
665fn collect_device_updates(
671 verification_machine: VerificationMachine,
672 own_identity: Option<OwnUserIdentityData>,
673 identities: IdentityChanges,
674 devices: DeviceChanges,
675) -> DeviceUpdates {
676 let mut new: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
677 let mut changed: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
678
679 let (new_identities, changed_identities, unchanged_identities) = identities.into_maps();
680
681 let map_device = |device: DeviceData| {
682 let device_owner_identity = new_identities
683 .get(device.user_id())
684 .or_else(|| changed_identities.get(device.user_id()))
685 .or_else(|| unchanged_identities.get(device.user_id()))
686 .cloned();
687
688 Device {
689 inner: device,
690 verification_machine: verification_machine.to_owned(),
691 own_identity: own_identity.to_owned(),
692 device_owner_identity,
693 }
694 };
695
696 for device in devices.new {
697 let device = map_device(device);
698
699 new.entry(device.user_id().to_owned())
700 .or_default()
701 .insert(device.device_id().to_owned(), device);
702 }
703
704 for device in devices.changed {
705 let device = map_device(device);
706
707 changed
708 .entry(device.user_id().to_owned())
709 .or_default()
710 .insert(device.device_id().to_owned(), device.to_owned());
711 }
712
713 DeviceUpdates { new, changed }
714}
715
716#[derive(Clone, Debug, Default)]
719pub struct DeviceUpdates {
720 pub new: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Device>>,
726 pub changed: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Device>>,
728}
729
730#[derive(Clone, Debug, Default)]
733pub struct IdentityUpdates {
734 pub new: BTreeMap<OwnedUserId, UserIdentity>,
740 pub changed: BTreeMap<OwnedUserId, UserIdentity>,
742 pub unchanged: BTreeMap<OwnedUserId, UserIdentity>,
744}
745
746#[derive(Clone, Zeroize, ZeroizeOnDrop, Deserialize, Serialize)]
756#[serde(transparent)]
757pub struct BackupDecryptionKey {
758 pub(crate) inner: Box<[u8; BackupDecryptionKey::KEY_SIZE]>,
759}
760
761impl BackupDecryptionKey {
762 pub const KEY_SIZE: usize = 32;
764
765 pub fn new() -> Result<Self, rand::Error> {
767 let mut rng = rand::thread_rng();
768
769 let mut key = Box::new([0u8; Self::KEY_SIZE]);
770 rand::Fill::try_fill(key.as_mut_slice(), &mut rng)?;
771
772 Ok(Self { inner: key })
773 }
774
775 pub fn to_base64(&self) -> String {
777 base64_encode(self.inner.as_slice())
778 }
779}
780
781#[cfg(not(tarpaulin_include))]
782impl Debug for BackupDecryptionKey {
783 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
784 f.debug_tuple("BackupDecryptionKey").field(&"...").finish()
785 }
786}
787
788#[derive(Clone, Zeroize, ZeroizeOnDrop, Deserialize, Serialize)]
793#[serde(transparent)]
794pub struct DehydratedDeviceKey {
795 pub(crate) inner: Box<[u8; DehydratedDeviceKey::KEY_SIZE]>,
796}
797
798impl DehydratedDeviceKey {
799 pub const KEY_SIZE: usize = 32;
801
802 pub fn new() -> Result<Self, rand::Error> {
804 let mut rng = rand::thread_rng();
805
806 let mut key = Box::new([0u8; Self::KEY_SIZE]);
807 rand::Fill::try_fill(key.as_mut_slice(), &mut rng)?;
808
809 Ok(Self { inner: key })
810 }
811
812 pub fn from_slice(slice: &[u8]) -> Result<Self, DehydrationError> {
816 if slice.len() == 32 {
817 let mut key = Box::new([0u8; 32]);
818 key.copy_from_slice(slice);
819 Ok(DehydratedDeviceKey { inner: key })
820 } else {
821 Err(DehydrationError::PickleKeyLength(slice.len()))
822 }
823 }
824
825 pub fn from_bytes(raw_key: &[u8; 32]) -> Self {
827 let mut inner = Box::new([0u8; Self::KEY_SIZE]);
828 inner.copy_from_slice(raw_key);
829
830 Self { inner }
831 }
832
833 pub fn to_base64(&self) -> String {
835 base64_encode(self.inner.as_slice())
836 }
837}
838
839impl From<&[u8; 32]> for DehydratedDeviceKey {
840 fn from(value: &[u8; 32]) -> Self {
841 DehydratedDeviceKey { inner: Box::new(*value) }
842 }
843}
844
845impl From<DehydratedDeviceKey> for Vec<u8> {
846 fn from(key: DehydratedDeviceKey) -> Self {
847 key.inner.to_vec()
848 }
849}
850
851#[cfg(not(tarpaulin_include))]
852impl Debug for DehydratedDeviceKey {
853 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
854 f.debug_tuple("DehydratedDeviceKey").field(&"...").finish()
855 }
856}
857
858impl DeviceChanges {
859 pub fn extend(&mut self, other: DeviceChanges) {
861 self.new.extend(other.new);
862 self.changed.extend(other.changed);
863 self.deleted.extend(other.deleted);
864 }
865
866 fn is_empty(&self) -> bool {
867 self.new.is_empty() && self.changed.is_empty() && self.deleted.is_empty()
868 }
869}
870
871#[derive(Debug, Clone, Default)]
873pub struct RoomKeyCounts {
874 pub total: usize,
876 pub backed_up: usize,
878}
879
880#[derive(Default, Clone, Debug)]
882pub struct BackupKeys {
883 pub decryption_key: Option<BackupDecryptionKey>,
885 pub backup_version: Option<String>,
887}
888
889#[derive(Default, Zeroize, ZeroizeOnDrop)]
892pub struct CrossSigningKeyExport {
893 pub master_key: Option<String>,
895 pub self_signing_key: Option<String>,
897 pub user_signing_key: Option<String>,
899}
900
901#[cfg(not(tarpaulin_include))]
902impl Debug for CrossSigningKeyExport {
903 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
904 f.debug_struct("CrossSigningKeyExport")
905 .field("master_key", &self.master_key.is_some())
906 .field("self_signing_key", &self.self_signing_key.is_some())
907 .field("user_signing_key", &self.user_signing_key.is_some())
908 .finish_non_exhaustive()
909 }
910}
911
912#[derive(Debug, Error)]
915pub enum SecretImportError {
916 #[error(transparent)]
918 Key(#[from] vodozemac::KeyError),
919 #[error(
922 "The public key of the imported private key doesn't match to the \
923 public key that was uploaded to the server"
924 )]
925 MismatchedPublicKeys,
926 #[error(transparent)]
928 Store(#[from] CryptoStoreError),
929}
930
931#[derive(Debug, Error)]
936pub enum SecretsBundleExportError {
937 #[error(transparent)]
939 Store(#[from] CryptoStoreError),
940 #[error("The store is missing one or multiple cross-signing keys")]
942 MissingCrossSigningKey(KeyUsage),
943 #[error("The store doesn't contain any cross-signing keys")]
945 MissingCrossSigningKeys,
946 #[error("The store contains a backup key, but no backup version")]
949 MissingBackupVersion,
950}
951
952#[derive(Clone, Copy, Debug, PartialEq, Eq)]
955pub(crate) enum UserKeyQueryResult {
956 WasPending,
957 WasNotPending,
958
959 TimeoutExpired,
961}
962
963#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
965pub struct RoomSettings {
966 pub algorithm: EventEncryptionAlgorithm,
968
969 pub only_allow_trusted_devices: bool,
972
973 pub session_rotation_period: Option<Duration>,
976
977 pub session_rotation_period_messages: Option<usize>,
980}
981
982impl Default for RoomSettings {
983 fn default() -> Self {
984 Self {
985 algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
986 only_allow_trusted_devices: false,
987 session_rotation_period: None,
988 session_rotation_period_messages: None,
989 }
990 }
991}
992
993#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
995pub struct RoomKeyInfo {
996 pub algorithm: EventEncryptionAlgorithm,
1001
1002 pub room_id: OwnedRoomId,
1004
1005 pub sender_key: Curve25519PublicKey,
1007
1008 pub session_id: String,
1010}
1011
1012impl From<&InboundGroupSession> for RoomKeyInfo {
1013 fn from(group_session: &InboundGroupSession) -> Self {
1014 RoomKeyInfo {
1015 algorithm: group_session.algorithm().clone(),
1016 room_id: group_session.room_id().to_owned(),
1017 sender_key: group_session.sender_key(),
1018 session_id: group_session.session_id().to_owned(),
1019 }
1020 }
1021}
1022
1023#[derive(Clone, Debug, Deserialize, Serialize)]
1025pub struct RoomKeyWithheldInfo {
1026 pub room_id: OwnedRoomId,
1028
1029 pub session_id: String,
1031
1032 pub withheld_event: RoomKeyWithheldEvent,
1035}
1036
1037impl Store {
1038 pub(crate) fn new(
1040 account: StaticAccountData,
1041 identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
1042 store: Arc<CryptoStoreWrapper>,
1043 verification_machine: VerificationMachine,
1044 ) -> Self {
1045 Self {
1046 inner: Arc::new(StoreInner {
1047 static_account: account,
1048 identity,
1049 store: store.clone(),
1050 verification_machine,
1051 cache: Arc::new(RwLock::new(StoreCache {
1052 store,
1053 tracked_users: Default::default(),
1054 loaded_tracked_users: Default::default(),
1055 account: Default::default(),
1056 })),
1057 }),
1058 }
1059 }
1060
1061 pub(crate) fn user_id(&self) -> &UserId {
1063 &self.inner.static_account.user_id
1064 }
1065
1066 pub(crate) fn device_id(&self) -> &DeviceId {
1068 self.inner.verification_machine.own_device_id()
1069 }
1070
1071 pub(crate) fn static_account(&self) -> &StaticAccountData {
1073 &self.inner.static_account
1074 }
1075
1076 pub(crate) async fn cache(&self) -> Result<StoreCacheGuard> {
1077 Ok(StoreCacheGuard { cache: self.inner.cache.clone().read_owned().await })
1082 }
1083
1084 pub(crate) async fn transaction(&self) -> StoreTransaction {
1085 StoreTransaction::new(self.clone()).await
1086 }
1087
1088 pub(crate) async fn with_transaction<
1091 T,
1092 Fut: futures_core::Future<Output = Result<(StoreTransaction, T), crate::OlmError>>,
1093 F: FnOnce(StoreTransaction) -> Fut,
1094 >(
1095 &self,
1096 func: F,
1097 ) -> Result<T, crate::OlmError> {
1098 let tr = self.transaction().await;
1099 let (tr, res) = func(tr).await?;
1100 tr.commit().await?;
1101 Ok(res)
1102 }
1103
1104 #[cfg(test)]
1105 pub(crate) async fn reset_cross_signing_identity(&self) {
1107 self.inner.identity.lock().await.reset();
1108 }
1109
1110 pub(crate) fn private_identity(&self) -> Arc<Mutex<PrivateCrossSigningIdentity>> {
1112 self.inner.identity.clone()
1113 }
1114
1115 pub(crate) async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
1117 let changes = Changes { sessions: sessions.to_vec(), ..Default::default() };
1118
1119 self.save_changes(changes).await
1120 }
1121
1122 pub(crate) async fn get_sessions(
1123 &self,
1124 sender_key: &str,
1125 ) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
1126 self.inner.store.get_sessions(sender_key).await
1127 }
1128
1129 pub(crate) async fn save_changes(&self, changes: Changes) -> Result<()> {
1130 self.inner.store.save_changes(changes).await
1131 }
1132
1133 pub(crate) async fn compare_group_session(
1140 &self,
1141 session: &InboundGroupSession,
1142 ) -> Result<SessionOrdering> {
1143 let old_session = self
1144 .inner
1145 .store
1146 .get_inbound_group_session(session.room_id(), session.session_id())
1147 .await?;
1148
1149 Ok(if let Some(old_session) = old_session {
1150 session.compare(&old_session).await
1151 } else {
1152 SessionOrdering::Better
1153 })
1154 }
1155
1156 #[cfg(test)]
1157 pub(crate) async fn save_device_data(&self, devices: &[DeviceData]) -> Result<()> {
1159 let changes = Changes {
1160 devices: DeviceChanges { changed: devices.to_vec(), ..Default::default() },
1161 ..Default::default()
1162 };
1163
1164 self.save_changes(changes).await
1165 }
1166
1167 pub(crate) async fn save_inbound_group_sessions(
1169 &self,
1170 sessions: &[InboundGroupSession],
1171 ) -> Result<()> {
1172 let changes = Changes { inbound_group_sessions: sessions.to_vec(), ..Default::default() };
1173
1174 self.save_changes(changes).await
1175 }
1176
1177 pub(crate) async fn device_display_name(&self) -> Result<Option<String>, CryptoStoreError> {
1179 Ok(self
1180 .inner
1181 .store
1182 .get_device(self.user_id(), self.device_id())
1183 .await?
1184 .and_then(|d| d.display_name().map(|d| d.to_owned())))
1185 }
1186
1187 pub(crate) async fn get_device_data(
1192 &self,
1193 user_id: &UserId,
1194 device_id: &DeviceId,
1195 ) -> Result<Option<DeviceData>> {
1196 self.inner.store.get_device(user_id, device_id).await
1197 }
1198
1199 pub(crate) async fn get_device_data_for_user_filtered(
1207 &self,
1208 user_id: &UserId,
1209 ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1210 self.inner.store.get_user_devices(user_id).await.map(|mut d| {
1211 if user_id == self.user_id() {
1212 d.remove(self.device_id());
1213 }
1214 d
1215 })
1216 }
1217
1218 pub(crate) async fn get_device_data_for_user(
1227 &self,
1228 user_id: &UserId,
1229 ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1230 self.inner.store.get_user_devices(user_id).await
1231 }
1232
1233 pub(crate) async fn get_device_from_curve_key(
1239 &self,
1240 user_id: &UserId,
1241 curve_key: Curve25519PublicKey,
1242 ) -> Result<Option<Device>> {
1243 self.get_user_devices(user_id)
1244 .await
1245 .map(|d| d.devices().find(|d| d.curve25519_key() == Some(curve_key)))
1246 }
1247
1248 pub(crate) async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
1258 let devices = self.get_device_data_for_user(user_id).await?;
1259
1260 let own_identity = self
1261 .inner
1262 .store
1263 .get_user_identity(self.user_id())
1264 .await?
1265 .and_then(|i| i.own().cloned());
1266 let device_owner_identity = self.inner.store.get_user_identity(user_id).await?;
1267
1268 Ok(UserDevices {
1269 inner: devices,
1270 verification_machine: self.inner.verification_machine.clone(),
1271 own_identity,
1272 device_owner_identity,
1273 })
1274 }
1275
1276 pub(crate) async fn get_device(
1286 &self,
1287 user_id: &UserId,
1288 device_id: &DeviceId,
1289 ) -> Result<Option<Device>> {
1290 if let Some(device_data) = self.inner.store.get_device(user_id, device_id).await? {
1291 Ok(Some(self.wrap_device_data(device_data).await?))
1292 } else {
1293 Ok(None)
1294 }
1295 }
1296
1297 pub(crate) async fn wrap_device_data(&self, device_data: DeviceData) -> Result<Device> {
1302 let own_identity = self
1303 .inner
1304 .store
1305 .get_user_identity(self.user_id())
1306 .await?
1307 .and_then(|i| i.own().cloned());
1308
1309 let device_owner_identity =
1310 self.inner.store.get_user_identity(device_data.user_id()).await?;
1311
1312 Ok(Device {
1313 inner: device_data,
1314 verification_machine: self.inner.verification_machine.clone(),
1315 own_identity,
1316 device_owner_identity,
1317 })
1318 }
1319
1320 pub(crate) async fn get_identity(&self, user_id: &UserId) -> Result<Option<UserIdentity>> {
1322 let own_identity = self
1323 .inner
1324 .store
1325 .get_user_identity(self.user_id())
1326 .await?
1327 .and_then(as_variant!(UserIdentityData::Own));
1328
1329 Ok(self.inner.store.get_user_identity(user_id).await?.map(|i| {
1330 UserIdentity::new(
1331 self.clone(),
1332 i,
1333 self.inner.verification_machine.to_owned(),
1334 own_identity,
1335 )
1336 }))
1337 }
1338
1339 pub async fn export_secret(
1348 &self,
1349 secret_name: &SecretName,
1350 ) -> Result<Option<String>, CryptoStoreError> {
1351 Ok(match secret_name {
1352 SecretName::CrossSigningMasterKey
1353 | SecretName::CrossSigningUserSigningKey
1354 | SecretName::CrossSigningSelfSigningKey => {
1355 self.inner.identity.lock().await.export_secret(secret_name).await
1356 }
1357 SecretName::RecoveryKey => {
1358 if let Some(key) = self.load_backup_keys().await?.decryption_key {
1359 let exported = key.to_base64();
1360 Some(exported)
1361 } else {
1362 None
1363 }
1364 }
1365 name => {
1366 warn!(secret = ?name, "Unknown secret was requested");
1367 None
1368 }
1369 })
1370 }
1371
1372 pub async fn export_cross_signing_keys(
1380 &self,
1381 ) -> Result<Option<CrossSigningKeyExport>, CryptoStoreError> {
1382 let master_key = self.export_secret(&SecretName::CrossSigningMasterKey).await?;
1383 let self_signing_key = self.export_secret(&SecretName::CrossSigningSelfSigningKey).await?;
1384 let user_signing_key = self.export_secret(&SecretName::CrossSigningUserSigningKey).await?;
1385
1386 Ok(if master_key.is_none() && self_signing_key.is_none() && user_signing_key.is_none() {
1387 None
1388 } else {
1389 Some(CrossSigningKeyExport { master_key, self_signing_key, user_signing_key })
1390 })
1391 }
1392
1393 pub async fn import_cross_signing_keys(
1398 &self,
1399 export: CrossSigningKeyExport,
1400 ) -> Result<CrossSigningStatus, SecretImportError> {
1401 if let Some(public_identity) =
1402 self.get_identity(self.user_id()).await?.and_then(|i| i.own())
1403 {
1404 let identity = self.inner.identity.lock().await;
1405
1406 identity
1407 .import_secrets(
1408 public_identity.to_owned(),
1409 export.master_key.as_deref(),
1410 export.self_signing_key.as_deref(),
1411 export.user_signing_key.as_deref(),
1412 )
1413 .await?;
1414
1415 let status = identity.status().await;
1416
1417 let diff = identity.get_public_identity_diff(&public_identity.inner).await;
1418
1419 let mut changes =
1420 Changes { private_identity: Some(identity.clone()), ..Default::default() };
1421
1422 if diff.none_differ() {
1423 public_identity.mark_as_verified();
1424 changes.identities.changed.push(UserIdentityData::Own(public_identity.inner));
1425 }
1426
1427 info!(?status, "Successfully imported the private cross-signing keys");
1428
1429 self.save_changes(changes).await?;
1430 } else {
1431 warn!(
1432 "No public identity found while importing cross-signing keys, \
1433 a /keys/query needs to be done"
1434 );
1435 }
1436
1437 Ok(self.inner.identity.lock().await.status().await)
1438 }
1439
1440 pub async fn export_secrets_bundle(&self) -> Result<SecretsBundle, SecretsBundleExportError> {
1452 let Some(cross_signing) = self.export_cross_signing_keys().await? else {
1453 return Err(SecretsBundleExportError::MissingCrossSigningKeys);
1454 };
1455
1456 let Some(master_key) = cross_signing.master_key.clone() else {
1457 return Err(SecretsBundleExportError::MissingCrossSigningKey(KeyUsage::Master));
1458 };
1459
1460 let Some(user_signing_key) = cross_signing.user_signing_key.clone() else {
1461 return Err(SecretsBundleExportError::MissingCrossSigningKey(KeyUsage::UserSigning));
1462 };
1463
1464 let Some(self_signing_key) = cross_signing.self_signing_key.clone() else {
1465 return Err(SecretsBundleExportError::MissingCrossSigningKey(KeyUsage::SelfSigning));
1466 };
1467
1468 let backup_keys = self.load_backup_keys().await?;
1469
1470 let backup = if let Some(key) = backup_keys.decryption_key {
1471 if let Some(backup_version) = backup_keys.backup_version {
1472 Some(BackupSecrets::MegolmBackupV1Curve25519AesSha2(
1473 MegolmBackupV1Curve25519AesSha2Secrets { key, backup_version },
1474 ))
1475 } else {
1476 return Err(SecretsBundleExportError::MissingBackupVersion);
1477 }
1478 } else {
1479 None
1480 };
1481
1482 Ok(SecretsBundle {
1483 cross_signing: CrossSigningSecrets { master_key, user_signing_key, self_signing_key },
1484 backup,
1485 })
1486 }
1487
1488 pub async fn import_secrets_bundle(
1501 &self,
1502 bundle: &SecretsBundle,
1503 ) -> Result<(), SecretImportError> {
1504 let mut changes = Changes::default();
1505
1506 if let Some(backup_bundle) = &bundle.backup {
1507 match backup_bundle {
1508 BackupSecrets::MegolmBackupV1Curve25519AesSha2(bundle) => {
1509 changes.backup_decryption_key = Some(bundle.key.clone());
1510 changes.backup_version = Some(bundle.backup_version.clone());
1511 }
1512 }
1513 }
1514
1515 let identity = self.inner.identity.lock().await;
1516
1517 identity
1518 .import_secrets_unchecked(
1519 Some(&bundle.cross_signing.master_key),
1520 Some(&bundle.cross_signing.self_signing_key),
1521 Some(&bundle.cross_signing.user_signing_key),
1522 )
1523 .await?;
1524
1525 let public_identity = identity.to_public_identity().await.expect(
1526 "We should be able to create a new public identity since we just imported \
1527 all the private cross-signing keys",
1528 );
1529
1530 changes.private_identity = Some(identity.clone());
1531 changes.identities.new.push(UserIdentityData::Own(public_identity));
1532
1533 Ok(self.save_changes(changes).await?)
1534 }
1535
1536 pub async fn import_secret(&self, secret: &GossippedSecret) -> Result<(), SecretImportError> {
1538 match &secret.secret_name {
1539 SecretName::CrossSigningMasterKey
1540 | SecretName::CrossSigningUserSigningKey
1541 | SecretName::CrossSigningSelfSigningKey => {
1542 if let Some(public_identity) =
1543 self.get_identity(self.user_id()).await?.and_then(|i| i.own())
1544 {
1545 let identity = self.inner.identity.lock().await;
1546
1547 identity
1548 .import_secret(
1549 public_identity,
1550 &secret.secret_name,
1551 &secret.event.content.secret,
1552 )
1553 .await?;
1554 info!(
1555 secret_name = ?secret.secret_name,
1556 "Successfully imported a private cross signing key"
1557 );
1558
1559 let changes =
1560 Changes { private_identity: Some(identity.clone()), ..Default::default() };
1561
1562 self.save_changes(changes).await?;
1563 }
1564 }
1565 SecretName::RecoveryKey => {
1566 }
1572 name => {
1573 warn!(secret = ?name, "Tried to import an unknown secret");
1574 }
1575 }
1576
1577 Ok(())
1578 }
1579
1580 pub async fn get_only_allow_trusted_devices(&self) -> Result<bool> {
1583 let value = self.get_value("only_allow_trusted_devices").await?.unwrap_or_default();
1584 Ok(value)
1585 }
1586
1587 pub async fn set_only_allow_trusted_devices(
1590 &self,
1591 block_untrusted_devices: bool,
1592 ) -> Result<()> {
1593 self.set_value("only_allow_trusted_devices", &block_untrusted_devices).await
1594 }
1595
1596 pub async fn get_value<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
1598 let Some(value) = self.get_custom_value(key).await? else {
1599 return Ok(None);
1600 };
1601 let deserialized = self.deserialize_value(&value)?;
1602 Ok(Some(deserialized))
1603 }
1604
1605 pub async fn set_value(&self, key: &str, value: &impl Serialize) -> Result<()> {
1607 let serialized = self.serialize_value(value)?;
1608 self.set_custom_value(key, serialized).await?;
1609 Ok(())
1610 }
1611
1612 fn serialize_value(&self, value: &impl Serialize) -> Result<Vec<u8>> {
1613 let serialized =
1614 rmp_serde::to_vec_named(value).map_err(|x| CryptoStoreError::Backend(x.into()))?;
1615 Ok(serialized)
1616 }
1617
1618 fn deserialize_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T> {
1619 let deserialized =
1620 rmp_serde::from_slice(value).map_err(|e| CryptoStoreError::Backend(e.into()))?;
1621 Ok(deserialized)
1622 }
1623
1624 pub fn room_keys_received_stream(
1636 &self,
1637 ) -> impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>> {
1638 self.inner.store.room_keys_received_stream()
1639 }
1640
1641 pub fn room_keys_withheld_received_stream(
1650 &self,
1651 ) -> impl Stream<Item = Vec<RoomKeyWithheldInfo>> {
1652 self.inner.store.room_keys_withheld_received_stream()
1653 }
1654
1655 pub fn user_identities_stream(&self) -> impl Stream<Item = IdentityUpdates> {
1686 let verification_machine = self.inner.verification_machine.to_owned();
1687
1688 let this = self.clone();
1689 self.inner.store.identities_stream().map(move |(own_identity, identities, _)| {
1690 let (new_identities, changed_identities, unchanged_identities) = identities.into_maps();
1691
1692 let map_identity = |(user_id, identity)| {
1693 (
1694 user_id,
1695 UserIdentity::new(
1696 this.clone(),
1697 identity,
1698 verification_machine.to_owned(),
1699 own_identity.to_owned(),
1700 ),
1701 )
1702 };
1703
1704 let new = new_identities.into_iter().map(map_identity).collect();
1705 let changed = changed_identities.into_iter().map(map_identity).collect();
1706 let unchanged = unchanged_identities.into_iter().map(map_identity).collect();
1707
1708 IdentityUpdates { new, changed, unchanged }
1709 })
1710 }
1711
1712 pub fn devices_stream(&self) -> impl Stream<Item = DeviceUpdates> {
1744 let verification_machine = self.inner.verification_machine.to_owned();
1745
1746 self.inner.store.identities_stream().map(move |(own_identity, identities, devices)| {
1747 collect_device_updates(
1748 verification_machine.to_owned(),
1749 own_identity,
1750 identities,
1751 devices,
1752 )
1753 })
1754 }
1755
1756 pub fn identities_stream_raw(&self) -> impl Stream<Item = (IdentityChanges, DeviceChanges)> {
1766 self.inner.store.identities_stream().map(|(_, identities, devices)| (identities, devices))
1767 }
1768
1769 pub fn create_store_lock(
1772 &self,
1773 lock_key: String,
1774 lock_value: String,
1775 ) -> CrossProcessStoreLock<LockableCryptoStore> {
1776 self.inner.store.create_store_lock(lock_key, lock_value)
1777 }
1778
1779 pub fn secrets_stream(&self) -> impl Stream<Item = GossippedSecret> {
1819 self.inner.store.secrets_stream()
1820 }
1821
1822 pub async fn import_room_keys(
1835 &self,
1836 exported_keys: Vec<ExportedRoomKey>,
1837 from_backup_version: Option<&str>,
1838 progress_listener: impl Fn(usize, usize),
1839 ) -> Result<RoomKeyImportResult> {
1840 let exported_keys: Vec<&ExportedRoomKey> = exported_keys.iter().collect();
1841 self.import_sessions_impl(exported_keys, from_backup_version, progress_listener).await
1842 }
1843
1844 pub async fn import_exported_room_keys(
1871 &self,
1872 exported_keys: Vec<ExportedRoomKey>,
1873 progress_listener: impl Fn(usize, usize),
1874 ) -> Result<RoomKeyImportResult> {
1875 self.import_room_keys(exported_keys, None, progress_listener).await
1876 }
1877
1878 async fn import_sessions_impl<T>(
1879 &self,
1880 room_keys: Vec<T>,
1881 from_backup_version: Option<&str>,
1882 progress_listener: impl Fn(usize, usize),
1883 ) -> Result<RoomKeyImportResult>
1884 where
1885 T: TryInto<InboundGroupSession> + RoomKeyExport + Copy,
1886 T::Error: Debug,
1887 {
1888 let mut sessions = Vec::new();
1889
1890 async fn new_session_better(
1891 session: &InboundGroupSession,
1892 old_session: Option<InboundGroupSession>,
1893 ) -> bool {
1894 if let Some(old_session) = &old_session {
1895 session.compare(old_session).await == SessionOrdering::Better
1896 } else {
1897 true
1898 }
1899 }
1900
1901 let total_count = room_keys.len();
1902 let mut keys = BTreeMap::new();
1903
1904 for (i, key) in room_keys.into_iter().enumerate() {
1905 match key.try_into() {
1906 Ok(session) => {
1907 let old_session = self
1908 .inner
1909 .store
1910 .get_inbound_group_session(session.room_id(), session.session_id())
1911 .await?;
1912
1913 if new_session_better(&session, old_session).await {
1916 if from_backup_version.is_some() {
1917 session.mark_as_backed_up();
1918 }
1919
1920 keys.entry(session.room_id().to_owned())
1921 .or_insert_with(BTreeMap::new)
1922 .entry(session.sender_key().to_base64())
1923 .or_insert_with(BTreeSet::new)
1924 .insert(session.session_id().to_owned());
1925
1926 sessions.push(session);
1927 }
1928 }
1929 Err(e) => {
1930 warn!(
1931 sender_key = key.sender_key().to_base64(),
1932 room_id = ?key.room_id(),
1933 session_id = key.session_id(),
1934 error = ?e,
1935 "Couldn't import a room key from a file export."
1936 );
1937 }
1938 }
1939
1940 progress_listener(i, total_count);
1941 }
1942
1943 let imported_count = sessions.len();
1944
1945 self.inner.store.save_inbound_group_sessions(sessions, from_backup_version).await?;
1946
1947 info!(total_count, imported_count, room_keys = ?keys, "Successfully imported room keys");
1948
1949 Ok(RoomKeyImportResult::new(imported_count, total_count, keys))
1950 }
1951
1952 pub(crate) fn crypto_store(&self) -> Arc<CryptoStoreWrapper> {
1953 self.inner.store.clone()
1954 }
1955
1956 pub async fn export_room_keys(
1979 &self,
1980 predicate: impl FnMut(&InboundGroupSession) -> bool,
1981 ) -> Result<Vec<ExportedRoomKey>> {
1982 let mut exported = Vec::new();
1983
1984 let mut sessions = self.get_inbound_group_sessions().await?;
1985 sessions.retain(predicate);
1986
1987 for session in sessions {
1988 let export = session.export().await;
1989 exported.push(export);
1990 }
1991
1992 Ok(exported)
1993 }
1994
1995 pub async fn export_room_keys_stream(
2028 &self,
2029 predicate: impl FnMut(&InboundGroupSession) -> bool,
2030 ) -> Result<impl Stream<Item = ExportedRoomKey>> {
2031 let sessions = self.get_inbound_group_sessions().await?;
2033 Ok(futures_util::stream::iter(sessions.into_iter().filter(predicate))
2034 .then(|session| async move { session.export().await }))
2035 }
2036
2037 pub async fn build_room_key_bundle(
2042 &self,
2043 room_id: &RoomId,
2044 ) -> std::result::Result<RoomKeyBundle, CryptoStoreError> {
2045 let mut sessions = self.get_inbound_group_sessions().await?;
2048 sessions.retain(|session| session.room_id == room_id);
2049
2050 let mut bundle = RoomKeyBundle::default();
2051 for session in sessions {
2052 if session.shared_history() {
2053 bundle.room_keys.push(session.export().await.into());
2054 } else {
2055 bundle.withheld.push(RoomKeyWithheldContent::new(
2056 session.algorithm().to_owned(),
2057 WithheldCode::Unauthorised,
2058 session.room_id().to_owned(),
2059 session.session_id().to_owned(),
2060 session.sender_key().to_owned(),
2061 self.device_id().to_owned(),
2062 ));
2063 }
2064 }
2065
2066 Ok(bundle)
2067 }
2068
2069 #[instrument(skip(self, bundle, progress_listener), fields(bundle_size = bundle.room_keys.len()))]
2082 pub async fn receive_room_key_bundle(
2083 &self,
2084 room_id: &RoomId,
2085 sender_user: &UserId,
2086 sender_data: &SenderData,
2087 bundle: RoomKeyBundle,
2088 progress_listener: impl Fn(usize, usize),
2089 ) -> Result<(), CryptoStoreError> {
2090 let (good, bad): (Vec<_>, Vec<_>) = bundle.room_keys.iter().partition_map(|key| {
2091 if key.room_id != room_id {
2092 trace!("Ignoring key for incorrect room {} in bundle", key.room_id);
2093 Either::Right(key)
2094 } else {
2095 Either::Left(key)
2096 }
2097 });
2098
2099 match (bad.is_empty(), good.is_empty()) {
2100 (true, true) => {
2102 warn!("Received a completely empty room key bundle");
2103 }
2104
2105 (false, true) => {
2107 let bad_keys: Vec<_> =
2108 bad.iter().map(|&key| (&key.room_id, &key.session_id)).collect();
2109
2110 warn!(
2111 ?bad_keys,
2112 "Received a room key bundle for the wrong room, ignoring all room keys from the bundle"
2113 );
2114 }
2115
2116 (_, false) => {
2118 if !bad.is_empty() {
2121 warn!(
2122 bad_key_count = bad.len(),
2123 "The room key bundle contained some room keys \
2124 that were meant for a different room"
2125 );
2126 }
2127
2128 self.import_sessions_impl(good, None, progress_listener).await?;
2129 }
2130 }
2131
2132 Ok(())
2133 }
2134}
2135
2136impl Deref for Store {
2137 type Target = DynCryptoStore;
2138
2139 fn deref(&self) -> &Self::Target {
2140 self.inner.store.deref().deref()
2141 }
2142}
2143
2144#[derive(Clone, Debug)]
2146pub struct LockableCryptoStore(Arc<dyn CryptoStore<Error = CryptoStoreError>>);
2147
2148impl matrix_sdk_common::store_locks::BackingStore for LockableCryptoStore {
2149 type LockError = CryptoStoreError;
2150
2151 async fn try_lock(
2152 &self,
2153 lease_duration_ms: u32,
2154 key: &str,
2155 holder: &str,
2156 ) -> std::result::Result<bool, Self::LockError> {
2157 self.0.try_take_leased_lock(lease_duration_ms, key, holder).await
2158 }
2159}
2160
2161#[cfg(test)]
2162mod tests {
2163 use std::pin::pin;
2164
2165 use futures_util::StreamExt;
2166 use insta::{_macro_support::Content, assert_json_snapshot, internals::ContentPath};
2167 use matrix_sdk_test::async_test;
2168 use ruma::{device_id, room_id, user_id, RoomId};
2169 use vodozemac::megolm::SessionKey;
2170
2171 use crate::{
2172 machine::test_helpers::get_machine_pair,
2173 olm::{InboundGroupSession, SenderData},
2174 store::DehydratedDeviceKey,
2175 types::EventEncryptionAlgorithm,
2176 OlmMachine,
2177 };
2178
2179 #[async_test]
2180 async fn test_import_room_keys_notifies_stream() {
2181 use futures_util::FutureExt;
2182
2183 let (alice, bob, _) =
2184 get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2185
2186 let room1_id = room_id!("!room1:localhost");
2187 alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2188 let exported_sessions = alice.store().export_room_keys(|_| true).await.unwrap();
2189
2190 let mut room_keys_received_stream = Box::pin(bob.store().room_keys_received_stream());
2191 bob.store().import_room_keys(exported_sessions, None, |_, _| {}).await.unwrap();
2192
2193 let room_keys = room_keys_received_stream
2194 .next()
2195 .now_or_never()
2196 .flatten()
2197 .expect("We should have received an update of room key infos")
2198 .unwrap();
2199 assert_eq!(room_keys.len(), 1);
2200 assert_eq!(room_keys[0].room_id, "!room1:localhost");
2201 }
2202
2203 #[async_test]
2204 async fn test_export_room_keys_provides_selected_keys() {
2205 let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2207 let room1_id = room_id!("!room1:localhost");
2208 let room2_id = room_id!("!room2:localhost");
2209 let room3_id = room_id!("!room3:localhost");
2210 alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2211 alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap();
2212 alice.create_outbound_group_session_with_defaults_test_helper(room3_id).await.unwrap();
2213
2214 let keys = alice
2216 .store()
2217 .export_room_keys(|s| s.room_id() == room2_id || s.room_id() == room3_id)
2218 .await
2219 .unwrap();
2220
2221 assert_eq!(keys.len(), 2);
2223 assert_eq!(keys[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2224 assert_eq!(keys[1].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2225 assert_eq!(keys[0].room_id, "!room2:localhost");
2226 assert_eq!(keys[1].room_id, "!room3:localhost");
2227 assert_eq!(keys[0].session_key.to_base64().len(), 220);
2228 assert_eq!(keys[1].session_key.to_base64().len(), 220);
2229 }
2230
2231 #[async_test]
2232 async fn test_export_room_keys_stream_can_provide_all_keys() {
2233 let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2235 let room1_id = room_id!("!room1:localhost");
2236 let room2_id = room_id!("!room2:localhost");
2237 alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2238 alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap();
2239
2240 let mut keys = pin!(alice.store().export_room_keys_stream(|_| true).await.unwrap());
2242
2243 let mut collected = vec![];
2245 while let Some(key) = keys.next().await {
2246 collected.push(key);
2247 }
2248
2249 assert_eq!(collected.len(), 2);
2251 assert_eq!(collected[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2252 assert_eq!(collected[1].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2253 assert_eq!(collected[0].room_id, "!room1:localhost");
2254 assert_eq!(collected[1].room_id, "!room2:localhost");
2255 assert_eq!(collected[0].session_key.to_base64().len(), 220);
2256 assert_eq!(collected[1].session_key.to_base64().len(), 220);
2257 }
2258
2259 #[async_test]
2260 async fn test_export_room_keys_stream_can_provide_a_subset_of_keys() {
2261 let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2263 let room1_id = room_id!("!room1:localhost");
2264 let room2_id = room_id!("!room2:localhost");
2265 alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2266 alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap();
2267
2268 let mut keys =
2270 pin!(alice.store().export_room_keys_stream(|s| s.room_id() == room1_id).await.unwrap());
2271
2272 let mut collected = vec![];
2274 while let Some(key) = keys.next().await {
2275 collected.push(key);
2276 }
2277
2278 assert_eq!(collected.len(), 1);
2280 assert_eq!(collected[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2281 assert_eq!(collected[0].room_id, "!room1:localhost");
2282 assert_eq!(collected[0].session_key.to_base64().len(), 220);
2283 }
2284
2285 #[async_test]
2286 async fn test_export_secrets_bundle() {
2287 let user_id = user_id!("@alice:example.com");
2288 let (first, second, _) = get_machine_pair(user_id, user_id, false).await;
2289
2290 let _ = first
2291 .bootstrap_cross_signing(false)
2292 .await
2293 .expect("We should be able to bootstrap cross-signing");
2294
2295 let bundle = first.store().export_secrets_bundle().await.expect(
2296 "We should be able to export the secrets bundle, now that we \
2297 have the cross-signing keys",
2298 );
2299
2300 assert!(bundle.backup.is_none(), "The bundle should not contain a backup key");
2301
2302 second
2303 .store()
2304 .import_secrets_bundle(&bundle)
2305 .await
2306 .expect("We should be able to import the secrets bundle");
2307
2308 let status = second.cross_signing_status().await;
2309 let identity = second.get_identity(user_id, None).await.unwrap().unwrap().own().unwrap();
2310
2311 assert!(identity.is_verified(), "The public identity should be marked as verified.");
2312
2313 assert!(status.is_complete(), "We should have imported all the cross-signing keys");
2314 }
2315
2316 #[async_test]
2317 async fn test_create_dehydrated_device_key() {
2318 let pickle_key = DehydratedDeviceKey::new()
2319 .expect("Should be able to create a random dehydrated device key");
2320
2321 let to_vec = pickle_key.inner.to_vec();
2322 let pickle_key_from_slice = DehydratedDeviceKey::from_slice(to_vec.as_slice())
2323 .expect("Should be able to create a dehydrated device key from slice");
2324
2325 assert_eq!(pickle_key_from_slice.to_base64(), pickle_key.to_base64());
2326 }
2327
2328 #[async_test]
2329 async fn test_create_dehydrated_errors() {
2330 let too_small = [0u8; 22];
2331 let pickle_key = DehydratedDeviceKey::from_slice(&too_small);
2332
2333 assert!(pickle_key.is_err());
2334
2335 let too_big = [0u8; 40];
2336 let pickle_key = DehydratedDeviceKey::from_slice(&too_big);
2337
2338 assert!(pickle_key.is_err());
2339 }
2340
2341 #[async_test]
2342 async fn test_build_room_key_bundle() {
2343 let alice = OlmMachine::new(user_id!("@a:s.co"), device_id!("ALICE")).await;
2346 let bob = OlmMachine::new(user_id!("@b:s.co"), device_id!("BOB")).await;
2347
2348 let room1_id = room_id!("!room1:localhost");
2349 let room2_id = room_id!("!room2:localhost");
2350
2351 let session_key1 = "AgAAAAC2XHVzsMBKs4QCRElJ92CJKyGtknCSC8HY7cQ7UYwndMKLQAejXLh5UA0l6s736mgctcUMNvELScUWrObdflrHo+vth/gWreXOaCnaSxmyjjKErQwyIYTkUfqbHy40RJfEesLwnN23on9XAkch/iy8R2+Jz7B8zfG01f2Ow2SxPQFnAndcO1ZSD2GmXgedy6n4B20MWI1jGP2wiexOWbFSya8DO/VxC9m5+/mF+WwYqdpKn9g4Y05Yw4uz7cdjTc3rXm7xK+8E7hI//5QD1nHPvuKYbjjM9u2JSL+Bzp61Cw";
2356 let session_key2 = "AgAAAAC1BXreFTUQQSBGekTEuYxhdytRKyv4JgDGcG+VOBYdPNGgs807SdibCGJky4lJ3I+7ZDGHoUzZPZP/4ogGu4kxni0PWdtWuN7+5zsuamgoFF/BkaGeUUGv6kgIkx8pyPpM5SASTUEP9bN2loDSpUPYwfiIqz74DgC4WQ4435sTBctYvKz8n+TDJwdLXpyT6zKljuqADAioud+s/iqx9LYn9HpbBfezZcvbg67GtE113pLrvde3IcPI5s6dNHK2onGO2B2eoaobcen18bbEDnlUGPeIivArLya7Da6us14jBQ";
2357 let session_key3 = "AgAAAAAM9KFsliaUUhGSXgwOzM5UemjkNH4n8NHgvC/y8hhw13zTF+ooGD4uIYEXYX630oNvQm/EvgZo+dkoc0re+vsqsx4sQeNODdSjcBsWOa0oDF+irQn9oYoLUDPI1IBtY1rX+FV99Zm/xnG7uFOX7aTVlko2GSdejy1w9mfobmfxu5aUc04A9zaKJP1pOthZvRAlhpymGYHgsDtWPrrjyc/yypMflE4kIUEEEtu1kT6mrAmcl615XYRAHYK9G2+fZsGvokwzbkl4nulGwcZMpQEoM0nD2o3GWgX81HW3nGfKBg";
2358 let session_key4 = "AgAAAAA4Kkesxq2h4v9PLD6Sm3Smxspz1PXTqytQPCMQMkkrHNmzV2bHlJ+6/Al9cu8vh1Oj69AK0WUAeJOJuaiskEeg/PI3P03+UYLeC379RzgqwSHdBgdQ41G2vD6zpgmE/8vYToe+qpCZACtPOswZxyqxHH+T/Iq0nv13JmlFGIeA6fEPfr5Y28B49viG74Fs9rxV9EH5PfjbuPM/p+Sz5obShuaBPKQBX1jT913nEXPoIJ06exNZGr0285nw/LgVvNlmWmbqNnbzO2cNZjQWA+xZYz5FSfyCxwqEBbEdUCuRCQ";
2359
2360 let sessions = [
2361 create_inbound_group_session_with_visibility(
2362 &alice,
2363 room1_id,
2364 &SessionKey::from_base64(session_key1).unwrap(),
2365 true,
2366 ),
2367 create_inbound_group_session_with_visibility(
2368 &alice,
2369 room1_id,
2370 &SessionKey::from_base64(session_key2).unwrap(),
2371 true,
2372 ),
2373 create_inbound_group_session_with_visibility(
2374 &alice,
2375 room1_id,
2376 &SessionKey::from_base64(session_key3).unwrap(),
2377 false,
2378 ),
2379 create_inbound_group_session_with_visibility(
2380 &alice,
2381 room2_id,
2382 &SessionKey::from_base64(session_key4).unwrap(),
2383 true,
2384 ),
2385 ];
2386 bob.store().save_inbound_group_sessions(&sessions).await.unwrap();
2387
2388 let mut bundle = bob.store().build_room_key_bundle(room1_id).await.unwrap();
2390
2391 bundle.room_keys.sort_by_key(|session| session.session_id.clone());
2395
2396 let alice_curve_key = alice.identity_keys().curve25519.to_base64();
2398 let map_alice_curve_key = move |value: Content, _path: ContentPath<'_>| {
2399 assert_eq!(value.as_str().unwrap(), alice_curve_key);
2400 "[alice curve key]"
2401 };
2402 let alice_ed25519_key = alice.identity_keys().ed25519.to_base64();
2403 let map_alice_ed25519_key = move |value: Content, _path: ContentPath<'_>| {
2404 assert_eq!(value.as_str().unwrap(), alice_ed25519_key);
2405 "[alice ed25519 key]"
2406 };
2407
2408 insta::with_settings!({ sort_maps => true }, {
2409 assert_json_snapshot!(bundle, {
2410 ".room_keys[].sender_key" => insta::dynamic_redaction(map_alice_curve_key.clone()),
2411 ".withheld[].sender_key" => insta::dynamic_redaction(map_alice_curve_key),
2412 ".room_keys[].sender_claimed_keys.ed25519" => insta::dynamic_redaction(map_alice_ed25519_key),
2413 });
2414 });
2415 }
2416
2417 fn create_inbound_group_session_with_visibility(
2422 olm_machine: &OlmMachine,
2423 room_id: &RoomId,
2424 session_key: &SessionKey,
2425 shared_history: bool,
2426 ) -> InboundGroupSession {
2427 let identity_keys = &olm_machine.store().static_account().identity_keys;
2428 InboundGroupSession::new(
2429 identity_keys.curve25519,
2430 identity_keys.ed25519,
2431 room_id,
2432 session_key,
2433 SenderData::unknown(),
2434 EventEncryptionAlgorithm::MegolmV1AesSha2,
2435 None,
2436 shared_history,
2437 )
2438 .unwrap()
2439 }
2440}