matrix_sdk/sliding_sync/
cache.rs1use matrix_sdk_base::{StateStore, StoreError};
8use matrix_sdk_common::timer;
9use ruma::UserId;
10use tracing::{trace, warn};
11
12use super::{FrozenSlidingSyncList, SlidingSync, SlidingSyncPositionMarkers};
13#[cfg(feature = "e2e-encryption")]
14use crate::sliding_sync::FrozenSlidingSyncPos;
15use crate::{sliding_sync::SlidingSyncListCachePolicy, Client, Result};
16
17pub(super) fn format_storage_key_prefix(id: &str, user_id: &UserId) -> String {
20 format!("sliding_sync_store::{id}::{user_id}")
21}
22
23#[cfg(feature = "e2e-encryption")]
26fn format_storage_key_for_sliding_sync(storage_key: &str) -> String {
27 format!("{storage_key}::instance")
28}
29
30fn format_storage_key_for_sliding_sync_list(storage_key: &str, list_name: &str) -> String {
33 format!("{storage_key}::list::{list_name}")
34}
35
36async fn invalidate_cached_list(
39 storage: &dyn StateStore<Error = StoreError>,
40 storage_key: &str,
41 list_name: &str,
42) {
43 let storage_key_for_list = format_storage_key_for_sliding_sync_list(storage_key, list_name);
44 let _ = storage.remove_custom_value(storage_key_for_list.as_bytes()).await;
45}
46
47pub(super) async fn store_sliding_sync_state(
49 sliding_sync: &SlidingSync,
50 _position: &SlidingSyncPositionMarkers,
51) -> Result<()> {
52 let storage_key = &sliding_sync.inner.storage_key;
53
54 trace!(storage_key, "Saving a `SlidingSync` to the state store");
55 let storage = sliding_sync.inner.client.state_store();
56
57 #[cfg(feature = "e2e-encryption")]
58 {
59 let position = _position;
60 let instance_storage_key = format_storage_key_for_sliding_sync(storage_key);
61
62 if let Some(olm_machine) = &*sliding_sync.inner.client.olm_machine().await {
67 let pos_blob = serde_json::to_vec(&FrozenSlidingSyncPos { pos: position.pos.clone() })?;
68 olm_machine.store().set_custom_value(&instance_storage_key, pos_blob).await?;
69 }
70 }
71
72 let frozen_lists = {
74 sliding_sync
75 .inner
76 .lists
77 .read()
78 .await
79 .iter()
80 .filter(|(_, list)| matches!(list.cache_policy(), SlidingSyncListCachePolicy::Enabled))
81 .map(|(list_name, list)| {
82 Ok((
83 format_storage_key_for_sliding_sync_list(storage_key, list_name),
84 serde_json::to_vec(&FrozenSlidingSyncList::freeze(list))?,
85 ))
86 })
87 .collect::<Result<Vec<_>, crate::Error>>()?
88 };
89
90 for (storage_key_for_list, frozen_list) in frozen_lists {
91 trace!(storage_key_for_list, "Saving a `SlidingSyncList`");
92
93 storage.set_custom_value(storage_key_for_list.as_bytes(), frozen_list).await?;
94 }
95
96 Ok(())
97}
98
99pub(super) async fn restore_sliding_sync_list(
103 storage: &dyn StateStore<Error = StoreError>,
104 storage_key: &str,
105 list_name: &str,
106) -> Result<Option<FrozenSlidingSyncList>> {
107 let _timer = timer!(format!("loading list from DB {list_name}"));
108
109 let storage_key_for_list = format_storage_key_for_sliding_sync_list(storage_key, list_name);
110
111 match storage
112 .get_custom_value(storage_key_for_list.as_bytes())
113 .await?
114 .map(|custom_value| serde_json::from_slice::<FrozenSlidingSyncList>(&custom_value))
115 {
116 Some(Ok(frozen_list)) => {
117 trace!(list_name, "successfully read the list from cache");
119 return Ok(Some(frozen_list));
120 }
121
122 Some(Err(_)) => {
123 warn!(
129 list_name,
130 "failed to deserialize the list from the cache, it is obsolete; removing the cache entry!"
131 );
132 invalidate_cached_list(storage, storage_key, list_name).await;
134 }
135
136 None => {
137 trace!(list_name, "failed to find the list in the cache");
140 }
141 }
142
143 Ok(None)
144}
145
146#[derive(Default)]
148pub(super) struct RestoredFields {
149 pub to_device_token: Option<String>,
150 pub pos: Option<String>,
151}
152
153pub(super) async fn restore_sliding_sync_state(
158 _client: &Client,
159 storage_key: &str,
160) -> Result<Option<RestoredFields>> {
161 let _timer = timer!(format!("loading sliding sync {storage_key} state from DB"));
162
163 #[cfg_attr(not(feature = "e2e-encryption"), allow(unused_mut))]
164 let mut restored_fields = RestoredFields::default();
165
166 #[cfg(feature = "e2e-encryption")]
167 if let Some(olm_machine) = &*_client.olm_machine().await {
168 match olm_machine.store().next_batch_token().await? {
169 Some(token) => {
170 restored_fields.to_device_token = Some(token);
171 }
172 None => trace!("No `SlidingSync` in the crypto-store cache"),
173 }
174 }
175
176 #[cfg(feature = "e2e-encryption")]
178 if let Some(olm_machine) = &*_client.olm_machine().await {
179 let instance_storage_key = format_storage_key_for_sliding_sync(storage_key);
180
181 if let Ok(Some(blob)) = olm_machine.store().get_custom_value(&instance_storage_key).await {
182 if let Ok(frozen_pos) = serde_json::from_slice::<FrozenSlidingSyncPos>(&blob) {
183 trace!("Successfully read the `Sliding Sync` pos from the crypto store cache");
184 restored_fields.pos = frozen_pos.pos;
185 }
186 }
187 }
188
189 Ok(Some(restored_fields))
190}
191
192#[cfg(test)]
193mod tests {
194 use std::sync::{Arc, RwLock};
195
196 use matrix_sdk_test::async_test;
197
198 #[cfg(feature = "e2e-encryption")]
199 use super::format_storage_key_for_sliding_sync;
200 use super::{
201 super::SlidingSyncList, format_storage_key_for_sliding_sync_list,
202 format_storage_key_prefix, restore_sliding_sync_state, store_sliding_sync_state,
203 };
204 use crate::{test_utils::logged_in_client, Result};
205
206 #[allow(clippy::await_holding_lock)]
207 #[async_test]
208 async fn test_sliding_sync_can_be_stored_and_restored() -> Result<()> {
209 let client = logged_in_client(Some("https://foo.bar".to_owned())).await;
210
211 let store = client.state_store();
212
213 let sync_id = "test-sync-id";
214 let storage_key = format_storage_key_prefix(sync_id, client.user_id().unwrap());
215
216 assert!(store
218 .get_custom_value(
219 format_storage_key_for_sliding_sync_list(&storage_key, "list_foo").as_bytes()
220 )
221 .await?
222 .is_none());
223
224 assert!(store
225 .get_custom_value(
226 format_storage_key_for_sliding_sync_list(&storage_key, "list_bar").as_bytes()
227 )
228 .await?
229 .is_none());
230
231 let storage_key = {
233 let sliding_sync = client
234 .sliding_sync(sync_id)?
235 .add_cached_list(SlidingSyncList::builder("list_foo"))
236 .await?
237 .add_list(SlidingSyncList::builder("list_bar"))
238 .build()
239 .await?;
240
241 {
243 let lists = sliding_sync.inner.lists.write().await;
244
245 let list_foo = lists.get("list_foo").unwrap();
246 list_foo.set_maximum_number_of_rooms(Some(42));
247
248 let list_bar = lists.get("list_bar").unwrap();
249 list_bar.set_maximum_number_of_rooms(Some(1337));
250 }
251
252 let position_guard = sliding_sync.inner.position.lock().await;
253 assert!(sliding_sync.cache_to_storage(&position_guard).await.is_ok());
254
255 storage_key
256 };
257
258 assert!(store
260 .get_custom_value(
261 format_storage_key_for_sliding_sync_list(&storage_key, "list_foo").as_bytes()
262 )
263 .await?
264 .is_some());
265
266 assert!(store
268 .get_custom_value(
269 format_storage_key_for_sliding_sync_list(&storage_key, "list_bar").as_bytes()
270 )
271 .await?
272 .is_none());
273
274 let max_number_of_room_stream = Arc::new(RwLock::new(None));
276 let cloned_stream = max_number_of_room_stream.clone();
277 let sliding_sync = client
278 .sliding_sync(sync_id)?
279 .add_cached_list(SlidingSyncList::builder("list_foo").once_built(move |list| {
280 assert_eq!(list.maximum_number_of_rooms(), None);
282
283 let mut stream = cloned_stream.write().unwrap();
284 *stream = Some(list.maximum_number_of_rooms_stream());
285 list
286 }))
287 .await?
288 .add_list(SlidingSyncList::builder("list_bar"))
289 .build()
290 .await?;
291
292 {
294 let lists = sliding_sync.inner.lists.read().await;
295
296 let list_foo = lists.get("list_foo").unwrap();
298 assert_eq!(list_foo.maximum_number_of_rooms(), Some(42));
299
300 let list_bar = lists.get("list_bar").unwrap();
302 assert_eq!(list_bar.maximum_number_of_rooms(), None);
303 }
304
305 {
308 let mut stream =
309 max_number_of_room_stream.write().unwrap().take().expect("stream must be set");
310 let initial_max_number_of_rooms =
311 stream.next().await.expect("stream must have emitted something");
312 assert_eq!(initial_max_number_of_rooms, Some(42));
313 }
314
315 Ok(())
316 }
317
318 #[cfg(feature = "e2e-encryption")]
319 #[async_test]
320 async fn test_sliding_sync_high_level_cache_and_restore() -> Result<()> {
321 let client = logged_in_client(Some("https://foo.bar".to_owned())).await;
322
323 let sync_id = "test-sync-id";
324 let storage_key_prefix = format_storage_key_prefix(sync_id, client.user_id().unwrap());
325 let full_storage_key = format_storage_key_for_sliding_sync(&storage_key_prefix);
326 let sliding_sync = client.sliding_sync(sync_id)?.build().await?;
327
328 if let Some(olm_machine) = &*client.base_client().olm_machine().await {
330 let store = olm_machine.store();
331 assert!(store.next_batch_token().await?.is_none());
332 }
333
334 let state_store = client.state_store();
335 assert!(state_store.get_custom_value(full_storage_key.as_bytes()).await?.is_none());
336
337 let pos = "pos".to_owned();
339 {
340 let mut position_guard = sliding_sync.inner.position.lock().await;
341 position_guard.pos = Some(pos.clone());
342
343 store_sliding_sync_state(&sliding_sync, &position_guard).await?;
345 }
346
347 drop(sliding_sync);
349
350 let restored_fields = restore_sliding_sync_state(&client, &storage_key_prefix)
351 .await?
352 .expect("must have restored sliding sync fields");
353
354 assert_eq!(restored_fields.pos.unwrap(), pos);
356
357 {
362 let olm_machine = client.base_client().olm_machine().await;
363 let olm_machine = olm_machine.as_ref().unwrap();
364 assert!(olm_machine.store().next_batch_token().await?.is_none());
365 }
366
367 Ok(())
368 }
369}