Skip to main content

matrix_sdk/client/
thread_subscriptions.rs

1// Copyright 2025 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::BTreeMap,
17    sync::{
18        Arc, OnceLock,
19        atomic::{self, AtomicBool},
20    },
21};
22
23use matrix_sdk_base::{
24    StateStoreDataKey, StateStoreDataValue, ThreadSubscriptionCatchupToken,
25    store::{StoredThreadSubscription, ThreadSubscriptionStatus},
26    task_monitor::BackgroundTaskHandle,
27};
28use ruma::{
29    EventId, OwnedEventId, OwnedRoomId, RoomId,
30    api::client::threads::get_thread_subscriptions_changes::unstable::{
31        ThreadSubscription, ThreadUnsubscription,
32    },
33    assign,
34};
35use tokio::sync::{
36    Mutex, OwnedMutexGuard,
37    mpsc::{Receiver, Sender, channel},
38};
39use tracing::{debug, instrument, trace, warn};
40
41use crate::{Client, Result, client::WeakClient};
42
43struct GuardedStoreAccess {
44    _mutex: OwnedMutexGuard<()>,
45    client: Client,
46    is_outdated: Arc<AtomicBool>,
47}
48
49impl GuardedStoreAccess {
50    /// Return the current list of catchup tokens, if any.
51    ///
52    /// It is guaranteed that if the list is set, then it's non-empty.
53    async fn load_catchup_tokens(&self) -> Result<Option<Vec<ThreadSubscriptionCatchupToken>>> {
54        let loaded = self
55            .client
56            .state_store()
57            .get_kv_data(StateStoreDataKey::ThreadSubscriptionsCatchupTokens)
58            .await?;
59
60        match loaded {
61            Some(data) => {
62                if let Some(tokens) = data.into_thread_subscriptions_catchup_tokens() {
63                    // If the tokens list is empty, automatically clean it up.
64                    if tokens.is_empty() {
65                        self.save_catchup_tokens(tokens).await?;
66                        Ok(None)
67                    } else {
68                        Ok(Some(tokens))
69                    }
70                } else {
71                    warn!(
72                        "invalid data in thread subscriptions catchup tokens state store k/v entry"
73                    );
74                    Ok(None)
75                }
76            }
77
78            None => Ok(None),
79        }
80    }
81
82    /// Saves the tokens in the database.
83    ///
84    /// Returns whether the list of tokens is empty or not.
85    #[instrument(skip_all, fields(num_tokens = tokens.len()))]
86    async fn save_catchup_tokens(
87        &self,
88        tokens: Vec<ThreadSubscriptionCatchupToken>,
89    ) -> Result<bool> {
90        let store = self.client.state_store();
91        let is_empty = if tokens.is_empty() {
92            store.remove_kv_data(StateStoreDataKey::ThreadSubscriptionsCatchupTokens).await?;
93
94            trace!("Marking thread subscriptions as not outdated \\o/");
95            self.is_outdated.store(false, atomic::Ordering::SeqCst);
96            true
97        } else {
98            store
99                .set_kv_data(
100                    StateStoreDataKey::ThreadSubscriptionsCatchupTokens,
101                    StateStoreDataValue::ThreadSubscriptionsCatchupTokens(tokens),
102                )
103                .await?;
104
105            trace!("Marking thread subscriptions as outdated.");
106            self.is_outdated.store(true, atomic::Ordering::SeqCst);
107            false
108        };
109        Ok(is_empty)
110    }
111}
112
113pub struct ThreadSubscriptionCatchup {
114    /// The task catching up thread subscriptions in the background.
115    _task: OnceLock<BackgroundTaskHandle>,
116
117    /// Whether the known list of thread subscriptions is outdated or not, i.e.
118    /// all thread subscriptions have been caught up
119    is_outdated: Arc<AtomicBool>,
120
121    /// A weak reference to the parent [`Client`] instance.
122    client: WeakClient,
123
124    /// A sender to wake up the catchup task when new catchup tokens are
125    /// available.
126    ping_sender: Sender<()>,
127
128    /// A mutex to ensure there's only one writer on the thread subscriptions
129    /// catchup tokens at a time.
130    uniq_mutex: Arc<Mutex<()>>,
131}
132
133impl ThreadSubscriptionCatchup {
134    pub async fn new(client: Client) -> Arc<Self> {
135        let is_outdated = Arc::new(AtomicBool::new(true));
136        let weak_client = WeakClient::from_client(&client);
137        let (ping_sender, ping_receiver) = channel(8);
138        let uniq_mutex = Arc::new(Mutex::new(()));
139
140        let this = Arc::new(Self {
141            _task: OnceLock::new(),
142            is_outdated,
143            client: weak_client.clone(),
144            ping_sender,
145            uniq_mutex,
146        });
147
148        // Create the task only if the client is configured to handle thread
149        // subscriptions.
150        match client.enabled_thread_subscriptions().await {
151            Ok(true) => {
152                let _ = this._task.get_or_init(|| {
153                    let that = this.clone();
154
155                    client
156                        .task_monitor()
157                        .spawn_infinite_task("client::thread_subscriptions_catchup", async move {
158                            Self::thread_subscriptions_catchup_task(that, ping_receiver).await;
159                        })
160                        .abort_on_drop()
161                });
162            }
163
164            Ok(false) => {
165                debug!("Thread subscriptions catchup not enabled, not starting the catchup task");
166            }
167
168            Err(err) => {
169                warn!("Failed to check if thread subscriptions catchup is enabled: {err}");
170            }
171        }
172
173        this
174    }
175
176    /// Returns whether the known list of thread subscriptions is outdated or
177    /// more thread subscriptions need to be caught up.
178    pub(crate) fn is_outdated(&self) -> bool {
179        self.is_outdated.load(atomic::Ordering::SeqCst)
180    }
181
182    /// Store the new subscriptions changes, received via the sync response or
183    /// from the msc4308 companion endpoint.
184    #[instrument(skip_all)]
185    pub(crate) async fn sync_subscriptions(
186        &self,
187        subscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadSubscription>>,
188        unsubscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadUnsubscription>>,
189        token: Option<ThreadSubscriptionCatchupToken>,
190    ) -> Result<()> {
191        // Precompute the updates so we don't hold the guard for too long.
192        let updates = build_subscription_updates(&subscribed, &unsubscribed);
193        let Some(guard) = self.lock().await else {
194            // Client is shutting down.
195            return Ok(());
196        };
197        self.save_catchup_token(&guard, token).await?;
198        if !updates.is_empty() {
199            trace!(
200                "saving {} new subscriptions and {} unsubscriptions",
201                subscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
202                unsubscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
203            );
204            guard.client.state_store().upsert_thread_subscriptions(updates).await?;
205        }
206        Ok(())
207    }
208
209    /// Internal helper to lock writes to the thread subscriptions catchup
210    /// tokens list.
211    async fn lock(&self) -> Option<GuardedStoreAccess> {
212        let client = self.client.get()?;
213        let mutex_guard = self.uniq_mutex.clone().lock_owned().await;
214        Some(GuardedStoreAccess {
215            _mutex: mutex_guard,
216            client,
217            is_outdated: self.is_outdated.clone(),
218        })
219    }
220
221    /// Save a new catchup token (or absence thereof) in the state store.
222    async fn save_catchup_token(
223        &self,
224        guard: &GuardedStoreAccess,
225        token: Option<ThreadSubscriptionCatchupToken>,
226    ) -> Result<()> {
227        // Note: saving an empty tokens list will mark the thread subscriptions list as
228        // not outdated.
229        let mut tokens = guard.load_catchup_tokens().await?.unwrap_or_default();
230
231        if let Some(token) = token {
232            trace!(?token, "Saving catchup token");
233            tokens.push(token);
234        } else {
235            trace!("No catchup token to save");
236        }
237
238        let is_token_list_empty = guard.save_catchup_tokens(tokens).await?;
239
240        // Wake up the catchup task, in case it's waiting.
241        if !is_token_list_empty {
242            let _ = self.ping_sender.send(()).await;
243        }
244
245        Ok(())
246    }
247
248    /// The background task listening to new catchup tokens, and using them to
249    /// catch up the thread subscriptions via the [MSC4308] companion
250    /// endpoint.
251    ///
252    /// It will continue to process catchup tokens until there are none, and
253    /// then wait for a new one to be available and inserted in the
254    /// database.
255    ///
256    /// It always processes catch up tokens from the newest to the oldest, since
257    /// newest tokens are more interesting than older ones. Indeed, they're
258    /// more likely to include entries with higher bump-stamps, i.e. to include
259    /// more recent thread subscriptions statuses for each thread, so more
260    /// relevant information.
261    ///
262    /// [MSC4308]: https://github.com/matrix-org/matrix-spec-proposals/pull/4308
263    #[instrument(skip_all)]
264    async fn thread_subscriptions_catchup_task(this: Arc<Self>, mut ping_receiver: Receiver<()>) {
265        loop {
266            // Load the current catchup token.
267            let Some(guard) = this.lock().await else {
268                // Client is shutting down.
269                return;
270            };
271
272            let store_tokens = match guard.load_catchup_tokens().await {
273                Ok(tokens) => tokens,
274                Err(err) => {
275                    warn!("Failed to load thread subscriptions catchup tokens: {err}");
276                    continue;
277                }
278            };
279
280            let Some(mut tokens) = store_tokens else {
281                // Release the mutex.
282                drop(guard);
283
284                // Wait for a wake up.
285                trace!("Waiting for an explicit wake up to process future thread subscriptions");
286
287                if let Some(()) = ping_receiver.recv().await {
288                    trace!("Woke up!");
289                    continue;
290                }
291
292                // Channel closed, the client is shutting down.
293                break;
294            };
295
296            // We do have a tokens. Pop the last value, and use it to catch up!
297            let last = tokens.pop().expect("must be set per `load_catchup_tokens` contract");
298
299            // Release the mutex before running the network request.
300            let client = guard.client.clone();
301            drop(guard);
302
303            // Start the actual catchup!
304            let req = assign!(ruma::api::client::threads::get_thread_subscriptions_changes::unstable::Request::new(), {
305                from: Some(last.from.clone()),
306                to: last.to.clone(),
307            });
308
309            match client.send(req).await {
310                Ok(resp) => {
311                    // Precompute the updates so we don't hold the guard for too long.
312                    let updates = build_subscription_updates(&resp.subscribed, &resp.unsubscribed);
313
314                    let guard = this
315                        .lock()
316                        .await
317                        .expect("a client instance is alive, so the locking should not fail");
318
319                    if !updates.is_empty() {
320                        trace!(
321                            "saving {} new subscriptions and {} unsubscriptions",
322                            resp.subscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
323                            resp.unsubscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
324                        );
325
326                        if let Err(err) =
327                            guard.client.state_store().upsert_thread_subscriptions(updates).await
328                        {
329                            warn!("Failed to store caught up thread subscriptions: {err}");
330                            continue;
331                        }
332                    }
333
334                    // Refresh the tokens, as the list might have changed while we sent the
335                    // request.
336                    let mut tokens = match guard.load_catchup_tokens().await {
337                        Ok(tokens) => tokens.unwrap_or_default(),
338                        Err(err) => {
339                            warn!("Failed to load thread subscriptions catchup tokens: {err}");
340                            continue;
341                        }
342                    };
343
344                    let Some(index) = tokens.iter().position(|t| *t == last) else {
345                        warn!("Thread subscriptions catchup token disappeared while processing it");
346                        continue;
347                    };
348
349                    if let Some(next_batch) = resp.end {
350                        // If the response contained a next batch token, reuse the same catchup
351                        // token entry, so the `to` value remains the same.
352                        tokens[index] =
353                            ThreadSubscriptionCatchupToken { from: next_batch, to: last.to };
354                    } else {
355                        // No next batch, we can remove this token from the list.
356                        tokens.remove(index);
357                    }
358
359                    if let Err(err) = guard.save_catchup_tokens(tokens).await {
360                        warn!("Failed to save updated thread subscriptions catchup tokens: {err}");
361                    }
362                }
363
364                Err(err) => {
365                    warn!("Failed to catch up thread subscriptions: {err}");
366                }
367            }
368        }
369    }
370}
371
372/// Internal helper for building the thread subscription updates Vec.
373fn build_subscription_updates<'a>(
374    subscribed: &'a BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadSubscription>>,
375    unsubscribed: &'a BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadUnsubscription>>,
376) -> Vec<(&'a RoomId, &'a EventId, StoredThreadSubscription)> {
377    let mut updates: Vec<(&RoomId, &EventId, StoredThreadSubscription)> =
378        Vec::with_capacity(unsubscribed.len() + subscribed.len());
379
380    // Take into account the new unsubscriptions.
381    for (room_id, room_map) in unsubscribed {
382        for (event_id, thread_sub) in room_map {
383            updates.push((
384                room_id,
385                event_id,
386                StoredThreadSubscription {
387                    status: ThreadSubscriptionStatus::Unsubscribed,
388                    bump_stamp: Some(thread_sub.bump_stamp.into()),
389                },
390            ));
391        }
392    }
393
394    // Take into account the new subscriptions.
395    for (room_id, room_map) in subscribed {
396        for (event_id, thread_sub) in room_map {
397            updates.push((
398                room_id,
399                event_id,
400                StoredThreadSubscription {
401                    status: ThreadSubscriptionStatus::Subscribed {
402                        automatic: thread_sub.automatic,
403                    },
404                    bump_stamp: Some(thread_sub.bump_stamp.into()),
405                },
406            ));
407        }
408    }
409
410    updates
411}
412
413#[cfg(test)]
414mod tests {
415    use std::ops::Not as _;
416
417    use matrix_sdk_base::ThreadSubscriptionCatchupToken;
418    use matrix_sdk_test::async_test;
419
420    use crate::test_utils::client::MockClientBuilder;
421
422    #[async_test]
423    async fn test_load_save_catchup_tokens() {
424        let client = MockClientBuilder::new(None).build().await;
425
426        let tsc = client.thread_subscription_catchup();
427
428        // At first there are no catchup tokens, and we are outdated.
429        let guard = tsc.lock().await.unwrap();
430        assert!(guard.load_catchup_tokens().await.unwrap().is_none());
431        assert!(tsc.is_outdated());
432
433        // When I save a token,
434        let token =
435            ThreadSubscriptionCatchupToken { from: "from".to_owned(), to: Some("to".to_owned()) };
436        guard.save_catchup_tokens(vec![token.clone()]).await.unwrap();
437
438        // Well, it is saved,
439        let tokens = guard.load_catchup_tokens().await.unwrap();
440        assert_eq!(tokens, Some(vec![token]));
441
442        // And we are still outdated.
443        assert!(tsc.is_outdated());
444
445        // When I remove the token,
446        guard.save_catchup_tokens(vec![]).await.unwrap();
447
448        // It is gone,
449        assert!(guard.load_catchup_tokens().await.unwrap().is_none());
450
451        // And we are not outdated anymore!
452        assert!(tsc.is_outdated().not());
453    }
454}
455
456#[cfg(all(test, not(target_family = "wasm")))]
457mod timed_tests {
458    use std::time::Duration;
459
460    use matrix_sdk_base::{ThreadingSupport, sleep::sleep};
461    use matrix_sdk_test::async_test;
462    use tokio::task::yield_now;
463
464    use crate::{client::WeakClient, test_utils::mocks::MatrixMockServer};
465
466    #[async_test]
467    async fn test_issue_6573_client_can_drop_thread_subscriptions_task() {
468        let server = MatrixMockServer::new().await;
469        server.mock_versions().with_thread_subscriptions().ok().mount().await;
470
471        let client = server
472            .client_builder()
473            .no_server_versions()
474            .on_builder(|builder| {
475                builder
476                    .with_threading_support(ThreadingSupport::Enabled { with_subscriptions: true })
477            })
478            .build()
479            .await;
480
481        let tsc = client.thread_subscription_catchup();
482
483        // Wait for anything to start up.
484        yield_now().await;
485
486        // Ensure the task is running.
487        assert!(tsc._task.get().is_some());
488
489        // Get a weak reference to the client.
490        let weak_client = WeakClient::from_client(&client);
491
492        // Drop the client will drop the task.
493        drop(client);
494
495        // Wait for anything to shutdown.
496        sleep(Duration::from_secs(2)).await;
497
498        // The client has been dropped correctly.
499        assert_eq!(weak_client.strong_count(), 0);
500    }
501}