matrix_sdk/sliding_sync/
cache.rs1use matrix_sdk_base::{StateStore, StoreError};
8use matrix_sdk_common::timer;
9use ruma::UserId;
10use tracing::{trace, warn};
11
12use super::{FrozenSlidingSyncList, SlidingSync, SlidingSyncPositionMarkers};
13#[cfg(doc)]
14use crate::sliding_sync::SlidingSyncList;
15use crate::{Client, Result, sliding_sync::SlidingSyncListCachePolicy};
16
17pub(super) fn format_storage_key_prefix(id: &str, user_id: &UserId) -> String {
20 format!("sliding_sync_store::{id}::{user_id}")
21}
22
23#[cfg(feature = "e2e-encryption")]
26fn format_storage_key_for_sliding_sync(storage_key: &str) -> String {
27 format!("{storage_key}::instance")
28}
29
30fn format_storage_key_for_sliding_sync_list(storage_key: &str, list_name: &str) -> String {
33 format!("{storage_key}::list::{list_name}")
34}
35
36async fn remove_cached_list(
38 storage: &dyn StateStore<Error = StoreError>,
39 storage_key: &str,
40 list_name: &str,
41) {
42 let storage_key_for_list = format_storage_key_for_sliding_sync_list(storage_key, list_name);
43 let _ = storage.remove_custom_value(storage_key_for_list.as_bytes()).await;
44}
45
46pub(super) async fn store_sliding_sync_state(
48 sliding_sync: &SlidingSync,
49 _position: &SlidingSyncPositionMarkers,
50) -> Result<()> {
51 let storage_key = &sliding_sync.inner.storage_key;
52
53 trace!(storage_key, "Saving a `SlidingSync` to the state store");
54 let storage = sliding_sync.inner.client.state_store();
55
56 #[cfg(feature = "e2e-encryption")]
57 {
58 let position = _position;
59 let instance_storage_key = format_storage_key_for_sliding_sync(storage_key);
60
61 if let Some(olm_machine) = &*sliding_sync.inner.client.olm_machine().await {
66 let pos_blob = serde_json::to_vec(&FrozenSlidingSyncPos { pos: position.pos.clone() })?;
67 olm_machine.store().set_custom_value(&instance_storage_key, pos_blob).await?;
68 }
69 }
70
71 let frozen_lists = {
73 sliding_sync
74 .inner
75 .lists
76 .read()
77 .await
78 .iter()
79 .filter(|(_, list)| matches!(list.cache_policy(), SlidingSyncListCachePolicy::Enabled))
80 .map(|(list_name, list)| {
81 Ok((
82 format_storage_key_for_sliding_sync_list(storage_key, list_name),
83 serde_json::to_vec(&FrozenSlidingSyncList::freeze(list))?,
84 ))
85 })
86 .collect::<Result<Vec<_>, crate::Error>>()?
87 };
88
89 for (storage_key_for_list, frozen_list) in frozen_lists {
90 trace!(storage_key_for_list, "Saving a `SlidingSyncList`");
91
92 storage.set_custom_value(storage_key_for_list.as_bytes(), frozen_list).await?;
93 }
94
95 Ok(())
96}
97
98pub(super) async fn restore_sliding_sync_list(
102 storage: &dyn StateStore<Error = StoreError>,
103 storage_key: &str,
104 list_name: &str,
105) -> Result<Option<FrozenSlidingSyncList>> {
106 let _timer = timer!(format!("loading list from DB {list_name}"));
107
108 let storage_key_for_list = format_storage_key_for_sliding_sync_list(storage_key, list_name);
109
110 match storage
111 .get_custom_value(storage_key_for_list.as_bytes())
112 .await?
113 .map(|custom_value| serde_json::from_slice::<FrozenSlidingSyncList>(&custom_value))
114 {
115 Some(Ok(frozen_list)) => {
116 trace!(list_name, "successfully read the list from cache");
118 return Ok(Some(frozen_list));
119 }
120
121 Some(Err(_)) => {
122 warn!(
128 list_name,
129 "failed to deserialize the list from the cache, it is obsolete; removing the cache entry!"
130 );
131 remove_cached_list(storage, storage_key, list_name).await;
133 }
134
135 None => {
136 trace!(list_name, "failed to find the list in the cache");
139 }
140 }
141
142 Ok(None)
143}
144
145#[derive(Default)]
147pub(super) struct RestoredFields {
148 pub to_device_token: Option<String>,
149 pub pos: Option<String>,
150}
151
152#[cfg(feature = "e2e-encryption")]
155#[derive(serde::Serialize, serde::Deserialize)]
156struct FrozenSlidingSyncPos {
157 #[serde(skip_serializing_if = "Option::is_none")]
158 pos: Option<String>,
159}
160
161pub(super) async fn restore_sliding_sync_state(
166 _client: &Client,
167 _storage_key: &str,
168) -> Result<Option<RestoredFields>> {
169 #[cfg(not(feature = "e2e-encryption"))]
170 return Ok(Some(Default::default()));
171
172 #[cfg(feature = "e2e-encryption")]
173 {
174 let _timer = timer!(format!("loading sliding sync {_storage_key} state from DB"));
175
176 let mut restored_fields = RestoredFields::default();
177
178 if let Some(olm_machine) = &*_client.olm_machine().await {
179 match olm_machine.store().next_batch_token().await? {
180 Some(token) => {
181 restored_fields.to_device_token = Some(token);
182 }
183 None => trace!("Couldn't read the previous to-device token from the crypto store"),
184 }
185
186 let instance_storage_key = format_storage_key_for_sliding_sync(_storage_key);
187
188 if let Ok(Some(blob)) =
189 olm_machine.store().get_custom_value(&instance_storage_key).await
190 && let Ok(frozen_pos) = serde_json::from_slice::<FrozenSlidingSyncPos>(&blob)
191 {
192 trace!("Successfully read the `Sliding Sync` pos from the crypto store cache");
193 restored_fields.pos = frozen_pos.pos;
194 }
195 }
196
197 Ok(Some(restored_fields))
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use std::sync::{Arc, RwLock};
204
205 use matrix_sdk_test::async_test;
206
207 #[cfg(feature = "e2e-encryption")]
208 use super::format_storage_key_for_sliding_sync;
209 use super::{
210 super::SlidingSyncList, format_storage_key_for_sliding_sync_list,
211 format_storage_key_prefix, restore_sliding_sync_state, store_sliding_sync_state,
212 };
213 use crate::{Result, test_utils::logged_in_client};
214
215 #[allow(clippy::await_holding_lock)]
216 #[async_test]
217 async fn test_sliding_sync_can_be_stored_and_restored() -> Result<()> {
218 let client = logged_in_client(Some("https://foo.bar".to_owned())).await;
219
220 let store = client.state_store();
221
222 let sync_id = "test-sync-id";
223 let storage_key = format_storage_key_prefix(sync_id, client.user_id().unwrap());
224
225 assert!(
227 store
228 .get_custom_value(
229 format_storage_key_for_sliding_sync_list(&storage_key, "list_foo").as_bytes()
230 )
231 .await?
232 .is_none()
233 );
234
235 assert!(
236 store
237 .get_custom_value(
238 format_storage_key_for_sliding_sync_list(&storage_key, "list_bar").as_bytes()
239 )
240 .await?
241 .is_none()
242 );
243
244 let storage_key = {
246 let sliding_sync = client
247 .sliding_sync(sync_id)?
248 .add_cached_list(SlidingSyncList::builder("list_foo"))
249 .await?
250 .add_list(SlidingSyncList::builder("list_bar"))
251 .build()
252 .await?;
253
254 {
256 let lists = sliding_sync.inner.lists.write().await;
257
258 let list_foo = lists.get("list_foo").unwrap();
259 list_foo.set_maximum_number_of_rooms(Some(42));
260
261 let list_bar = lists.get("list_bar").unwrap();
262 list_bar.set_maximum_number_of_rooms(Some(1337));
263 }
264
265 let position_guard = sliding_sync.inner.position.lock().await;
266 assert!(sliding_sync.cache_to_storage(&position_guard).await.is_ok());
267
268 storage_key
269 };
270
271 assert!(
273 store
274 .get_custom_value(
275 format_storage_key_for_sliding_sync_list(&storage_key, "list_foo").as_bytes()
276 )
277 .await?
278 .is_some()
279 );
280
281 assert!(
283 store
284 .get_custom_value(
285 format_storage_key_for_sliding_sync_list(&storage_key, "list_bar").as_bytes()
286 )
287 .await?
288 .is_none()
289 );
290
291 let max_number_of_room_stream = Arc::new(RwLock::new(None));
293 let cloned_stream = max_number_of_room_stream.clone();
294 let sliding_sync = client
295 .sliding_sync(sync_id)?
296 .add_cached_list(SlidingSyncList::builder("list_foo").once_built(move |list| {
297 assert_eq!(list.maximum_number_of_rooms(), None);
299
300 let mut stream = cloned_stream.write().unwrap();
301 *stream = Some(list.maximum_number_of_rooms_stream());
302 list
303 }))
304 .await?
305 .add_list(SlidingSyncList::builder("list_bar"))
306 .build()
307 .await?;
308
309 {
311 let lists = sliding_sync.inner.lists.read().await;
312
313 let list_foo = lists.get("list_foo").unwrap();
315 assert_eq!(list_foo.maximum_number_of_rooms(), Some(42));
316
317 let list_bar = lists.get("list_bar").unwrap();
319 assert_eq!(list_bar.maximum_number_of_rooms(), None);
320 }
321
322 {
325 let mut stream =
326 max_number_of_room_stream.write().unwrap().take().expect("stream must be set");
327 let initial_max_number_of_rooms =
328 stream.next().await.expect("stream must have emitted something");
329 assert_eq!(initial_max_number_of_rooms, Some(42));
330 }
331
332 Ok(())
333 }
334
335 #[cfg(feature = "e2e-encryption")]
336 #[async_test]
337 async fn test_sliding_sync_high_level_cache_and_restore() -> Result<()> {
338 let client = logged_in_client(Some("https://foo.bar".to_owned())).await;
339
340 let sync_id = "test-sync-id";
341 let storage_key_prefix = format_storage_key_prefix(sync_id, client.user_id().unwrap());
342 let full_storage_key = format_storage_key_for_sliding_sync(&storage_key_prefix);
343 let sliding_sync = client.sliding_sync(sync_id)?.build().await?;
344
345 if let Some(olm_machine) = &*client.base_client().olm_machine().await {
347 let store = olm_machine.store();
348 assert!(store.next_batch_token().await?.is_none());
349 }
350
351 let state_store = client.state_store();
352 assert!(state_store.get_custom_value(full_storage_key.as_bytes()).await?.is_none());
353
354 let pos = "pos".to_owned();
356 {
357 let mut position_guard = sliding_sync.inner.position.lock().await;
358 position_guard.pos = Some(pos.clone());
359
360 store_sliding_sync_state(&sliding_sync, &position_guard).await?;
362 }
363
364 drop(sliding_sync);
366
367 let restored_fields = restore_sliding_sync_state(&client, &storage_key_prefix)
368 .await?
369 .expect("must have restored sliding sync fields");
370
371 assert_eq!(restored_fields.pos.unwrap(), pos);
373
374 {
379 let olm_machine = client.base_client().olm_machine().await;
380 let olm_machine = olm_machine.as_ref().unwrap();
381 assert!(olm_machine.store().next_batch_token().await?.is_none());
382 }
383
384 Ok(())
385 }
386}