matrix_sdk_redis/
redis_crypto_store.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::{BTreeMap, HashMap},
17    sync::{Arc, RwLock},
18};
19
20use async_trait::async_trait;
21use matrix_sdk_base::ruma::{
22    events::secret::request::SecretName, DeviceId, OwnedDeviceId, RoomId,
23    TransactionId, UserId,
24};
25use matrix_sdk_crypto::{
26    olm::{
27        Curve25519PublicKey, InboundGroupSession, OutboundGroupSession,
28        PickledInboundGroupSession, PickledOutboundGroupSession,
29        PickledSession, PrivateCrossSigningIdentity, SenderDataType, Session,
30        StaticAccountData,
31    },
32    store::{
33        BackupDecryptionKey, BackupKeys, Changes, CryptoStore,
34        CryptoStoreError, PendingChanges, Result, RoomKeyCounts, RoomSettings,
35    },
36    types::{
37        events::{
38            room_key_request::SupportedKeyInfo,
39            room_key_withheld::RoomKeyWithheldEvent,
40        },
41        EventEncryptionAlgorithm,
42    },
43    Account, DeviceData, GossipRequest, GossippedSecret, SecretInfo,
44    TrackedUser, UserIdentityData,
45};
46use matrix_sdk_store_encryption::StoreCipher;
47use serde::{Deserialize, Serialize};
48use tracing::{info, warn};
49
50use crate::redis_shim::{RedisClientShim, RedisConnectionShim};
51
52trait RedisKey {
53    fn redis_key(&self) -> String;
54}
55
56impl RedisKey for TransactionId {
57    fn redis_key(&self) -> String {
58        self.to_string()
59    }
60}
61
62impl RedisKey for SecretName {
63    fn redis_key(&self) -> String {
64        self.to_string()
65    }
66}
67
68impl RedisKey for SecretInfo {
69    fn redis_key(&self) -> String {
70        match self {
71            SecretInfo::KeyRequest(k) => k.redis_key(),
72            SecretInfo::SecretRequest(s) => s.redis_key(),
73        }
74    }
75}
76
77impl RedisKey for SupportedKeyInfo {
78    fn redis_key(&self) -> String {
79        (self.room_id(), self.algorithm(), self.session_id()).redis_key()
80    }
81}
82
83impl RedisKey for EventEncryptionAlgorithm {
84    fn redis_key(&self) -> String {
85        self.as_ref().redis_key()
86    }
87}
88
89impl RedisKey for &RoomId {
90    fn redis_key(&self) -> String {
91        self.as_str().redis_key()
92    }
93}
94
95impl RedisKey for &str {
96    fn redis_key(&self) -> String {
97        String::from(*self)
98    }
99}
100
101impl<A, B, C> RedisKey for (A, B, C)
102where
103    A: RedisKey,
104    B: RedisKey,
105    C: RedisKey,
106{
107    fn redis_key(&self) -> String {
108        format!(
109            "{}|{}|{}",
110            self.0.redis_key(),
111            self.1.redis_key(),
112            self.2.redis_key()
113        )
114    }
115}
116
117/// A store that holds its information in a Redis database
118#[derive(Clone)]
119pub struct RedisStore<C>
120where
121    C: RedisClientShim,
122{
123    key_prefix: String,
124    client: C,
125    static_account: Arc<RwLock<Option<StaticAccountData>>>,
126    store_cipher: Option<Arc<StoreCipher>>,
127}
128
129impl<C> std::fmt::Debug for RedisStore<C>
130where
131    C: RedisClientShim,
132{
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        f.debug_struct("RedisStore<C>")
135            .field(
136                "client.get_connection_info().redis",
137                &self.client.get_connection_info().redis,
138            )
139            .field("key_prefix", &self.key_prefix)
140            .field("account_info", &self.static_account)
141            .finish()
142    }
143}
144
145impl<C> RedisStore<C>
146where
147    C: RedisClientShim,
148{
149    #[allow(dead_code)]
150    /// Open the Redis-based cryptostore at the given URL using the given
151    /// passphrase to encrypt private data.
152    pub async fn open_with_passphrase(
153        client: C,
154        passphrase: Option<&str>,
155    ) -> Result<Self> {
156        Self::open(client, passphrase, String::from("matrix-sdk-crypto|")).await
157    }
158
159    /// Open the Redis-based cryptostore at the given URL using the given
160    /// passphrase to encrypt private data and assuming all Redis keys are
161    /// prefixed with the given string.
162    pub async fn open(
163        client: C,
164        passphrase: Option<&str>,
165        key_prefix: String,
166    ) -> Result<Self> {
167        let mut connection = client.get_async_connection().await.unwrap();
168
169        let store_cipher = if let Some(passphrase) = passphrase {
170            Some(
171                Self::get_or_create_store_cipher(
172                    passphrase,
173                    &key_prefix,
174                    &mut connection,
175                )
176                .await?
177                .into(),
178            )
179        } else {
180            None
181        };
182
183        Ok(Self {
184            key_prefix,
185            client,
186            static_account: RwLock::new(None).into(),
187            store_cipher,
188        })
189    }
190
191    fn get_static_account(&self) -> Option<StaticAccountData> {
192        self.static_account.read().unwrap().clone()
193    }
194
195    fn serialize_value(
196        &self,
197        event: &impl Serialize,
198    ) -> Result<Vec<u8>, CryptoStoreError> {
199        if let Some(key) = &self.store_cipher {
200            key.encrypt_value(event)
201                .map_err(|e| CryptoStoreError::Backend(Box::new(e)))
202        } else {
203            Ok(serde_json::to_vec(event)?)
204        }
205    }
206
207    fn deserialize_value<T: for<'b> Deserialize<'b>>(
208        &self,
209        event: &[u8],
210    ) -> Result<T, CryptoStoreError> {
211        if let Some(key) = &self.store_cipher {
212            key.decrypt_value(event)
213                .map_err(|e| CryptoStoreError::Backend(Box::new(e)))
214        } else {
215            Ok(serde_json::from_slice(event)?)
216        }
217    }
218
219    async fn reset_backup_state(&self) -> Result<()> {
220        let redis_key = format!("{}inbound_group_sessions", self.key_prefix);
221        let mut connection = self.client.get_async_connection().await?;
222
223        // Read out all the sessions, set them as not backed up
224        let sessions: Vec<(String, String)> =
225            connection.hgetall(&redis_key).await?;
226        let pickles: Vec<(String, PickledInboundGroupSession)> = sessions
227            .into_iter()
228            .map(|(k, s)| {
229                let mut pickle: PickledInboundGroupSession =
230                    self.deserialize_value(s.as_bytes()).unwrap();
231                pickle.backed_up = false;
232                (k, pickle)
233            })
234            .collect();
235
236        // Write them back out in a transaction
237        let mut pipeline = self.client.create_pipe();
238
239        for (k, pickle) in pickles {
240            pipeline.hset(&redis_key, &k, self.serialize_value(&pickle)?);
241        }
242
243        pipeline.query_async(&mut connection).await.unwrap();
244
245        Ok(())
246    }
247
248    async fn get_or_create_store_cipher<Conn>(
249        passphrase: &str,
250        key_prefix: &str,
251        connection: &mut Conn,
252    ) -> Result<StoreCipher>
253    where
254        Conn: RedisConnectionShim,
255    {
256        let key_id = format!("{}{}", key_prefix, "store_cipher");
257        let key_db_entry: Option<Vec<u8>> = connection.get(&key_id).await?;
258        let key = if let Some(key_db_entry) = key_db_entry {
259            StoreCipher::import(passphrase, &key_db_entry)
260                .map_err(|_| CryptoStoreError::UnpicklingError)?
261        } else {
262            let cipher = StoreCipher::new()
263                .map_err(|e| CryptoStoreError::Backend(Box::new(e)))?;
264
265            #[cfg(not(test))]
266            let export = cipher.export(passphrase);
267            #[cfg(test)]
268            let export = cipher._insecure_export_fast_for_testing(passphrase);
269
270            let export =
271                export.map_err(|e| CryptoStoreError::Backend(Box::new(e)))?;
272            connection.set(&key_id, export).await?;
273            cipher
274        };
275
276        Ok(key)
277    }
278
279    async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
280        let mut connection = self.client.get_async_connection().await?;
281        let tracked_users: HashMap<String, Vec<u8>> = connection
282            .hgetall(&format!("{}tracked_users", self.key_prefix))
283            .await?;
284
285        let mut users = Vec::new();
286
287        for (_, user) in tracked_users {
288            let user: TrackedUser = self.deserialize_value(&user)?;
289            users.push(user)
290        }
291
292        Ok(users)
293    }
294
295    async fn load_outbound_group_session(
296        &self,
297        room_id: &RoomId,
298    ) -> Result<Option<OutboundGroupSession>> {
299        let account_info = self
300            .get_static_account()
301            .ok_or(CryptoStoreError::AccountUnset)?;
302
303        let mut connection = self.client.get_async_connection().await?;
304
305        let redis_key = format!("{}outbound_session_changes", self.key_prefix);
306        let session_vec: Option<Vec<u8>> =
307            connection.hget(&redis_key, room_id.as_str()).await?;
308
309        if let Some(session_vec) = session_vec {
310            let session: PickledOutboundGroupSession =
311                self.deserialize_value(&session_vec)?;
312
313            let unpickled: OutboundGroupSession =
314                OutboundGroupSession::from_pickle(
315                    account_info.device_id.into(),
316                    account_info.identity_keys,
317                    session,
318                )
319                .map_err(CryptoStoreError::Pickle)?;
320
321            Ok(Some(unpickled))
322        } else {
323            Ok(None)
324        }
325    }
326
327    async fn save_pending_changes(
328        &self,
329        changes: PendingChanges,
330    ) -> Result<()> {
331        let mut connection = self.client.get_async_connection().await?;
332
333        let pickled_account = if let Some(account) = changes.account {
334            *self.static_account.write().unwrap() =
335                Some(account.static_data().clone());
336            Some(account.pickle())
337        } else {
338            None
339        };
340
341        let mut pipeline = self.client.create_pipe();
342        if let Some(pickled_account) = pickled_account {
343            let serialized_account = self.serialize_value(&pickled_account)?;
344            pipeline.set_vec(
345                &format!("{}account", self.key_prefix),
346                serialized_account,
347            );
348        }
349
350        pipeline.query_async(&mut connection).await?;
351
352        Ok(())
353    }
354
355    async fn save_changes(&self, changes: Changes) -> Result<()> {
356        let private_identity_pickle = if let Some(i) = changes.private_identity
357        {
358            Some(i.pickle().await)
359        } else {
360            None
361        };
362
363        let backup_decryption_key_pickle = changes.backup_decryption_key;
364
365        let device_changes = changes.devices;
366        let mut session_changes: HashMap<String, Vec<PickledSession>> =
367            HashMap::new();
368
369        for session in changes.sessions {
370            let pickle = session.pickle().await;
371            let sender_key = session.sender_key().to_base64();
372            session_changes.entry(sender_key).or_default().push(pickle);
373        }
374
375        let mut inbound_session_changes = HashMap::new();
376
377        for session in changes.inbound_group_sessions {
378            let room_id = session.room_id();
379            let session_id = session.session_id();
380            let key = format!("{}|{}", room_id.as_str(), session_id);
381            let pickle = session.pickle().await;
382
383            inbound_session_changes.insert(key, pickle);
384        }
385
386        let mut outbound_session_changes = HashMap::new();
387
388        for session in changes.outbound_group_sessions {
389            let room_id = session.room_id().to_owned();
390            let pickle = session.pickle().await;
391            outbound_session_changes.insert(room_id.clone(), pickle);
392        }
393
394        let identity_changes = changes.identities;
395        let olm_hashes = changes.message_hashes;
396        let key_requests = changes.key_requests;
397        let backup_version = changes.backup_version;
398
399        let mut connection = self.client.get_async_connection().await?;
400
401        // Wrap in a Redis transaction
402        let mut pipeline = self.client.create_pipe();
403
404        if let Some(i) = &private_identity_pickle {
405            let redis_key = format!("{}private_identity", self.key_prefix);
406            pipeline.set_vec(&redis_key, self.serialize_value(&i)?);
407        }
408
409        for (key, sessions) in &session_changes {
410            let redis_key = format!("{}sessions|{}", self.key_prefix, key);
411            pipeline.set_vec(&redis_key, self.serialize_value(sessions)?);
412        }
413
414        let redis_key = format!("{}inbound_group_sessions", self.key_prefix);
415        for (key, inbound_group_session) in &inbound_session_changes {
416            pipeline.hset(
417                &redis_key,
418                key,
419                self.serialize_value(&inbound_group_session)?,
420            );
421        }
422
423        let redis_key = format!("{}outbound_session_changes", self.key_prefix);
424        for (key, outbound_group_sessions) in &outbound_session_changes {
425            pipeline.hset(
426                &redis_key,
427                key.as_str(),
428                self.serialize_value(outbound_group_sessions)?,
429            );
430        }
431
432        let redis_key = format!("{}olm_hashes", self.key_prefix);
433        for hash in &olm_hashes {
434            pipeline.sadd(&redis_key, serde_json::to_string(hash)?);
435        }
436
437        let unsent_secret_requests_key =
438            format!("{}unsent_secret_requests", self.key_prefix);
439
440        for key_request in &key_requests {
441            let key_request_id = key_request.request_id.redis_key();
442
443            let secret_requests_by_info_key = format!(
444                "{}secret_requests_by_info|{}",
445                self.key_prefix,
446                key_request.info.redis_key()
447            );
448            pipeline.set(
449                &secret_requests_by_info_key,
450                key_request.request_id.redis_key(),
451            );
452
453            let outgoing_secret_requests_key = format!(
454                "{}outgoing_secret_requests|{}",
455                self.key_prefix, key_request_id
456            );
457            if key_request.sent_out {
458                pipeline.hdel(&unsent_secret_requests_key, &key_request_id);
459                pipeline.set_vec(
460                    &outgoing_secret_requests_key,
461                    self.serialize_value(&key_request)?,
462                );
463            } else {
464                pipeline.del(&outgoing_secret_requests_key);
465                pipeline.hset(
466                    &unsent_secret_requests_key,
467                    &key_request_id,
468                    self.serialize_value(&key_request)?,
469                );
470            }
471        }
472
473        for device in device_changes.new.iter().chain(&device_changes.changed) {
474            let redis_key =
475                format!("{}devices|{}", self.key_prefix, device.user_id());
476
477            pipeline.hset(
478                &redis_key,
479                device.device_id().as_str(),
480                self.serialize_value(device)?,
481            );
482        }
483
484        for device in device_changes.deleted {
485            let redis_key =
486                format!("{}devices|{}", self.key_prefix, device.user_id());
487            pipeline.hdel(&redis_key, device.device_id().as_str());
488        }
489
490        for identity in
491            identity_changes.changed.iter().chain(&identity_changes.new)
492        {
493            let redis_key =
494                format!("{}identities|{}", self.key_prefix, identity.user_id());
495
496            pipeline.set_vec(&redis_key, self.serialize_value(identity)?);
497        }
498
499        if let Some(r) = &backup_decryption_key_pickle {
500            let redis_key = format!("{}recovery_key_v1", self.key_prefix);
501            pipeline.set_vec(&redis_key, self.serialize_value(r)?);
502        }
503
504        if let Some(r) = &backup_version {
505            let redis_key = format!("{}backup_version_v1", self.key_prefix);
506            pipeline.set_vec(&redis_key, self.serialize_value(r)?);
507        }
508
509        for (room_id, data) in changes.withheld_session_info {
510            for (session_id, event) in data {
511                let redis_key = format!(
512                    "{}direct_withheld_info|{}|{}",
513                    self.key_prefix, room_id, session_id
514                );
515                let value = self.serialize_value(&event)?;
516                pipeline.set_vec(&redis_key, value);
517            }
518        }
519
520        for (room_id, settings) in changes.room_settings {
521            let redis_key =
522                format!("{}room_settings|{}", self.key_prefix, room_id);
523            let value = self.serialize_value(&settings)?;
524            pipeline.set_vec(&redis_key, value)
525        }
526
527        let secrets: BTreeMap<SecretName, Vec<GossippedSecret>> =
528            changes.secrets.iter().fold(BTreeMap::new(), |mut m, s| {
529                m.entry(s.secret_name.clone()).or_default().push(s.clone());
530                m
531            });
532        for (secret_name, secrets) in secrets {
533            let redis_key =
534                format!("{}secrets{}", self.key_prefix, secret_name);
535            for secret in secrets {
536                let value = self.serialize_value(&secret)?;
537                pipeline.lpush(&redis_key, value)
538            }
539        }
540
541        pipeline.query_async(&mut connection).await?;
542
543        Ok(())
544    }
545
546    async fn get_outgoing_key_request_helper(
547        &self,
548        request_id: &str,
549    ) -> Result<Option<GossipRequest>> {
550        let mut connection = self.client.get_async_connection().await?;
551        let redis_key = format!(
552            "{}outgoing_secret_requests|{}",
553            self.key_prefix, request_id
554        );
555        let req_vec: Option<Vec<u8>> = connection.get(&redis_key).await?;
556        let request = req_vec
557            .map(|req_vec| self.deserialize_value(&req_vec))
558            .transpose()?;
559        let request = if request.is_none() {
560            let redis_key =
561                format!("{}unsent_secret_requests", self.key_prefix);
562            let req_bytes: Option<Vec<u8>> =
563                connection.hget(&redis_key, request_id).await?;
564            req_bytes
565                .map(|req_bytes| self.deserialize_value(&req_bytes))
566                .transpose()?
567        } else {
568            request
569        };
570
571        Ok(request)
572    }
573
574    /// Save a batch of tracked users.
575    ///
576    /// # Arguments
577    ///
578    /// * `tracked_users` - A list of tuples. The first element of the tuple is
579    /// the user ID, the second element is if the user should be considered to
580    /// be dirty.
581    pub async fn save_tracked_users(
582        &self,
583        tracked_users: &[(&UserId, bool)],
584    ) -> Result<(), CryptoStoreError> {
585        let mut connection = self.client.get_async_connection().await.unwrap();
586
587        let users: Vec<TrackedUser> = tracked_users
588            .iter()
589            .map(|(u, d)| TrackedUser {
590                user_id: (*u).into(),
591                dirty: *d,
592            })
593            .collect();
594
595        let mut pipeline = self.client.create_pipe();
596
597        for user in users {
598            pipeline.hset(
599                &format!("{}tracked_users", self.key_prefix),
600                user.user_id.as_str(),
601                self.serialize_value(&user)?,
602            );
603        }
604
605        pipeline.query_async(&mut connection).await?;
606
607        Ok(())
608    }
609}
610
611#[async_trait]
612impl<C> CryptoStore for RedisStore<C>
613where
614    C: RedisClientShim,
615{
616    type Error = CryptoStoreError;
617
618    async fn load_account(&self) -> Result<Option<Account>> {
619        let mut connection = self.client.get_async_connection().await?;
620        let acct_json: Option<Vec<u8>> = connection
621            .get(&format!("{}account", self.key_prefix))
622            .await?;
623
624        if let Some(pickle) = acct_json {
625            let pickle = self.deserialize_value(&pickle)?;
626
627            let account = Account::from_pickle(pickle)?;
628
629            *self.static_account.write().unwrap() =
630                Some(account.static_data().clone());
631
632            Ok(Some(account))
633        } else {
634            Ok(None)
635        }
636    }
637
638    async fn load_identity(
639        &self,
640    ) -> Result<Option<PrivateCrossSigningIdentity>> {
641        let mut connection = self.client.get_async_connection().await?;
642        let key_prefix: String = format!("{}private_identity", self.key_prefix);
643        let i_string: Option<Vec<u8>> = connection.get(&key_prefix).await?;
644        if let Some(i) = i_string {
645            let pickle = self.deserialize_value(&i)?;
646            Ok(Some(
647                PrivateCrossSigningIdentity::from_pickle(pickle)
648                    .map_err(|_| CryptoStoreError::UnpicklingError)?,
649            ))
650        } else {
651            Ok(None)
652        }
653    }
654
655    async fn save_pending_changes(
656        &self,
657        changes: PendingChanges,
658    ) -> Result<()> {
659        self.save_pending_changes(changes).await
660    }
661
662    async fn save_changes(&self, changes: Changes) -> Result<()> {
663        self.save_changes(changes).await
664    }
665
666    async fn save_inbound_group_sessions(
667        &self,
668        sessions: Vec<InboundGroupSession>,
669        backed_up_to_version: Option<&str>,
670    ) -> matrix_sdk_crypto::store::Result<(), Self::Error> {
671        // Sanity-check that the data in the sessions corresponds to backed_up_version
672        sessions.iter().for_each(|s| {
673            let backed_up = s.backed_up();
674            if backed_up != backed_up_to_version.is_some() {
675                warn!(
676                    backed_up,
677                    backed_up_to_version,
678                    "Session backed-up flag does not correspond to backup version setting",
679                );
680            }
681        });
682
683        // Currently, this store doesn't save the backup version separately, so this
684        // just delegates to save_changes.
685        self.save_changes(Changes {
686            inbound_group_sessions: sessions,
687            ..Changes::default()
688        })
689        .await
690    }
691
692    async fn get_sessions(
693        &self,
694        sender_key: &str,
695    ) -> Result<Option<Vec<Session>>> {
696        let mut connection = self.client.get_async_connection().await.unwrap();
697
698        let key = format!("{}sessions|{}", self.key_prefix, sender_key);
699        let sessions_list_as_vec: Option<Vec<u8>> =
700            connection.get(&key).await?;
701        let sessions_list: Vec<PickledSession> = match sessions_list_as_vec {
702            Some(sessions_list_as_vec) => {
703                self.deserialize_value(&sessions_list_as_vec)?
704            }
705            _ => Vec::new(),
706        };
707
708        let device_keys = self.get_own_device().await?.as_device_keys().clone();
709        let sessions: Vec<Session> = sessions_list
710            .into_iter()
711            .map(|p| {
712                Session::from_pickle(device_keys.clone(), p)
713                    .map_err(|_| CryptoStoreError::AccountUnset)
714            })
715            .collect::<Result<_>>()?;
716
717        if sessions.is_empty() {
718            Ok(None)
719        } else {
720            Ok(Some(sessions))
721        }
722    }
723
724    async fn get_inbound_group_session(
725        &self,
726        room_id: &RoomId,
727        session_id: &str,
728    ) -> Result<Option<InboundGroupSession>> {
729        let key = format!("{room_id}|{session_id}");
730        let redis_key = format!("{}inbound_group_sessions", self.key_prefix);
731        let mut connection = self.client.get_async_connection().await?;
732        let pickle_str: Option<String> =
733            connection.hget(&redis_key, &key).await?;
734
735        match pickle_str {
736            Some(pickle_str) => Ok(Some(InboundGroupSession::from_pickle(
737                self.deserialize_value(pickle_str.as_bytes())?,
738            )?)),
739            _ => Ok(None),
740        }
741    }
742
743    async fn get_inbound_group_sessions(
744        &self,
745    ) -> Result<Vec<InboundGroupSession>> {
746        let redis_key = format!("{}inbound_group_sessions", self.key_prefix);
747        let mut connection = self.client.get_async_connection().await?;
748        let igss: Vec<String> = connection.hvals(&redis_key).await?;
749
750        let pickles: Result<Vec<PickledInboundGroupSession>> = igss
751            .iter()
752            .map(|p| self.deserialize_value(p.as_bytes()))
753            .collect();
754
755        Ok(pickles?
756            .into_iter()
757            .filter_map(|p| InboundGroupSession::from_pickle(p).ok())
758            .collect())
759    }
760
761    async fn get_inbound_group_sessions_for_device_batch(
762        &self,
763        sender_key: Curve25519PublicKey,
764        sender_data_type: SenderDataType,
765        after_session_id: Option<String>,
766        limit: usize,
767    ) -> Result<Vec<InboundGroupSession>> {
768        // Have we hit the after_session_id yet?
769        let mut hit_session_id = after_session_id.is_none();
770        let pipe_after_session_id = after_session_id.map(|s| format!("|{s}"));
771
772        let redis_key = format!("{}inbound_group_sessions", self.key_prefix);
773        let mut connection = self.client.get_async_connection().await?;
774        // TODO: fetches all and filters here
775        let igss: Vec<(String, String)> =
776            connection.hgetall(&redis_key).await?;
777
778        let mut pickles: Vec<Result<_, _>> = vec![];
779
780        for (room_and_session_id, session_string) in igss {
781            let session_pickle: Result<PickledInboundGroupSession> =
782                self.deserialize_value(session_string.as_bytes());
783            match session_pickle {
784                Ok(session_pickle) => {
785                    if hit_session_id {
786                        if session_pickle.sender_key == sender_key
787                            && session_pickle.sender_data.to_type()
788                                == sender_data_type
789                        {
790                            pickles.push(
791                                InboundGroupSession::from_pickle(
792                                    session_pickle,
793                                )
794                                .map_err(CryptoStoreError::from),
795                            );
796                        }
797                    } else if let Some(pipe_after_session_id) =
798                        &pipe_after_session_id
799                    {
800                        if room_and_session_id.ends_with(pipe_after_session_id)
801                        {
802                            hit_session_id = true;
803                        }
804                    }
805                }
806                Err(e) => {
807                    pickles.push(Err(e));
808                }
809            }
810        }
811
812        pickles.into_iter().take(limit).collect()
813    }
814
815    async fn inbound_group_session_counts(
816        &self,
817        _backup_version: Option<&str>,
818    ) -> Result<RoomKeyCounts> {
819        let redis_key = format!("{}inbound_group_sessions", self.key_prefix);
820        let mut connection = self.client.get_async_connection().await?;
821        let igss: Vec<String> = connection.hvals(&redis_key).await?;
822
823        let pickles: Result<Vec<PickledInboundGroupSession>> = igss
824            .iter()
825            .map(|p| self.deserialize_value(p.as_bytes()))
826            .collect();
827
828        let pickles = pickles?;
829
830        let total = pickles.len();
831        let backed_up = pickles.into_iter().filter(|p| p.backed_up).count();
832
833        Ok(RoomKeyCounts { total, backed_up })
834    }
835
836    async fn inbound_group_sessions_for_backup(
837        &self,
838        _backup_version: &str,
839        limit: usize,
840    ) -> Result<Vec<InboundGroupSession>> {
841        let redis_key = format!("{}inbound_group_sessions", self.key_prefix);
842        let mut connection = self.client.get_async_connection().await?;
843        let igss: Vec<String> = connection.hvals(&redis_key).await?;
844
845        let pickles = igss
846            .iter()
847            .map(|p| self.deserialize_value(p.as_bytes()))
848            .filter_map(
849                |p: Result<PickledInboundGroupSession, CryptoStoreError>| {
850                    match p {
851                        Ok(p) => {
852                            if !p.backed_up {
853                                Some(
854                                    InboundGroupSession::from_pickle(p)
855                                        .map_err(CryptoStoreError::from),
856                                )
857                            } else {
858                                None
859                            }
860                        }
861
862                        Err(p) => Some(Err(p)),
863                    }
864                },
865            )
866            .take(limit)
867            .collect::<Result<_>>()?;
868
869        Ok(pickles)
870    }
871
872    async fn mark_inbound_group_sessions_as_backed_up(
873        &self,
874        _backup_version: &str,
875        session_ids: &[(&RoomId, &str)],
876    ) -> Result<()> {
877        if session_ids.is_empty() {
878            // We are not expecting to be called with an empty list of sessions
879            warn!("No sessions to mark as backed up!");
880            return Ok(());
881        }
882
883        let redis_key = format!("{}inbound_group_sessions", self.key_prefix);
884        let mut connection = self.client.get_async_connection().await?;
885
886        for (room_id, session_id) in session_ids {
887            let key = format!("{room_id}|{session_id}");
888            let session: Option<String> =
889                connection.hget(&redis_key, &key).await?;
890            if let Some(session) = session {
891                let mut session: PickledInboundGroupSession =
892                    self.deserialize_value(session.as_bytes())?;
893                if !session.backed_up {
894                    session.backed_up = true;
895                    connection
896                        .hset(&redis_key, &key, self.serialize_value(&session)?)
897                        .await?;
898                    info!("AJB WROTE!");
899                }
900            }
901        }
902
903        Ok(())
904    }
905
906    async fn reset_backup_state(&self) -> Result<()> {
907        self.reset_backup_state().await
908    }
909
910    async fn get_outbound_group_session(
911        &self,
912        room_id: &RoomId,
913    ) -> Result<Option<OutboundGroupSession>> {
914        self.load_outbound_group_session(room_id).await
915    }
916
917    async fn load_tracked_users(&self) -> Result<Vec<TrackedUser>> {
918        self.load_tracked_users().await
919    }
920
921    async fn save_tracked_users(
922        &self,
923        users: &[(&UserId, bool)],
924    ) -> Result<()> {
925        self.save_tracked_users(users).await?;
926        Ok(())
927    }
928
929    async fn get_device(
930        &self,
931        user_id: &UserId,
932        device_id: &DeviceId,
933    ) -> Result<Option<DeviceData>> {
934        let mut connection = self.client.get_async_connection().await?;
935        let key = format!("{}devices|{user_id}", self.key_prefix);
936        let dev: Option<Vec<u8>> =
937            connection.hget(&key, device_id.as_str()).await?;
938        Ok(dev.map(|d| self.deserialize_value(&d)).transpose()?)
939    }
940
941    async fn get_user_devices(
942        &self,
943        user_id: &UserId,
944    ) -> Result<HashMap<OwnedDeviceId, DeviceData>> {
945        let mut connection = self.client.get_async_connection().await?;
946        let user_device: HashMap<String, Vec<u8>> = connection
947            .hgetall(&format!("{}devices|{user_id}", self.key_prefix))
948            .await?;
949
950        user_device
951            .into_iter()
952            .map(|(device_id, device_str)| {
953                let d = self.deserialize_value(&device_str)?;
954                Ok((device_id.into(), d))
955            })
956            .collect()
957    }
958
959    async fn get_own_device(&self) -> Result<DeviceData> {
960        let account_info = self
961            .get_static_account()
962            .ok_or(CryptoStoreError::AccountUnset)?;
963
964        Ok(self
965            .get_device(&account_info.user_id, &account_info.device_id)
966            .await?
967            .expect("We should be able to find our own device."))
968    }
969
970    async fn get_user_identity(
971        &self,
972        user_id: &UserId,
973    ) -> Result<Option<UserIdentityData>> {
974        let mut connection = self.client.get_async_connection().await?;
975        let redis_key = format!("{}identities|{user_id}", self.key_prefix);
976        let identity_string: Option<Vec<u8>> =
977            connection.get(&redis_key).await?;
978        let identity = identity_string
979            .map(|s| self.deserialize_value(&s))
980            .transpose()?;
981        Ok(identity)
982    }
983
984    async fn is_message_known(
985        &self,
986        message_hash: &matrix_sdk_crypto::olm::OlmMessageHash,
987    ) -> Result<bool> {
988        let mut connection = self.client.get_async_connection().await?;
989        let redis_key = format!("{}olm_hashes", self.key_prefix);
990        let ret = connection
991            .sismember(&redis_key, &serde_json::to_string(message_hash)?)
992            .await?;
993        Ok(ret)
994    }
995
996    async fn get_outgoing_secret_requests(
997        &self,
998        request_id: &TransactionId,
999    ) -> Result<Option<GossipRequest>> {
1000        self.get_outgoing_key_request_helper(&request_id.redis_key())
1001            .await
1002    }
1003
1004    async fn get_secret_request_by_info(
1005        &self,
1006        key_info: &SecretInfo,
1007    ) -> Result<Option<GossipRequest>> {
1008        let mut connection = self.client.get_async_connection().await.unwrap();
1009        let redis_key = format!(
1010            "{}secret_requests_by_info|{}",
1011            self.key_prefix,
1012            key_info.redis_key()
1013        );
1014        let id: Option<String> = connection.get(&redis_key).await.unwrap();
1015
1016        if let Some(id) = id {
1017            self.get_outgoing_key_request_helper(&id).await
1018        } else {
1019            Ok(None)
1020        }
1021    }
1022
1023    async fn get_unsent_secret_requests(&self) -> Result<Vec<GossipRequest>> {
1024        let mut connection = self.client.get_async_connection().await.unwrap();
1025        let redis_key = format!("{}unsent_secret_requests", self.key_prefix);
1026        let req_map: HashMap<String, Vec<u8>> =
1027            connection.hgetall(&redis_key).await.unwrap();
1028        Ok(req_map
1029            .values()
1030            .map(|req| self.deserialize_value(req).unwrap())
1031            .collect())
1032    }
1033
1034    async fn delete_outgoing_secret_requests(
1035        &self,
1036        request_id: &TransactionId,
1037    ) -> Result<()> {
1038        let mut connection = self.client.get_async_connection().await?;
1039        let okr_req_id_key = format!(
1040            "{}outgoing_secret_requests|{}",
1041            self.key_prefix,
1042            request_id.redis_key()
1043        );
1044        let sent_request: Option<Vec<u8>> =
1045            connection.get(&okr_req_id_key).await?;
1046
1047        // Wrap the deletes in a Redis transaction
1048        // TODO: race: if someone updates sent_request before we delete it, we
1049        // could be deleting the old stuff, when others are using a newer version,
1050        // so we would be in an inconsistent state where the sent_request is deleted,
1051        // but the things it refers to still exists.
1052        let mut pipeline = self.client.create_pipe();
1053        if let Some(sent_request) = sent_request {
1054            pipeline.del(&okr_req_id_key);
1055            let usr_key = format!("{}unsent_secret_requests", self.key_prefix);
1056            pipeline.hdel(&usr_key, &request_id.redis_key());
1057            let sent_request: GossipRequest =
1058                self.deserialize_value(&sent_request)?;
1059            let srbi_info_key = format!(
1060                "{}secret_requests_by_info|{}",
1061                self.key_prefix,
1062                sent_request.info.redis_key()
1063            );
1064            pipeline.del(&srbi_info_key);
1065        }
1066        pipeline.query_async(&mut connection).await?;
1067
1068        Ok(())
1069    }
1070
1071    async fn get_secrets_from_inbox(
1072        &self,
1073        secret_name: &SecretName,
1074    ) -> Result<Vec<GossippedSecret>> {
1075        let redis_key = format!("{}secrets{}", self.key_prefix, secret_name);
1076        let mut connection = self.client.get_async_connection().await?;
1077        let secrets: Option<Vec<Vec<u8>>> =
1078            connection.lrange(&redis_key, 0, -1).await?;
1079
1080        if let Some(secrets) = secrets {
1081            secrets.iter().map(|s| self.deserialize_value(&s)).collect()
1082        } else {
1083            Ok(vec![])
1084        }
1085    }
1086
1087    async fn delete_secrets_from_inbox(
1088        &self,
1089        secret_name: &SecretName,
1090    ) -> Result<()> {
1091        let redis_key = format!("{}secrets{}", self.key_prefix, secret_name);
1092        let mut connection = self.client.get_async_connection().await?;
1093        connection.del(&redis_key).await?;
1094        Ok(())
1095    }
1096
1097    async fn get_withheld_info(
1098        &self,
1099        room_id: &RoomId,
1100        session_id: &str,
1101    ) -> Result<Option<RoomKeyWithheldEvent>> {
1102        let mut connection = self.client.get_async_connection().await?;
1103        let redis_key = format!(
1104            "{}direct_withheld_info|{}|{}",
1105            self.key_prefix,
1106            room_id.redis_key(),
1107            session_id
1108        );
1109
1110        let value: Option<Vec<u8>> = connection.get(&redis_key).await?;
1111        value.map(|v| self.deserialize_value(&v)).transpose()
1112    }
1113
1114    async fn load_backup_keys(&self) -> Result<BackupKeys> {
1115        let mut connection = self.client.get_async_connection().await?;
1116        let redis_key = format!("{}backup_version_v1", self.key_prefix);
1117        let version_v: Option<Vec<u8>> = connection.get(&redis_key).await?;
1118        let version =
1119            version_v.map(|v| self.deserialize_value(&v)).transpose()?;
1120
1121        let redis_key = format!("{}recovery_key_v1", self.key_prefix);
1122        let decryption_key_str: Option<Vec<u8>> =
1123            connection.get(&redis_key).await?;
1124        let decryption_key: Option<BackupDecryptionKey> = decryption_key_str
1125            .map(|s| self.deserialize_value(&s))
1126            .transpose()?;
1127
1128        Ok(BackupKeys {
1129            backup_version: version,
1130            decryption_key,
1131        })
1132    }
1133
1134    async fn get_room_settings(
1135        &self,
1136        room_id: &RoomId,
1137    ) -> Result<Option<RoomSettings>> {
1138        let mut connection = self.client.get_async_connection().await?;
1139        let redis_key = format!("{}room_settings|{}", self.key_prefix, room_id);
1140        let value: Option<Vec<u8>> = connection.get(&redis_key).await?;
1141        value.map(|v| self.deserialize_value(&v)).transpose()
1142    }
1143
1144    async fn get_custom_value(&self, key: &str) -> Result<Option<Vec<u8>>> {
1145        let mut connection = self.client.get_async_connection().await?;
1146        let redis_key = format!("{}custom_value|{}", self.key_prefix, key);
1147        let value: Option<Vec<u8>> = connection.get(&redis_key).await?;
1148        value.map(|v| self.deserialize_value(&v)).transpose()
1149    }
1150
1151    async fn set_custom_value(&self, key: &str, value: Vec<u8>) -> Result<()> {
1152        let mut connection = self.client.get_async_connection().await?;
1153        let redis_key = format!("{}custom_value|{}", self.key_prefix, key);
1154
1155        Ok(connection
1156            .set(&redis_key, self.serialize_value(&value)?)
1157            .await?)
1158    }
1159
1160    async fn remove_custom_value(&self, key: &str) -> Result<()> {
1161        let mut connection = self.client.get_async_connection().await?;
1162        let redis_key = format!("{}custom_value|{}", self.key_prefix, key);
1163        connection.del(&redis_key).await?;
1164        Ok(())
1165    }
1166
1167    /// TODO: DOES NOT DO ANYTHING
1168    async fn try_take_leased_lock(
1169        &self,
1170        _lease_duration_ms: u32,
1171        _key: &str,
1172        _holder: &str,
1173    ) -> Result<bool> {
1174        todo!("try_take_leased_lock is not implemented!")
1175    }
1176
1177    /// Load the next-batch token for a to-device query, if any.
1178    async fn next_batch_token(&self) -> Result<Option<String>, Self::Error> {
1179        let mut connection = self.client.get_async_connection().await?;
1180        let redis_key = format!("{}next_batch_token", self.key_prefix);
1181        if let Some(token) = connection.get::<Vec<u8>>(&redis_key).await? {
1182            let maybe_token: Option<String> = self.deserialize_value(&token)?;
1183            Ok(maybe_token)
1184        } else {
1185            Ok(None)
1186        }
1187    }
1188}
1189
1190#[cfg(test)]
1191mod test_fake_redis {
1192    use matrix_sdk_crypto::cryptostore_integration_tests;
1193    use once_cell::sync::Lazy;
1194
1195    use super::RedisStore;
1196    use crate::fake_redis::FakeRedisClient;
1197
1198    static REDIS_CLIENT: Lazy<FakeRedisClient> =
1199        Lazy::new(FakeRedisClient::new);
1200
1201    async fn get_store(
1202        name: &str,
1203        passphrase: Option<&str>,
1204        _clear_data: bool,
1205    ) -> RedisStore<FakeRedisClient> {
1206        let key_prefix = format!("matrix-sdk-crypto|test|{name}|");
1207
1208        RedisStore::open(REDIS_CLIENT.clone(), passphrase, key_prefix)
1209            .await
1210            .expect("Can't create a Redis store")
1211    }
1212
1213    cryptostore_integration_tests!();
1214}
1215
1216// To run tests against a real Redis, use:
1217// ```sh
1218// cargo test redis --features=real-redis-tests
1219// ```
1220#[cfg(feature = "real-redis-tests")]
1221#[cfg(test)]
1222mod test_real_redis {
1223    use matrix_sdk_crypto::cryptostore_integration_tests;
1224    use once_cell::sync::Lazy;
1225    use redis::Commands;
1226
1227    use super::RedisStore;
1228    static REDIS_URL: &str = "redis://127.0.0.1/";
1229
1230    // We pretend to use this as our shared client, so that
1231    // we clear Redis the first time we access it, but actually
1232    // we clone it each time we use it, so they are independent.
1233    static REDIS_CLIENT: Lazy<redis::Client> = Lazy::new(|| {
1234        let client = redis::Client::open(REDIS_URL).unwrap();
1235        let mut connection = client.get_connection().unwrap();
1236        let keys: Vec<String> =
1237            connection.keys("matrix-sdk-crypto|test|*").unwrap();
1238        for k in keys {
1239            let _: () = connection.del(k).unwrap();
1240        }
1241        client
1242    });
1243
1244    async fn get_store(
1245        name: &str,
1246        passphrase: Option<&str>,
1247        _clear_data: bool,
1248    ) -> RedisStore<redis::Client> {
1249        let key_prefix = format!("matrix-sdk-crypto|test|{}|", name);
1250        let redis_client = REDIS_CLIENT.clone();
1251        let store = RedisStore::open(redis_client, passphrase, key_prefix)
1252            .await
1253            .expect("Can't create a Redis store");
1254
1255        store
1256    }
1257
1258    cryptostore_integration_tests!();
1259}