matrix_sdk_store_media_cache_wrapper/
lib.rs

1use async_trait::async_trait;
2use base64ct::{Base64UrlUnpadded, Encoding};
3use core::fmt::Debug;
4use matrix_sdk_base::{
5    deserialized_responses::RawAnySyncOrStrippedState,
6    media::{MediaRequest, UniqueKey},
7    store::StoreEncryptionError,
8    MinimalRoomMemberEvent, RoomInfo, RoomMemberships, StateChanges, StateStore, StateStoreDataKey,
9    StateStoreDataValue, StoreError,
10};
11use matrix_sdk_store_encryption::StoreCipher;
12use ruma_common::{serde::Raw, EventId, MxcUri, OwnedEventId, OwnedUserId, RoomId, UserId};
13use ruma_events::{
14    presence::PresenceEvent,
15    receipt::{Receipt, ReceiptThread, ReceiptType},
16    AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, GlobalAccountDataEventType,
17    RoomAccountDataEventType, StateEventType,
18};
19use std::{
20    collections::{BTreeMap, BTreeSet},
21    fs,
22    path::PathBuf,
23};
24use tracing::instrument;
25
26#[async_trait]
27trait MediaStore: Debug + Sync + Send {
28    type Error: Debug + Into<StoreError> + From<serde_json::Error>;
29
30    async fn add_media_content(
31        &self,
32        request: &MediaRequest,
33        content: Vec<u8>,
34    ) -> Result<(), Self::Error>;
35
36    async fn get_media_content(
37        &self,
38        request: &MediaRequest,
39    ) -> Result<Option<Vec<u8>>, Self::Error>;
40
41    async fn remove_media_content(&self, request: &MediaRequest) -> Result<(), Self::Error>;
42
43    async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<(), Self::Error>;
44}
45
46pub struct FileCacheMediaStore {
47    cache_dir: PathBuf,
48    store_cipher: StoreCipher,
49}
50
51impl Debug for FileCacheMediaStore {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("FileCacheMediaStore")
54            .field("cache_dir", &self.cache_dir)
55            .finish()
56    }
57}
58
59impl FileCacheMediaStore {
60    pub fn with_store_cipher(cache_dir: PathBuf, store_cipher: StoreCipher) -> FileCacheMediaStore {
61        FileCacheMediaStore {
62            cache_dir,
63            store_cipher,
64        }
65    }
66
67    fn encode_value(&self, value: Vec<u8>) -> Result<Vec<u8>, StoreError> {
68        let encoded = self
69            .store_cipher
70            .encrypt_value_data(value)
71            .map_err(StoreError::backend)?;
72        rmp_serde::to_vec_named(&encoded).map_err(StoreError::backend)
73    }
74
75    fn decode_value(&self, value: &[u8]) -> Result<Vec<u8>, StoreError> {
76        let encrypted = rmp_serde::from_slice(value).map_err(StoreError::backend)?;
77        self.store_cipher
78            .decrypt_value_data(encrypted)
79            .map_err(StoreError::backend)
80    }
81
82    fn encode_key(&self, key: impl AsRef<[u8]>) -> String {
83        Base64UrlUnpadded::encode_string(&self.store_cipher.hash_key("ext_media", key.as_ref()))
84    }
85}
86
87#[async_trait]
88impl MediaStore for FileCacheMediaStore {
89    type Error = StoreError;
90
91    #[instrument(skip_all)]
92    async fn add_media_content(
93        &self,
94        request: &MediaRequest,
95        content: Vec<u8>,
96    ) -> Result<(), Self::Error> {
97        let base_filename = self.encode_key(request.source.unique_key());
98        let data = self
99            .encode_value(content)
100            .map_err(|e| StoreError::Backend(Box::new(e)))?;
101        fs::write(self.cache_dir.join(base_filename), data)
102            .map_err(|e| StoreError::Backend(Box::new(e)))?;
103        Ok(())
104    }
105
106    #[instrument(skip_all)]
107    async fn get_media_content(
108        &self,
109        request: &MediaRequest,
110    ) -> Result<Option<Vec<u8>>, Self::Error> {
111        let base_filename = self.encode_key(request.source.unique_key());
112        fs::read(self.cache_dir.join(base_filename))
113            .ok()
114            .map(|data| self.decode_value(&data))
115            .transpose()
116    }
117
118    #[instrument(skip_all)]
119    async fn remove_media_content(&self, request: &MediaRequest) -> Result<(), Self::Error> {
120        let base_filename = self.encode_key(request.source.unique_key());
121        fs::remove_file(self.cache_dir.join(base_filename))
122            .map_err(|e| StoreError::Backend(Box::new(e)))?;
123        Ok(())
124    }
125
126    #[instrument(skip_all)]
127    async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<(), Self::Error> {
128        let base_filename = self.encode_key(uri);
129        fs::remove_file(self.cache_dir.join(base_filename))
130            .map_err(|e| StoreError::Backend(Box::new(e)))?;
131        Ok(())
132    }
133}
134
135#[derive(Debug)]
136pub enum StoreCacheWrapperError {
137    StoreError(StoreError),
138    EncryptionError(StoreEncryptionError),
139}
140
141impl From<StoreError> for StoreCacheWrapperError {
142    fn from(value: StoreError) -> Self {
143        StoreCacheWrapperError::StoreError(value)
144    }
145}
146
147impl From<StoreEncryptionError> for StoreCacheWrapperError {
148    fn from(value: StoreEncryptionError) -> Self {
149        StoreCacheWrapperError::EncryptionError(value)
150    }
151}
152
153impl From<serde_json::error::Error> for StoreCacheWrapperError {
154    fn from(value: serde_json::error::Error) -> Self {
155        StoreCacheWrapperError::StoreError(StoreError::Json(value))
156    }
157}
158
159impl From<StoreCacheWrapperError> for StoreError {
160    fn from(val: StoreCacheWrapperError) -> Self {
161        match val {
162            StoreCacheWrapperError::StoreError(e) => e,
163            StoreCacheWrapperError::EncryptionError(e) => StoreError::backend(Box::new(e)),
164        }
165    }
166}
167
168impl core::fmt::Display for StoreCacheWrapperError {
169    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
170        match self {
171            StoreCacheWrapperError::StoreError(e) => {
172                write!(f, "StoreCacheWrapperError::StoreError: {:?}", e)
173            }
174            StoreCacheWrapperError::EncryptionError(e) => {
175                write!(f, "StoreCacheWrapperError::EncryptionError: {:?}", e)
176            }
177        }
178    }
179}
180
181impl std::error::Error for StoreCacheWrapperError {}
182
183pub async fn wrap_with_file_cache<T>(
184    state_store: T,
185    cache_path: PathBuf,
186    passphrase: &str,
187) -> Result<MediaStoreWrapper<T, FileCacheMediaStore>, StoreCacheWrapperError>
188where
189    T: StateStore + Sync + Send,
190{
191    let cipher = if let Some(enc_key) = state_store
192        .get_custom_value(b"ext_media_key")
193        .await
194        .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?
195    {
196        StoreCipher::import(passphrase, &enc_key)
197            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?
198    } else {
199        let cipher = StoreCipher::new()?;
200        let key = cipher
201            .export(passphrase)
202            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?;
203        state_store
204            .set_custom_value(b"ext_media_key", key)
205            .await
206            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?;
207        cipher
208    };
209
210    fs::create_dir_all(cache_path.as_path())
211        .map_err(|e| StoreCacheWrapperError::StoreError(StoreError::Backend(Box::new(e))))?;
212
213    Ok(MediaStoreWrapper::new(
214        state_store,
215        FileCacheMediaStore::with_store_cipher(cache_path, cipher),
216    ))
217}
218
219#[derive(Debug)]
220pub struct MediaStoreWrapper<T, M>
221where
222    T: Debug,
223    M: Debug,
224{
225    inner: T,
226    media: M,
227}
228
229impl<T, M> MediaStoreWrapper<T, M>
230where
231    T: Debug,
232    M: Debug,
233{
234    pub fn new(inner: T, media: M) -> MediaStoreWrapper<T, M> {
235        MediaStoreWrapper { inner, media }
236    }
237}
238
239#[async_trait]
240impl<T, M> StateStore for MediaStoreWrapper<T, M>
241where
242    T: StateStore + Sync + Send,
243    M: MediaStore + Sync + Send,
244{
245    type Error = StoreCacheWrapperError;
246
247    async fn get_kv_data(
248        &self,
249        key: StateStoreDataKey<'_>,
250    ) -> Result<Option<StateStoreDataValue>, Self::Error> {
251        Ok(self
252            .inner
253            .get_kv_data(key)
254            .await
255            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
256    }
257
258    async fn set_kv_data(
259        &self,
260        key: StateStoreDataKey<'_>,
261        value: StateStoreDataValue,
262    ) -> Result<(), Self::Error> {
263        Ok(self
264            .inner
265            .set_kv_data(key, value)
266            .await
267            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
268    }
269
270    async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error> {
271        Ok(self
272            .inner
273            .remove_kv_data(key)
274            .await
275            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
276    }
277
278    async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error> {
279        Ok(self
280            .inner
281            .save_changes(changes)
282            .await
283            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
284    }
285
286    async fn get_presence_event(
287        &self,
288        user_id: &UserId,
289    ) -> Result<Option<Raw<PresenceEvent>>, Self::Error> {
290        Ok(self
291            .inner
292            .get_presence_event(user_id)
293            .await
294            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
295    }
296
297    async fn get_presence_events(
298        &self,
299        user_ids: &[OwnedUserId],
300    ) -> Result<Vec<Raw<PresenceEvent>>, Self::Error> {
301        Ok(self
302            .inner
303            .get_presence_events(user_ids)
304            .await
305            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
306    }
307
308    async fn get_state_event(
309        &self,
310        room_id: &RoomId,
311        event_type: StateEventType,
312        state_key: &str,
313    ) -> Result<Option<RawAnySyncOrStrippedState>, Self::Error> {
314        Ok(self
315            .inner
316            .get_state_event(room_id, event_type, state_key)
317            .await
318            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
319    }
320
321    async fn get_state_events(
322        &self,
323        room_id: &RoomId,
324        event_type: StateEventType,
325    ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
326        Ok(self
327            .inner
328            .get_state_events(room_id, event_type)
329            .await
330            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
331    }
332
333    async fn get_state_events_for_keys(
334        &self,
335        room_id: &RoomId,
336        event_type: StateEventType,
337        state_keys: &[&str],
338    ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
339        Ok(self
340            .inner
341            .get_state_events_for_keys(room_id, event_type, state_keys)
342            .await
343            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
344    }
345
346    async fn get_profile(
347        &self,
348        room_id: &RoomId,
349        user_id: &UserId,
350    ) -> Result<Option<MinimalRoomMemberEvent>, Self::Error> {
351        Ok(self
352            .inner
353            .get_profile(room_id, user_id)
354            .await
355            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
356    }
357
358    async fn get_profiles<'a>(
359        &self,
360        room_id: &RoomId,
361        user_ids: &'a [OwnedUserId],
362    ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>, Self::Error> {
363        Ok(self
364            .inner
365            .get_profiles(room_id, user_ids)
366            .await
367            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
368    }
369
370    async fn get_user_ids(
371        &self,
372        room_id: &RoomId,
373        membership: RoomMemberships,
374    ) -> Result<Vec<OwnedUserId>, Self::Error> {
375        Ok(self
376            .inner
377            .get_user_ids(room_id, membership)
378            .await
379            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
380    }
381
382    #[allow(deprecated)]
383    async fn get_invited_user_ids(
384        &self,
385        room_id: &RoomId,
386    ) -> Result<Vec<OwnedUserId>, Self::Error> {
387        Ok(self
388            .inner
389            .get_invited_user_ids(room_id)
390            .await
391            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
392    }
393
394    #[allow(deprecated)]
395    async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result<Vec<OwnedUserId>, Self::Error> {
396        Ok(self
397            .inner
398            .get_joined_user_ids(room_id)
399            .await
400            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
401    }
402
403    async fn get_room_infos(&self) -> Result<Vec<RoomInfo>, Self::Error> {
404        Ok(self
405            .inner
406            .get_room_infos()
407            .await
408            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
409    }
410
411    #[allow(deprecated)]
412    async fn get_stripped_room_infos(&self) -> Result<Vec<RoomInfo>, Self::Error> {
413        Ok(self
414            .inner
415            .get_stripped_room_infos()
416            .await
417            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
418    }
419
420    async fn get_users_with_display_name(
421        &self,
422        room_id: &RoomId,
423        display_name: &str,
424    ) -> Result<BTreeSet<OwnedUserId>, Self::Error> {
425        Ok(self
426            .inner
427            .get_users_with_display_name(room_id, display_name)
428            .await
429            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
430    }
431
432    async fn get_users_with_display_names<'a>(
433        &self,
434        room_id: &RoomId,
435        display_names: &'a [String],
436    ) -> Result<BTreeMap<&'a str, BTreeSet<OwnedUserId>>, Self::Error> {
437        Ok(self
438            .inner
439            .get_users_with_display_names(room_id, display_names)
440            .await
441            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
442    }
443
444    async fn get_account_data_event(
445        &self,
446        event_type: GlobalAccountDataEventType,
447    ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>, Self::Error> {
448        Ok(self
449            .inner
450            .get_account_data_event(event_type)
451            .await
452            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
453    }
454
455    async fn get_room_account_data_event(
456        &self,
457        room_id: &RoomId,
458        event_type: RoomAccountDataEventType,
459    ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>, Self::Error> {
460        Ok(self
461            .inner
462            .get_room_account_data_event(room_id, event_type)
463            .await
464            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
465    }
466
467    async fn get_user_room_receipt_event(
468        &self,
469        room_id: &RoomId,
470        receipt_type: ReceiptType,
471        thread: ReceiptThread,
472        user_id: &UserId,
473    ) -> Result<Option<(OwnedEventId, Receipt)>, Self::Error> {
474        Ok(self
475            .inner
476            .get_user_room_receipt_event(room_id, receipt_type, thread, user_id)
477            .await
478            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
479    }
480
481    async fn get_event_room_receipt_events(
482        &self,
483        room_id: &RoomId,
484        receipt_type: ReceiptType,
485        thread: ReceiptThread,
486        event_id: &EventId,
487    ) -> Result<Vec<(OwnedUserId, Receipt)>, Self::Error> {
488        Ok(self
489            .inner
490            .get_event_room_receipt_events(room_id, receipt_type, thread, event_id)
491            .await
492            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
493    }
494
495    async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
496        Ok(self
497            .inner
498            .get_custom_value(key)
499            .await
500            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
501    }
502
503    async fn set_custom_value(
504        &self,
505        key: &[u8],
506        value: Vec<u8>,
507    ) -> Result<Option<Vec<u8>>, Self::Error> {
508        Ok(self
509            .inner
510            .set_custom_value(key, value)
511            .await
512            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
513    }
514
515    async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
516        Ok(self
517            .inner
518            .remove_custom_value(key)
519            .await
520            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
521    }
522
523    async fn remove_room(&self, room_id: &RoomId) -> Result<(), Self::Error> {
524        Ok(self
525            .inner
526            .remove_room(room_id)
527            .await
528            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
529    }
530
531    // All the media stuff!
532
533    async fn add_media_content(
534        &self,
535        request: &MediaRequest,
536        content: Vec<u8>,
537    ) -> Result<(), Self::Error> {
538        Ok(self
539            .media
540            .add_media_content(request, content)
541            .await
542            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
543    }
544
545    async fn get_media_content(
546        &self,
547        request: &MediaRequest,
548    ) -> Result<Option<Vec<u8>>, Self::Error> {
549        Ok(self
550            .media
551            .get_media_content(request)
552            .await
553            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
554    }
555
556    async fn remove_media_content(&self, request: &MediaRequest) -> Result<(), Self::Error> {
557        Ok(self
558            .media
559            .remove_media_content(request)
560            .await
561            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
562    }
563
564    async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<(), Self::Error> {
565        Ok(self
566            .media
567            .remove_media_content_for_uri(uri)
568            .await
569            .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576    use anyhow::Result;
577    use matrix_sdk_base::media::MediaFormat;
578    use matrix_sdk_sqlite::SqliteStateStore;
579    use matrix_sdk_test::async_test;
580    use ruma_common::OwnedMxcUri;
581    use ruma_events::room::MediaSource;
582    use uuid::Uuid;
583
584    fn fake_mr(id: &str) -> MediaRequest {
585        MediaRequest {
586            source: MediaSource::Plain(OwnedMxcUri::from(id)),
587            format: MediaFormat::File,
588        }
589    }
590
591    #[async_test]
592    async fn it_works() -> Result<()> {
593        let cache_dir = tempfile::tempdir()?;
594        let cipher = StoreCipher::new()?;
595        let fmc = FileCacheMediaStore::with_store_cipher(cache_dir.into_path(), cipher);
596        let some_content = "this is some content";
597        fmc.add_media_content(&fake_mr("my_id"), some_content.into())
598            .await?;
599        assert_eq!(
600            fmc.get_media_content(&fake_mr("my_id")).await?,
601            Some(some_content.into())
602        );
603
604        Ok(())
605    }
606
607    #[async_test]
608    async fn it_works_after_restart() -> Result<()> {
609        let cache_dir = tempfile::tempdir()?;
610        let passphrase = "this is a secret passphrase";
611        let some_content = "this is some content";
612        let my_item_id = "my_id";
613        let enc_key = {
614            // first media cache
615            let cipher = StoreCipher::new()?;
616            let export = cipher.export(passphrase)?;
617            let fmc =
618                FileCacheMediaStore::with_store_cipher(cache_dir.path().to_path_buf(), cipher);
619            fmc.add_media_content(&fake_mr(my_item_id), some_content.into())
620                .await?;
621            assert_eq!(
622                fmc.get_media_content(&fake_mr(my_item_id)).await?,
623                Some(some_content.into())
624            );
625            export
626        };
627
628        // second media cache
629        let cipher = StoreCipher::import(passphrase, &enc_key)?;
630        let fmc = FileCacheMediaStore::with_store_cipher(cache_dir.path().to_path_buf(), cipher);
631        assert_eq!(
632            fmc.get_media_content(&fake_mr(my_item_id)).await?,
633            Some(some_content.into())
634        );
635
636        Ok(())
637    }
638
639    #[async_test]
640    async fn test_with_sqlite_store() -> Result<()> {
641        let db_path = tempfile::tempdir()?;
642        let cache_dir = tempfile::tempdir()?;
643        let passphrase = Uuid::new_v4().to_string();
644        let some_content = "this is some content";
645        let my_item_id = "my_id";
646        {
647            // as a block means we are closing things up
648            let db = SqliteStateStore::open(db_path.path(), Some(&passphrase)).await?;
649            let outer =
650                wrap_with_file_cache(db, cache_dir.path().to_path_buf(), &passphrase).await?;
651            // first media cache
652            outer
653                .add_media_content(&fake_mr(my_item_id), some_content.into())
654                .await?;
655            assert_eq!(
656                outer.get_media_content(&fake_mr(my_item_id)).await?,
657                Some(some_content.into())
658            );
659        };
660
661        // second media cache
662        let db = SqliteStateStore::open(db_path, Some(&passphrase)).await?;
663        let outer = wrap_with_file_cache(db, cache_dir.path().to_path_buf(), &passphrase).await?;
664        // first media cache
665        outer
666            .add_media_content(&fake_mr(my_item_id), some_content.into())
667            .await?;
668        assert_eq!(
669            outer.get_media_content(&fake_mr(my_item_id)).await?,
670            Some(some_content.into())
671        );
672
673        // and try out all the functions.
674        outer.set_custom_value(b"A", "b".into()).await?;
675        assert_eq!(outer.get_custom_value(b"A").await?, Some("b".into()));
676
677        Ok(())
678    }
679}