1use std::{
16 collections::BTreeMap,
17 sync::{
18 Arc,
19 atomic::{self, AtomicBool},
20 },
21};
22
23use matrix_sdk_base::{
24 StateStoreDataKey, StateStoreDataValue, ThreadSubscriptionCatchupToken,
25 executor::AbortOnDrop,
26 store::{StoredThreadSubscription, ThreadSubscriptionStatus},
27};
28use matrix_sdk_common::executor::spawn;
29use once_cell::sync::OnceCell;
30use ruma::{
31 OwnedEventId, OwnedRoomId,
32 api::client::threads::get_thread_subscriptions_changes::unstable::{
33 ThreadSubscription, ThreadUnsubscription,
34 },
35 assign,
36};
37use tokio::sync::{
38 Mutex, OwnedMutexGuard,
39 mpsc::{Receiver, Sender, channel},
40};
41use tracing::{instrument, trace, warn};
42
43use crate::{Client, Result, client::WeakClient};
44
45struct GuardedStoreAccess {
46 _mutex: OwnedMutexGuard<()>,
47 client: Client,
48 is_outdated: Arc<AtomicBool>,
49}
50
51impl GuardedStoreAccess {
52 async fn load_catchup_tokens(&self) -> Result<Option<Vec<ThreadSubscriptionCatchupToken>>> {
56 let loaded = self
57 .client
58 .state_store()
59 .get_kv_data(StateStoreDataKey::ThreadSubscriptionsCatchupTokens)
60 .await?;
61
62 match loaded {
63 Some(data) => {
64 if let Some(tokens) = data.into_thread_subscriptions_catchup_tokens() {
65 if tokens.is_empty() {
67 self.save_catchup_tokens(tokens).await?;
68 Ok(None)
69 } else {
70 Ok(Some(tokens))
71 }
72 } else {
73 warn!(
74 "invalid data in thread subscriptions catchup tokens state store k/v entry"
75 );
76 Ok(None)
77 }
78 }
79
80 None => Ok(None),
81 }
82 }
83
84 #[instrument(skip_all, fields(num_tokens = tokens.len()))]
88 async fn save_catchup_tokens(
89 &self,
90 tokens: Vec<ThreadSubscriptionCatchupToken>,
91 ) -> Result<bool> {
92 let store = self.client.state_store();
93 let is_empty = if tokens.is_empty() {
94 store.remove_kv_data(StateStoreDataKey::ThreadSubscriptionsCatchupTokens).await?;
95
96 trace!("Marking thread subscriptions as not outdated \\o/");
97 self.is_outdated.store(false, atomic::Ordering::SeqCst);
98 true
99 } else {
100 store
101 .set_kv_data(
102 StateStoreDataKey::ThreadSubscriptionsCatchupTokens,
103 StateStoreDataValue::ThreadSubscriptionsCatchupTokens(tokens),
104 )
105 .await?;
106
107 trace!("Marking thread subscriptions as outdated.");
108 self.is_outdated.store(true, atomic::Ordering::SeqCst);
109 false
110 };
111 Ok(is_empty)
112 }
113}
114
115pub struct ThreadSubscriptionCatchup {
116 _task: OnceCell<AbortOnDrop<()>>,
118
119 is_outdated: Arc<AtomicBool>,
122
123 client: WeakClient,
125
126 ping_sender: Sender<()>,
129
130 uniq_mutex: Arc<Mutex<()>>,
133}
134
135impl ThreadSubscriptionCatchup {
136 pub fn new(client: Client) -> Arc<Self> {
137 let is_outdated = Arc::new(AtomicBool::new(true));
138
139 let weak_client = WeakClient::from_client(&client);
140
141 let (ping_sender, ping_receiver) = channel(8);
142
143 let uniq_mutex = Arc::new(Mutex::new(()));
144
145 let this = Arc::new(Self {
146 _task: OnceCell::new(),
147 is_outdated,
148 client: weak_client,
149 ping_sender,
150 uniq_mutex,
151 });
152
153 if client.enabled_thread_subscriptions() {
156 let _ = this._task.get_or_init(|| {
157 AbortOnDrop::new(spawn(Self::thread_subscriptions_catchup_task(
158 this.clone(),
159 ping_receiver,
160 )))
161 });
162 }
163
164 this
165 }
166
167 pub(crate) fn is_outdated(&self) -> bool {
170 self.is_outdated.load(atomic::Ordering::SeqCst)
171 }
172
173 #[instrument(skip_all)]
176 pub(crate) async fn sync_subscriptions(
177 &self,
178 subscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadSubscription>>,
179 unsubscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadUnsubscription>>,
180 token: Option<ThreadSubscriptionCatchupToken>,
181 ) -> Result<()> {
182 let Some(guard) = self.lock().await else {
183 return Ok(());
185 };
186 self.save_catchup_token(&guard, token).await?;
187 self.store_subscriptions(&guard, subscribed, unsubscribed).await?;
188 Ok(())
189 }
190
191 async fn store_subscriptions(
192 &self,
193 guard: &GuardedStoreAccess,
194 subscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadSubscription>>,
195 unsubscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadUnsubscription>>,
196 ) -> Result<()> {
197 if subscribed.is_empty() && unsubscribed.is_empty() {
198 return Ok(());
200 }
201
202 trace!(
203 "saving {} new subscriptions and {} unsubscriptions",
204 subscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
205 unsubscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
206 );
207
208 for (room_id, room_map) in unsubscribed {
210 for (event_id, thread_sub) in room_map {
211 guard
212 .client
213 .state_store()
214 .upsert_thread_subscription(
215 &room_id,
216 &event_id,
217 StoredThreadSubscription {
218 status: ThreadSubscriptionStatus::Unsubscribed,
219 bump_stamp: Some(thread_sub.bump_stamp.into()),
220 },
221 )
222 .await?;
223 }
224 }
225
226 for (room_id, room_map) in subscribed {
228 for (event_id, thread_sub) in room_map {
229 guard
230 .client
231 .state_store()
232 .upsert_thread_subscription(
233 &room_id,
234 &event_id,
235 StoredThreadSubscription {
236 status: ThreadSubscriptionStatus::Subscribed {
237 automatic: thread_sub.automatic,
238 },
239 bump_stamp: Some(thread_sub.bump_stamp.into()),
240 },
241 )
242 .await?;
243 }
244 }
245
246 Ok(())
247 }
248
249 async fn lock(&self) -> Option<GuardedStoreAccess> {
252 let client = self.client.get()?;
253 let mutex_guard = self.uniq_mutex.clone().lock_owned().await;
254 Some(GuardedStoreAccess {
255 _mutex: mutex_guard,
256 client,
257 is_outdated: self.is_outdated.clone(),
258 })
259 }
260
261 async fn save_catchup_token(
263 &self,
264 guard: &GuardedStoreAccess,
265 token: Option<ThreadSubscriptionCatchupToken>,
266 ) -> Result<()> {
267 let mut tokens = guard.load_catchup_tokens().await?.unwrap_or_default();
270
271 if let Some(token) = token {
272 trace!(?token, "Saving catchup token");
273 tokens.push(token);
274 } else {
275 trace!("No catchup token to save");
276 }
277
278 let is_token_list_empty = guard.save_catchup_tokens(tokens).await?;
279
280 if !is_token_list_empty {
282 let _ = self.ping_sender.send(()).await;
283 }
284
285 Ok(())
286 }
287
288 #[instrument(skip_all)]
304 async fn thread_subscriptions_catchup_task(this: Arc<Self>, mut ping_receiver: Receiver<()>) {
305 loop {
306 let Some(guard) = this.lock().await else {
308 return;
310 };
311
312 let store_tokens = match guard.load_catchup_tokens().await {
313 Ok(tokens) => tokens,
314 Err(err) => {
315 warn!("Failed to load thread subscriptions catchup tokens: {err}");
316 continue;
317 }
318 };
319
320 let Some(mut tokens) = store_tokens else {
321 drop(guard);
323
324 trace!("Waiting for an explicit wake up to process future thread subscriptions");
326
327 if let Some(()) = ping_receiver.recv().await {
328 trace!("Woke up!");
329 continue;
330 }
331
332 break;
334 };
335
336 let last = tokens.pop().expect("must be set per `load_catchup_tokens` contract");
338
339 let client = guard.client.clone();
341 drop(guard);
342
343 let req = assign!(ruma::api::client::threads::get_thread_subscriptions_changes::unstable::Request::new(), {
345 from: Some(last.from.clone()),
346 to: last.to.clone(),
347 });
348
349 match client.send(req).await {
350 Ok(resp) => {
351 let guard = this
352 .lock()
353 .await
354 .expect("a client instance is alive, so the locking should not fail");
355
356 if let Err(err) =
357 this.store_subscriptions(&guard, resp.subscribed, resp.unsubscribed).await
358 {
359 warn!("Failed to store caught up thread subscriptions: {err}");
360 continue;
361 }
362
363 let mut tokens = match guard.load_catchup_tokens().await {
366 Ok(tokens) => tokens.unwrap_or_default(),
367 Err(err) => {
368 warn!("Failed to load thread subscriptions catchup tokens: {err}");
369 continue;
370 }
371 };
372
373 let Some(index) = tokens.iter().position(|t| *t == last) else {
374 warn!("Thread subscriptions catchup token disappeared while processing it");
375 continue;
376 };
377
378 if let Some(next_batch) = resp.end {
379 tokens[index] =
382 ThreadSubscriptionCatchupToken { from: next_batch, to: last.to };
383 } else {
384 tokens.remove(index);
386 }
387
388 if let Err(err) = guard.save_catchup_tokens(tokens).await {
389 warn!("Failed to save updated thread subscriptions catchup tokens: {err}");
390 }
391 }
392
393 Err(err) => {
394 warn!("Failed to catch up thread subscriptions: {err}");
395 }
396 }
397 }
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use std::ops::Not as _;
404
405 use matrix_sdk_base::ThreadSubscriptionCatchupToken;
406 use matrix_sdk_test::async_test;
407
408 use crate::test_utils::client::MockClientBuilder;
409
410 #[async_test]
411 async fn test_load_save_catchup_tokens() {
412 let client = MockClientBuilder::new(None).build().await;
413
414 let tsc = client.thread_subscription_catchup();
415
416 let guard = tsc.lock().await.unwrap();
418 assert!(guard.load_catchup_tokens().await.unwrap().is_none());
419 assert!(tsc.is_outdated());
420
421 let token =
423 ThreadSubscriptionCatchupToken { from: "from".to_owned(), to: Some("to".to_owned()) };
424 guard.save_catchup_tokens(vec![token.clone()]).await.unwrap();
425
426 let tokens = guard.load_catchup_tokens().await.unwrap();
428 assert_eq!(tokens, Some(vec![token]));
429
430 assert!(tsc.is_outdated());
432
433 guard.save_catchup_tokens(vec![]).await.unwrap();
435
436 assert!(guard.load_catchup_tokens().await.unwrap().is_none());
438
439 assert!(tsc.is_outdated().not());
441 }
442}