matrix_sdk/sliding_sync/
cache.rs1use matrix_sdk_base::{StateStore, StoreError};
8use matrix_sdk_common::timer;
9use ruma::UserId;
10use tracing::{trace, warn};
11
12use super::{FrozenSlidingSyncList, SlidingSync, SlidingSyncPositionMarkers};
13#[cfg(feature = "e2e-encryption")]
14use crate::sliding_sync::FrozenSlidingSyncPos;
15#[cfg(doc)]
16use crate::sliding_sync::SlidingSyncList;
17use crate::{sliding_sync::SlidingSyncListCachePolicy, Client, Result};
18
19pub(super) fn format_storage_key_prefix(id: &str, user_id: &UserId) -> String {
22 format!("sliding_sync_store::{id}::{user_id}")
23}
24
25#[cfg(feature = "e2e-encryption")]
28fn format_storage_key_for_sliding_sync(storage_key: &str) -> String {
29 format!("{storage_key}::instance")
30}
31
32fn format_storage_key_for_sliding_sync_list(storage_key: &str, list_name: &str) -> String {
35 format!("{storage_key}::list::{list_name}")
36}
37
38async fn remove_cached_list(
40 storage: &dyn StateStore<Error = StoreError>,
41 storage_key: &str,
42 list_name: &str,
43) {
44 let storage_key_for_list = format_storage_key_for_sliding_sync_list(storage_key, list_name);
45 let _ = storage.remove_custom_value(storage_key_for_list.as_bytes()).await;
46}
47
48pub(super) async fn store_sliding_sync_state(
50 sliding_sync: &SlidingSync,
51 _position: &SlidingSyncPositionMarkers,
52) -> Result<()> {
53 let storage_key = &sliding_sync.inner.storage_key;
54
55 trace!(storage_key, "Saving a `SlidingSync` to the state store");
56 let storage = sliding_sync.inner.client.state_store();
57
58 #[cfg(feature = "e2e-encryption")]
59 {
60 let position = _position;
61 let instance_storage_key = format_storage_key_for_sliding_sync(storage_key);
62
63 if let Some(olm_machine) = &*sliding_sync.inner.client.olm_machine().await {
68 let pos_blob = serde_json::to_vec(&FrozenSlidingSyncPos { pos: position.pos.clone() })?;
69 olm_machine.store().set_custom_value(&instance_storage_key, pos_blob).await?;
70 }
71 }
72
73 let frozen_lists = {
75 sliding_sync
76 .inner
77 .lists
78 .read()
79 .await
80 .iter()
81 .filter(|(_, list)| matches!(list.cache_policy(), SlidingSyncListCachePolicy::Enabled))
82 .map(|(list_name, list)| {
83 Ok((
84 format_storage_key_for_sliding_sync_list(storage_key, list_name),
85 serde_json::to_vec(&FrozenSlidingSyncList::freeze(list))?,
86 ))
87 })
88 .collect::<Result<Vec<_>, crate::Error>>()?
89 };
90
91 for (storage_key_for_list, frozen_list) in frozen_lists {
92 trace!(storage_key_for_list, "Saving a `SlidingSyncList`");
93
94 storage.set_custom_value(storage_key_for_list.as_bytes(), frozen_list).await?;
95 }
96
97 Ok(())
98}
99
100pub(super) async fn restore_sliding_sync_list(
104 storage: &dyn StateStore<Error = StoreError>,
105 storage_key: &str,
106 list_name: &str,
107) -> Result<Option<FrozenSlidingSyncList>> {
108 let _timer = timer!(format!("loading list from DB {list_name}"));
109
110 let storage_key_for_list = format_storage_key_for_sliding_sync_list(storage_key, list_name);
111
112 match storage
113 .get_custom_value(storage_key_for_list.as_bytes())
114 .await?
115 .map(|custom_value| serde_json::from_slice::<FrozenSlidingSyncList>(&custom_value))
116 {
117 Some(Ok(frozen_list)) => {
118 trace!(list_name, "successfully read the list from cache");
120 return Ok(Some(frozen_list));
121 }
122
123 Some(Err(_)) => {
124 warn!(
130 list_name,
131 "failed to deserialize the list from the cache, it is obsolete; removing the cache entry!"
132 );
133 remove_cached_list(storage, storage_key, list_name).await;
135 }
136
137 None => {
138 trace!(list_name, "failed to find the list in the cache");
141 }
142 }
143
144 Ok(None)
145}
146
147#[derive(Default)]
149pub(super) struct RestoredFields {
150 pub to_device_token: Option<String>,
151 pub pos: Option<String>,
152}
153
154pub(super) async fn restore_sliding_sync_state(
159 _client: &Client,
160 storage_key: &str,
161) -> Result<Option<RestoredFields>> {
162 let _timer = timer!(format!("loading sliding sync {storage_key} state from DB"));
163
164 #[cfg_attr(not(feature = "e2e-encryption"), allow(unused_mut))]
165 let mut restored_fields = RestoredFields::default();
166
167 #[cfg(feature = "e2e-encryption")]
168 if let Some(olm_machine) = &*_client.olm_machine().await {
169 match olm_machine.store().next_batch_token().await? {
170 Some(token) => {
171 restored_fields.to_device_token = Some(token);
172 }
173 None => trace!("No `SlidingSync` in the crypto-store cache"),
174 }
175 }
176
177 #[cfg(feature = "e2e-encryption")]
179 if let Some(olm_machine) = &*_client.olm_machine().await {
180 let instance_storage_key = format_storage_key_for_sliding_sync(storage_key);
181
182 if let Ok(Some(blob)) = olm_machine.store().get_custom_value(&instance_storage_key).await {
183 if let Ok(frozen_pos) = serde_json::from_slice::<FrozenSlidingSyncPos>(&blob) {
184 trace!("Successfully read the `Sliding Sync` pos from the crypto store cache");
185 restored_fields.pos = frozen_pos.pos;
186 }
187 }
188 }
189
190 Ok(Some(restored_fields))
191}
192
193#[cfg(test)]
194mod tests {
195 use std::sync::{Arc, RwLock};
196
197 use matrix_sdk_test::async_test;
198
199 #[cfg(feature = "e2e-encryption")]
200 use super::format_storage_key_for_sliding_sync;
201 use super::{
202 super::SlidingSyncList, format_storage_key_for_sliding_sync_list,
203 format_storage_key_prefix, restore_sliding_sync_state, store_sliding_sync_state,
204 };
205 use crate::{test_utils::logged_in_client, Result};
206
207 #[allow(clippy::await_holding_lock)]
208 #[async_test]
209 async fn test_sliding_sync_can_be_stored_and_restored() -> Result<()> {
210 let client = logged_in_client(Some("https://foo.bar".to_owned())).await;
211
212 let store = client.state_store();
213
214 let sync_id = "test-sync-id";
215 let storage_key = format_storage_key_prefix(sync_id, client.user_id().unwrap());
216
217 assert!(store
219 .get_custom_value(
220 format_storage_key_for_sliding_sync_list(&storage_key, "list_foo").as_bytes()
221 )
222 .await?
223 .is_none());
224
225 assert!(store
226 .get_custom_value(
227 format_storage_key_for_sliding_sync_list(&storage_key, "list_bar").as_bytes()
228 )
229 .await?
230 .is_none());
231
232 let storage_key = {
234 let sliding_sync = client
235 .sliding_sync(sync_id)?
236 .add_cached_list(SlidingSyncList::builder("list_foo"))
237 .await?
238 .add_list(SlidingSyncList::builder("list_bar"))
239 .build()
240 .await?;
241
242 {
244 let lists = sliding_sync.inner.lists.write().await;
245
246 let list_foo = lists.get("list_foo").unwrap();
247 list_foo.set_maximum_number_of_rooms(Some(42));
248
249 let list_bar = lists.get("list_bar").unwrap();
250 list_bar.set_maximum_number_of_rooms(Some(1337));
251 }
252
253 let position_guard = sliding_sync.inner.position.lock().await;
254 assert!(sliding_sync.cache_to_storage(&position_guard).await.is_ok());
255
256 storage_key
257 };
258
259 assert!(store
261 .get_custom_value(
262 format_storage_key_for_sliding_sync_list(&storage_key, "list_foo").as_bytes()
263 )
264 .await?
265 .is_some());
266
267 assert!(store
269 .get_custom_value(
270 format_storage_key_for_sliding_sync_list(&storage_key, "list_bar").as_bytes()
271 )
272 .await?
273 .is_none());
274
275 let max_number_of_room_stream = Arc::new(RwLock::new(None));
277 let cloned_stream = max_number_of_room_stream.clone();
278 let sliding_sync = client
279 .sliding_sync(sync_id)?
280 .add_cached_list(SlidingSyncList::builder("list_foo").once_built(move |list| {
281 assert_eq!(list.maximum_number_of_rooms(), None);
283
284 let mut stream = cloned_stream.write().unwrap();
285 *stream = Some(list.maximum_number_of_rooms_stream());
286 list
287 }))
288 .await?
289 .add_list(SlidingSyncList::builder("list_bar"))
290 .build()
291 .await?;
292
293 {
295 let lists = sliding_sync.inner.lists.read().await;
296
297 let list_foo = lists.get("list_foo").unwrap();
299 assert_eq!(list_foo.maximum_number_of_rooms(), Some(42));
300
301 let list_bar = lists.get("list_bar").unwrap();
303 assert_eq!(list_bar.maximum_number_of_rooms(), None);
304 }
305
306 {
309 let mut stream =
310 max_number_of_room_stream.write().unwrap().take().expect("stream must be set");
311 let initial_max_number_of_rooms =
312 stream.next().await.expect("stream must have emitted something");
313 assert_eq!(initial_max_number_of_rooms, Some(42));
314 }
315
316 Ok(())
317 }
318
319 #[cfg(feature = "e2e-encryption")]
320 #[async_test]
321 async fn test_sliding_sync_high_level_cache_and_restore() -> Result<()> {
322 let client = logged_in_client(Some("https://foo.bar".to_owned())).await;
323
324 let sync_id = "test-sync-id";
325 let storage_key_prefix = format_storage_key_prefix(sync_id, client.user_id().unwrap());
326 let full_storage_key = format_storage_key_for_sliding_sync(&storage_key_prefix);
327 let sliding_sync = client.sliding_sync(sync_id)?.build().await?;
328
329 if let Some(olm_machine) = &*client.base_client().olm_machine().await {
331 let store = olm_machine.store();
332 assert!(store.next_batch_token().await?.is_none());
333 }
334
335 let state_store = client.state_store();
336 assert!(state_store.get_custom_value(full_storage_key.as_bytes()).await?.is_none());
337
338 let pos = "pos".to_owned();
340 {
341 let mut position_guard = sliding_sync.inner.position.lock().await;
342 position_guard.pos = Some(pos.clone());
343
344 store_sliding_sync_state(&sliding_sync, &position_guard).await?;
346 }
347
348 drop(sliding_sync);
350
351 let restored_fields = restore_sliding_sync_state(&client, &storage_key_prefix)
352 .await?
353 .expect("must have restored sliding sync fields");
354
355 assert_eq!(restored_fields.pos.unwrap(), pos);
357
358 {
363 let olm_machine = client.base_client().olm_machine().await;
364 let olm_machine = olm_machine.as_ref().unwrap();
365 assert!(olm_machine.store().next_batch_token().await?.is_none());
366 }
367
368 Ok(())
369 }
370}