1use std::{
16 collections::BTreeMap,
17 sync::{
18 Arc, OnceLock,
19 atomic::{self, AtomicBool},
20 },
21};
22
23use matrix_sdk_base::{
24 StateStoreDataKey, StateStoreDataValue, ThreadSubscriptionCatchupToken,
25 store::{StoredThreadSubscription, ThreadSubscriptionStatus},
26 task_monitor::BackgroundTaskHandle,
27};
28use ruma::{
29 EventId, OwnedEventId, OwnedRoomId, RoomId,
30 api::client::threads::get_thread_subscriptions_changes::unstable::{
31 ThreadSubscription, ThreadUnsubscription,
32 },
33 assign,
34};
35use tokio::sync::{
36 Mutex, OwnedMutexGuard,
37 mpsc::{Receiver, Sender, channel},
38};
39use tracing::{debug, instrument, trace, warn};
40
41use crate::{Client, Result, client::WeakClient};
42
43struct GuardedStoreAccess {
44 _mutex: OwnedMutexGuard<()>,
45 client: Client,
46 is_outdated: Arc<AtomicBool>,
47}
48
49impl GuardedStoreAccess {
50 async fn load_catchup_tokens(&self) -> Result<Option<Vec<ThreadSubscriptionCatchupToken>>> {
54 let loaded = self
55 .client
56 .state_store()
57 .get_kv_data(StateStoreDataKey::ThreadSubscriptionsCatchupTokens)
58 .await?;
59
60 match loaded {
61 Some(data) => {
62 if let Some(tokens) = data.into_thread_subscriptions_catchup_tokens() {
63 if tokens.is_empty() {
65 self.save_catchup_tokens(tokens).await?;
66 Ok(None)
67 } else {
68 Ok(Some(tokens))
69 }
70 } else {
71 warn!(
72 "invalid data in thread subscriptions catchup tokens state store k/v entry"
73 );
74 Ok(None)
75 }
76 }
77
78 None => Ok(None),
79 }
80 }
81
82 #[instrument(skip_all, fields(num_tokens = tokens.len()))]
86 async fn save_catchup_tokens(
87 &self,
88 tokens: Vec<ThreadSubscriptionCatchupToken>,
89 ) -> Result<bool> {
90 let store = self.client.state_store();
91 let is_empty = if tokens.is_empty() {
92 store.remove_kv_data(StateStoreDataKey::ThreadSubscriptionsCatchupTokens).await?;
93
94 trace!("Marking thread subscriptions as not outdated \\o/");
95 self.is_outdated.store(false, atomic::Ordering::SeqCst);
96 true
97 } else {
98 store
99 .set_kv_data(
100 StateStoreDataKey::ThreadSubscriptionsCatchupTokens,
101 StateStoreDataValue::ThreadSubscriptionsCatchupTokens(tokens),
102 )
103 .await?;
104
105 trace!("Marking thread subscriptions as outdated.");
106 self.is_outdated.store(true, atomic::Ordering::SeqCst);
107 false
108 };
109 Ok(is_empty)
110 }
111}
112
113pub struct ThreadSubscriptionCatchup {
114 _task: OnceLock<BackgroundTaskHandle>,
116
117 is_outdated: Arc<AtomicBool>,
120
121 client: WeakClient,
123
124 ping_sender: Sender<()>,
127
128 uniq_mutex: Arc<Mutex<()>>,
131}
132
133impl ThreadSubscriptionCatchup {
134 pub async fn new(client: Client) -> Arc<Self> {
135 let is_outdated = Arc::new(AtomicBool::new(true));
136 let weak_client = WeakClient::from_client(&client);
137 let (ping_sender, ping_receiver) = channel(8);
138 let uniq_mutex = Arc::new(Mutex::new(()));
139
140 let this = Arc::new(Self {
141 _task: OnceLock::new(),
142 is_outdated,
143 client: weak_client.clone(),
144 ping_sender,
145 uniq_mutex,
146 });
147
148 match client.enabled_thread_subscriptions().await {
151 Ok(true) => {
152 let _ = this._task.get_or_init(|| {
153 let that = this.clone();
154
155 client
156 .task_monitor()
157 .spawn_infinite_task("client::thread_subscriptions_catchup", async move {
158 Self::thread_subscriptions_catchup_task(that, ping_receiver).await;
159 })
160 .abort_on_drop()
161 });
162 }
163
164 Ok(false) => {
165 debug!("Thread subscriptions catchup not enabled, not starting the catchup task");
166 }
167
168 Err(err) => {
169 warn!("Failed to check if thread subscriptions catchup is enabled: {err}");
170 }
171 }
172
173 this
174 }
175
176 pub(crate) fn is_outdated(&self) -> bool {
179 self.is_outdated.load(atomic::Ordering::SeqCst)
180 }
181
182 #[instrument(skip_all)]
185 pub(crate) async fn sync_subscriptions(
186 &self,
187 subscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadSubscription>>,
188 unsubscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadUnsubscription>>,
189 token: Option<ThreadSubscriptionCatchupToken>,
190 ) -> Result<()> {
191 let updates = build_subscription_updates(&subscribed, &unsubscribed);
193 let Some(guard) = self.lock().await else {
194 return Ok(());
196 };
197 self.save_catchup_token(&guard, token).await?;
198 if !updates.is_empty() {
199 trace!(
200 "saving {} new subscriptions and {} unsubscriptions",
201 subscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
202 unsubscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
203 );
204 guard.client.state_store().upsert_thread_subscriptions(updates).await?;
205 }
206 Ok(())
207 }
208
209 async fn lock(&self) -> Option<GuardedStoreAccess> {
212 let client = self.client.get()?;
213 let mutex_guard = self.uniq_mutex.clone().lock_owned().await;
214 Some(GuardedStoreAccess {
215 _mutex: mutex_guard,
216 client,
217 is_outdated: self.is_outdated.clone(),
218 })
219 }
220
221 async fn save_catchup_token(
223 &self,
224 guard: &GuardedStoreAccess,
225 token: Option<ThreadSubscriptionCatchupToken>,
226 ) -> Result<()> {
227 let mut tokens = guard.load_catchup_tokens().await?.unwrap_or_default();
230
231 if let Some(token) = token {
232 trace!(?token, "Saving catchup token");
233 tokens.push(token);
234 } else {
235 trace!("No catchup token to save");
236 }
237
238 let is_token_list_empty = guard.save_catchup_tokens(tokens).await?;
239
240 if !is_token_list_empty {
242 let _ = self.ping_sender.send(()).await;
243 }
244
245 Ok(())
246 }
247
248 #[instrument(skip_all)]
264 async fn thread_subscriptions_catchup_task(this: Arc<Self>, mut ping_receiver: Receiver<()>) {
265 loop {
266 let Some(guard) = this.lock().await else {
268 return;
270 };
271
272 let store_tokens = match guard.load_catchup_tokens().await {
273 Ok(tokens) => tokens,
274 Err(err) => {
275 warn!("Failed to load thread subscriptions catchup tokens: {err}");
276 continue;
277 }
278 };
279
280 let Some(mut tokens) = store_tokens else {
281 drop(guard);
283
284 trace!("Waiting for an explicit wake up to process future thread subscriptions");
286
287 if let Some(()) = ping_receiver.recv().await {
288 trace!("Woke up!");
289 continue;
290 }
291
292 break;
294 };
295
296 let last = tokens.pop().expect("must be set per `load_catchup_tokens` contract");
298
299 let client = guard.client.clone();
301 drop(guard);
302
303 let req = assign!(ruma::api::client::threads::get_thread_subscriptions_changes::unstable::Request::new(), {
305 from: Some(last.from.clone()),
306 to: last.to.clone(),
307 });
308
309 match client.send(req).await {
310 Ok(resp) => {
311 let updates = build_subscription_updates(&resp.subscribed, &resp.unsubscribed);
313
314 let guard = this
315 .lock()
316 .await
317 .expect("a client instance is alive, so the locking should not fail");
318
319 if !updates.is_empty() {
320 trace!(
321 "saving {} new subscriptions and {} unsubscriptions",
322 resp.subscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
323 resp.unsubscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
324 );
325
326 if let Err(err) =
327 guard.client.state_store().upsert_thread_subscriptions(updates).await
328 {
329 warn!("Failed to store caught up thread subscriptions: {err}");
330 continue;
331 }
332 }
333
334 let mut tokens = match guard.load_catchup_tokens().await {
337 Ok(tokens) => tokens.unwrap_or_default(),
338 Err(err) => {
339 warn!("Failed to load thread subscriptions catchup tokens: {err}");
340 continue;
341 }
342 };
343
344 let Some(index) = tokens.iter().position(|t| *t == last) else {
345 warn!("Thread subscriptions catchup token disappeared while processing it");
346 continue;
347 };
348
349 if let Some(next_batch) = resp.end {
350 tokens[index] =
353 ThreadSubscriptionCatchupToken { from: next_batch, to: last.to };
354 } else {
355 tokens.remove(index);
357 }
358
359 if let Err(err) = guard.save_catchup_tokens(tokens).await {
360 warn!("Failed to save updated thread subscriptions catchup tokens: {err}");
361 }
362 }
363
364 Err(err) => {
365 warn!("Failed to catch up thread subscriptions: {err}");
366 }
367 }
368 }
369 }
370}
371
372fn build_subscription_updates<'a>(
374 subscribed: &'a BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadSubscription>>,
375 unsubscribed: &'a BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadUnsubscription>>,
376) -> Vec<(&'a RoomId, &'a EventId, StoredThreadSubscription)> {
377 let mut updates: Vec<(&RoomId, &EventId, StoredThreadSubscription)> =
378 Vec::with_capacity(unsubscribed.len() + subscribed.len());
379
380 for (room_id, room_map) in unsubscribed {
382 for (event_id, thread_sub) in room_map {
383 updates.push((
384 room_id,
385 event_id,
386 StoredThreadSubscription {
387 status: ThreadSubscriptionStatus::Unsubscribed,
388 bump_stamp: Some(thread_sub.bump_stamp.into()),
389 },
390 ));
391 }
392 }
393
394 for (room_id, room_map) in subscribed {
396 for (event_id, thread_sub) in room_map {
397 updates.push((
398 room_id,
399 event_id,
400 StoredThreadSubscription {
401 status: ThreadSubscriptionStatus::Subscribed {
402 automatic: thread_sub.automatic,
403 },
404 bump_stamp: Some(thread_sub.bump_stamp.into()),
405 },
406 ));
407 }
408 }
409
410 updates
411}
412
413#[cfg(test)]
414mod tests {
415 use std::ops::Not as _;
416
417 use matrix_sdk_base::ThreadSubscriptionCatchupToken;
418 use matrix_sdk_test::async_test;
419
420 use crate::test_utils::client::MockClientBuilder;
421
422 #[async_test]
423 async fn test_load_save_catchup_tokens() {
424 let client = MockClientBuilder::new(None).build().await;
425
426 let tsc = client.thread_subscription_catchup();
427
428 let guard = tsc.lock().await.unwrap();
430 assert!(guard.load_catchup_tokens().await.unwrap().is_none());
431 assert!(tsc.is_outdated());
432
433 let token =
435 ThreadSubscriptionCatchupToken { from: "from".to_owned(), to: Some("to".to_owned()) };
436 guard.save_catchup_tokens(vec![token.clone()]).await.unwrap();
437
438 let tokens = guard.load_catchup_tokens().await.unwrap();
440 assert_eq!(tokens, Some(vec![token]));
441
442 assert!(tsc.is_outdated());
444
445 guard.save_catchup_tokens(vec![]).await.unwrap();
447
448 assert!(guard.load_catchup_tokens().await.unwrap().is_none());
450
451 assert!(tsc.is_outdated().not());
453 }
454}
455
456#[cfg(all(test, not(target_family = "wasm")))]
457mod timed_tests {
458 use std::time::Duration;
459
460 use matrix_sdk_base::{ThreadingSupport, sleep::sleep};
461 use matrix_sdk_test::async_test;
462 use tokio::task::yield_now;
463
464 use crate::{client::WeakClient, test_utils::mocks::MatrixMockServer};
465
466 #[async_test]
467 async fn test_issue_6573_client_can_drop_thread_subscriptions_task() {
468 let server = MatrixMockServer::new().await;
469 server.mock_versions().with_thread_subscriptions().ok().mount().await;
470
471 let client = server
472 .client_builder()
473 .no_server_versions()
474 .on_builder(|builder| {
475 builder
476 .with_threading_support(ThreadingSupport::Enabled { with_subscriptions: true })
477 })
478 .build()
479 .await;
480
481 let tsc = client.thread_subscription_catchup();
482
483 yield_now().await;
485
486 assert!(tsc._task.get().is_some());
488
489 let weak_client = WeakClient::from_client(&client);
491
492 drop(client);
494
495 sleep(Duration::from_secs(2)).await;
497
498 assert_eq!(weak_client.strong_count(), 0);
500 }
501}