1use std::{
16 borrow::Cow,
17 collections::{HashMap, HashSet},
18 path::{Path, PathBuf},
19 sync::{Arc, RwLock},
20};
21
22use async_trait::async_trait;
23use dashmap::DashSet;
24use matrix_sdk_common::locks::Mutex;
25use matrix_sdk_crypto::{
26 olm::{
27 IdentityKeys, InboundGroupSession, OutboundGroupSession, PickledInboundGroupSession,
28 PrivateCrossSigningIdentity, Session,
29 },
30 store::{
31 caches::SessionStore, BackupKeys, Changes, CryptoStore, CryptoStoreError, Result,
32 RoomKeyCounts,
33 },
34 types::{events::room_key_request::SupportedKeyInfo, EventEncryptionAlgorithm},
35 GossipRequest, ReadOnlyAccount, ReadOnlyDevice, ReadOnlyUserIdentities, SecretInfo,
36};
37use matrix_sdk_store_encryption::StoreCipher;
38use ruma::{DeviceId, OwnedDeviceId, OwnedUserId, RoomId, TransactionId, UserId};
39use serde::{de::DeserializeOwned, Deserialize, Serialize};
40pub use sled::Error;
41use sled::{
42 transaction::{ConflictableTransactionError, TransactionError},
43 Batch, Config, Db, IVec, Transactional, Tree,
44};
45use tracing::debug;
46
47use super::OpenStoreError;
48use crate::encode_key::{EncodeKey, ENCODE_SEPARATOR};
49
50const DATABASE_VERSION: u8 = 5;
51
52const DEVICE_TABLE_NAME: &str = "crypto-store-devices";
56const IDENTITIES_TABLE_NAME: &str = "crypto-store-identities";
57const SESSIONS_TABLE_NAME: &str = "crypto-store-sessions";
58const INBOUND_GROUP_TABLE_NAME: &str = "crypto-store-inbound-group-sessions";
59const OUTBOUND_GROUP_TABLE_NAME: &str = "crypto-store-outbound-group-sessions";
60const SECRET_REQUEST_BY_INFO_TABLE: &str = "crypto-store-secret-request-by-info";
61const TRACKED_USERS_TABLE: &str = "crypto-store-secret-tracked-users";
62
63impl EncodeKey for InboundGroupSession {
64 fn encode(&self) -> Vec<u8> {
65 (self.room_id(), self.sender_key().to_base64(), self.session_id()).encode()
66 }
67
68 fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
69 (self.room_id(), self.sender_key().to_base64(), self.session_id())
70 .encode_secure(table_name, store_cipher)
71 }
72}
73
74impl EncodeKey for OutboundGroupSession {
75 fn encode(&self) -> Vec<u8> {
76 self.room_id().encode()
77 }
78 fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
79 self.room_id().encode_secure(table_name, store_cipher)
80 }
81}
82
83impl EncodeKey for Session {
84 fn encode(&self) -> Vec<u8> {
85 let sender_key = self.sender_key().to_base64();
86 let session_id = self.session_id();
87
88 [sender_key.as_bytes(), &[ENCODE_SEPARATOR], session_id.as_bytes(), &[ENCODE_SEPARATOR]]
89 .concat()
90 }
91
92 fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
93 let sender_key =
94 store_cipher.hash_key(table_name, self.sender_key().to_base64().as_bytes());
95 let session_id = store_cipher.hash_key(table_name, self.session_id().as_bytes());
96
97 [sender_key.as_slice(), &[ENCODE_SEPARATOR], session_id.as_slice(), &[ENCODE_SEPARATOR]]
98 .concat()
99 }
100}
101
102impl EncodeKey for SecretInfo {
103 fn encode(&self) -> Vec<u8> {
104 match self {
105 SecretInfo::KeyRequest(k) => k.encode(),
106 SecretInfo::SecretRequest(s) => s.encode(),
107 }
108 }
109 fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
110 match self {
111 SecretInfo::KeyRequest(k) => k.encode_secure(table_name, store_cipher),
112 SecretInfo::SecretRequest(s) => s.encode_secure(table_name, store_cipher),
113 }
114 }
115}
116
117impl EncodeKey for EventEncryptionAlgorithm {
118 fn encode_as_bytes(&self) -> Cow<'_, [u8]> {
119 let s: &str = self.as_ref();
120 s.as_bytes().into()
121 }
122}
123
124impl EncodeKey for SupportedKeyInfo {
125 fn encode(&self) -> Vec<u8> {
126 (self.room_id(), &self.algorithm(), self.session_id()).encode()
127 }
128 fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
129 let room_id = store_cipher.hash_key(table_name, self.room_id().as_bytes());
130 let algorithm = store_cipher.hash_key(table_name, self.algorithm().as_ref().as_bytes());
131 let session_id = store_cipher.hash_key(table_name, self.session_id().as_bytes());
132
133 [
134 room_id.as_slice(),
135 &[ENCODE_SEPARATOR],
136 algorithm.as_slice(),
137 &[ENCODE_SEPARATOR],
138 session_id.as_slice(),
139 &[ENCODE_SEPARATOR],
140 ]
141 .concat()
142 }
143}
144
145impl EncodeKey for ReadOnlyDevice {
146 fn encode(&self) -> Vec<u8> {
147 (self.user_id(), self.device_id()).encode()
148 }
149 fn encode_secure(&self, table_name: &str, store_cipher: &StoreCipher) -> Vec<u8> {
150 (self.user_id(), self.device_id()).encode_secure(table_name, store_cipher)
151 }
152}
153
154#[derive(Clone, Debug)]
155pub struct AccountInfo {
156 user_id: Arc<UserId>,
157 device_id: Arc<DeviceId>,
158 identity_keys: Arc<IdentityKeys>,
159}
160
161#[derive(Debug, Serialize, Deserialize)]
162struct TrackedUser {
163 user_id: OwnedUserId,
164 dirty: bool,
165}
166
167#[derive(Clone)]
171pub struct SledCryptoStore {
172 account_info: Arc<RwLock<Option<AccountInfo>>>,
173 store_cipher: Option<Arc<StoreCipher>>,
174 path: Option<PathBuf>,
175 inner: Db,
176
177 session_cache: SessionStore,
178 tracked_users_cache: Arc<DashSet<OwnedUserId>>,
179 users_for_key_query_cache: Arc<DashSet<OwnedUserId>>,
180
181 account: Tree,
182 private_identity: Tree,
183
184 olm_hashes: Tree,
185 sessions: Tree,
186 inbound_group_sessions: Tree,
187 outbound_group_sessions: Tree,
188
189 outgoing_secret_requests: Tree,
190 unsent_secret_requests: Tree,
191 secret_requests_by_info: Tree,
192
193 devices: Tree,
194 identities: Tree,
195
196 tracked_users: Tree,
197}
198
199impl std::fmt::Debug for SledCryptoStore {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 if let Some(path) = &self.path {
202 f.debug_struct("SledCryptoStore").field("path", &path).finish()
203 } else {
204 f.debug_struct("SledCryptoStore").field("path", &"memory store").finish()
205 }
206 }
207}
208
209impl SledCryptoStore {
210 pub fn open_with_passphrase(
213 path: impl AsRef<Path>,
214 passphrase: Option<&str>,
215 ) -> Result<Self, OpenStoreError> {
216 let path = path.as_ref().join("matrix-sdk-crypto");
217 let db =
218 Config::new().temporary(false).path(&path).open().map_err(CryptoStoreError::backend)?;
219
220 let store_cipher = passphrase
221 .map(|p| Self::get_or_create_store_cipher(p, &db))
222 .transpose()?
223 .map(Into::into);
224
225 SledCryptoStore::open_helper(db, Some(path), store_cipher)
226 }
227
228 pub fn open_with_database(db: Db, passphrase: Option<&str>) -> Result<Self, OpenStoreError> {
231 let store_cipher = passphrase
232 .map(|p| Self::get_or_create_store_cipher(p, &db))
233 .transpose()?
234 .map(Into::into);
235
236 SledCryptoStore::open_helper(db, None, store_cipher)
237 }
238
239 fn get_account_info(&self) -> Option<AccountInfo> {
240 self.account_info.read().unwrap().clone()
241 }
242
243 fn serialize_value(&self, event: &impl Serialize) -> Result<Vec<u8>, CryptoStoreError> {
244 if let Some(key) = &self.store_cipher {
245 key.encrypt_value(event).map_err(CryptoStoreError::backend)
246 } else {
247 Ok(serde_json::to_vec(event)?)
248 }
249 }
250
251 fn deserialize_value<T: DeserializeOwned>(&self, event: &[u8]) -> Result<T, CryptoStoreError> {
252 if let Some(key) = &self.store_cipher {
253 key.decrypt_value(event).map_err(CryptoStoreError::backend)
254 } else {
255 Ok(serde_json::from_slice(event)?)
256 }
257 }
258
259 fn encode_key<T: EncodeKey>(&self, table_name: &str, key: T) -> Vec<u8> {
260 if let Some(store_cipher) = &self.store_cipher {
261 key.encode_secure(table_name, store_cipher).to_vec()
262 } else {
263 key.encode()
264 }
265 }
266
267 async fn reset_backup_state(&self) -> Result<()> {
268 let mut pickles: Vec<(IVec, PickledInboundGroupSession)> = self
269 .inbound_group_sessions
270 .iter()
271 .map(|p| {
272 let item = p.map_err(CryptoStoreError::backend)?;
273 Ok((item.0, self.deserialize_value(&item.1)?))
274 })
275 .collect::<Result<_>>()?;
276
277 for (_, pickle) in &mut pickles {
278 pickle.backed_up = false;
279 }
280
281 let ret: Result<(), TransactionError<CryptoStoreError>> =
282 self.inbound_group_sessions.transaction(|inbound_sessions| {
283 for (key, pickle) in &pickles {
284 inbound_sessions.insert(
285 key,
286 self.serialize_value(pickle)
287 .map_err(ConflictableTransactionError::Abort)?,
288 )?;
289 }
290
291 Ok(())
292 });
293
294 ret.map_err(CryptoStoreError::backend)?;
295
296 self.inner.flush_async().await.map_err(CryptoStoreError::backend)?;
297
298 Ok(())
299 }
300
301 fn upgrade(&self) -> Result<()> {
302 let version = self
303 .inner
304 .get("store_version")
305 .map_err(CryptoStoreError::backend)?
306 .map(|v| {
307 let (version_bytes, _) = v.split_at(std::mem::size_of::<u8>());
308 u8::from_be_bytes(version_bytes.try_into().unwrap_or_default())
309 })
310 .unwrap_or(DATABASE_VERSION);
311
312 if version != DATABASE_VERSION {
313 debug!(version, new_version = DATABASE_VERSION, "Upgrading the Sled crypto store");
314 }
315
316 if version <= 3 {
317 return Err(CryptoStoreError::UnsupportedDatabaseVersion(
318 version.into(),
319 DATABASE_VERSION.into(),
320 ));
321 }
322
323 if version <= 4 {
324 self.outgoing_secret_requests.clear().map_err(CryptoStoreError::backend)?;
328 self.unsent_secret_requests.clear().map_err(CryptoStoreError::backend)?;
329 self.secret_requests_by_info.clear().map_err(CryptoStoreError::backend)?;
330 }
331
332 self.inner
333 .insert("store_version", DATABASE_VERSION.to_be_bytes().as_ref())
334 .map_err(CryptoStoreError::backend)?;
335 self.inner.flush().map_err(CryptoStoreError::backend)?;
336
337 Ok(())
338 }
339
340 fn get_or_create_store_cipher(passphrase: &str, database: &Db) -> Result<StoreCipher> {
341 let cipher = if let Some(key) =
342 database.get("store_cipher".encode()).map_err(CryptoStoreError::backend)?
343 {
344 StoreCipher::import(passphrase, &key).map_err(|_| CryptoStoreError::UnpicklingError)?
345 } else {
346 let cipher = StoreCipher::new().map_err(CryptoStoreError::backend)?;
347 #[cfg(not(test))]
348 let export = cipher.export(passphrase);
349 #[cfg(test)]
350 let export = cipher._insecure_export_fast_for_testing(passphrase);
351 database
352 .insert("store_cipher".encode(), export.map_err(CryptoStoreError::backend)?)
353 .map_err(CryptoStoreError::backend)?;
354 cipher
355 };
356
357 Ok(cipher)
358 }
359
360 pub(crate) fn open_helper(
361 db: Db,
362 path: Option<PathBuf>,
363 store_cipher: Option<Arc<StoreCipher>>,
364 ) -> Result<Self, OpenStoreError> {
365 let account = db.open_tree("account")?;
366 let private_identity = db.open_tree("private_identity")?;
367
368 let sessions = db.open_tree("session")?;
369 let inbound_group_sessions = db.open_tree("inbound_group_sessions")?;
370
371 let outbound_group_sessions = db.open_tree("outbound_group_sessions")?;
372
373 let tracked_users = db.open_tree("tracked_users")?;
374 let olm_hashes = db.open_tree("olm_hashes")?;
375
376 let devices = db.open_tree("devices")?;
377 let identities = db.open_tree("identities")?;
378
379 let outgoing_secret_requests = db.open_tree("outgoing_secret_requests")?;
380 let unsent_secret_requests = db.open_tree("unsent_secret_requests")?;
381 let secret_requests_by_info = db.open_tree("secret_requests_by_info")?;
382
383 let session_cache = SessionStore::new();
384
385 let database = Self {
386 account_info: RwLock::new(None).into(),
387 path,
388 inner: db,
389 store_cipher,
390 account,
391 private_identity,
392 sessions,
393 session_cache,
394 tracked_users_cache: DashSet::new().into(),
395 users_for_key_query_cache: DashSet::new().into(),
396 inbound_group_sessions,
397 outbound_group_sessions,
398 outgoing_secret_requests,
399 unsent_secret_requests,
400 secret_requests_by_info,
401 devices,
402 tracked_users,
403 olm_hashes,
404 identities,
405 };
406
407 database.upgrade()?;
408
409 Ok(database)
410 }
411
412 async fn load_tracked_users(&self) -> Result<()> {
413 for value in &self.tracked_users {
414 let (_, user) = value.map_err(CryptoStoreError::backend)?;
415 let user: TrackedUser = self.deserialize_value(&user)?;
416
417 self.tracked_users_cache.insert(user.user_id.to_owned());
418
419 if user.dirty {
420 self.users_for_key_query_cache.insert(user.user_id);
421 }
422 }
423
424 Ok(())
425 }
426
427 async fn load_outbound_group_session(
428 &self,
429 room_id: &RoomId,
430 ) -> Result<Option<OutboundGroupSession>> {
431 let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?;
432
433 self.outbound_group_sessions
434 .get(self.encode_key(OUTBOUND_GROUP_TABLE_NAME, room_id))
435 .map_err(CryptoStoreError::backend)?
436 .map(|p| self.deserialize_value(&p))
437 .transpose()?
438 .map(|p| {
439 Ok(OutboundGroupSession::from_pickle(
440 account_info.device_id,
441 account_info.identity_keys,
442 p,
443 )?)
444 })
445 .transpose()
446 }
447
448 async fn save_changes(&self, changes: Changes) -> Result<()> {
449 let account_pickle = if let Some(account) = changes.account {
450 let account_info = AccountInfo {
451 user_id: account.user_id.clone(),
452 device_id: account.device_id.clone(),
453 identity_keys: account.identity_keys.clone(),
454 };
455
456 *self.account_info.write().unwrap() = Some(account_info);
457 Some(account.pickle().await)
458 } else {
459 None
460 };
461
462 let private_identity_pickle =
463 if let Some(i) = changes.private_identity { Some(i.pickle().await?) } else { None };
464
465 let recovery_key_pickle = changes.recovery_key;
466
467 let device_changes = changes.devices;
468 let mut session_changes = HashMap::new();
469
470 for session in changes.sessions {
471 let pickle = session.pickle().await;
472 let key = self.encode_key(SESSIONS_TABLE_NAME, &session);
473
474 self.session_cache.add(session).await;
475 session_changes.insert(key, pickle);
476 }
477
478 let mut inbound_session_changes = HashMap::new();
479
480 for session in changes.inbound_group_sessions {
481 let key = self.encode_key(INBOUND_GROUP_TABLE_NAME, &session);
482 let pickle = session.pickle().await;
483
484 inbound_session_changes.insert(key, pickle);
485 }
486
487 let mut outbound_session_changes = HashMap::new();
488
489 for session in changes.outbound_group_sessions {
490 let key = self.encode_key(OUTBOUND_GROUP_TABLE_NAME, &session);
491 let pickle = session.pickle().await;
492
493 outbound_session_changes.insert(key, pickle);
494 }
495
496 let identity_changes = changes.identities;
497 let olm_hashes = changes.message_hashes;
498 let key_requests = changes.key_requests;
499 let backup_version = changes.backup_version;
500
501 let ret: Result<(), TransactionError<CryptoStoreError>> = (
502 &self.account,
503 &self.private_identity,
504 &self.devices,
505 &self.identities,
506 &self.sessions,
507 &self.inbound_group_sessions,
508 &self.outbound_group_sessions,
509 &self.olm_hashes,
510 &self.outgoing_secret_requests,
511 &self.unsent_secret_requests,
512 &self.secret_requests_by_info,
513 )
514 .transaction(
515 |(
516 account,
517 private_identity,
518 devices,
519 identities,
520 sessions,
521 inbound_sessions,
522 outbound_sessions,
523 hashes,
524 outgoing_secret_requests,
525 unsent_secret_requests,
526 secret_requests_by_info,
527 )| {
528 if let Some(a) = &account_pickle {
529 account.insert(
530 "account".encode(),
531 self.serialize_value(a).map_err(ConflictableTransactionError::Abort)?,
532 )?;
533 }
534
535 if let Some(i) = &private_identity_pickle {
536 private_identity.insert(
537 "identity".encode(),
538 self.serialize_value(&i)
539 .map_err(ConflictableTransactionError::Abort)?,
540 )?;
541 }
542
543 if let Some(r) = &recovery_key_pickle {
544 account.insert(
545 "recovery_key_v1".encode(),
546 self.serialize_value(r).map_err(ConflictableTransactionError::Abort)?,
547 )?;
548 }
549
550 if let Some(b) = &backup_version {
551 account.insert(
552 "backup_version_v1".encode(),
553 self.serialize_value(b).map_err(ConflictableTransactionError::Abort)?,
554 )?;
555 }
556
557 for device in device_changes.new.iter().chain(&device_changes.changed) {
558 let key = self.encode_key(DEVICE_TABLE_NAME, device);
559 let device = self
560 .serialize_value(&device)
561 .map_err(ConflictableTransactionError::Abort)?;
562 devices.insert(key, device)?;
563 }
564
565 for device in &device_changes.deleted {
566 let key = self.encode_key(DEVICE_TABLE_NAME, device);
567 devices.remove(key)?;
568 }
569
570 for identity in identity_changes.changed.iter().chain(&identity_changes.new) {
571 identities.insert(
572 self.encode_key(IDENTITIES_TABLE_NAME, identity.user_id()),
573 self.serialize_value(&identity)
574 .map_err(ConflictableTransactionError::Abort)?,
575 )?;
576 }
577
578 for (key, session) in &session_changes {
579 sessions.insert(
580 key.as_slice(),
581 self.serialize_value(&session)
582 .map_err(ConflictableTransactionError::Abort)?,
583 )?;
584 }
585
586 for (key, session) in &inbound_session_changes {
587 inbound_sessions.insert(
588 key.as_slice(),
589 self.serialize_value(&session)
590 .map_err(ConflictableTransactionError::Abort)?,
591 )?;
592 }
593
594 for (key, session) in &outbound_session_changes {
595 outbound_sessions.insert(
596 key.as_slice(),
597 self.serialize_value(&session)
598 .map_err(ConflictableTransactionError::Abort)?,
599 )?;
600 }
601
602 for hash in &olm_hashes {
603 hashes.insert(
604 serde_json::to_vec(hash)
605 .map_err(CryptoStoreError::Serialization)
606 .map_err(ConflictableTransactionError::Abort)?,
607 &[0],
608 )?;
609 }
610
611 for key_request in &key_requests {
612 secret_requests_by_info.insert(
613 self.encode_key(SECRET_REQUEST_BY_INFO_TABLE, &key_request.info),
614 key_request.request_id.encode(),
615 )?;
616
617 let key_request_id = key_request.request_id.encode();
618
619 if key_request.sent_out {
620 unsent_secret_requests.remove(key_request_id.clone())?;
621 outgoing_secret_requests.insert(
622 key_request_id,
623 self.serialize_value(&key_request)
624 .map_err(ConflictableTransactionError::Abort)?,
625 )?;
626 } else {
627 outgoing_secret_requests.remove(key_request_id.clone())?;
628 unsent_secret_requests.insert(
629 key_request_id,
630 self.serialize_value(&key_request)
631 .map_err(ConflictableTransactionError::Abort)?,
632 )?;
633 }
634 }
635
636 Ok(())
637 },
638 );
639
640 ret.map_err(CryptoStoreError::backend)?;
641 self.inner.flush().map_err(CryptoStoreError::backend)?;
642
643 Ok(())
644 }
645
646 async fn get_outgoing_key_request_helper(&self, id: &[u8]) -> Result<Option<GossipRequest>> {
647 let request = self
648 .outgoing_secret_requests
649 .get(id)
650 .map_err(CryptoStoreError::backend)?
651 .map(|r| self.deserialize_value(&r))
652 .transpose()?;
653
654 let request = if request.is_none() {
655 self.unsent_secret_requests
656 .get(id)
657 .map_err(CryptoStoreError::backend)?
658 .map(|r| self.deserialize_value(&r))
659 .transpose()?
660 } else {
661 request
662 };
663
664 Ok(request)
665 }
666
667 pub async fn save_tracked_users(
675 &self,
676 tracked_users: &[(&UserId, bool)],
677 ) -> Result<(), CryptoStoreError> {
678 let users: Vec<TrackedUser> = tracked_users
679 .iter()
680 .map(|(u, d)| TrackedUser { user_id: (*u).into(), dirty: *d })
681 .collect();
682
683 let mut batch = Batch::default();
684
685 for user in users {
686 batch.insert(
687 self.encode_key(TRACKED_USERS_TABLE, user.user_id.as_str()),
688 self.serialize_value(&user)?,
689 );
690 }
691
692 self.tracked_users.apply_batch(batch).map_err(CryptoStoreError::backend)
693 }
694}
695
696#[async_trait]
697impl CryptoStore for SledCryptoStore {
698 async fn load_account(&self) -> Result<Option<ReadOnlyAccount>> {
699 if let Some(pickle) =
700 self.account.get("account".encode()).map_err(CryptoStoreError::backend)?
701 {
702 let pickle = self.deserialize_value(&pickle)?;
703
704 self.load_tracked_users().await?;
705 let account = ReadOnlyAccount::from_pickle(pickle)?;
706
707 let account_info = AccountInfo {
708 user_id: account.user_id.clone(),
709 device_id: account.device_id.clone(),
710 identity_keys: account.identity_keys.clone(),
711 };
712
713 *self.account_info.write().unwrap() = Some(account_info);
714
715 Ok(Some(account))
716 } else {
717 Ok(None)
718 }
719 }
720
721 async fn save_account(&self, account: ReadOnlyAccount) -> Result<()> {
722 self.save_changes(Changes { account: Some(account), ..Default::default() }).await
723 }
724
725 async fn load_identity(&self) -> Result<Option<PrivateCrossSigningIdentity>> {
726 if let Some(i) =
727 self.private_identity.get("identity".encode()).map_err(CryptoStoreError::backend)?
728 {
729 let pickle = self.deserialize_value(&i)?;
730 Ok(Some(
731 PrivateCrossSigningIdentity::from_pickle(pickle)
732 .await
733 .map_err(|_| CryptoStoreError::UnpicklingError)?,
734 ))
735 } else {
736 Ok(None)
737 }
738 }
739
740 async fn save_changes(&self, changes: Changes) -> Result<()> {
741 self.save_changes(changes).await
742 }
743
744 async fn get_sessions(&self, sender_key: &str) -> Result<Option<Arc<Mutex<Vec<Session>>>>> {
745 let account_info = self.get_account_info().ok_or(CryptoStoreError::AccountUnset)?;
746
747 if self.session_cache.get(sender_key).is_none() {
748 let sessions: Result<Vec<Session>> = self
749 .sessions
750 .scan_prefix(self.encode_key(SESSIONS_TABLE_NAME, sender_key))
751 .map(|s| self.deserialize_value(&s.map_err(CryptoStoreError::backend)?.1))
752 .map(|p| {
753 Ok(Session::from_pickle(
754 account_info.user_id.clone(),
755 account_info.device_id.clone(),
756 account_info.identity_keys.clone(),
757 p?,
758 ))
759 })
760 .collect();
761
762 self.session_cache.set_for_sender(sender_key, sessions?);
763 }
764
765 Ok(self.session_cache.get(sender_key))
766 }
767
768 async fn get_inbound_group_session(
769 &self,
770 room_id: &RoomId,
771 sender_key: &str,
772 session_id: &str,
773 ) -> Result<Option<InboundGroupSession>> {
774 let key = self.encode_key(INBOUND_GROUP_TABLE_NAME, (room_id, sender_key, session_id));
775 let pickle = self
776 .inbound_group_sessions
777 .get(&key)
778 .map_err(CryptoStoreError::backend)?
779 .map(|p| self.deserialize_value(&p));
780
781 if let Some(pickle) = pickle {
782 Ok(Some(InboundGroupSession::from_pickle(pickle?)?))
783 } else {
784 Ok(None)
785 }
786 }
787
788 async fn get_inbound_group_sessions(&self) -> Result<Vec<InboundGroupSession>> {
789 let pickles: Result<Vec<PickledInboundGroupSession>> = self
790 .inbound_group_sessions
791 .iter()
792 .map(|p| self.deserialize_value(&p.map_err(CryptoStoreError::backend)?.1))
793 .collect();
794
795 Ok(pickles?.into_iter().filter_map(|p| InboundGroupSession::from_pickle(p).ok()).collect())
796 }
797
798 async fn inbound_group_session_counts(&self) -> Result<RoomKeyCounts> {
799 let pickles: Vec<PickledInboundGroupSession> = self
800 .inbound_group_sessions
801 .iter()
802 .map(|p| {
803 let item = p.map_err(CryptoStoreError::backend)?;
804 self.deserialize_value(&item.1)
805 })
806 .collect::<Result<_>>()?;
807
808 let total = pickles.len();
809 let backed_up = pickles.into_iter().filter(|p| p.backed_up).count();
810
811 Ok(RoomKeyCounts { total, backed_up })
812 }
813
814 async fn inbound_group_sessions_for_backup(
815 &self,
816 limit: usize,
817 ) -> Result<Vec<InboundGroupSession>> {
818 let pickles: Vec<InboundGroupSession> = self
819 .inbound_group_sessions
820 .iter()
821 .map(|p| {
822 let item = p.map_err(CryptoStoreError::backend)?;
823 self.deserialize_value(&item.1)
824 })
825 .filter_map(|p: Result<PickledInboundGroupSession, CryptoStoreError>| match p {
826 Ok(p) => {
827 if !p.backed_up {
828 Some(InboundGroupSession::from_pickle(p).map_err(CryptoStoreError::from))
829 } else {
830 None
831 }
832 }
833
834 Err(p) => Some(Err(p)),
835 })
836 .take(limit)
837 .collect::<Result<_>>()?;
838
839 Ok(pickles)
840 }
841
842 async fn reset_backup_state(&self) -> Result<()> {
843 self.reset_backup_state().await
844 }
845
846 async fn get_outbound_group_sessions(
847 &self,
848 room_id: &RoomId,
849 ) -> Result<Option<OutboundGroupSession>> {
850 self.load_outbound_group_session(room_id).await
851 }
852
853 fn is_user_tracked(&self, user_id: &UserId) -> bool {
854 self.tracked_users_cache.contains(user_id)
855 }
856
857 fn has_users_for_key_query(&self) -> bool {
858 !self.users_for_key_query_cache.is_empty()
859 }
860
861 fn users_for_key_query(&self) -> HashSet<OwnedUserId> {
862 self.users_for_key_query_cache.iter().map(|u| u.clone()).collect()
863 }
864
865 fn tracked_users(&self) -> HashSet<OwnedUserId> {
866 self.tracked_users_cache.to_owned().iter().map(|u| u.clone()).collect()
867 }
868
869 async fn update_tracked_user(&self, user: &UserId, dirty: bool) -> Result<bool> {
870 let already_added = self.tracked_users_cache.insert(user.to_owned());
871
872 if dirty {
873 self.users_for_key_query_cache.insert(user.to_owned());
874 } else {
875 self.users_for_key_query_cache.remove(user);
876 }
877
878 self.save_tracked_users(&[(user, dirty)]).await?;
879
880 Ok(already_added)
881 }
882
883 async fn get_device(
884 &self,
885 user_id: &UserId,
886 device_id: &DeviceId,
887 ) -> Result<Option<ReadOnlyDevice>> {
888 let key = self.encode_key(DEVICE_TABLE_NAME, (user_id, device_id));
889
890 Ok(self
891 .devices
892 .get(key)
893 .map_err(CryptoStoreError::backend)?
894 .map(|d| self.deserialize_value(&d))
895 .transpose()?)
896 }
897
898 async fn get_user_devices(
899 &self,
900 user_id: &UserId,
901 ) -> Result<HashMap<OwnedDeviceId, ReadOnlyDevice>> {
902 let key = self.encode_key(DEVICE_TABLE_NAME, user_id);
903 self.devices
904 .scan_prefix(key)
905 .map(|d| self.deserialize_value(&d.map_err(CryptoStoreError::backend)?.1))
906 .map(|d| {
907 let d: ReadOnlyDevice = d?;
908 Ok((d.device_id().to_owned(), d))
909 })
910 .collect()
911 }
912
913 async fn get_user_identity(&self, user_id: &UserId) -> Result<Option<ReadOnlyUserIdentities>> {
914 let key = self.encode_key(IDENTITIES_TABLE_NAME, user_id);
915
916 Ok(self
917 .identities
918 .get(key)
919 .map_err(CryptoStoreError::backend)?
920 .map(|i| self.deserialize_value(&i))
921 .transpose()?)
922 }
923
924 async fn is_message_known(
925 &self,
926 message_hash: &matrix_sdk_crypto::olm::OlmMessageHash,
927 ) -> Result<bool> {
928 Ok(self
929 .olm_hashes
930 .contains_key(serde_json::to_vec(message_hash)?)
931 .map_err(CryptoStoreError::backend)?)
932 }
933
934 async fn get_outgoing_secret_requests(
935 &self,
936 request_id: &TransactionId,
937 ) -> Result<Option<GossipRequest>> {
938 let request_id = request_id.encode();
939
940 self.get_outgoing_key_request_helper(&request_id).await
941 }
942
943 async fn get_secret_request_by_info(
944 &self,
945 key_info: &SecretInfo,
946 ) -> Result<Option<GossipRequest>> {
947 let id = self
948 .secret_requests_by_info
949 .get(self.encode_key(SECRET_REQUEST_BY_INFO_TABLE, key_info))
950 .map_err(CryptoStoreError::backend)?;
951
952 if let Some(id) = id {
953 self.get_outgoing_key_request_helper(&id).await
954 } else {
955 Ok(None)
956 }
957 }
958
959 async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
960 let requests: Result<Vec<GossipRequest>> = self
961 .unsent_secret_requests
962 .iter()
963 .map(|i| self.deserialize_value(&i.map_err(CryptoStoreError::backend)?.1))
964 .collect();
965
966 requests
967 }
968
969 async fn delete_outgoing_secret_requests(&self, request_id: &TransactionId) -> Result<()> {
970 let ret: Result<(), TransactionError<CryptoStoreError>> = (
971 &self.outgoing_secret_requests,
972 &self.unsent_secret_requests,
973 &self.secret_requests_by_info,
974 )
975 .transaction(
976 |(outgoing_key_requests, unsent_key_requests, key_requests_by_info)| {
977 let sent_request: Option<GossipRequest> = outgoing_key_requests
978 .remove(request_id.encode())?
979 .map(|r| self.deserialize_value(&r))
980 .transpose()
981 .map_err(ConflictableTransactionError::Abort)?;
982
983 let unsent_request: Option<GossipRequest> = unsent_key_requests
984 .remove(request_id.encode())?
985 .map(|r| self.deserialize_value(&r))
986 .transpose()
987 .map_err(ConflictableTransactionError::Abort)?;
988
989 if let Some(request) = sent_request {
990 key_requests_by_info
991 .remove(self.encode_key(SECRET_REQUEST_BY_INFO_TABLE, &request.info))?;
992 }
993
994 if let Some(request) = unsent_request {
995 key_requests_by_info
996 .remove(self.encode_key(SECRET_REQUEST_BY_INFO_TABLE, &request.info))?;
997 }
998
999 Ok(())
1000 },
1001 );
1002
1003 ret.map_err(CryptoStoreError::backend)?;
1004 self.inner.flush_async().await.map_err(CryptoStoreError::backend)?;
1005
1006 Ok(())
1007 }
1008
1009 async fn load_backup_keys(&self) -> Result<BackupKeys> {
1010 let key = {
1011 let backup_version = self
1012 .account
1013 .get("backup_version_v1".encode())
1014 .map_err(CryptoStoreError::backend)?
1015 .map(|v| self.deserialize_value(&v))
1016 .transpose()?;
1017
1018 let recovery_key = {
1019 self.account
1020 .get("recovery_key_v1".encode())
1021 .map_err(CryptoStoreError::backend)?
1022 .map(|p| self.deserialize_value(&p))
1023 .transpose()?
1024 };
1025
1026 BackupKeys { backup_version, recovery_key }
1027 };
1028
1029 Ok(key)
1030 }
1031}
1032
1033#[cfg(test)]
1034mod tests {
1035 use matrix_sdk_crypto::cryptostore_integration_tests;
1036 use once_cell::sync::Lazy;
1037 use tempfile::{tempdir, TempDir};
1038
1039 use super::SledCryptoStore;
1040
1041 static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
1042
1043 async fn get_store(name: &str, passphrase: Option<&str>) -> SledCryptoStore {
1044 let tmpdir_path = TMP_DIR.path().join(name);
1045
1046 let store =
1047 SledCryptoStore::open_with_passphrase(tmpdir_path.to_str().unwrap(), passphrase)
1048 .expect("Can't create a passphrase protected store");
1049
1050 store
1051 }
1052
1053 cryptostore_integration_tests!();
1054}
1055
1056#[cfg(test)]
1057mod encrypted_tests {
1058 use matrix_sdk_crypto::cryptostore_integration_tests;
1059 use once_cell::sync::Lazy;
1060 use tempfile::{tempdir, TempDir};
1061
1062 use super::SledCryptoStore;
1063
1064 static TMP_DIR: Lazy<TempDir> = Lazy::new(|| tempdir().unwrap());
1065
1066 async fn get_store(name: &str, passphrase: Option<&str>) -> SledCryptoStore {
1067 let tmpdir_path = TMP_DIR.path().join(name);
1068 let pass = passphrase.unwrap_or("default_test_password");
1069
1070 let store =
1071 SledCryptoStore::open_with_passphrase(tmpdir_path.to_str().unwrap(), Some(pass))
1072 .expect("Can't create a passphrase protected store");
1073
1074 store
1075 }
1076 cryptostore_integration_tests!();
1077}