1use std::{
16 collections::BTreeMap,
17 sync::{
18 Arc, OnceLock,
19 atomic::{self, AtomicBool},
20 },
21};
22
23use matrix_sdk_base::{
24 StateStoreDataKey, StateStoreDataValue, ThreadSubscriptionCatchupToken,
25 store::{StoredThreadSubscription, ThreadSubscriptionStatus},
26 task_monitor::BackgroundTaskHandle,
27};
28use ruma::{
29 EventId, OwnedEventId, OwnedRoomId, RoomId,
30 api::client::threads::get_thread_subscriptions_changes::unstable::{
31 ThreadSubscription, ThreadUnsubscription,
32 },
33 assign,
34};
35use tokio::sync::{
36 Mutex, OwnedMutexGuard,
37 mpsc::{Receiver, Sender, channel},
38};
39use tracing::{debug, instrument, trace, warn};
40
41use crate::{Client, Result, client::WeakClient};
42
43struct GuardedStoreAccess {
44 _mutex: OwnedMutexGuard<()>,
45 client: Client,
46 is_outdated: Arc<AtomicBool>,
47}
48
49impl GuardedStoreAccess {
50 async fn load_catchup_tokens(&self) -> Result<Option<Vec<ThreadSubscriptionCatchupToken>>> {
54 let loaded = self
55 .client
56 .state_store()
57 .get_kv_data(StateStoreDataKey::ThreadSubscriptionsCatchupTokens)
58 .await?;
59
60 match loaded {
61 Some(data) => {
62 if let Some(tokens) = data.into_thread_subscriptions_catchup_tokens() {
63 if tokens.is_empty() {
65 self.save_catchup_tokens(tokens).await?;
66 Ok(None)
67 } else {
68 Ok(Some(tokens))
69 }
70 } else {
71 warn!(
72 "invalid data in thread subscriptions catchup tokens state store k/v entry"
73 );
74 Ok(None)
75 }
76 }
77
78 None => Ok(None),
79 }
80 }
81
82 #[instrument(skip_all, fields(num_tokens = tokens.len()))]
86 async fn save_catchup_tokens(
87 &self,
88 tokens: Vec<ThreadSubscriptionCatchupToken>,
89 ) -> Result<bool> {
90 let store = self.client.state_store();
91 let is_empty = if tokens.is_empty() {
92 store.remove_kv_data(StateStoreDataKey::ThreadSubscriptionsCatchupTokens).await?;
93
94 trace!("Marking thread subscriptions as not outdated \\o/");
95 self.is_outdated.store(false, atomic::Ordering::SeqCst);
96 true
97 } else {
98 store
99 .set_kv_data(
100 StateStoreDataKey::ThreadSubscriptionsCatchupTokens,
101 StateStoreDataValue::ThreadSubscriptionsCatchupTokens(tokens),
102 )
103 .await?;
104
105 trace!("Marking thread subscriptions as outdated.");
106 self.is_outdated.store(true, atomic::Ordering::SeqCst);
107 false
108 };
109 Ok(is_empty)
110 }
111}
112
113pub struct ThreadSubscriptionCatchup {
114 _task: OnceLock<BackgroundTaskHandle>,
116
117 is_outdated: Arc<AtomicBool>,
120
121 client: WeakClient,
123
124 ping_sender: Sender<()>,
127
128 uniq_mutex: Arc<Mutex<()>>,
131}
132
133impl ThreadSubscriptionCatchup {
134 pub fn new(client: Client) -> Arc<Self> {
135 let is_outdated = Arc::new(AtomicBool::new(true));
136
137 let weak_client = WeakClient::from_client(&client);
138
139 let (ping_sender, ping_receiver) = channel(8);
140
141 let uniq_mutex = Arc::new(Mutex::new(()));
142
143 let this = Arc::new(Self {
144 _task: OnceLock::new(),
145 is_outdated,
146 client: weak_client,
147 ping_sender,
148 uniq_mutex,
149 });
150
151 let _ = this._task.get_or_init(|| {
154 let that = this.clone();
155 let client_clone = client.clone();
156
157 client.task_monitor().spawn_infinite_task("client::thread_subscriptions_catchup", async move {
158 match client_clone.enabled_thread_subscriptions().await {
159 Ok(enabled) => {
160 if !enabled {
161 debug!("Thread subscriptions catchup not enabled, not starting the catchup task");
162 return;
163 }
164 }
165
166 Err(err) => {
167 warn!("Failed to check if thread subscriptions catchup is enabled: {err}");
168 return;
169 }
170 }
171
172 Self::thread_subscriptions_catchup_task(
173 that,
174 ping_receiver,
175 ).await
176 }).abort_on_drop()
177 });
178
179 this
180 }
181
182 pub(crate) fn is_outdated(&self) -> bool {
185 self.is_outdated.load(atomic::Ordering::SeqCst)
186 }
187
188 #[instrument(skip_all)]
191 pub(crate) async fn sync_subscriptions(
192 &self,
193 subscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadSubscription>>,
194 unsubscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadUnsubscription>>,
195 token: Option<ThreadSubscriptionCatchupToken>,
196 ) -> Result<()> {
197 let updates = build_subscription_updates(&subscribed, &unsubscribed);
199 let Some(guard) = self.lock().await else {
200 return Ok(());
202 };
203 self.save_catchup_token(&guard, token).await?;
204 if !updates.is_empty() {
205 trace!(
206 "saving {} new subscriptions and {} unsubscriptions",
207 subscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
208 unsubscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
209 );
210 guard.client.state_store().upsert_thread_subscriptions(updates).await?;
211 }
212 Ok(())
213 }
214
215 async fn lock(&self) -> Option<GuardedStoreAccess> {
218 let client = self.client.get()?;
219 let mutex_guard = self.uniq_mutex.clone().lock_owned().await;
220 Some(GuardedStoreAccess {
221 _mutex: mutex_guard,
222 client,
223 is_outdated: self.is_outdated.clone(),
224 })
225 }
226
227 async fn save_catchup_token(
229 &self,
230 guard: &GuardedStoreAccess,
231 token: Option<ThreadSubscriptionCatchupToken>,
232 ) -> Result<()> {
233 let mut tokens = guard.load_catchup_tokens().await?.unwrap_or_default();
236
237 if let Some(token) = token {
238 trace!(?token, "Saving catchup token");
239 tokens.push(token);
240 } else {
241 trace!("No catchup token to save");
242 }
243
244 let is_token_list_empty = guard.save_catchup_tokens(tokens).await?;
245
246 if !is_token_list_empty {
248 let _ = self.ping_sender.send(()).await;
249 }
250
251 Ok(())
252 }
253
254 #[instrument(skip_all)]
270 async fn thread_subscriptions_catchup_task(this: Arc<Self>, mut ping_receiver: Receiver<()>) {
271 loop {
272 let Some(guard) = this.lock().await else {
274 return;
276 };
277
278 let store_tokens = match guard.load_catchup_tokens().await {
279 Ok(tokens) => tokens,
280 Err(err) => {
281 warn!("Failed to load thread subscriptions catchup tokens: {err}");
282 continue;
283 }
284 };
285
286 let Some(mut tokens) = store_tokens else {
287 drop(guard);
289
290 trace!("Waiting for an explicit wake up to process future thread subscriptions");
292
293 if let Some(()) = ping_receiver.recv().await {
294 trace!("Woke up!");
295 continue;
296 }
297
298 break;
300 };
301
302 let last = tokens.pop().expect("must be set per `load_catchup_tokens` contract");
304
305 let client = guard.client.clone();
307 drop(guard);
308
309 let req = assign!(ruma::api::client::threads::get_thread_subscriptions_changes::unstable::Request::new(), {
311 from: Some(last.from.clone()),
312 to: last.to.clone(),
313 });
314
315 match client.send(req).await {
316 Ok(resp) => {
317 let updates = build_subscription_updates(&resp.subscribed, &resp.unsubscribed);
319
320 let guard = this
321 .lock()
322 .await
323 .expect("a client instance is alive, so the locking should not fail");
324
325 if !updates.is_empty() {
326 trace!(
327 "saving {} new subscriptions and {} unsubscriptions",
328 resp.subscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
329 resp.unsubscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
330 );
331
332 if let Err(err) =
333 guard.client.state_store().upsert_thread_subscriptions(updates).await
334 {
335 warn!("Failed to store caught up thread subscriptions: {err}");
336 continue;
337 }
338 }
339
340 let mut tokens = match guard.load_catchup_tokens().await {
343 Ok(tokens) => tokens.unwrap_or_default(),
344 Err(err) => {
345 warn!("Failed to load thread subscriptions catchup tokens: {err}");
346 continue;
347 }
348 };
349
350 let Some(index) = tokens.iter().position(|t| *t == last) else {
351 warn!("Thread subscriptions catchup token disappeared while processing it");
352 continue;
353 };
354
355 if let Some(next_batch) = resp.end {
356 tokens[index] =
359 ThreadSubscriptionCatchupToken { from: next_batch, to: last.to };
360 } else {
361 tokens.remove(index);
363 }
364
365 if let Err(err) = guard.save_catchup_tokens(tokens).await {
366 warn!("Failed to save updated thread subscriptions catchup tokens: {err}");
367 }
368 }
369
370 Err(err) => {
371 warn!("Failed to catch up thread subscriptions: {err}");
372 }
373 }
374 }
375 }
376}
377
378fn build_subscription_updates<'a>(
380 subscribed: &'a BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadSubscription>>,
381 unsubscribed: &'a BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadUnsubscription>>,
382) -> Vec<(&'a RoomId, &'a EventId, StoredThreadSubscription)> {
383 let mut updates: Vec<(&RoomId, &EventId, StoredThreadSubscription)> =
384 Vec::with_capacity(unsubscribed.len() + subscribed.len());
385
386 for (room_id, room_map) in unsubscribed {
388 for (event_id, thread_sub) in room_map {
389 updates.push((
390 room_id,
391 event_id,
392 StoredThreadSubscription {
393 status: ThreadSubscriptionStatus::Unsubscribed,
394 bump_stamp: Some(thread_sub.bump_stamp.into()),
395 },
396 ));
397 }
398 }
399
400 for (room_id, room_map) in subscribed {
402 for (event_id, thread_sub) in room_map {
403 updates.push((
404 room_id,
405 event_id,
406 StoredThreadSubscription {
407 status: ThreadSubscriptionStatus::Subscribed {
408 automatic: thread_sub.automatic,
409 },
410 bump_stamp: Some(thread_sub.bump_stamp.into()),
411 },
412 ));
413 }
414 }
415
416 updates
417}
418
419#[cfg(test)]
420mod tests {
421 use std::ops::Not as _;
422
423 use matrix_sdk_base::ThreadSubscriptionCatchupToken;
424 use matrix_sdk_test::async_test;
425
426 use crate::test_utils::client::MockClientBuilder;
427
428 #[async_test]
429 async fn test_load_save_catchup_tokens() {
430 let client = MockClientBuilder::new(None).build().await;
431
432 let tsc = client.thread_subscription_catchup();
433
434 let guard = tsc.lock().await.unwrap();
436 assert!(guard.load_catchup_tokens().await.unwrap().is_none());
437 assert!(tsc.is_outdated());
438
439 let token =
441 ThreadSubscriptionCatchupToken { from: "from".to_owned(), to: Some("to".to_owned()) };
442 guard.save_catchup_tokens(vec![token.clone()]).await.unwrap();
443
444 let tokens = guard.load_catchup_tokens().await.unwrap();
446 assert_eq!(tokens, Some(vec![token]));
447
448 assert!(tsc.is_outdated());
450
451 guard.save_catchup_tokens(vec![]).await.unwrap();
453
454 assert!(guard.load_catchup_tokens().await.unwrap().is_none());
456
457 assert!(tsc.is_outdated().not());
459 }
460}