1use async_trait::async_trait;
2use base64ct::{Base64UrlUnpadded, Encoding};
3use core::fmt::Debug;
4use matrix_sdk_base::{
5 deserialized_responses::RawAnySyncOrStrippedState,
6 media::{MediaRequest, UniqueKey},
7 store::StoreEncryptionError,
8 MinimalRoomMemberEvent, RoomInfo, RoomMemberships, StateChanges, StateStore, StateStoreDataKey,
9 StateStoreDataValue, StoreError,
10};
11use matrix_sdk_store_encryption::StoreCipher;
12use ruma_common::{serde::Raw, EventId, MxcUri, OwnedEventId, OwnedUserId, RoomId, UserId};
13use ruma_events::{
14 presence::PresenceEvent,
15 receipt::{Receipt, ReceiptThread, ReceiptType},
16 AnyGlobalAccountDataEvent, AnyRoomAccountDataEvent, GlobalAccountDataEventType,
17 RoomAccountDataEventType, StateEventType,
18};
19use std::{
20 collections::{BTreeMap, BTreeSet},
21 fs,
22 path::PathBuf,
23};
24use tracing::instrument;
25
26#[async_trait]
27trait MediaStore: Debug + Sync + Send {
28 type Error: Debug + Into<StoreError> + From<serde_json::Error>;
29
30 async fn add_media_content(
31 &self,
32 request: &MediaRequest,
33 content: Vec<u8>,
34 ) -> Result<(), Self::Error>;
35
36 async fn get_media_content(
37 &self,
38 request: &MediaRequest,
39 ) -> Result<Option<Vec<u8>>, Self::Error>;
40
41 async fn remove_media_content(&self, request: &MediaRequest) -> Result<(), Self::Error>;
42
43 async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<(), Self::Error>;
44}
45
46pub struct FileCacheMediaStore {
47 cache_dir: PathBuf,
48 store_cipher: StoreCipher,
49}
50
51impl Debug for FileCacheMediaStore {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("FileCacheMediaStore")
54 .field("cache_dir", &self.cache_dir)
55 .finish()
56 }
57}
58
59impl FileCacheMediaStore {
60 pub fn with_store_cipher(cache_dir: PathBuf, store_cipher: StoreCipher) -> FileCacheMediaStore {
61 FileCacheMediaStore {
62 cache_dir,
63 store_cipher,
64 }
65 }
66
67 fn encode_value(&self, value: Vec<u8>) -> Result<Vec<u8>, StoreError> {
68 let encoded = self
69 .store_cipher
70 .encrypt_value_data(value)
71 .map_err(StoreError::backend)?;
72 rmp_serde::to_vec_named(&encoded).map_err(StoreError::backend)
73 }
74
75 fn decode_value(&self, value: &[u8]) -> Result<Vec<u8>, StoreError> {
76 let encrypted = rmp_serde::from_slice(value).map_err(StoreError::backend)?;
77 self.store_cipher
78 .decrypt_value_data(encrypted)
79 .map_err(StoreError::backend)
80 }
81
82 fn encode_key(&self, key: impl AsRef<[u8]>) -> String {
83 Base64UrlUnpadded::encode_string(&self.store_cipher.hash_key("ext_media", key.as_ref()))
84 }
85}
86
87#[async_trait]
88impl MediaStore for FileCacheMediaStore {
89 type Error = StoreError;
90
91 #[instrument(skip_all)]
92 async fn add_media_content(
93 &self,
94 request: &MediaRequest,
95 content: Vec<u8>,
96 ) -> Result<(), Self::Error> {
97 let base_filename = self.encode_key(request.source.unique_key());
98 let data = self
99 .encode_value(content)
100 .map_err(|e| StoreError::Backend(Box::new(e)))?;
101 fs::write(self.cache_dir.join(base_filename), data)
102 .map_err(|e| StoreError::Backend(Box::new(e)))?;
103 Ok(())
104 }
105
106 #[instrument(skip_all)]
107 async fn get_media_content(
108 &self,
109 request: &MediaRequest,
110 ) -> Result<Option<Vec<u8>>, Self::Error> {
111 let base_filename = self.encode_key(request.source.unique_key());
112 fs::read(self.cache_dir.join(base_filename))
113 .ok()
114 .map(|data| self.decode_value(&data))
115 .transpose()
116 }
117
118 #[instrument(skip_all)]
119 async fn remove_media_content(&self, request: &MediaRequest) -> Result<(), Self::Error> {
120 let base_filename = self.encode_key(request.source.unique_key());
121 fs::remove_file(self.cache_dir.join(base_filename))
122 .map_err(|e| StoreError::Backend(Box::new(e)))?;
123 Ok(())
124 }
125
126 #[instrument(skip_all)]
127 async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<(), Self::Error> {
128 let base_filename = self.encode_key(uri);
129 fs::remove_file(self.cache_dir.join(base_filename))
130 .map_err(|e| StoreError::Backend(Box::new(e)))?;
131 Ok(())
132 }
133}
134
135#[derive(Debug)]
136pub enum StoreCacheWrapperError {
137 StoreError(StoreError),
138 EncryptionError(StoreEncryptionError),
139}
140
141impl From<StoreError> for StoreCacheWrapperError {
142 fn from(value: StoreError) -> Self {
143 StoreCacheWrapperError::StoreError(value)
144 }
145}
146
147impl From<StoreEncryptionError> for StoreCacheWrapperError {
148 fn from(value: StoreEncryptionError) -> Self {
149 StoreCacheWrapperError::EncryptionError(value)
150 }
151}
152
153impl From<serde_json::error::Error> for StoreCacheWrapperError {
154 fn from(value: serde_json::error::Error) -> Self {
155 StoreCacheWrapperError::StoreError(StoreError::Json(value))
156 }
157}
158
159impl From<StoreCacheWrapperError> for StoreError {
160 fn from(val: StoreCacheWrapperError) -> Self {
161 match val {
162 StoreCacheWrapperError::StoreError(e) => e,
163 StoreCacheWrapperError::EncryptionError(e) => StoreError::backend(Box::new(e)),
164 }
165 }
166}
167
168impl core::fmt::Display for StoreCacheWrapperError {
169 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
170 match self {
171 StoreCacheWrapperError::StoreError(e) => {
172 write!(f, "StoreCacheWrapperError::StoreError: {:?}", e)
173 }
174 StoreCacheWrapperError::EncryptionError(e) => {
175 write!(f, "StoreCacheWrapperError::EncryptionError: {:?}", e)
176 }
177 }
178 }
179}
180
181impl std::error::Error for StoreCacheWrapperError {}
182
183pub async fn wrap_with_file_cache<T>(
184 state_store: T,
185 cache_path: PathBuf,
186 passphrase: &str,
187) -> Result<MediaStoreWrapper<T, FileCacheMediaStore>, StoreCacheWrapperError>
188where
189 T: StateStore + Sync + Send,
190{
191 let cipher = if let Some(enc_key) = state_store
192 .get_custom_value(b"ext_media_key")
193 .await
194 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?
195 {
196 StoreCipher::import(passphrase, &enc_key)
197 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?
198 } else {
199 let cipher = StoreCipher::new()?;
200 let key = cipher
201 .export(passphrase)
202 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?;
203 state_store
204 .set_custom_value(b"ext_media_key", key)
205 .await
206 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?;
207 cipher
208 };
209
210 fs::create_dir_all(cache_path.as_path())
211 .map_err(|e| StoreCacheWrapperError::StoreError(StoreError::Backend(Box::new(e))))?;
212
213 Ok(MediaStoreWrapper::new(
214 state_store,
215 FileCacheMediaStore::with_store_cipher(cache_path, cipher),
216 ))
217}
218
219#[derive(Debug)]
220pub struct MediaStoreWrapper<T, M>
221where
222 T: Debug,
223 M: Debug,
224{
225 inner: T,
226 media: M,
227}
228
229impl<T, M> MediaStoreWrapper<T, M>
230where
231 T: Debug,
232 M: Debug,
233{
234 pub fn new(inner: T, media: M) -> MediaStoreWrapper<T, M> {
235 MediaStoreWrapper { inner, media }
236 }
237}
238
239#[async_trait]
240impl<T, M> StateStore for MediaStoreWrapper<T, M>
241where
242 T: StateStore + Sync + Send,
243 M: MediaStore + Sync + Send,
244{
245 type Error = StoreCacheWrapperError;
246
247 async fn get_kv_data(
248 &self,
249 key: StateStoreDataKey<'_>,
250 ) -> Result<Option<StateStoreDataValue>, Self::Error> {
251 Ok(self
252 .inner
253 .get_kv_data(key)
254 .await
255 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
256 }
257
258 async fn set_kv_data(
259 &self,
260 key: StateStoreDataKey<'_>,
261 value: StateStoreDataValue,
262 ) -> Result<(), Self::Error> {
263 Ok(self
264 .inner
265 .set_kv_data(key, value)
266 .await
267 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
268 }
269
270 async fn remove_kv_data(&self, key: StateStoreDataKey<'_>) -> Result<(), Self::Error> {
271 Ok(self
272 .inner
273 .remove_kv_data(key)
274 .await
275 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
276 }
277
278 async fn save_changes(&self, changes: &StateChanges) -> Result<(), Self::Error> {
279 Ok(self
280 .inner
281 .save_changes(changes)
282 .await
283 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
284 }
285
286 async fn get_presence_event(
287 &self,
288 user_id: &UserId,
289 ) -> Result<Option<Raw<PresenceEvent>>, Self::Error> {
290 Ok(self
291 .inner
292 .get_presence_event(user_id)
293 .await
294 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
295 }
296
297 async fn get_presence_events(
298 &self,
299 user_ids: &[OwnedUserId],
300 ) -> Result<Vec<Raw<PresenceEvent>>, Self::Error> {
301 Ok(self
302 .inner
303 .get_presence_events(user_ids)
304 .await
305 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
306 }
307
308 async fn get_state_event(
309 &self,
310 room_id: &RoomId,
311 event_type: StateEventType,
312 state_key: &str,
313 ) -> Result<Option<RawAnySyncOrStrippedState>, Self::Error> {
314 Ok(self
315 .inner
316 .get_state_event(room_id, event_type, state_key)
317 .await
318 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
319 }
320
321 async fn get_state_events(
322 &self,
323 room_id: &RoomId,
324 event_type: StateEventType,
325 ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
326 Ok(self
327 .inner
328 .get_state_events(room_id, event_type)
329 .await
330 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
331 }
332
333 async fn get_state_events_for_keys(
334 &self,
335 room_id: &RoomId,
336 event_type: StateEventType,
337 state_keys: &[&str],
338 ) -> Result<Vec<RawAnySyncOrStrippedState>, Self::Error> {
339 Ok(self
340 .inner
341 .get_state_events_for_keys(room_id, event_type, state_keys)
342 .await
343 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
344 }
345
346 async fn get_profile(
347 &self,
348 room_id: &RoomId,
349 user_id: &UserId,
350 ) -> Result<Option<MinimalRoomMemberEvent>, Self::Error> {
351 Ok(self
352 .inner
353 .get_profile(room_id, user_id)
354 .await
355 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
356 }
357
358 async fn get_profiles<'a>(
359 &self,
360 room_id: &RoomId,
361 user_ids: &'a [OwnedUserId],
362 ) -> Result<BTreeMap<&'a UserId, MinimalRoomMemberEvent>, Self::Error> {
363 Ok(self
364 .inner
365 .get_profiles(room_id, user_ids)
366 .await
367 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
368 }
369
370 async fn get_user_ids(
371 &self,
372 room_id: &RoomId,
373 membership: RoomMemberships,
374 ) -> Result<Vec<OwnedUserId>, Self::Error> {
375 Ok(self
376 .inner
377 .get_user_ids(room_id, membership)
378 .await
379 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
380 }
381
382 #[allow(deprecated)]
383 async fn get_invited_user_ids(
384 &self,
385 room_id: &RoomId,
386 ) -> Result<Vec<OwnedUserId>, Self::Error> {
387 Ok(self
388 .inner
389 .get_invited_user_ids(room_id)
390 .await
391 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
392 }
393
394 #[allow(deprecated)]
395 async fn get_joined_user_ids(&self, room_id: &RoomId) -> Result<Vec<OwnedUserId>, Self::Error> {
396 Ok(self
397 .inner
398 .get_joined_user_ids(room_id)
399 .await
400 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
401 }
402
403 async fn get_room_infos(&self) -> Result<Vec<RoomInfo>, Self::Error> {
404 Ok(self
405 .inner
406 .get_room_infos()
407 .await
408 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
409 }
410
411 #[allow(deprecated)]
412 async fn get_stripped_room_infos(&self) -> Result<Vec<RoomInfo>, Self::Error> {
413 Ok(self
414 .inner
415 .get_stripped_room_infos()
416 .await
417 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
418 }
419
420 async fn get_users_with_display_name(
421 &self,
422 room_id: &RoomId,
423 display_name: &str,
424 ) -> Result<BTreeSet<OwnedUserId>, Self::Error> {
425 Ok(self
426 .inner
427 .get_users_with_display_name(room_id, display_name)
428 .await
429 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
430 }
431
432 async fn get_users_with_display_names<'a>(
433 &self,
434 room_id: &RoomId,
435 display_names: &'a [String],
436 ) -> Result<BTreeMap<&'a str, BTreeSet<OwnedUserId>>, Self::Error> {
437 Ok(self
438 .inner
439 .get_users_with_display_names(room_id, display_names)
440 .await
441 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
442 }
443
444 async fn get_account_data_event(
445 &self,
446 event_type: GlobalAccountDataEventType,
447 ) -> Result<Option<Raw<AnyGlobalAccountDataEvent>>, Self::Error> {
448 Ok(self
449 .inner
450 .get_account_data_event(event_type)
451 .await
452 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
453 }
454
455 async fn get_room_account_data_event(
456 &self,
457 room_id: &RoomId,
458 event_type: RoomAccountDataEventType,
459 ) -> Result<Option<Raw<AnyRoomAccountDataEvent>>, Self::Error> {
460 Ok(self
461 .inner
462 .get_room_account_data_event(room_id, event_type)
463 .await
464 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
465 }
466
467 async fn get_user_room_receipt_event(
468 &self,
469 room_id: &RoomId,
470 receipt_type: ReceiptType,
471 thread: ReceiptThread,
472 user_id: &UserId,
473 ) -> Result<Option<(OwnedEventId, Receipt)>, Self::Error> {
474 Ok(self
475 .inner
476 .get_user_room_receipt_event(room_id, receipt_type, thread, user_id)
477 .await
478 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
479 }
480
481 async fn get_event_room_receipt_events(
482 &self,
483 room_id: &RoomId,
484 receipt_type: ReceiptType,
485 thread: ReceiptThread,
486 event_id: &EventId,
487 ) -> Result<Vec<(OwnedUserId, Receipt)>, Self::Error> {
488 Ok(self
489 .inner
490 .get_event_room_receipt_events(room_id, receipt_type, thread, event_id)
491 .await
492 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
493 }
494
495 async fn get_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
496 Ok(self
497 .inner
498 .get_custom_value(key)
499 .await
500 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
501 }
502
503 async fn set_custom_value(
504 &self,
505 key: &[u8],
506 value: Vec<u8>,
507 ) -> Result<Option<Vec<u8>>, Self::Error> {
508 Ok(self
509 .inner
510 .set_custom_value(key, value)
511 .await
512 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
513 }
514
515 async fn remove_custom_value(&self, key: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
516 Ok(self
517 .inner
518 .remove_custom_value(key)
519 .await
520 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
521 }
522
523 async fn remove_room(&self, room_id: &RoomId) -> Result<(), Self::Error> {
524 Ok(self
525 .inner
526 .remove_room(room_id)
527 .await
528 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
529 }
530
531 async fn add_media_content(
534 &self,
535 request: &MediaRequest,
536 content: Vec<u8>,
537 ) -> Result<(), Self::Error> {
538 Ok(self
539 .media
540 .add_media_content(request, content)
541 .await
542 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
543 }
544
545 async fn get_media_content(
546 &self,
547 request: &MediaRequest,
548 ) -> Result<Option<Vec<u8>>, Self::Error> {
549 Ok(self
550 .media
551 .get_media_content(request)
552 .await
553 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
554 }
555
556 async fn remove_media_content(&self, request: &MediaRequest) -> Result<(), Self::Error> {
557 Ok(self
558 .media
559 .remove_media_content(request)
560 .await
561 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
562 }
563
564 async fn remove_media_content_for_uri(&self, uri: &MxcUri) -> Result<(), Self::Error> {
565 Ok(self
566 .media
567 .remove_media_content_for_uri(uri)
568 .await
569 .map_err(|e| StoreCacheWrapperError::StoreError(e.into()))?)
570 }
571}
572
573#[cfg(test)]
574mod tests {
575 use super::*;
576 use anyhow::Result;
577 use matrix_sdk_base::media::MediaFormat;
578 use matrix_sdk_sqlite::SqliteStateStore;
579 use matrix_sdk_test::async_test;
580 use ruma_common::OwnedMxcUri;
581 use ruma_events::room::MediaSource;
582 use uuid::Uuid;
583
584 fn fake_mr(id: &str) -> MediaRequest {
585 MediaRequest {
586 source: MediaSource::Plain(OwnedMxcUri::from(id)),
587 format: MediaFormat::File,
588 }
589 }
590
591 #[async_test]
592 async fn it_works() -> Result<()> {
593 let cache_dir = tempfile::tempdir()?;
594 let cipher = StoreCipher::new()?;
595 let fmc = FileCacheMediaStore::with_store_cipher(cache_dir.into_path(), cipher);
596 let some_content = "this is some content";
597 fmc.add_media_content(&fake_mr("my_id"), some_content.into())
598 .await?;
599 assert_eq!(
600 fmc.get_media_content(&fake_mr("my_id")).await?,
601 Some(some_content.into())
602 );
603
604 Ok(())
605 }
606
607 #[async_test]
608 async fn it_works_after_restart() -> Result<()> {
609 let cache_dir = tempfile::tempdir()?;
610 let passphrase = "this is a secret passphrase";
611 let some_content = "this is some content";
612 let my_item_id = "my_id";
613 let enc_key = {
614 let cipher = StoreCipher::new()?;
616 let export = cipher.export(passphrase)?;
617 let fmc =
618 FileCacheMediaStore::with_store_cipher(cache_dir.path().to_path_buf(), cipher);
619 fmc.add_media_content(&fake_mr(my_item_id), some_content.into())
620 .await?;
621 assert_eq!(
622 fmc.get_media_content(&fake_mr(my_item_id)).await?,
623 Some(some_content.into())
624 );
625 export
626 };
627
628 let cipher = StoreCipher::import(passphrase, &enc_key)?;
630 let fmc = FileCacheMediaStore::with_store_cipher(cache_dir.path().to_path_buf(), cipher);
631 assert_eq!(
632 fmc.get_media_content(&fake_mr(my_item_id)).await?,
633 Some(some_content.into())
634 );
635
636 Ok(())
637 }
638
639 #[async_test]
640 async fn test_with_sqlite_store() -> Result<()> {
641 let db_path = tempfile::tempdir()?;
642 let cache_dir = tempfile::tempdir()?;
643 let passphrase = Uuid::new_v4().to_string();
644 let some_content = "this is some content";
645 let my_item_id = "my_id";
646 {
647 let db = SqliteStateStore::open(db_path.path(), Some(&passphrase)).await?;
649 let outer =
650 wrap_with_file_cache(db, cache_dir.path().to_path_buf(), &passphrase).await?;
651 outer
653 .add_media_content(&fake_mr(my_item_id), some_content.into())
654 .await?;
655 assert_eq!(
656 outer.get_media_content(&fake_mr(my_item_id)).await?,
657 Some(some_content.into())
658 );
659 };
660
661 let db = SqliteStateStore::open(db_path, Some(&passphrase)).await?;
663 let outer = wrap_with_file_cache(db, cache_dir.path().to_path_buf(), &passphrase).await?;
664 outer
666 .add_media_content(&fake_mr(my_item_id), some_content.into())
667 .await?;
668 assert_eq!(
669 outer.get_media_content(&fake_mr(my_item_id)).await?,
670 Some(some_content.into())
671 );
672
673 outer.set_custom_value(b"A", "b".into()).await?;
675 assert_eq!(outer.get_custom_value(b"A").await?, Some("b".into()));
676
677 Ok(())
678 }
679}