1use std::{
42    collections::{BTreeMap, BTreeSet, HashMap, HashSet},
43    fmt::Debug,
44    ops::Deref,
45    pin::pin,
46    sync::{atomic::Ordering, Arc},
47    time::Duration,
48};
49
50use as_variant::as_variant;
51use futures_core::Stream;
52use futures_util::StreamExt;
53use matrix_sdk_common::locks::RwLock as StdRwLock;
54use ruma::{
55    encryption::KeyUsage, events::secret::request::SecretName, DeviceId, OwnedDeviceId,
56    OwnedRoomId, OwnedUserId, RoomId, UserId,
57};
58use serde::{de::DeserializeOwned, Deserialize, Serialize};
59use thiserror::Error;
60use tokio::sync::{Mutex, MutexGuard, Notify, OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
61use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
62use tracing::{info, warn};
63use vodozemac::{base64_encode, megolm::SessionOrdering, Curve25519PublicKey};
64use zeroize::{Zeroize, ZeroizeOnDrop};
65
66#[cfg(doc)]
67use crate::{backups::BackupMachine, identities::OwnUserIdentity};
68use crate::{
69    gossiping::GossippedSecret,
70    identities::{user::UserIdentity, Device, DeviceData, UserDevices, UserIdentityData},
71    olm::{
72        Account, ExportedRoomKey, InboundGroupSession, OlmMessageHash, OutboundGroupSession,
73        PrivateCrossSigningIdentity, Session, StaticAccountData,
74    },
75    types::{
76        events::room_key_withheld::RoomKeyWithheldEvent, BackupSecrets, CrossSigningSecrets,
77        EventEncryptionAlgorithm, MegolmBackupV1Curve25519AesSha2Secrets, SecretsBundle,
78    },
79    verification::VerificationMachine,
80    CrossSigningStatus, OwnUserIdentityData, RoomKeyImportResult,
81};
82
83pub mod caches;
84mod crypto_store_wrapper;
85mod error;
86mod memorystore;
87mod traits;
88
89#[cfg(any(test, feature = "testing"))]
90#[macro_use]
91#[allow(missing_docs)]
92pub mod integration_tests;
93
94use caches::{SequenceNumber, UsersForKeyQuery};
95pub(crate) use crypto_store_wrapper::CryptoStoreWrapper;
96pub use error::{CryptoStoreError, Result};
97use matrix_sdk_common::{
98    deserialized_responses::WithheldCode, store_locks::CrossProcessStoreLock, timeout::timeout,
99};
100pub use memorystore::MemoryStore;
101pub use traits::{CryptoStore, DynCryptoStore, IntoCryptoStore};
102
103use crate::types::{
104    events::room_key_withheld::RoomKeyWithheldContent, room_history::RoomKeyBundle,
105};
106pub use crate::{
107    dehydrated_devices::DehydrationError,
108    gossiping::{GossipRequest, SecretInfo},
109};
110
111#[derive(Debug, Clone)]
118pub struct Store {
119    inner: Arc<StoreInner>,
120}
121
122#[derive(Debug, Default)]
123pub(crate) struct KeyQueryManager {
124    users_for_key_query: Mutex<UsersForKeyQuery>,
126
127    users_for_key_query_notify: Notify,
129}
130
131impl KeyQueryManager {
132    pub async fn synced<'a>(&'a self, cache: &'a StoreCache) -> Result<SyncedKeyQueryManager<'a>> {
133        self.ensure_sync_tracked_users(cache).await?;
134        Ok(SyncedKeyQueryManager { cache, manager: self })
135    }
136
137    async fn ensure_sync_tracked_users(&self, cache: &StoreCache) -> Result<()> {
144        let loaded = cache.loaded_tracked_users.read().await;
146        if *loaded {
147            return Ok(());
148        }
149
150        drop(loaded);
152        let mut loaded = cache.loaded_tracked_users.write().await;
153
154        if *loaded {
158            return Ok(());
159        }
160
161        let tracked_users = cache.store.load_tracked_users().await?;
162
163        let mut query_users_lock = self.users_for_key_query.lock().await;
164        let mut tracked_users_cache = cache.tracked_users.write();
165        for user in tracked_users {
166            tracked_users_cache.insert(user.user_id.to_owned());
167
168            if user.dirty {
169                query_users_lock.insert_user(&user.user_id);
170            }
171        }
172
173        *loaded = true;
174
175        Ok(())
176    }
177
178    pub async fn wait_if_user_key_query_pending(
188        &self,
189        cache: StoreCacheGuard,
190        timeout_duration: Duration,
191        user: &UserId,
192    ) -> Result<UserKeyQueryResult> {
193        {
194            self.ensure_sync_tracked_users(&cache).await?;
197            drop(cache);
198        }
199
200        let mut users_for_key_query = self.users_for_key_query.lock().await;
201        let Some(waiter) = users_for_key_query.maybe_register_waiting_task(user) else {
202            return Ok(UserKeyQueryResult::WasNotPending);
203        };
204
205        let wait_for_completion = async {
206            while !waiter.completed.load(Ordering::Relaxed) {
207                let mut notified = pin!(self.users_for_key_query_notify.notified());
211                notified.as_mut().enable();
212                drop(users_for_key_query);
213
214                notified.await;
216
217                users_for_key_query = self.users_for_key_query.lock().await;
221            }
222        };
223
224        match timeout(Box::pin(wait_for_completion), timeout_duration).await {
225            Err(_) => {
226                warn!(
227                    user_id = ?user,
228                    "The user has a pending `/keys/query` request which did \
229                    not finish yet, some devices might be missing."
230                );
231
232                Ok(UserKeyQueryResult::TimeoutExpired)
233            }
234            _ => Ok(UserKeyQueryResult::WasPending),
235        }
236    }
237}
238
239pub(crate) struct SyncedKeyQueryManager<'a> {
240    cache: &'a StoreCache,
241    manager: &'a KeyQueryManager,
242}
243
244impl SyncedKeyQueryManager<'_> {
245    pub async fn update_tracked_users(&self, users: impl Iterator<Item = &UserId>) -> Result<()> {
250        let mut store_updates = Vec::new();
251        let mut key_query_lock = self.manager.users_for_key_query.lock().await;
252
253        {
254            let mut tracked_users = self.cache.tracked_users.write();
255            for user_id in users {
256                if tracked_users.insert(user_id.to_owned()) {
257                    key_query_lock.insert_user(user_id);
258                    store_updates.push((user_id, true))
259                }
260            }
261        }
262
263        self.cache.store.save_tracked_users(&store_updates).await
264    }
265
266    pub async fn mark_tracked_users_as_changed(
273        &self,
274        users: impl Iterator<Item = &UserId>,
275    ) -> Result<()> {
276        let mut store_updates: Vec<(&UserId, bool)> = Vec::new();
277        let mut key_query_lock = self.manager.users_for_key_query.lock().await;
278
279        {
280            let tracked_users = &self.cache.tracked_users.read();
281            for user_id in users {
282                if tracked_users.contains(user_id) {
283                    key_query_lock.insert_user(user_id);
284                    store_updates.push((user_id, true));
285                }
286            }
287        }
288
289        self.cache.store.save_tracked_users(&store_updates).await
290    }
291
292    pub async fn mark_tracked_users_as_up_to_date(
298        &self,
299        users: impl Iterator<Item = &UserId>,
300        sequence_number: SequenceNumber,
301    ) -> Result<()> {
302        let mut store_updates: Vec<(&UserId, bool)> = Vec::new();
303        let mut key_query_lock = self.manager.users_for_key_query.lock().await;
304
305        {
306            let tracked_users = self.cache.tracked_users.read();
307            for user_id in users {
308                if tracked_users.contains(user_id) {
309                    let clean = key_query_lock.maybe_remove_user(user_id, sequence_number);
310                    store_updates.push((user_id, !clean));
311                }
312            }
313        }
314
315        self.cache.store.save_tracked_users(&store_updates).await?;
316        self.manager.users_for_key_query_notify.notify_waiters();
318
319        Ok(())
320    }
321
322    pub async fn users_for_key_query(&self) -> (HashSet<OwnedUserId>, SequenceNumber) {
334        self.manager.users_for_key_query.lock().await.users_for_key_query()
335    }
336
337    pub fn tracked_users(&self) -> HashSet<OwnedUserId> {
339        self.cache.tracked_users.read().iter().cloned().collect()
340    }
341
342    pub async fn mark_user_as_changed(&self, user: &UserId) -> Result<()> {
348        self.manager.users_for_key_query.lock().await.insert_user(user);
349        self.cache.tracked_users.write().insert(user.to_owned());
350
351        self.cache.store.save_tracked_users(&[(user, true)]).await
352    }
353}
354
355#[derive(Debug)]
356pub(crate) struct StoreCache {
357    store: Arc<CryptoStoreWrapper>,
358    tracked_users: StdRwLock<BTreeSet<OwnedUserId>>,
359    loaded_tracked_users: RwLock<bool>,
360    account: Mutex<Option<Account>>,
361}
362
363impl StoreCache {
364    pub(crate) fn store_wrapper(&self) -> &CryptoStoreWrapper {
365        self.store.as_ref()
366    }
367
368    async fn account(&self) -> Result<impl Deref<Target = Account> + '_> {
380        let mut guard = self.account.lock().await;
381        if guard.is_some() {
382            Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
383        } else {
384            match self.store.load_account().await? {
385                Some(account) => {
386                    *guard = Some(account);
387                    Ok(MutexGuard::map(guard, |acc| acc.as_mut().unwrap()))
388                }
389                None => Err(CryptoStoreError::AccountUnset),
390            }
391        }
392    }
393}
394
395pub(crate) struct StoreCacheGuard {
401    cache: OwnedRwLockReadGuard<StoreCache>,
402    }
404
405impl StoreCacheGuard {
406    pub async fn account(&self) -> Result<impl Deref<Target = Account> + '_> {
414        self.cache.account().await
415    }
416}
417
418impl Deref for StoreCacheGuard {
419    type Target = StoreCache;
420
421    fn deref(&self) -> &Self::Target {
422        &self.cache
423    }
424}
425
426#[allow(missing_debug_implementations)]
428pub struct StoreTransaction {
429    store: Store,
430    changes: PendingChanges,
431    cache: OwnedRwLockWriteGuard<StoreCache>,
433}
434
435impl StoreTransaction {
436    async fn new(store: Store) -> Self {
438        let cache = store.inner.cache.clone();
439
440        Self { store, changes: PendingChanges::default(), cache: cache.clone().write_owned().await }
441    }
442
443    pub(crate) fn cache(&self) -> &StoreCache {
444        &self.cache
445    }
446
447    pub fn store(&self) -> &Store {
449        &self.store
450    }
451
452    pub async fn account(&mut self) -> Result<&mut Account> {
459        if self.changes.account.is_none() {
460            let _ = self.cache.account().await?;
462            self.changes.account = self.cache.account.lock().await.take();
463        }
464        Ok(self.changes.account.as_mut().unwrap())
465    }
466
467    pub async fn commit(self) -> Result<()> {
470        if self.changes.is_empty() {
471            return Ok(());
472        }
473
474        let account = self.changes.account.as_ref().map(|acc| acc.deep_clone());
476
477        self.store.save_pending_changes(self.changes).await?;
478
479        if let Some(account) = account {
481            *self.cache.account.lock().await = Some(account);
482        }
483
484        Ok(())
485    }
486}
487
488#[derive(Debug)]
489struct StoreInner {
490    identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
491    store: Arc<CryptoStoreWrapper>,
492
493    cache: Arc<RwLock<StoreCache>>,
497
498    verification_machine: VerificationMachine,
499
500    static_account: StaticAccountData,
503}
504
505#[derive(Default, Debug)]
511#[allow(missing_docs)]
512pub struct PendingChanges {
513    pub account: Option<Account>,
514}
515
516impl PendingChanges {
517    pub fn is_empty(&self) -> bool {
519        self.account.is_none()
520    }
521}
522
523#[derive(Default, Debug)]
526#[allow(missing_docs)]
527pub struct Changes {
528    pub private_identity: Option<PrivateCrossSigningIdentity>,
529    pub backup_version: Option<String>,
530    pub backup_decryption_key: Option<BackupDecryptionKey>,
531    pub dehydrated_device_pickle_key: Option<DehydratedDeviceKey>,
532    pub sessions: Vec<Session>,
533    pub message_hashes: Vec<OlmMessageHash>,
534    pub inbound_group_sessions: Vec<InboundGroupSession>,
535    pub outbound_group_sessions: Vec<OutboundGroupSession>,
536    pub key_requests: Vec<GossipRequest>,
537    pub identities: IdentityChanges,
538    pub devices: DeviceChanges,
539    pub withheld_session_info: BTreeMap<OwnedRoomId, BTreeMap<String, RoomKeyWithheldEvent>>,
541    pub room_settings: HashMap<OwnedRoomId, RoomSettings>,
542    pub secrets: Vec<GossippedSecret>,
543    pub next_batch_token: Option<String>,
544}
545
546#[derive(Clone, Debug, Serialize, Deserialize)]
548pub struct TrackedUser {
549    pub user_id: OwnedUserId,
551    pub dirty: bool,
556}
557
558impl Changes {
559    pub fn is_empty(&self) -> bool {
561        self.private_identity.is_none()
562            && self.backup_version.is_none()
563            && self.backup_decryption_key.is_none()
564            && self.dehydrated_device_pickle_key.is_none()
565            && self.sessions.is_empty()
566            && self.message_hashes.is_empty()
567            && self.inbound_group_sessions.is_empty()
568            && self.outbound_group_sessions.is_empty()
569            && self.key_requests.is_empty()
570            && self.identities.is_empty()
571            && self.devices.is_empty()
572            && self.withheld_session_info.is_empty()
573            && self.room_settings.is_empty()
574            && self.secrets.is_empty()
575            && self.next_batch_token.is_none()
576    }
577}
578
579#[derive(Debug, Clone, Default)]
590#[allow(missing_docs)]
591pub struct IdentityChanges {
592    pub new: Vec<UserIdentityData>,
593    pub changed: Vec<UserIdentityData>,
594    pub unchanged: Vec<UserIdentityData>,
595}
596
597impl IdentityChanges {
598    fn is_empty(&self) -> bool {
599        self.new.is_empty() && self.changed.is_empty()
600    }
601
602    fn into_maps(
605        self,
606    ) -> (
607        BTreeMap<OwnedUserId, UserIdentityData>,
608        BTreeMap<OwnedUserId, UserIdentityData>,
609        BTreeMap<OwnedUserId, UserIdentityData>,
610    ) {
611        let new: BTreeMap<_, _> = self
612            .new
613            .into_iter()
614            .map(|identity| (identity.user_id().to_owned(), identity))
615            .collect();
616
617        let changed: BTreeMap<_, _> = self
618            .changed
619            .into_iter()
620            .map(|identity| (identity.user_id().to_owned(), identity))
621            .collect();
622
623        let unchanged: BTreeMap<_, _> = self
624            .unchanged
625            .into_iter()
626            .map(|identity| (identity.user_id().to_owned(), identity))
627            .collect();
628
629        (new, changed, unchanged)
630    }
631}
632
633#[derive(Debug, Clone, Default)]
634#[allow(missing_docs)]
635pub struct DeviceChanges {
636    pub new: Vec<DeviceData>,
637    pub changed: Vec<DeviceData>,
638    pub deleted: Vec<DeviceData>,
639}
640
641fn collect_device_updates(
647    verification_machine: VerificationMachine,
648    own_identity: Option<OwnUserIdentityData>,
649    identities: IdentityChanges,
650    devices: DeviceChanges,
651) -> DeviceUpdates {
652    let mut new: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
653    let mut changed: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new();
654
655    let (new_identities, changed_identities, unchanged_identities) = identities.into_maps();
656
657    let map_device = |device: DeviceData| {
658        let device_owner_identity = new_identities
659            .get(device.user_id())
660            .or_else(|| changed_identities.get(device.user_id()))
661            .or_else(|| unchanged_identities.get(device.user_id()))
662            .cloned();
663
664        Device {
665            inner: device,
666            verification_machine: verification_machine.to_owned(),
667            own_identity: own_identity.to_owned(),
668            device_owner_identity,
669        }
670    };
671
672    for device in devices.new {
673        let device = map_device(device);
674
675        new.entry(device.user_id().to_owned())
676            .or_default()
677            .insert(device.device_id().to_owned(), device);
678    }
679
680    for device in devices.changed {
681        let device = map_device(device);
682
683        changed
684            .entry(device.user_id().to_owned())
685            .or_default()
686            .insert(device.device_id().to_owned(), device.to_owned());
687    }
688
689    DeviceUpdates { new, changed }
690}
691
692#[derive(Clone, Debug, Default)]
695pub struct DeviceUpdates {
696    pub new: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Device>>,
702    pub changed: BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, Device>>,
704}
705
706#[derive(Clone, Debug, Default)]
709pub struct IdentityUpdates {
710    pub new: BTreeMap<OwnedUserId, UserIdentity>,
716    pub changed: BTreeMap<OwnedUserId, UserIdentity>,
718    pub unchanged: BTreeMap<OwnedUserId, UserIdentity>,
720}
721
722#[derive(Clone, Zeroize, ZeroizeOnDrop, Deserialize, Serialize)]
732#[serde(transparent)]
733pub struct BackupDecryptionKey {
734    pub(crate) inner: Box<[u8; BackupDecryptionKey::KEY_SIZE]>,
735}
736
737impl BackupDecryptionKey {
738    pub const KEY_SIZE: usize = 32;
740
741    pub fn new() -> Result<Self, rand::Error> {
743        let mut rng = rand::thread_rng();
744
745        let mut key = Box::new([0u8; Self::KEY_SIZE]);
746        rand::Fill::try_fill(key.as_mut_slice(), &mut rng)?;
747
748        Ok(Self { inner: key })
749    }
750
751    pub fn to_base64(&self) -> String {
753        base64_encode(self.inner.as_slice())
754    }
755}
756
757#[cfg(not(tarpaulin_include))]
758impl Debug for BackupDecryptionKey {
759    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
760        f.debug_tuple("BackupDecryptionKey").field(&"...").finish()
761    }
762}
763
764#[derive(Clone, Zeroize, ZeroizeOnDrop, Deserialize, Serialize)]
769#[serde(transparent)]
770pub struct DehydratedDeviceKey {
771    pub(crate) inner: Box<[u8; DehydratedDeviceKey::KEY_SIZE]>,
772}
773
774impl DehydratedDeviceKey {
775    pub const KEY_SIZE: usize = 32;
777
778    pub fn new() -> Result<Self, rand::Error> {
780        let mut rng = rand::thread_rng();
781
782        let mut key = Box::new([0u8; Self::KEY_SIZE]);
783        rand::Fill::try_fill(key.as_mut_slice(), &mut rng)?;
784
785        Ok(Self { inner: key })
786    }
787
788    pub fn from_slice(slice: &[u8]) -> Result<Self, DehydrationError> {
792        if slice.len() == 32 {
793            let mut key = Box::new([0u8; 32]);
794            key.copy_from_slice(slice);
795            Ok(DehydratedDeviceKey { inner: key })
796        } else {
797            Err(DehydrationError::PickleKeyLength(slice.len()))
798        }
799    }
800
801    pub fn from_bytes(raw_key: &[u8; 32]) -> Self {
803        let mut inner = Box::new([0u8; Self::KEY_SIZE]);
804        inner.copy_from_slice(raw_key);
805
806        Self { inner }
807    }
808
809    pub fn to_base64(&self) -> String {
811        base64_encode(self.inner.as_slice())
812    }
813}
814
815impl From<&[u8; 32]> for DehydratedDeviceKey {
816    fn from(value: &[u8; 32]) -> Self {
817        DehydratedDeviceKey { inner: Box::new(*value) }
818    }
819}
820
821impl From<DehydratedDeviceKey> for Vec<u8> {
822    fn from(key: DehydratedDeviceKey) -> Self {
823        key.inner.to_vec()
824    }
825}
826
827#[cfg(not(tarpaulin_include))]
828impl Debug for DehydratedDeviceKey {
829    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
830        f.debug_tuple("DehydratedDeviceKey").field(&"...").finish()
831    }
832}
833
834impl DeviceChanges {
835    pub fn extend(&mut self, other: DeviceChanges) {
837        self.new.extend(other.new);
838        self.changed.extend(other.changed);
839        self.deleted.extend(other.deleted);
840    }
841
842    fn is_empty(&self) -> bool {
843        self.new.is_empty() && self.changed.is_empty() && self.deleted.is_empty()
844    }
845}
846
847#[derive(Debug, Clone, Default)]
849pub struct RoomKeyCounts {
850    pub total: usize,
852    pub backed_up: usize,
854}
855
856#[derive(Default, Clone, Debug)]
858pub struct BackupKeys {
859    pub decryption_key: Option<BackupDecryptionKey>,
861    pub backup_version: Option<String>,
863}
864
865#[derive(Default, Zeroize, ZeroizeOnDrop)]
868pub struct CrossSigningKeyExport {
869    pub master_key: Option<String>,
871    pub self_signing_key: Option<String>,
873    pub user_signing_key: Option<String>,
875}
876
877#[cfg(not(tarpaulin_include))]
878impl Debug for CrossSigningKeyExport {
879    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
880        f.debug_struct("CrossSigningKeyExport")
881            .field("master_key", &self.master_key.is_some())
882            .field("self_signing_key", &self.self_signing_key.is_some())
883            .field("user_signing_key", &self.user_signing_key.is_some())
884            .finish_non_exhaustive()
885    }
886}
887
888#[derive(Debug, Error)]
891pub enum SecretImportError {
892    #[error(transparent)]
894    Key(#[from] vodozemac::KeyError),
895    #[error(
898        "The public key of the imported private key doesn't match to the \
899            public key that was uploaded to the server"
900    )]
901    MismatchedPublicKeys,
902    #[error(transparent)]
904    Store(#[from] CryptoStoreError),
905}
906
907#[derive(Debug, Error)]
912pub enum SecretsBundleExportError {
913    #[error(transparent)]
915    Store(#[from] CryptoStoreError),
916    #[error("The store is missing one or multiple cross-signing keys")]
918    MissingCrossSigningKey(KeyUsage),
919    #[error("The store doesn't contain any cross-signing keys")]
921    MissingCrossSigningKeys,
922    #[error("The store contains a backup key, but no backup version")]
925    MissingBackupVersion,
926}
927
928#[derive(Clone, Copy, Debug, PartialEq, Eq)]
931pub(crate) enum UserKeyQueryResult {
932    WasPending,
933    WasNotPending,
934
935    TimeoutExpired,
937}
938
939#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
941pub struct RoomSettings {
942    pub algorithm: EventEncryptionAlgorithm,
944
945    pub only_allow_trusted_devices: bool,
948
949    pub session_rotation_period: Option<Duration>,
952
953    pub session_rotation_period_messages: Option<usize>,
956}
957
958impl Default for RoomSettings {
959    fn default() -> Self {
960        Self {
961            algorithm: EventEncryptionAlgorithm::MegolmV1AesSha2,
962            only_allow_trusted_devices: false,
963            session_rotation_period: None,
964            session_rotation_period_messages: None,
965        }
966    }
967}
968
969#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
971pub struct RoomKeyInfo {
972    pub algorithm: EventEncryptionAlgorithm,
977
978    pub room_id: OwnedRoomId,
980
981    pub sender_key: Curve25519PublicKey,
983
984    pub session_id: String,
986}
987
988impl From<&InboundGroupSession> for RoomKeyInfo {
989    fn from(group_session: &InboundGroupSession) -> Self {
990        RoomKeyInfo {
991            algorithm: group_session.algorithm().clone(),
992            room_id: group_session.room_id().to_owned(),
993            sender_key: group_session.sender_key(),
994            session_id: group_session.session_id().to_owned(),
995        }
996    }
997}
998
999#[derive(Clone, Debug, Deserialize, Serialize)]
1001pub struct RoomKeyWithheldInfo {
1002    pub room_id: OwnedRoomId,
1004
1005    pub session_id: String,
1007
1008    pub withheld_event: RoomKeyWithheldEvent,
1011}
1012
1013impl Store {
1014    pub(crate) fn new(
1016        account: StaticAccountData,
1017        identity: Arc<Mutex<PrivateCrossSigningIdentity>>,
1018        store: Arc<CryptoStoreWrapper>,
1019        verification_machine: VerificationMachine,
1020    ) -> Self {
1021        Self {
1022            inner: Arc::new(StoreInner {
1023                static_account: account,
1024                identity,
1025                store: store.clone(),
1026                verification_machine,
1027                cache: Arc::new(RwLock::new(StoreCache {
1028                    store,
1029                    tracked_users: Default::default(),
1030                    loaded_tracked_users: Default::default(),
1031                    account: Default::default(),
1032                })),
1033            }),
1034        }
1035    }
1036
1037    pub(crate) fn user_id(&self) -> &UserId {
1039        &self.inner.static_account.user_id
1040    }
1041
1042    pub(crate) fn device_id(&self) -> &DeviceId {
1044        self.inner.verification_machine.own_device_id()
1045    }
1046
1047    pub(crate) fn static_account(&self) -> &StaticAccountData {
1049        &self.inner.static_account
1050    }
1051
1052    pub(crate) async fn cache(&self) -> Result<StoreCacheGuard> {
1053        Ok(StoreCacheGuard { cache: self.inner.cache.clone().read_owned().await })
1058    }
1059
1060    pub(crate) async fn transaction(&self) -> StoreTransaction {
1061        StoreTransaction::new(self.clone()).await
1062    }
1063
1064    pub(crate) async fn with_transaction<
1067        T,
1068        Fut: futures_core::Future<Output = Result<(StoreTransaction, T), crate::OlmError>>,
1069        F: FnOnce(StoreTransaction) -> Fut,
1070    >(
1071        &self,
1072        func: F,
1073    ) -> Result<T, crate::OlmError> {
1074        let tr = self.transaction().await;
1075        let (tr, res) = func(tr).await?;
1076        tr.commit().await?;
1077        Ok(res)
1078    }
1079
1080    #[cfg(test)]
1081    pub(crate) async fn reset_cross_signing_identity(&self) {
1083        self.inner.identity.lock().await.reset();
1084    }
1085
1086    pub(crate) fn private_identity(&self) -> Arc<Mutex<PrivateCrossSigningIdentity>> {
1088        self.inner.identity.clone()
1089    }
1090
1091    pub(crate) async fn save_sessions(&self, sessions: &[Session]) -> Result<()> {
1093        let changes = Changes { sessions: sessions.to_vec(), ..Default::default() };
1094
1095        self.save_changes(changes).await
1096    }
1097
1098    pub(crate) async fn get_sessions(
1099        &self,
1100        sender_key: &str,
1101    ) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
1102        self.inner.store.get_sessions(sender_key).await
1103    }
1104
1105    pub(crate) async fn save_changes(&self, changes: Changes) -> Result<()> {
1106        self.inner.store.save_changes(changes).await
1107    }
1108
1109    pub(crate) async fn compare_group_session(
1116        &self,
1117        session: &InboundGroupSession,
1118    ) -> Result<SessionOrdering> {
1119        let old_session = self
1120            .inner
1121            .store
1122            .get_inbound_group_session(session.room_id(), session.session_id())
1123            .await?;
1124
1125        Ok(if let Some(old_session) = old_session {
1126            session.compare(&old_session).await
1127        } else {
1128            SessionOrdering::Better
1129        })
1130    }
1131
1132    #[cfg(test)]
1133    pub(crate) async fn save_device_data(&self, devices: &[DeviceData]) -> Result<()> {
1135        let changes = Changes {
1136            devices: DeviceChanges { changed: devices.to_vec(), ..Default::default() },
1137            ..Default::default()
1138        };
1139
1140        self.save_changes(changes).await
1141    }
1142
1143    pub(crate) async fn save_inbound_group_sessions(
1145        &self,
1146        sessions: &[InboundGroupSession],
1147    ) -> Result<()> {
1148        let changes = Changes { inbound_group_sessions: sessions.to_vec(), ..Default::default() };
1149
1150        self.save_changes(changes).await
1151    }
1152
1153    pub(crate) async fn device_display_name(&self) -> Result<Option<String>, CryptoStoreError> {
1155        Ok(self
1156            .inner
1157            .store
1158            .get_device(self.user_id(), self.device_id())
1159            .await?
1160            .and_then(|d| d.display_name().map(|d| d.to_owned())))
1161    }
1162
1163    pub(crate) async fn get_device_data(
1168        &self,
1169        user_id: &UserId,
1170        device_id: &DeviceId,
1171    ) -> Result<Option<DeviceData>> {
1172        self.inner.store.get_device(user_id, device_id).await
1173    }
1174
1175    pub(crate) async fn get_device_data_for_user_filtered(
1183        &self,
1184        user_id: &UserId,
1185    ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1186        self.inner.store.get_user_devices(user_id).await.map(|mut d| {
1187            if user_id == self.user_id() {
1188                d.remove(self.device_id());
1189            }
1190            d
1191        })
1192    }
1193
1194    pub(crate) async fn get_device_data_for_user(
1203        &self,
1204        user_id: &UserId,
1205    ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
1206        self.inner.store.get_user_devices(user_id).await
1207    }
1208
1209    pub(crate) async fn get_device_from_curve_key(
1215        &self,
1216        user_id: &UserId,
1217        curve_key: Curve25519PublicKey,
1218    ) -> Result<Option<Device>> {
1219        self.get_user_devices(user_id)
1220            .await
1221            .map(|d| d.devices().find(|d| d.curve25519_key() == Some(curve_key)))
1222    }
1223
1224    pub(crate) async fn get_user_devices(&self, user_id: &UserId) -> Result<UserDevices> {
1234        let devices = self.get_device_data_for_user(user_id).await?;
1235
1236        let own_identity = self
1237            .inner
1238            .store
1239            .get_user_identity(self.user_id())
1240            .await?
1241            .and_then(|i| i.own().cloned());
1242        let device_owner_identity = self.inner.store.get_user_identity(user_id).await?;
1243
1244        Ok(UserDevices {
1245            inner: devices,
1246            verification_machine: self.inner.verification_machine.clone(),
1247            own_identity,
1248            device_owner_identity,
1249        })
1250    }
1251
1252    pub(crate) async fn get_device(
1262        &self,
1263        user_id: &UserId,
1264        device_id: &DeviceId,
1265    ) -> Result<Option<Device>> {
1266        if let Some(device_data) = self.inner.store.get_device(user_id, device_id).await? {
1267            Ok(Some(self.wrap_device_data(device_data).await?))
1268        } else {
1269            Ok(None)
1270        }
1271    }
1272
1273    pub(crate) async fn wrap_device_data(&self, device_data: DeviceData) -> Result<Device> {
1278        let own_identity = self
1279            .inner
1280            .store
1281            .get_user_identity(self.user_id())
1282            .await?
1283            .and_then(|i| i.own().cloned());
1284
1285        let device_owner_identity =
1286            self.inner.store.get_user_identity(device_data.user_id()).await?;
1287
1288        Ok(Device {
1289            inner: device_data,
1290            verification_machine: self.inner.verification_machine.clone(),
1291            own_identity,
1292            device_owner_identity,
1293        })
1294    }
1295
1296    pub(crate) async fn get_identity(&self, user_id: &UserId) -> Result<Option<UserIdentity>> {
1298        let own_identity = self
1299            .inner
1300            .store
1301            .get_user_identity(self.user_id())
1302            .await?
1303            .and_then(as_variant!(UserIdentityData::Own));
1304
1305        Ok(self.inner.store.get_user_identity(user_id).await?.map(|i| {
1306            UserIdentity::new(
1307                self.clone(),
1308                i,
1309                self.inner.verification_machine.to_owned(),
1310                own_identity,
1311            )
1312        }))
1313    }
1314
1315    pub async fn export_secret(
1324        &self,
1325        secret_name: &SecretName,
1326    ) -> Result<Option<String>, CryptoStoreError> {
1327        Ok(match secret_name {
1328            SecretName::CrossSigningMasterKey
1329            | SecretName::CrossSigningUserSigningKey
1330            | SecretName::CrossSigningSelfSigningKey => {
1331                self.inner.identity.lock().await.export_secret(secret_name).await
1332            }
1333            SecretName::RecoveryKey => {
1334                if let Some(key) = self.load_backup_keys().await?.decryption_key {
1335                    let exported = key.to_base64();
1336                    Some(exported)
1337                } else {
1338                    None
1339                }
1340            }
1341            name => {
1342                warn!(secret = ?name, "Unknown secret was requested");
1343                None
1344            }
1345        })
1346    }
1347
1348    pub async fn export_cross_signing_keys(
1356        &self,
1357    ) -> Result<Option<CrossSigningKeyExport>, CryptoStoreError> {
1358        let master_key = self.export_secret(&SecretName::CrossSigningMasterKey).await?;
1359        let self_signing_key = self.export_secret(&SecretName::CrossSigningSelfSigningKey).await?;
1360        let user_signing_key = self.export_secret(&SecretName::CrossSigningUserSigningKey).await?;
1361
1362        Ok(if master_key.is_none() && self_signing_key.is_none() && user_signing_key.is_none() {
1363            None
1364        } else {
1365            Some(CrossSigningKeyExport { master_key, self_signing_key, user_signing_key })
1366        })
1367    }
1368
1369    pub async fn import_cross_signing_keys(
1374        &self,
1375        export: CrossSigningKeyExport,
1376    ) -> Result<CrossSigningStatus, SecretImportError> {
1377        if let Some(public_identity) =
1378            self.get_identity(self.user_id()).await?.and_then(|i| i.own())
1379        {
1380            let identity = self.inner.identity.lock().await;
1381
1382            identity
1383                .import_secrets(
1384                    public_identity.to_owned(),
1385                    export.master_key.as_deref(),
1386                    export.self_signing_key.as_deref(),
1387                    export.user_signing_key.as_deref(),
1388                )
1389                .await?;
1390
1391            let status = identity.status().await;
1392
1393            let diff = identity.get_public_identity_diff(&public_identity.inner).await;
1394
1395            let mut changes =
1396                Changes { private_identity: Some(identity.clone()), ..Default::default() };
1397
1398            if diff.none_differ() {
1399                public_identity.mark_as_verified();
1400                changes.identities.changed.push(UserIdentityData::Own(public_identity.inner));
1401            }
1402
1403            info!(?status, "Successfully imported the private cross-signing keys");
1404
1405            self.save_changes(changes).await?;
1406        } else {
1407            warn!("No public identity found while importing cross-signing keys, a /keys/query needs to be done");
1408        }
1409
1410        Ok(self.inner.identity.lock().await.status().await)
1411    }
1412
1413    pub async fn export_secrets_bundle(&self) -> Result<SecretsBundle, SecretsBundleExportError> {
1425        let Some(cross_signing) = self.export_cross_signing_keys().await? else {
1426            return Err(SecretsBundleExportError::MissingCrossSigningKeys);
1427        };
1428
1429        let Some(master_key) = cross_signing.master_key.clone() else {
1430            return Err(SecretsBundleExportError::MissingCrossSigningKey(KeyUsage::Master));
1431        };
1432
1433        let Some(user_signing_key) = cross_signing.user_signing_key.clone() else {
1434            return Err(SecretsBundleExportError::MissingCrossSigningKey(KeyUsage::UserSigning));
1435        };
1436
1437        let Some(self_signing_key) = cross_signing.self_signing_key.clone() else {
1438            return Err(SecretsBundleExportError::MissingCrossSigningKey(KeyUsage::SelfSigning));
1439        };
1440
1441        let backup_keys = self.load_backup_keys().await?;
1442
1443        let backup = if let Some(key) = backup_keys.decryption_key {
1444            if let Some(backup_version) = backup_keys.backup_version {
1445                Some(BackupSecrets::MegolmBackupV1Curve25519AesSha2(
1446                    MegolmBackupV1Curve25519AesSha2Secrets { key, backup_version },
1447                ))
1448            } else {
1449                return Err(SecretsBundleExportError::MissingBackupVersion);
1450            }
1451        } else {
1452            None
1453        };
1454
1455        Ok(SecretsBundle {
1456            cross_signing: CrossSigningSecrets { master_key, user_signing_key, self_signing_key },
1457            backup,
1458        })
1459    }
1460
1461    pub async fn import_secrets_bundle(
1474        &self,
1475        bundle: &SecretsBundle,
1476    ) -> Result<(), SecretImportError> {
1477        let mut changes = Changes::default();
1478
1479        if let Some(backup_bundle) = &bundle.backup {
1480            match backup_bundle {
1481                BackupSecrets::MegolmBackupV1Curve25519AesSha2(bundle) => {
1482                    changes.backup_decryption_key = Some(bundle.key.clone());
1483                    changes.backup_version = Some(bundle.backup_version.clone());
1484                }
1485            }
1486        }
1487
1488        let identity = self.inner.identity.lock().await;
1489
1490        identity
1491            .import_secrets_unchecked(
1492                Some(&bundle.cross_signing.master_key),
1493                Some(&bundle.cross_signing.self_signing_key),
1494                Some(&bundle.cross_signing.user_signing_key),
1495            )
1496            .await?;
1497
1498        let public_identity = identity.to_public_identity().await.expect(
1499            "We should be able to create a new public identity since we just imported \
1500             all the private cross-signing keys",
1501        );
1502
1503        changes.private_identity = Some(identity.clone());
1504        changes.identities.new.push(UserIdentityData::Own(public_identity));
1505
1506        Ok(self.save_changes(changes).await?)
1507    }
1508
1509    pub async fn import_secret(&self, secret: &GossippedSecret) -> Result<(), SecretImportError> {
1511        match &secret.secret_name {
1512            SecretName::CrossSigningMasterKey
1513            | SecretName::CrossSigningUserSigningKey
1514            | SecretName::CrossSigningSelfSigningKey => {
1515                if let Some(public_identity) =
1516                    self.get_identity(self.user_id()).await?.and_then(|i| i.own())
1517                {
1518                    let identity = self.inner.identity.lock().await;
1519
1520                    identity
1521                        .import_secret(
1522                            public_identity,
1523                            &secret.secret_name,
1524                            &secret.event.content.secret,
1525                        )
1526                        .await?;
1527                    info!(
1528                        secret_name = ?secret.secret_name,
1529                        "Successfully imported a private cross signing key"
1530                    );
1531
1532                    let changes =
1533                        Changes { private_identity: Some(identity.clone()), ..Default::default() };
1534
1535                    self.save_changes(changes).await?;
1536                }
1537            }
1538            SecretName::RecoveryKey => {
1539                }
1545            name => {
1546                warn!(secret = ?name, "Tried to import an unknown secret");
1547            }
1548        }
1549
1550        Ok(())
1551    }
1552
1553    pub async fn get_only_allow_trusted_devices(&self) -> Result<bool> {
1556        let value = self.get_value("only_allow_trusted_devices").await?.unwrap_or_default();
1557        Ok(value)
1558    }
1559
1560    pub async fn set_only_allow_trusted_devices(
1563        &self,
1564        block_untrusted_devices: bool,
1565    ) -> Result<()> {
1566        self.set_value("only_allow_trusted_devices", &block_untrusted_devices).await
1567    }
1568
1569    pub async fn get_value<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
1571        let Some(value) = self.get_custom_value(key).await? else {
1572            return Ok(None);
1573        };
1574        let deserialized = self.deserialize_value(&value)?;
1575        Ok(Some(deserialized))
1576    }
1577
1578    pub async fn set_value(&self, key: &str, value: &impl Serialize) -> Result<()> {
1580        let serialized = self.serialize_value(value)?;
1581        self.set_custom_value(key, serialized).await?;
1582        Ok(())
1583    }
1584
1585    fn serialize_value(&self, value: &impl Serialize) -> Result<Vec<u8>> {
1586        let serialized =
1587            rmp_serde::to_vec_named(value).map_err(|x| CryptoStoreError::Backend(x.into()))?;
1588        Ok(serialized)
1589    }
1590
1591    fn deserialize_value<T: DeserializeOwned>(&self, value: &[u8]) -> Result<T> {
1592        let deserialized =
1593            rmp_serde::from_slice(value).map_err(|e| CryptoStoreError::Backend(e.into()))?;
1594        Ok(deserialized)
1595    }
1596
1597    pub fn room_keys_received_stream(
1609        &self,
1610    ) -> impl Stream<Item = Result<Vec<RoomKeyInfo>, BroadcastStreamRecvError>> {
1611        self.inner.store.room_keys_received_stream()
1612    }
1613
1614    pub fn room_keys_withheld_received_stream(
1623        &self,
1624    ) -> impl Stream<Item = Vec<RoomKeyWithheldInfo>> {
1625        self.inner.store.room_keys_withheld_received_stream()
1626    }
1627
1628    pub fn user_identities_stream(&self) -> impl Stream<Item = IdentityUpdates> {
1659        let verification_machine = self.inner.verification_machine.to_owned();
1660
1661        let this = self.clone();
1662        self.inner.store.identities_stream().map(move |(own_identity, identities, _)| {
1663            let (new_identities, changed_identities, unchanged_identities) = identities.into_maps();
1664
1665            let map_identity = |(user_id, identity)| {
1666                (
1667                    user_id,
1668                    UserIdentity::new(
1669                        this.clone(),
1670                        identity,
1671                        verification_machine.to_owned(),
1672                        own_identity.to_owned(),
1673                    ),
1674                )
1675            };
1676
1677            let new = new_identities.into_iter().map(map_identity).collect();
1678            let changed = changed_identities.into_iter().map(map_identity).collect();
1679            let unchanged = unchanged_identities.into_iter().map(map_identity).collect();
1680
1681            IdentityUpdates { new, changed, unchanged }
1682        })
1683    }
1684
1685    pub fn devices_stream(&self) -> impl Stream<Item = DeviceUpdates> {
1717        let verification_machine = self.inner.verification_machine.to_owned();
1718
1719        self.inner.store.identities_stream().map(move |(own_identity, identities, devices)| {
1720            collect_device_updates(
1721                verification_machine.to_owned(),
1722                own_identity,
1723                identities,
1724                devices,
1725            )
1726        })
1727    }
1728
1729    pub fn identities_stream_raw(&self) -> impl Stream<Item = (IdentityChanges, DeviceChanges)> {
1739        self.inner.store.identities_stream().map(|(_, identities, devices)| (identities, devices))
1740    }
1741
1742    pub fn create_store_lock(
1745        &self,
1746        lock_key: String,
1747        lock_value: String,
1748    ) -> CrossProcessStoreLock<LockableCryptoStore> {
1749        self.inner.store.create_store_lock(lock_key, lock_value)
1750    }
1751
1752    pub fn secrets_stream(&self) -> impl Stream<Item = GossippedSecret> {
1792        self.inner.store.secrets_stream()
1793    }
1794
1795    pub async fn import_room_keys(
1808        &self,
1809        exported_keys: Vec<ExportedRoomKey>,
1810        from_backup_version: Option<&str>,
1811        progress_listener: impl Fn(usize, usize),
1812    ) -> Result<RoomKeyImportResult> {
1813        let mut sessions = Vec::new();
1814
1815        async fn new_session_better(
1816            session: &InboundGroupSession,
1817            old_session: Option<InboundGroupSession>,
1818        ) -> bool {
1819            if let Some(old_session) = &old_session {
1820                session.compare(old_session).await == SessionOrdering::Better
1821            } else {
1822                true
1823            }
1824        }
1825
1826        let total_count = exported_keys.len();
1827        let mut keys = BTreeMap::new();
1828
1829        for (i, key) in exported_keys.into_iter().enumerate() {
1830            match InboundGroupSession::from_export(&key) {
1831                Ok(session) => {
1832                    let old_session = self
1833                        .inner
1834                        .store
1835                        .get_inbound_group_session(session.room_id(), session.session_id())
1836                        .await?;
1837
1838                    if new_session_better(&session, old_session).await {
1841                        if from_backup_version.is_some() {
1842                            session.mark_as_backed_up();
1843                        }
1844
1845                        keys.entry(session.room_id().to_owned())
1846                            .or_insert_with(BTreeMap::new)
1847                            .entry(session.sender_key().to_base64())
1848                            .or_insert_with(BTreeSet::new)
1849                            .insert(session.session_id().to_owned());
1850
1851                        sessions.push(session);
1852                    }
1853                }
1854                Err(e) => {
1855                    warn!(
1856                        sender_key= key.sender_key.to_base64(),
1857                        room_id = ?key.room_id,
1858                        session_id = key.session_id,
1859                        error = ?e,
1860                        "Couldn't import a room key from a file export."
1861                    );
1862                }
1863            }
1864
1865            progress_listener(i, total_count);
1866        }
1867
1868        let imported_count = sessions.len();
1869
1870        self.inner.store.save_inbound_group_sessions(sessions, from_backup_version).await?;
1871
1872        info!(total_count, imported_count, room_keys = ?keys, "Successfully imported room keys");
1873
1874        Ok(RoomKeyImportResult::new(imported_count, total_count, keys))
1875    }
1876
1877    pub async fn import_exported_room_keys(
1904        &self,
1905        exported_keys: Vec<ExportedRoomKey>,
1906        progress_listener: impl Fn(usize, usize),
1907    ) -> Result<RoomKeyImportResult> {
1908        self.import_room_keys(exported_keys, None, progress_listener).await
1909    }
1910
1911    pub(crate) fn crypto_store(&self) -> Arc<CryptoStoreWrapper> {
1912        self.inner.store.clone()
1913    }
1914
1915    pub async fn export_room_keys(
1938        &self,
1939        predicate: impl FnMut(&InboundGroupSession) -> bool,
1940    ) -> Result<Vec<ExportedRoomKey>> {
1941        let mut exported = Vec::new();
1942
1943        let mut sessions = self.get_inbound_group_sessions().await?;
1944        sessions.retain(predicate);
1945
1946        for session in sessions {
1947            let export = session.export().await;
1948            exported.push(export);
1949        }
1950
1951        Ok(exported)
1952    }
1953
1954    pub async fn export_room_keys_stream(
1987        &self,
1988        predicate: impl FnMut(&InboundGroupSession) -> bool,
1989    ) -> Result<impl Stream<Item = ExportedRoomKey>> {
1990        let sessions = self.get_inbound_group_sessions().await?;
1992        Ok(futures_util::stream::iter(sessions.into_iter().filter(predicate))
1993            .then(|session| async move { session.export().await }))
1994    }
1995
1996    pub async fn build_room_key_bundle(
2001        &self,
2002        room_id: &RoomId,
2003    ) -> std::result::Result<RoomKeyBundle, CryptoStoreError> {
2004        let mut sessions = self.get_inbound_group_sessions().await?;
2007        sessions.retain(|session| session.room_id == room_id);
2008
2009        let mut bundle = RoomKeyBundle::default();
2010        for session in sessions {
2011            if session.shared_history() {
2012                bundle.room_keys.push(session.export().await.into());
2013            } else {
2014                bundle.withheld.push(RoomKeyWithheldContent::new(
2015                    session.algorithm().to_owned(),
2016                    WithheldCode::Unauthorised,
2017                    session.room_id().to_owned(),
2018                    session.session_id().to_owned(),
2019                    session.sender_key().to_owned(),
2020                    self.device_id().to_owned(),
2021                ));
2022            }
2023        }
2024
2025        Ok(bundle)
2026    }
2027}
2028
2029impl Deref for Store {
2030    type Target = DynCryptoStore;
2031
2032    fn deref(&self) -> &Self::Target {
2033        self.inner.store.deref().deref()
2034    }
2035}
2036
2037#[derive(Clone, Debug)]
2039pub struct LockableCryptoStore(Arc<dyn CryptoStore<Error = CryptoStoreError>>);
2040
2041#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
2042#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
2043impl matrix_sdk_common::store_locks::BackingStore for LockableCryptoStore {
2044    type LockError = CryptoStoreError;
2045
2046    async fn try_lock(
2047        &self,
2048        lease_duration_ms: u32,
2049        key: &str,
2050        holder: &str,
2051    ) -> std::result::Result<bool, Self::LockError> {
2052        self.0.try_take_leased_lock(lease_duration_ms, key, holder).await
2053    }
2054}
2055
2056#[cfg(test)]
2057mod tests {
2058    use std::pin::pin;
2059
2060    use futures_util::StreamExt;
2061    use insta::{_macro_support::Content, assert_json_snapshot, internals::ContentPath};
2062    use matrix_sdk_test::async_test;
2063    use ruma::{device_id, room_id, user_id, RoomId};
2064    use vodozemac::megolm::SessionKey;
2065
2066    use crate::{
2067        machine::test_helpers::get_machine_pair,
2068        olm::{InboundGroupSession, SenderData},
2069        store::DehydratedDeviceKey,
2070        types::EventEncryptionAlgorithm,
2071        OlmMachine,
2072    };
2073
2074    #[async_test]
2075    async fn test_import_room_keys_notifies_stream() {
2076        use futures_util::FutureExt;
2077
2078        let (alice, bob, _) =
2079            get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2080
2081        let room1_id = room_id!("!room1:localhost");
2082        alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2083        let exported_sessions = alice.store().export_room_keys(|_| true).await.unwrap();
2084
2085        let mut room_keys_received_stream = Box::pin(bob.store().room_keys_received_stream());
2086        bob.store().import_room_keys(exported_sessions, None, |_, _| {}).await.unwrap();
2087
2088        let room_keys = room_keys_received_stream
2089            .next()
2090            .now_or_never()
2091            .flatten()
2092            .expect("We should have received an update of room key infos")
2093            .unwrap();
2094        assert_eq!(room_keys.len(), 1);
2095        assert_eq!(room_keys[0].room_id, "!room1:localhost");
2096    }
2097
2098    #[async_test]
2099    async fn test_export_room_keys_provides_selected_keys() {
2100        let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2102        let room1_id = room_id!("!room1:localhost");
2103        let room2_id = room_id!("!room2:localhost");
2104        let room3_id = room_id!("!room3:localhost");
2105        alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2106        alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap();
2107        alice.create_outbound_group_session_with_defaults_test_helper(room3_id).await.unwrap();
2108
2109        let keys = alice
2111            .store()
2112            .export_room_keys(|s| s.room_id() == room2_id || s.room_id() == room3_id)
2113            .await
2114            .unwrap();
2115
2116        assert_eq!(keys.len(), 2);
2118        assert_eq!(keys[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2119        assert_eq!(keys[1].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2120        assert_eq!(keys[0].room_id, "!room2:localhost");
2121        assert_eq!(keys[1].room_id, "!room3:localhost");
2122        assert_eq!(keys[0].session_key.to_base64().len(), 220);
2123        assert_eq!(keys[1].session_key.to_base64().len(), 220);
2124    }
2125
2126    #[async_test]
2127    async fn test_export_room_keys_stream_can_provide_all_keys() {
2128        let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2130        let room1_id = room_id!("!room1:localhost");
2131        let room2_id = room_id!("!room2:localhost");
2132        alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2133        alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap();
2134
2135        let mut keys = pin!(alice.store().export_room_keys_stream(|_| true).await.unwrap());
2137
2138        let mut collected = vec![];
2140        while let Some(key) = keys.next().await {
2141            collected.push(key);
2142        }
2143
2144        assert_eq!(collected.len(), 2);
2146        assert_eq!(collected[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2147        assert_eq!(collected[1].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2148        assert_eq!(collected[0].room_id, "!room1:localhost");
2149        assert_eq!(collected[1].room_id, "!room2:localhost");
2150        assert_eq!(collected[0].session_key.to_base64().len(), 220);
2151        assert_eq!(collected[1].session_key.to_base64().len(), 220);
2152    }
2153
2154    #[async_test]
2155    async fn test_export_room_keys_stream_can_provide_a_subset_of_keys() {
2156        let (alice, _, _) = get_machine_pair(user_id!("@a:s.co"), user_id!("@b:s.co"), false).await;
2158        let room1_id = room_id!("!room1:localhost");
2159        let room2_id = room_id!("!room2:localhost");
2160        alice.create_outbound_group_session_with_defaults_test_helper(room1_id).await.unwrap();
2161        alice.create_outbound_group_session_with_defaults_test_helper(room2_id).await.unwrap();
2162
2163        let mut keys =
2165            pin!(alice.store().export_room_keys_stream(|s| s.room_id() == room1_id).await.unwrap());
2166
2167        let mut collected = vec![];
2169        while let Some(key) = keys.next().await {
2170            collected.push(key);
2171        }
2172
2173        assert_eq!(collected.len(), 1);
2175        assert_eq!(collected[0].algorithm, EventEncryptionAlgorithm::MegolmV1AesSha2);
2176        assert_eq!(collected[0].room_id, "!room1:localhost");
2177        assert_eq!(collected[0].session_key.to_base64().len(), 220);
2178    }
2179
2180    #[async_test]
2181    async fn test_export_secrets_bundle() {
2182        let user_id = user_id!("@alice:example.com");
2183        let (first, second, _) = get_machine_pair(user_id, user_id, false).await;
2184
2185        let _ = first
2186            .bootstrap_cross_signing(false)
2187            .await
2188            .expect("We should be able to bootstrap cross-signing");
2189
2190        let bundle = first.store().export_secrets_bundle().await.expect(
2191            "We should be able to export the secrets bundle, now that we \
2192             have the cross-signing keys",
2193        );
2194
2195        assert!(bundle.backup.is_none(), "The bundle should not contain a backup key");
2196
2197        second
2198            .store()
2199            .import_secrets_bundle(&bundle)
2200            .await
2201            .expect("We should be able to import the secrets bundle");
2202
2203        let status = second.cross_signing_status().await;
2204        let identity = second.get_identity(user_id, None).await.unwrap().unwrap().own().unwrap();
2205
2206        assert!(identity.is_verified(), "The public identity should be marked as verified.");
2207
2208        assert!(status.is_complete(), "We should have imported all the cross-signing keys");
2209    }
2210
2211    #[async_test]
2212    async fn test_create_dehydrated_device_key() {
2213        let pickle_key = DehydratedDeviceKey::new()
2214            .expect("Should be able to create a random dehydrated device key");
2215
2216        let to_vec = pickle_key.inner.to_vec();
2217        let pickle_key_from_slice = DehydratedDeviceKey::from_slice(to_vec.as_slice())
2218            .expect("Should be able to create a dehydrated device key from slice");
2219
2220        assert_eq!(pickle_key_from_slice.to_base64(), pickle_key.to_base64());
2221    }
2222
2223    #[async_test]
2224    async fn test_create_dehydrated_errors() {
2225        let too_small = [0u8; 22];
2226        let pickle_key = DehydratedDeviceKey::from_slice(&too_small);
2227
2228        assert!(pickle_key.is_err());
2229
2230        let too_big = [0u8; 40];
2231        let pickle_key = DehydratedDeviceKey::from_slice(&too_big);
2232
2233        assert!(pickle_key.is_err());
2234    }
2235
2236    #[async_test]
2237    async fn test_build_room_key_bundle() {
2238        let alice = OlmMachine::new(user_id!("@a:s.co"), device_id!("ALICE")).await;
2241        let bob = OlmMachine::new(user_id!("@b:s.co"), device_id!("BOB")).await;
2242
2243        let room1_id = room_id!("!room1:localhost");
2244        let room2_id = room_id!("!room2:localhost");
2245
2246        let session_key1 = "AgAAAAC2XHVzsMBKs4QCRElJ92CJKyGtknCSC8HY7cQ7UYwndMKLQAejXLh5UA0l6s736mgctcUMNvELScUWrObdflrHo+vth/gWreXOaCnaSxmyjjKErQwyIYTkUfqbHy40RJfEesLwnN23on9XAkch/iy8R2+Jz7B8zfG01f2Ow2SxPQFnAndcO1ZSD2GmXgedy6n4B20MWI1jGP2wiexOWbFSya8DO/VxC9m5+/mF+WwYqdpKn9g4Y05Yw4uz7cdjTc3rXm7xK+8E7hI//5QD1nHPvuKYbjjM9u2JSL+Bzp61Cw";
2251        let session_key2 = "AgAAAAC1BXreFTUQQSBGekTEuYxhdytRKyv4JgDGcG+VOBYdPNGgs807SdibCGJky4lJ3I+7ZDGHoUzZPZP/4ogGu4kxni0PWdtWuN7+5zsuamgoFF/BkaGeUUGv6kgIkx8pyPpM5SASTUEP9bN2loDSpUPYwfiIqz74DgC4WQ4435sTBctYvKz8n+TDJwdLXpyT6zKljuqADAioud+s/iqx9LYn9HpbBfezZcvbg67GtE113pLrvde3IcPI5s6dNHK2onGO2B2eoaobcen18bbEDnlUGPeIivArLya7Da6us14jBQ";
2252        let session_key3 = "AgAAAAAM9KFsliaUUhGSXgwOzM5UemjkNH4n8NHgvC/y8hhw13zTF+ooGD4uIYEXYX630oNvQm/EvgZo+dkoc0re+vsqsx4sQeNODdSjcBsWOa0oDF+irQn9oYoLUDPI1IBtY1rX+FV99Zm/xnG7uFOX7aTVlko2GSdejy1w9mfobmfxu5aUc04A9zaKJP1pOthZvRAlhpymGYHgsDtWPrrjyc/yypMflE4kIUEEEtu1kT6mrAmcl615XYRAHYK9G2+fZsGvokwzbkl4nulGwcZMpQEoM0nD2o3GWgX81HW3nGfKBg";
2253        let session_key4 = "AgAAAAA4Kkesxq2h4v9PLD6Sm3Smxspz1PXTqytQPCMQMkkrHNmzV2bHlJ+6/Al9cu8vh1Oj69AK0WUAeJOJuaiskEeg/PI3P03+UYLeC379RzgqwSHdBgdQ41G2vD6zpgmE/8vYToe+qpCZACtPOswZxyqxHH+T/Iq0nv13JmlFGIeA6fEPfr5Y28B49viG74Fs9rxV9EH5PfjbuPM/p+Sz5obShuaBPKQBX1jT913nEXPoIJ06exNZGr0285nw/LgVvNlmWmbqNnbzO2cNZjQWA+xZYz5FSfyCxwqEBbEdUCuRCQ";
2254
2255        let sessions = [
2256            create_inbound_group_session_with_visibility(
2257                &alice,
2258                room1_id,
2259                &SessionKey::from_base64(session_key1).unwrap(),
2260                true,
2261            ),
2262            create_inbound_group_session_with_visibility(
2263                &alice,
2264                room1_id,
2265                &SessionKey::from_base64(session_key2).unwrap(),
2266                true,
2267            ),
2268            create_inbound_group_session_with_visibility(
2269                &alice,
2270                room1_id,
2271                &SessionKey::from_base64(session_key3).unwrap(),
2272                false,
2273            ),
2274            create_inbound_group_session_with_visibility(
2275                &alice,
2276                room2_id,
2277                &SessionKey::from_base64(session_key4).unwrap(),
2278                true,
2279            ),
2280        ];
2281        bob.store().save_inbound_group_sessions(&sessions).await.unwrap();
2282
2283        let mut bundle = bob.store().build_room_key_bundle(room1_id).await.unwrap();
2285
2286        bundle.room_keys.sort_by_key(|session| session.session_id.clone());
2290
2291        let alice_curve_key = alice.identity_keys().curve25519.to_base64();
2293        let map_alice_curve_key = move |value: Content, _path: ContentPath<'_>| {
2294            assert_eq!(value.as_str().unwrap(), alice_curve_key);
2295            "[alice curve key]"
2296        };
2297        let alice_ed25519_key = alice.identity_keys().ed25519.to_base64();
2298        let map_alice_ed25519_key = move |value: Content, _path: ContentPath<'_>| {
2299            assert_eq!(value.as_str().unwrap(), alice_ed25519_key);
2300            "[alice ed25519 key]"
2301        };
2302
2303        insta::with_settings!({ sort_maps => true }, {
2304            assert_json_snapshot!(bundle, {
2305                ".room_keys[].sender_key" => insta::dynamic_redaction(map_alice_curve_key.clone()),
2306                ".withheld[].sender_key" => insta::dynamic_redaction(map_alice_curve_key),
2307                ".room_keys[].sender_claimed_keys.ed25519" => insta::dynamic_redaction(map_alice_ed25519_key),
2308            });
2309        });
2310    }
2311
2312    fn create_inbound_group_session_with_visibility(
2317        olm_machine: &OlmMachine,
2318        room_id: &RoomId,
2319        session_key: &SessionKey,
2320        shared_history: bool,
2321    ) -> InboundGroupSession {
2322        let identity_keys = &olm_machine.store().static_account().identity_keys;
2323        InboundGroupSession::new(
2324            identity_keys.curve25519,
2325            identity_keys.ed25519,
2326            room_id,
2327            session_key,
2328            SenderData::unknown(),
2329            EventEncryptionAlgorithm::MegolmV1AesSha2,
2330            None,
2331            shared_history,
2332        )
2333        .unwrap()
2334    }
2335}