matrix_sdk_sled/
crypto_store.rs

1// Copyright 2020 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    borrow::Cow,
17    collections::{HashMap, HashSet},
18    path::{Path, PathBuf},
19    sync::{Arc, RwLock},
20};
21
22use async_trait::async_trait;
23use dashmap::DashSet;
24use matrix_sdk_common::locks::Mutex;
25use matrix_sdk_crypto::{
26    olm::{
27        IdentityKeys, InboundGroupSession, OutboundGroupSession, PickledInboundGroupSession,
28        PrivateCrossSigningIdentity, Session,
29    },
30    store::{
31        caches::SessionStore, BackupKeys, Changes, CryptoStore, CryptoStoreError, Result,
32        RoomKeyCounts,
33    },
34    types::{events::room_key_request::SupportedKeyInfo, EventEncryptionAlgorithm},
35    GossipRequest, ReadOnlyAccount, ReadOnlyDevice, ReadOnlyUserIdentities, SecretInfo,
36};
37use matrix_sdk_store_encryption::StoreCipher;
38use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, RoomId, TransactionId, UserId};
39use serde::{de::DeserializeOwned, Deserialize, Serialize};
40pub use sled::Error;
41use sled::{
42    transaction::{ConflictableTransactionError, TransactionError},
43    Batch, Config, Db, IVec, Transactional, Tree,
44};
45use tracing::debug;
46
47use super::OpenStoreError;
48use crate::encode_key::{EncodeKey, ENCODE_SEPARATOR};
49
50const DATABASE_VERSION: u8 = 5;
51
52// Table names that are used to derive a separate key for each tree. This ensure
53// that user ids encoded for different trees won't end up as the same byte
54// sequence. This prevents corelation attacks on our tree metadata.
55const DEVICE_TABLE_NAME: &str = "crypto-store-devices";
56const IDENTITIES_TABLE_NAME: &str = "crypto-store-identities";
57const SESSIONS_TABLE_NAME: &str = "crypto-store-sessions";
58const INBOUND_GROUP_TABLE_NAME: &str = "crypto-store-inbound-group-sessions";
59const OUTBOUND_GROUP_TABLE_NAME: &str = "crypto-store-outbound-group-sessions";
60const SECRET_REQUEST_BY_INFO_TABLE: &str = "crypto-store-secret-request-by-info";
61const TRACKED_USERS_TABLE: &str = "crypto-store-secret-tracked-users";
62
63impl EncodeKey for InboundGroupSession {
64    fn encode(&self) -> Vec<u8> {
65        (self.room_id(), self.sender_key().to_base64(), self.session_id()).encode()
66    }
67
68    fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
69        (self.room_id(), self.sender_key().to_base64(), self.session_id())
70            .encode_secure(table_name, store_cipher)
71    }
72}
73
74impl EncodeKey for OutboundGroupSession {
75    fn encode(&self) -> Vec<u8> {
76        self.room_id().encode()
77    }
78    fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
79        self.room_id().encode_secure(table_name, store_cipher)
80    }
81}
82
83impl EncodeKey for Session {
84    fn encode(&self) -> Vec<u8> {
85        let sender_key = self.sender_key().to_base64();
86        let session_id = self.session_id();
87
88        [sender_key.as_bytes(), &[ENCODE_SEPARATOR], session_id.as_bytes(), &[ENCODE_SEPARATOR]]
89            .concat()
90    }
91
92    fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
93        let sender_key =
94            store_cipher.hash_key(table_name, self.sender_key().to_base64().as_bytes());
95        let session_id = store_cipher.hash_key(table_name, self.session_id().as_bytes());
96
97        [sender_key.as_slice(), &[ENCODE_SEPARATOR], session_id.as_slice(), &[ENCODE_SEPARATOR]]
98            .concat()
99    }
100}
101
102impl EncodeKey for SecretInfo {
103    fn encode(&self) -> Vec<u8> {
104        match self {
105            SecretInfo::KeyRequest(k) => k.encode(),
106            SecretInfo::SecretRequest(s) => s.encode(),
107        }
108    }
109    fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
110        match self {
111            SecretInfo::KeyRequest(k) => k.encode_secure(table_name, store_cipher),
112            SecretInfo::SecretRequest(s) => s.encode_secure(table_name, store_cipher),
113        }
114    }
115}
116
117impl EncodeKey for EventEncryptionAlgorithm {
118    fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
119        let s: &str = self.as_ref();
120        s.as_bytes().into()
121    }
122}
123
124impl EncodeKey for SupportedKeyInfo {
125    fn encode(&self) -> Vec<u8> {
126        (self.room_id(), &self.algorithm(), self.session_id()).encode()
127    }
128    fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
129        let room_id = store_cipher.hash_key(table_name, self.room_id().as_bytes());
130        let algorithm = store_cipher.hash_key(table_name, self.algorithm().as_ref().as_bytes());
131        let session_id = store_cipher.hash_key(table_name, self.session_id().as_bytes());
132
133        [
134            room_id.as_slice(),
135            &[ENCODE_SEPARATOR],
136            algorithm.as_slice(),
137            &[ENCODE_SEPARATOR],
138            session_id.as_slice(),
139            &[ENCODE_SEPARATOR],
140        ]
141        .concat()
142    }
143}
144
145impl EncodeKey for ReadOnlyDevice {
146    fn encode(&self) -> Vec<u8> {
147        (self.user_id(), self.device_id()).encode()
148    }
149    fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
150        (self.user_id(), self.device_id()).encode_secure(table_name, store_cipher)
151    }
152}
153
154#[derive(Clone, Debug)]
155pub struct AccountInfo {
156    user_id: Arc<UserId>,
157    device_id: Arc<DeviceId>,
158    identity_keys: Arc<IdentityKeys>,
159}
160
161#[derive(Debug, Serialize, Deserialize)]
162struct TrackedUser {
163    user_id: OwnedUserId,
164    dirty: bool,
165}
166
167/// A [sled] based cryptostore.
168///
169/// [sled]: https://github.com/spacejam/sled#readme
170#[derive(Clone)]
171pub struct SledCryptoStore {
172    account_info: Arc<RwLock<Option<AccountInfo>>>,
173    store_cipher: Option<Arc<StoreCipher>>,
174    path: Option<PathBuf>,
175    inner: Db,
176
177    session_cache: SessionStore,
178    tracked_users_cache: Arc<DashSet<OwnedUserId>>,
179    users_for_key_query_cache: Arc<DashSet<OwnedUserId>>,
180
181    account: Tree,
182    private_identity: Tree,
183
184    olm_hashes: Tree,
185    sessions: Tree,
186    inbound_group_sessions: Tree,
187    outbound_group_sessions: Tree,
188
189    outgoing_secret_requests: Tree,
190    unsent_secret_requests: Tree,
191    secret_requests_by_info: Tree,
192
193    devices: Tree,
194    identities: Tree,
195
196    tracked_users: Tree,
197}
198
199impl std::fmt::Debug for SledCryptoStore {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        if let Some(path) = &self.path {
202            f.debug_struct("SledCryptoStore").field("path", &path).finish()
203        } else {
204            f.debug_struct("SledCryptoStore").field("path", &"memory store").finish()
205        }
206    }
207}
208
209impl SledCryptoStore {
210    /// Open the sled-based crypto store at the given path using the given
211    /// passphrase to encrypt private data.
212    pub fn open_with_passphrase(
213        path: impl AsRef<Path>,
214        passphrase: Option<&str>,
215    ) -> Result<Self, OpenStoreError> {
216        let path = path.as_ref().join("matrix-sdk-crypto");
217        let db =
218            Config::new().temporary(false).path(&path).open().map_err(CryptoStoreError::backend)?;
219
220        let store_cipher = passphrase
221            .map(|p| Self::get_or_create_store_cipher(p, &db))
222            .transpose()?
223            .map(Into::into);
224
225        SledCryptoStore::open_helper(db, Some(path), store_cipher)
226    }
227
228    /// Create a sled-based crypto store using the given sled database.
229    /// The given passphrase will be used to encrypt private data.
230    pub fn open_with_database(db: Db, passphrase: Option<&str>) -> Result<Self, OpenStoreError> {
231        let store_cipher = passphrase
232            .map(|p| Self::get_or_create_store_cipher(p, &db))
233            .transpose()?
234            .map(Into::into);
235
236        SledCryptoStore::open_helper(db, None, store_cipher)
237    }
238
239    fn get_account_info(&self) -> Option<AccountInfo> {
240        self.account_info.read().unwrap().clone()
241    }
242
243    fn serialize_value(&self, event: &impl Serialize) -> Result<Vec<u8>, CryptoStoreError> {
244        if let Some(key) = &self.store_cipher {
245            key.encrypt_value(event).map_err(CryptoStoreError::backend)
246        } else {
247            Ok(serde_json::to_vec(event)?)
248        }
249    }
250
251    fn deserialize_value<T: DeserializeOwned>(&self, event: &[u8]) -> Result<T, CryptoStoreError> {
252        if let Some(key) = &self.store_cipher {
253            key.decrypt_value(event).map_err(CryptoStoreError::backend)
254        } else {
255            Ok(serde_json::from_slice(event)?)
256        }
257    }
258
259    fn encode_key<T: EncodeKey>(&self, table_name: &str, key: T) -> Vec<u8> {
260        if let Some(store_cipher) = &self.store_cipher {
261            key.encode_secure(table_name, store_cipher).to_vec()
262        } else {
263            key.encode()
264        }
265    }
266
267    async fn reset_backup_state(&self) -> Result<()> {
268        let mut pickles: Vec<(IVec, PickledInboundGroupSession)> = self
269            .inbound_group_sessions
270            .iter()
271            .map(|p| {
272                let item = p.map_err(CryptoStoreError::backend)?;
273                Ok((item.0, self.deserialize_value(&item.1)?))
274            })
275            .collect::<Result<_>>()?;
276
277        for (_, pickle) in &mut pickles {
278            pickle.backed_up = false;
279        }
280
281        let ret: Result<(), TransactionError<CryptoStoreError>> =
282            self.inbound_group_sessions.transaction(|inbound_sessions| {
283                for (key, pickle) in &pickles {
284                    inbound_sessions.insert(
285                        key,
286                        self.serialize_value(pickle)
287                            .map_err(ConflictableTransactionError::Abort)?,
288                    )?;
289                }
290
291                Ok(())
292            });
293
294        ret.map_err(CryptoStoreError::backend)?;
295
296        self.inner.flush_async().await.map_err(CryptoStoreError::backend)?;
297
298        Ok(())
299    }
300
301    fn upgrade(&self) -> Result<()> {
302        let version = self
303            .inner
304            .get("store_version")
305            .map_err(CryptoStoreError::backend)?
306            .map(|v| {
307                let (version_bytes, _) = v.split_at(std::mem::size_of::<u8>());
308                u8::from_be_bytes(version_bytes.try_into().unwrap_or_default())
309            })
310            .unwrap_or(DATABASE_VERSION);
311
312        if version != DATABASE_VERSION {
313            debug!(version, new_version = DATABASE_VERSION, "Upgrading the Sled crypto store");
314        }
315
316        if version <= 3 {
317            return Err(CryptoStoreError::UnsupportedDatabaseVersion(
318                version.into(),
319                DATABASE_VERSION.into(),
320            ));
321        }
322
323        if version <= 4 {
324            // Room key requests are not that important, if they are needed they
325            // will be sent out again. So let's drop all of them since we
326            // removed the `sender_key` from the hash key.
327            self.outgoing_secret_requests.clear().map_err(CryptoStoreError::backend)?;
328            self.unsent_secret_requests.clear().map_err(CryptoStoreError::backend)?;
329            self.secret_requests_by_info.clear().map_err(CryptoStoreError::backend)?;
330        }
331
332        self.inner
333            .insert("store_version", DATABASE_VERSION.to_be_bytes().as_ref())
334            .map_err(CryptoStoreError::backend)?;
335        self.inner.flush().map_err(CryptoStoreError::backend)?;
336
337        Ok(())
338    }
339
340    fn get_or_create_store_cipher(passphrase: &str, database: &Db) -> Result<StoreCipher> {
341        let cipher = if let Some(key) =
342            database.get("store_cipher".encode()).map_err(CryptoStoreError::backend)?
343        {
344            StoreCipher::import(passphrase, &key).map_err(|_| CryptoStoreError::UnpicklingError)?
345        } else {
346            let cipher = StoreCipher::new().map_err(CryptoStoreError::backend)?;
347            #[cfg(not(test))]
348            let export = cipher.export(passphrase);
349            #[cfg(test)]
350            let export = cipher._insecure_export_fast_for_testing(passphrase);
351            database
352                .insert("store_cipher".encode(), export.map_err(CryptoStoreError::backend)?)
353                .map_err(CryptoStoreError::backend)?;
354            cipher
355        };
356
357        Ok(cipher)
358    }
359
360    pub(crate) fn open_helper(
361        db: Db,
362        path: Option<PathBuf>,
363        store_cipher: Option<Arc<StoreCipher>>,
364    ) -> Result<Self, OpenStoreError> {
365        let account = db.open_tree("account")?;
366        let private_identity = db.open_tree("private_identity")?;
367
368        let sessions = db.open_tree("session")?;
369        let inbound_group_sessions = db.open_tree("inbound_group_sessions")?;
370
371        let outbound_group_sessions = db.open_tree("outbound_group_sessions")?;
372
373        let tracked_users = db.open_tree("tracked_users")?;
374        let olm_hashes = db.open_tree("olm_hashes")?;
375
376        let devices = db.open_tree("devices")?;
377        let identities = db.open_tree("identities")?;
378
379        let outgoing_secret_requests = db.open_tree("outgoing_secret_requests")?;
380        let unsent_secret_requests = db.open_tree("unsent_secret_requests")?;
381        let secret_requests_by_info = db.open_tree("secret_requests_by_info")?;
382
383        let session_cache = SessionStore::new();
384
385        let database = Self {
386            account_info: RwLock::new(None).into(),
387            path,
388            inner: db,
389            store_cipher,
390            account,
391            private_identity,
392            sessions,
393            session_cache,
394            tracked_users_cache: DashSet::new().into(),
395            users_for_key_query_cache: DashSet::new().into(),
396            inbound_group_sessions,
397            outbound_group_sessions,
398            outgoing_secret_requests,
399            unsent_secret_requests,
400            secret_requests_by_info,
401            devices,
402            tracked_users,
403            olm_hashes,
404            identities,
405        };
406
407        database.upgrade()?;
408
409        Ok(database)
410    }
411
412    async fn load_tracked_users(&self) -> Result<()> {
413        for value in &self.tracked_users {
414            let (_, user) = value.map_err(CryptoStoreError::backend)?;
415            let user: TrackedUser = self.deserialize_value(&user)?;
416
417            self.tracked_users_cache.insert(user.user_id.to_owned());
418
419            if user.dirty {
420                self.users_for_key_query_cache.insert(user.user_id);
421            }
422        }
423
424        Ok(())
425    }
426
427    async fn load_outbound_group_session(
428        &self,
429        room_id: &RoomId,
430    ) -> Result<Option<OutboundGroupSession>> {
431        let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?;
432
433        self.outbound_group_sessions
434            .get(self.encode_key(OUTBOUND_GROUP_TABLE_NAME, room_id))
435            .map_err(CryptoStoreError::backend)?
436            .map(|p| self.deserialize_value(&p))
437            .transpose()?
438            .map(|p| {
439                Ok(OutboundGroupSession::from_pickle(
440                    account_info.device_id,
441                    account_info.identity_keys,
442                    p,
443                )?)
444            })
445            .transpose()
446    }
447
448    async fn save_changes(&self, changes: Changes) -> Result<()> {
449        let account_pickle = if let Some(account) = changes.account {
450            let account_info = AccountInfo {
451                user_id: account.user_id.clone(),
452                device_id: account.device_id.clone(),
453                identity_keys: account.identity_keys.clone(),
454            };
455
456            *self.account_info.write().unwrap() = Some(account_info);
457            Some(account.pickle().await)
458        } else {
459            None
460        };
461
462        let private_identity_pickle =
463            if let Some(i) = changes.private_identity { Some(i.pickle().await?) } else { None };
464
465        let recovery_key_pickle = changes.recovery_key;
466
467        let device_changes = changes.devices;
468        let mut session_changes = HashMap::new();
469
470        for session in changes.sessions {
471            let pickle = session.pickle().await;
472            let key = self.encode_key(SESSIONS_TABLE_NAME, &session);
473
474            self.session_cache.add(session).await;
475            session_changes.insert(key, pickle);
476        }
477
478        let mut inbound_session_changes = HashMap::new();
479
480        for session in changes.inbound_group_sessions {
481            let key = self.encode_key(INBOUND_GROUP_TABLE_NAME, &session);
482            let pickle = session.pickle().await;
483
484            inbound_session_changes.insert(key, pickle);
485        }
486
487        let mut outbound_session_changes = HashMap::new();
488
489        for session in changes.outbound_group_sessions {
490            let key = self.encode_key(OUTBOUND_GROUP_TABLE_NAME, &session);
491            let pickle = session.pickle().await;
492
493            outbound_session_changes.insert(key, pickle);
494        }
495
496        let identity_changes = changes.identities;
497        let olm_hashes = changes.message_hashes;
498        let key_requests = changes.key_requests;
499        let backup_version = changes.backup_version;
500
501        let ret: Result<(), TransactionError<CryptoStoreError>> = (
502            &self.account,
503            &self.private_identity,
504            &self.devices,
505            &self.identities,
506            &self.sessions,
507            &self.inbound_group_sessions,
508            &self.outbound_group_sessions,
509            &self.olm_hashes,
510            &self.outgoing_secret_requests,
511            &self.unsent_secret_requests,
512            &self.secret_requests_by_info,
513        )
514            .transaction(
515                |(
516                    account,
517                    private_identity,
518                    devices,
519                    identities,
520                    sessions,
521                    inbound_sessions,
522                    outbound_sessions,
523                    hashes,
524                    outgoing_secret_requests,
525                    unsent_secret_requests,
526                    secret_requests_by_info,
527                )| {
528                    if let Some(a) = &account_pickle {
529                        account.insert(
530                            "account".encode(),
531                            self.serialize_value(a).map_err(ConflictableTransactionError::Abort)?,
532                        )?;
533                    }
534
535                    if let Some(i) = &private_identity_pickle {
536                        private_identity.insert(
537                            "identity".encode(),
538                            self.serialize_value(&i)
539                                .map_err(ConflictableTransactionError::Abort)?,
540                        )?;
541                    }
542
543                    if let Some(r) = &recovery_key_pickle {
544                        account.insert(
545                            "recovery_key_v1".encode(),
546                            self.serialize_value(r).map_err(ConflictableTransactionError::Abort)?,
547                        )?;
548                    }
549
550                    if let Some(b) = &backup_version {
551                        account.insert(
552                            "backup_version_v1".encode(),
553                            self.serialize_value(b).map_err(ConflictableTransactionError::Abort)?,
554                        )?;
555                    }
556
557                    for device in device_changes.new.iter().chain(&device_changes.changed) {
558                        let key = self.encode_key(DEVICE_TABLE_NAME, device);
559                        let device = self
560                            .serialize_value(&device)
561                            .map_err(ConflictableTransactionError::Abort)?;
562                        devices.insert(key, device)?;
563                    }
564
565                    for device in &device_changes.deleted {
566                        let key = self.encode_key(DEVICE_TABLE_NAME, device);
567                        devices.remove(key)?;
568                    }
569
570                    for identity in identity_changes.changed.iter().chain(&identity_changes.new) {
571                        identities.insert(
572                            self.encode_key(IDENTITIES_TABLE_NAME, identity.user_id()),
573                            self.serialize_value(&identity)
574                                .map_err(ConflictableTransactionError::Abort)?,
575                        )?;
576                    }
577
578                    for (key, session) in &session_changes {
579                        sessions.insert(
580                            key.as_slice(),
581                            self.serialize_value(&session)
582                                .map_err(ConflictableTransactionError::Abort)?,
583                        )?;
584                    }
585
586                    for (key, session) in &inbound_session_changes {
587                        inbound_sessions.insert(
588                            key.as_slice(),
589                            self.serialize_value(&session)
590                                .map_err(ConflictableTransactionError::Abort)?,
591                        )?;
592                    }
593
594                    for (key, session) in &outbound_session_changes {
595                        outbound_sessions.insert(
596                            key.as_slice(),
597                            self.serialize_value(&session)
598                                .map_err(ConflictableTransactionError::Abort)?,
599                        )?;
600                    }
601
602                    for hash in &olm_hashes {
603                        hashes.insert(
604                            serde_json::to_vec(hash)
605                                .map_err(CryptoStoreError::Serialization)
606                                .map_err(ConflictableTransactionError::Abort)?,
607                            &[0],
608                        )?;
609                    }
610
611                    for key_request in &key_requests {
612                        secret_requests_by_info.insert(
613                            self.encode_key(SECRET_REQUEST_BY_INFO_TABLE, &key_request.info),
614                            key_request.request_id.encode(),
615                        )?;
616
617                        let key_request_id = key_request.request_id.encode();
618
619                        if key_request.sent_out {
620                            unsent_secret_requests.remove(key_request_id.clone())?;
621                            outgoing_secret_requests.insert(
622                                key_request_id,
623                                self.serialize_value(&key_request)
624                                    .map_err(ConflictableTransactionError::Abort)?,
625                            )?;
626                        } else {
627                            outgoing_secret_requests.remove(key_request_id.clone())?;
628                            unsent_secret_requests.insert(
629                                key_request_id,
630                                self.serialize_value(&key_request)
631                                    .map_err(ConflictableTransactionError::Abort)?,
632                            )?;
633                        }
634                    }
635
636                    Ok(())
637                },
638            );
639
640        ret.map_err(CryptoStoreError::backend)?;
641        self.inner.flush().map_err(CryptoStoreError::backend)?;
642
643        Ok(())
644    }
645
646    async fn get_outgoing_key_request_helper(&self, id: &[u8]) -> Result<Option<GossipRequest>> {
647        let request = self
648            .outgoing_secret_requests
649            .get(id)
650            .map_err(CryptoStoreError::backend)?
651            .map(|r| self.deserialize_value(&r))
652            .transpose()?;
653
654        let request = if request.is_none() {
655            self.unsent_secret_requests
656                .get(id)
657                .map_err(CryptoStoreError::backend)?
658                .map(|r| self.deserialize_value(&r))
659                .transpose()?
660        } else {
661            request
662        };
663
664        Ok(request)
665    }
666
667    /// Save a batch of tracked users.
668    ///
669    /// # Arguments
670    ///
671    /// * `tracked_users` - A list of tuples. The first element of the tuple is
672    /// the user ID, the second element is if the user should be considered to
673    /// be dirty.
674    pub async fn save_tracked_users(
675        &self,
676        tracked_users: &[(&UserId, bool)],
677    ) -> Result<(), CryptoStoreError> {
678        let users: Vec<TrackedUser> = tracked_users
679            .iter()
680            .map(|(u, d)| TrackedUser { user_id: (*u).into(), dirty: *d })
681            .collect();
682
683        let mut batch = Batch::default();
684
685        for user in users {
686            batch.insert(
687                self.encode_key(TRACKED_USERS_TABLE, user.user_id.as_str()),
688                self.serialize_value(&user)?,
689            );
690        }
691
692        self.tracked_users.apply_batch(batch).map_err(CryptoStoreError::backend)
693    }
694}
695
696#[async_trait]
697impl CryptoStore for SledCryptoStore {
698    async fn load_account(&self) -> Result<Option<ReadOnlyAccount>> {
699        if let Some(pickle) =
700            self.account.get("account".encode()).map_err(CryptoStoreError::backend)?
701        {
702            let pickle = self.deserialize_value(&pickle)?;
703
704            self.load_tracked_users().await?;
705            let account = ReadOnlyAccount::from_pickle(pickle)?;
706
707            let account_info = AccountInfo {
708                user_id: account.user_id.clone(),
709                device_id: account.device_id.clone(),
710                identity_keys: account.identity_keys.clone(),
711            };
712
713            *self.account_info.write().unwrap() = Some(account_info);
714
715            Ok(Some(account))
716        } else {
717            Ok(None)
718        }
719    }
720
721    async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> {
722        self.save_changes(Changes { account: Some(account), ..Default::default() }).await
723    }
724
725    async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
726        if let Some(i) =
727            self.private_identity.get("identity".encode()).map_err(CryptoStoreError::backend)?
728        {
729            let pickle = self.deserialize_value(&i)?;
730            Ok(Some(
731                PrivateCrossSigningIdentity::from_pickle(pickle)
732                    .await
733                    .map_err(|_| CryptoStoreError::UnpicklingError)?,
734            ))
735        } else {
736            Ok(None)
737        }
738    }
739
740    async fn save_changes(&self, changes: Changes) -> Result<()> {
741        self.save_changes(changes).await
742    }
743
744    async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
745        let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?;
746
747        if self.session_cache.get(sender_key).is_none() {
748            let sessions: Result<Vec<Session>> = self
749                .sessions
750                .scan_prefix(self.encode_key(SESSIONS_TABLE_NAME, sender_key))
751                .map(|s| self.deserialize_value(&s.map_err(CryptoStoreError::backend)?.1))
752                .map(|p| {
753                    Ok(Session::from_pickle(
754                        account_info.user_id.clone(),
755                        account_info.device_id.clone(),
756                        account_info.identity_keys.clone(),
757                        p?,
758                    ))
759                })
760                .collect();
761
762            self.session_cache.set_for_sender(sender_key, sessions?);
763        }
764
765        Ok(self.session_cache.get(sender_key))
766    }
767
768    async fn get_inbound_group_session(
769        &self,
770        room_id: &RoomId,
771        sender_key: &str,
772        session_id: &str,
773    ) -> Result<Option<InboundGroupSession>> {
774        let key = self.encode_key(INBOUND_GROUP_TABLE_NAME, (room_id, sender_key, session_id));
775        let pickle = self
776            .inbound_group_sessions
777            .get(&key)
778            .map_err(CryptoStoreError::backend)?
779            .map(|p| self.deserialize_value(&p));
780
781        if let Some(pickle) = pickle {
782            Ok(Some(InboundGroupSession::from_pickle(pickle?)?))
783        } else {
784            Ok(None)
785        }
786    }
787
788    async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
789        let pickles: Result<Vec<PickledInboundGroupSession>> = self
790            .inbound_group_sessions
791            .iter()
792            .map(|p| self.deserialize_value(&p.map_err(CryptoStoreError::backend)?.1))
793            .collect();
794
795        Ok(pickles?.into_iter().filter_map(|p| InboundGroupSession::from_pickle(p).ok()).collect())
796    }
797
798    async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
799        let pickles: Vec<PickledInboundGroupSession> = self
800            .inbound_group_sessions
801            .iter()
802            .map(|p| {
803                let item = p.map_err(CryptoStoreError::backend)?;
804                self.deserialize_value(&item.1)
805            })
806            .collect::<Result<_>>()?;
807
808        let total = pickles.len();
809        let backed_up = pickles.into_iter().filter(|p| p.backed_up).count();
810
811        Ok(RoomKeyCounts { total, backed_up })
812    }
813
814    async fn inbound_group_sessions_for_backup(
815        &self,
816        limit: usize,
817    ) -> Result<Vec<InboundGroupSession>> {
818        let pickles: Vec<InboundGroupSession> = self
819            .inbound_group_sessions
820            .iter()
821            .map(|p| {
822                let item = p.map_err(CryptoStoreError::backend)?;
823                self.deserialize_value(&item.1)
824            })
825            .filter_map(|p: Result<PickledInboundGroupSession, CryptoStoreError>| match p {
826                Ok(p) => {
827                    if !p.backed_up {
828                        Some(InboundGroupSession::from_pickle(p).map_err(CryptoStoreError::from))
829                    } else {
830                        None
831                    }
832                }
833
834                Err(p) => Some(Err(p)),
835            })
836            .take(limit)
837            .collect::<Result<_>>()?;
838
839        Ok(pickles)
840    }
841
842    async fn reset_backup_state(&self) -> Result<()> {
843        self.reset_backup_state().await
844    }
845
846    async fn get_outbound_group_sessions(
847        &self,
848        room_id: &RoomId,
849    ) -> Result<Option<OutboundGroupSession>> {
850        self.load_outbound_group_session(room_id).await
851    }
852
853    fn is_user_tracked(&self, user_id: &UserId) -> bool {
854        self.tracked_users_cache.contains(user_id)
855    }
856
857    fn has_users_for_key_query(&self) -> bool {
858        !self.users_for_key_query_cache.is_empty()
859    }
860
861    fn users_for_key_query(&self) -> HashSet<OwnedUserId> {
862        self.users_for_key_query_cache.iter().map(|u| u.clone()).collect()
863    }
864
865    fn tracked_users(&self) -> HashSet<OwnedUserId> {
866        self.tracked_users_cache.to_owned().iter().map(|u| u.clone()).collect()
867    }
868
869    async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
870        let already_added = self.tracked_users_cache.insert(user.to_owned());
871
872        if dirty {
873            self.users_for_key_query_cache.insert(user.to_owned());
874        } else {
875            self.users_for_key_query_cache.remove(user);
876        }
877
878        self.save_tracked_users(&[(user, dirty)]).await?;
879
880        Ok(already_added)
881    }
882
883    async fn get_device(
884        &self,
885        user_id: &UserId,
886        device_id: &DeviceId,
887    ) -> Result<Option<ReadOnlyDevice>> {
888        let key = self.encode_key(DEVICE_TABLE_NAME, (user_id, device_id));
889
890        Ok(self
891            .devices
892            .get(key)
893            .map_err(CryptoStoreError::backend)?
894            .map(|d| self.deserialize_value(&d))
895            .transpose()?)
896    }
897
898    async fn get_user_devices(
899        &self,
900        user_id: &UserId,
901    ) -> Result<HashMap<OwnedDeviceId, ReadOnlyDevice>> {
902        let key = self.encode_key(DEVICE_TABLE_NAME, user_id);
903        self.devices
904            .scan_prefix(key)
905            .map(|d| self.deserialize_value(&d.map_err(CryptoStoreError::backend)?.1))
906            .map(|d| {
907                let d: ReadOnlyDevice = d?;
908                Ok((d.device_id().to_owned(), d))
909            })
910            .collect()
911    }
912
913    async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<ReadOnlyUserIdentities>> {
914        let key = self.encode_key(IDENTITIES_TABLE_NAME, user_id);
915
916        Ok(self
917            .identities
918            .get(key)
919            .map_err(CryptoStoreError::backend)?
920            .map(|i| self.deserialize_value(&i))
921            .transpose()?)
922    }
923
924    async fn is_message_known(
925        &self,
926        message_hash: &matrix_sdk_crypto::olm::OlmMessageHash,
927    ) -> Result<bool> {
928        Ok(self
929            .olm_hashes
930            .contains_key(serde_json::to_vec(message_hash)?)
931            .map_err(CryptoStoreError::backend)?)
932    }
933
934    async fn get_outgoing_secret_requests(
935        &self,
936        request_id: &TransactionId,
937    ) -> Result<Option<GossipRequest>> {
938        let request_id = request_id.encode();
939
940        self.get_outgoing_key_request_helper(&request_id).await
941    }
942
943    async fn get_secret_request_by_info(
944        &self,
945        key_info: &SecretInfo,
946    ) -> Result<Option<GossipRequest>> {
947        let id = self
948            .secret_requests_by_info
949            .get(self.encode_key(SECRET_REQUEST_BY_INFO_TABLE, key_info))
950            .map_err(CryptoStoreError::backend)?;
951
952        if let Some(id) = id {
953            self.get_outgoing_key_request_helper(&id).await
954        } else {
955            Ok(None)
956        }
957    }
958
959    async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
960        let requests: Result<Vec<GossipRequest>> = self
961            .unsent_secret_requests
962            .iter()
963            .map(|i| self.deserialize_value(&i.map_err(CryptoStoreError::backend)?.1))
964            .collect();
965
966        requests
967    }
968
969    async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
970        let ret: Result<(), TransactionError<CryptoStoreError>> = (
971            &self.outgoing_secret_requests,
972            &self.unsent_secret_requests,
973            &self.secret_requests_by_info,
974        )
975            .transaction(
976                |(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
977                    let sent_request: Option<GossipRequest> = outgoing_key_requests
978                        .remove(request_id.encode())?
979                        .map(|r| self.deserialize_value(&r))
980                        .transpose()
981                        .map_err(ConflictableTransactionError::Abort)?;
982
983                    let unsent_request: Option<GossipRequest> = unsent_key_requests
984                        .remove(request_id.encode())?
985                        .map(|r| self.deserialize_value(&r))
986                        .transpose()
987                        .map_err(ConflictableTransactionError::Abort)?;
988
989                    if let Some(request) = sent_request {
990                        key_requests_by_info
991                            .remove(self.encode_key(SECRET_REQUEST_BY_INFO_TABLE, &request.info))?;
992                    }
993
994                    if let Some(request) = unsent_request {
995                        key_requests_by_info
996                            .remove(self.encode_key(SECRET_REQUEST_BY_INFO_TABLE, &request.info))?;
997                    }
998
999                    Ok(())
1000                },
1001            );
1002
1003        ret.map_err(CryptoStoreError::backend)?;
1004        self.inner.flush_async().await.map_err(CryptoStoreError::backend)?;
1005
1006        Ok(())
1007    }
1008
1009    async fn load_backup_keys(&self) -> Result<BackupKeys> {
1010        let key = {
1011            let backup_version = self
1012                .account
1013                .get("backup_version_v1".encode())
1014                .map_err(CryptoStoreError::backend)?
1015                .map(|v| self.deserialize_value(&v))
1016                .transpose()?;
1017
1018            let recovery_key = {
1019                self.account
1020                    .get("recovery_key_v1".encode())
1021                    .map_err(CryptoStoreError::backend)?
1022                    .map(|p| self.deserialize_value(&p))
1023                    .transpose()?
1024            };
1025
1026            BackupKeys { backup_version, recovery_key }
1027        };
1028
1029        Ok(key)
1030    }
1031}
1032
1033#[cfg(test)]
1034mod tests {
1035    use matrix_sdk_crypto::cryptostore_integration_tests;
1036    use once_cell::sync::Lazy;
1037    use tempfile::{tempdir, TempDir};
1038
1039    use super::SledCryptoStore;
1040
1041    static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
1042
1043    async fn get_store(name: &str, passphrase: Option<&str>) -> SledCryptoStore {
1044        let tmpdir_path = TMP_DIR.path().join(name);
1045
1046        let store =
1047            SledCryptoStore::open_with_passphrase(tmpdir_path.to_str().unwrap(), passphrase)
1048                .expect("Can't create a passphrase protected store");
1049
1050        store
1051    }
1052
1053    cryptostore_integration_tests!();
1054}
1055
1056#[cfg(test)]
1057mod encrypted_tests {
1058    use matrix_sdk_crypto::cryptostore_integration_tests;
1059    use once_cell::sync::Lazy;
1060    use tempfile::{tempdir, TempDir};
1061
1062    use super::SledCryptoStore;
1063
1064    static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
1065
1066    async fn get_store(name: &str, passphrase: Option<&str>) -> SledCryptoStore {
1067        let tmpdir_path = TMP_DIR.path().join(name);
1068        let pass = passphrase.unwrap_or("default_test_password");
1069
1070        let store =
1071            SledCryptoStore::open_with_passphrase(tmpdir_path.to_str().unwrap(), Some(pass))
1072                .expect("Can't create a passphrase protected store");
1073
1074        store
1075    }
1076    cryptostore_integration_tests!();
1077}