Skip to main content

matrix_sdk/client/
thread_subscriptions.rs

1// Copyright 2025 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::BTreeMap,
17    sync::{
18        Arc, OnceLock,
19        atomic::{self, AtomicBool},
20    },
21};
22
23use matrix_sdk_base::{
24    StateStoreDataKey, StateStoreDataValue, ThreadSubscriptionCatchupToken,
25    store::{StoredThreadSubscription, ThreadSubscriptionStatus},
26    task_monitor::BackgroundTaskHandle,
27};
28use ruma::{
29    EventId, OwnedEventId, OwnedRoomId, RoomId,
30    api::client::threads::get_thread_subscriptions_changes::unstable::{
31        ThreadSubscription, ThreadUnsubscription,
32    },
33    assign,
34};
35use tokio::sync::{
36    Mutex, OwnedMutexGuard,
37    mpsc::{Receiver, Sender, channel},
38};
39use tracing::{debug, instrument, trace, warn};
40
41use crate::{Client, Result, client::WeakClient};
42
43struct GuardedStoreAccess {
44    _mutex: OwnedMutexGuard<()>,
45    client: Client,
46    is_outdated: Arc<AtomicBool>,
47}
48
49impl GuardedStoreAccess {
50    /// Return the current list of catchup tokens, if any.
51    ///
52    /// It is guaranteed that if the list is set, then it's non-empty.
53    async fn load_catchup_tokens(&self) -> Result<Option<Vec<ThreadSubscriptionCatchupToken>>> {
54        let loaded = self
55            .client
56            .state_store()
57            .get_kv_data(StateStoreDataKey::ThreadSubscriptionsCatchupTokens)
58            .await?;
59
60        match loaded {
61            Some(data) => {
62                if let Some(tokens) = data.into_thread_subscriptions_catchup_tokens() {
63                    // If the tokens list is empty, automatically clean it up.
64                    if tokens.is_empty() {
65                        self.save_catchup_tokens(tokens).await?;
66                        Ok(None)
67                    } else {
68                        Ok(Some(tokens))
69                    }
70                } else {
71                    warn!(
72                        "invalid data in thread subscriptions catchup tokens state store k/v entry"
73                    );
74                    Ok(None)
75                }
76            }
77
78            None => Ok(None),
79        }
80    }
81
82    /// Saves the tokens in the database.
83    ///
84    /// Returns whether the list of tokens is empty or not.
85    #[instrument(skip_all, fields(num_tokens = tokens.len()))]
86    async fn save_catchup_tokens(
87        &self,
88        tokens: Vec<ThreadSubscriptionCatchupToken>,
89    ) -> Result<bool> {
90        let store = self.client.state_store();
91        let is_empty = if tokens.is_empty() {
92            store.remove_kv_data(StateStoreDataKey::ThreadSubscriptionsCatchupTokens).await?;
93
94            trace!("Marking thread subscriptions as not outdated \\o/");
95            self.is_outdated.store(false, atomic::Ordering::SeqCst);
96            true
97        } else {
98            store
99                .set_kv_data(
100                    StateStoreDataKey::ThreadSubscriptionsCatchupTokens,
101                    StateStoreDataValue::ThreadSubscriptionsCatchupTokens(tokens),
102                )
103                .await?;
104
105            trace!("Marking thread subscriptions as outdated.");
106            self.is_outdated.store(true, atomic::Ordering::SeqCst);
107            false
108        };
109        Ok(is_empty)
110    }
111}
112
113pub struct ThreadSubscriptionCatchup {
114    /// The task catching up thread subscriptions in the background.
115    _task: OnceLock<BackgroundTaskHandle>,
116
117    /// Whether the known list of thread subscriptions is outdated or not, i.e.
118    /// all thread subscriptions have been caught up
119    is_outdated: Arc<AtomicBool>,
120
121    /// A weak reference to the parent [`Client`] instance.
122    client: WeakClient,
123
124    /// A sender to wake up the catchup task when new catchup tokens are
125    /// available.
126    ping_sender: Sender<()>,
127
128    /// A mutex to ensure there's only one writer on the thread subscriptions
129    /// catchup tokens at a time.
130    uniq_mutex: Arc<Mutex<()>>,
131}
132
133impl ThreadSubscriptionCatchup {
134    pub fn new(client: Client) -> Arc<Self> {
135        let is_outdated = Arc::new(AtomicBool::new(true));
136
137        let weak_client = WeakClient::from_client(&client);
138
139        let (ping_sender, ping_receiver) = channel(8);
140
141        let uniq_mutex = Arc::new(Mutex::new(()));
142
143        let this = Arc::new(Self {
144            _task: OnceLock::new(),
145            is_outdated,
146            client: weak_client,
147            ping_sender,
148            uniq_mutex,
149        });
150
151        // Create the task only if the client is configured to handle thread
152        // subscriptions.
153        let _ = this._task.get_or_init(|| {
154            let that = this.clone();
155            let client_clone = client.clone();
156
157            client.task_monitor().spawn_infinite_task("client::thread_subscriptions_catchup", async move {
158                match client_clone.enabled_thread_subscriptions().await {
159                    Ok(enabled) => {
160                        if !enabled {
161                            debug!("Thread subscriptions catchup not enabled, not starting the catchup task");
162                            return;
163                        }
164                    }
165
166                    Err(err) => {
167                        warn!("Failed to check if thread subscriptions catchup is enabled: {err}");
168                        return;
169                    }
170                }
171
172                Self::thread_subscriptions_catchup_task(
173                    that,
174                    ping_receiver,
175                ).await
176            }).abort_on_drop()
177        });
178
179        this
180    }
181
182    /// Returns whether the known list of thread subscriptions is outdated or
183    /// more thread subscriptions need to be caught up.
184    pub(crate) fn is_outdated(&self) -> bool {
185        self.is_outdated.load(atomic::Ordering::SeqCst)
186    }
187
188    /// Store the new subscriptions changes, received via the sync response or
189    /// from the msc4308 companion endpoint.
190    #[instrument(skip_all)]
191    pub(crate) async fn sync_subscriptions(
192        &self,
193        subscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadSubscription>>,
194        unsubscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadUnsubscription>>,
195        token: Option<ThreadSubscriptionCatchupToken>,
196    ) -> Result<()> {
197        // Precompute the updates so we don't hold the guard for too long.
198        let updates = build_subscription_updates(&subscribed, &unsubscribed);
199        let Some(guard) = self.lock().await else {
200            // Client is shutting down.
201            return Ok(());
202        };
203        self.save_catchup_token(&guard, token).await?;
204        if !updates.is_empty() {
205            trace!(
206                "saving {} new subscriptions and {} unsubscriptions",
207                subscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
208                unsubscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
209            );
210            guard.client.state_store().upsert_thread_subscriptions(updates).await?;
211        }
212        Ok(())
213    }
214
215    /// Internal helper to lock writes to the thread subscriptions catchup
216    /// tokens list.
217    async fn lock(&self) -> Option<GuardedStoreAccess> {
218        let client = self.client.get()?;
219        let mutex_guard = self.uniq_mutex.clone().lock_owned().await;
220        Some(GuardedStoreAccess {
221            _mutex: mutex_guard,
222            client,
223            is_outdated: self.is_outdated.clone(),
224        })
225    }
226
227    /// Save a new catchup token (or absence thereof) in the state store.
228    async fn save_catchup_token(
229        &self,
230        guard: &GuardedStoreAccess,
231        token: Option<ThreadSubscriptionCatchupToken>,
232    ) -> Result<()> {
233        // Note: saving an empty tokens list will mark the thread subscriptions list as
234        // not outdated.
235        let mut tokens = guard.load_catchup_tokens().await?.unwrap_or_default();
236
237        if let Some(token) = token {
238            trace!(?token, "Saving catchup token");
239            tokens.push(token);
240        } else {
241            trace!("No catchup token to save");
242        }
243
244        let is_token_list_empty = guard.save_catchup_tokens(tokens).await?;
245
246        // Wake up the catchup task, in case it's waiting.
247        if !is_token_list_empty {
248            let _ = self.ping_sender.send(()).await;
249        }
250
251        Ok(())
252    }
253
254    /// The background task listening to new catchup tokens, and using them to
255    /// catch up the thread subscriptions via the [MSC4308] companion
256    /// endpoint.
257    ///
258    /// It will continue to process catchup tokens until there are none, and
259    /// then wait for a new one to be available and inserted in the
260    /// database.
261    ///
262    /// It always processes catch up tokens from the newest to the oldest, since
263    /// newest tokens are more interesting than older ones. Indeed, they're
264    /// more likely to include entries with higher bump-stamps, i.e. to include
265    /// more recent thread subscriptions statuses for each thread, so more
266    /// relevant information.
267    ///
268    /// [MSC4308]: https://github.com/matrix-org/matrix-spec-proposals/pull/4308
269    #[instrument(skip_all)]
270    async fn thread_subscriptions_catchup_task(this: Arc<Self>, mut ping_receiver: Receiver<()>) {
271        loop {
272            // Load the current catchup token.
273            let Some(guard) = this.lock().await else {
274                // Client is shutting down.
275                return;
276            };
277
278            let store_tokens = match guard.load_catchup_tokens().await {
279                Ok(tokens) => tokens,
280                Err(err) => {
281                    warn!("Failed to load thread subscriptions catchup tokens: {err}");
282                    continue;
283                }
284            };
285
286            let Some(mut tokens) = store_tokens else {
287                // Release the mutex.
288                drop(guard);
289
290                // Wait for a wake up.
291                trace!("Waiting for an explicit wake up to process future thread subscriptions");
292
293                if let Some(()) = ping_receiver.recv().await {
294                    trace!("Woke up!");
295                    continue;
296                }
297
298                // Channel closed, the client is shutting down.
299                break;
300            };
301
302            // We do have a tokens. Pop the last value, and use it to catch up!
303            let last = tokens.pop().expect("must be set per `load_catchup_tokens` contract");
304
305            // Release the mutex before running the network request.
306            let client = guard.client.clone();
307            drop(guard);
308
309            // Start the actual catchup!
310            let req = assign!(ruma::api::client::threads::get_thread_subscriptions_changes::unstable::Request::new(), {
311                from: Some(last.from.clone()),
312                to: last.to.clone(),
313            });
314
315            match client.send(req).await {
316                Ok(resp) => {
317                    // Precompute the updates so we don't hold the guard for too long.
318                    let updates = build_subscription_updates(&resp.subscribed, &resp.unsubscribed);
319
320                    let guard = this
321                        .lock()
322                        .await
323                        .expect("a client instance is alive, so the locking should not fail");
324
325                    if !updates.is_empty() {
326                        trace!(
327                            "saving {} new subscriptions and {} unsubscriptions",
328                            resp.subscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
329                            resp.unsubscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
330                        );
331
332                        if let Err(err) =
333                            guard.client.state_store().upsert_thread_subscriptions(updates).await
334                        {
335                            warn!("Failed to store caught up thread subscriptions: {err}");
336                            continue;
337                        }
338                    }
339
340                    // Refresh the tokens, as the list might have changed while we sent the
341                    // request.
342                    let mut tokens = match guard.load_catchup_tokens().await {
343                        Ok(tokens) => tokens.unwrap_or_default(),
344                        Err(err) => {
345                            warn!("Failed to load thread subscriptions catchup tokens: {err}");
346                            continue;
347                        }
348                    };
349
350                    let Some(index) = tokens.iter().position(|t| *t == last) else {
351                        warn!("Thread subscriptions catchup token disappeared while processing it");
352                        continue;
353                    };
354
355                    if let Some(next_batch) = resp.end {
356                        // If the response contained a next batch token, reuse the same catchup
357                        // token entry, so the `to` value remains the same.
358                        tokens[index] =
359                            ThreadSubscriptionCatchupToken { from: next_batch, to: last.to };
360                    } else {
361                        // No next batch, we can remove this token from the list.
362                        tokens.remove(index);
363                    }
364
365                    if let Err(err) = guard.save_catchup_tokens(tokens).await {
366                        warn!("Failed to save updated thread subscriptions catchup tokens: {err}");
367                    }
368                }
369
370                Err(err) => {
371                    warn!("Failed to catch up thread subscriptions: {err}");
372                }
373            }
374        }
375    }
376}
377
378/// Internal helper for building the thread subscription updates Vec.
379fn build_subscription_updates<'a>(
380    subscribed: &'a BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadSubscription>>,
381    unsubscribed: &'a BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadUnsubscription>>,
382) -> Vec<(&'a RoomId, &'a EventId, StoredThreadSubscription)> {
383    let mut updates: Vec<(&RoomId, &EventId, StoredThreadSubscription)> =
384        Vec::with_capacity(unsubscribed.len() + subscribed.len());
385
386    // Take into account the new unsubscriptions.
387    for (room_id, room_map) in unsubscribed {
388        for (event_id, thread_sub) in room_map {
389            updates.push((
390                room_id,
391                event_id,
392                StoredThreadSubscription {
393                    status: ThreadSubscriptionStatus::Unsubscribed,
394                    bump_stamp: Some(thread_sub.bump_stamp.into()),
395                },
396            ));
397        }
398    }
399
400    // Take into account the new subscriptions.
401    for (room_id, room_map) in subscribed {
402        for (event_id, thread_sub) in room_map {
403            updates.push((
404                room_id,
405                event_id,
406                StoredThreadSubscription {
407                    status: ThreadSubscriptionStatus::Subscribed {
408                        automatic: thread_sub.automatic,
409                    },
410                    bump_stamp: Some(thread_sub.bump_stamp.into()),
411                },
412            ));
413        }
414    }
415
416    updates
417}
418
419#[cfg(test)]
420mod tests {
421    use std::ops::Not as _;
422
423    use matrix_sdk_base::ThreadSubscriptionCatchupToken;
424    use matrix_sdk_test::async_test;
425
426    use crate::test_utils::client::MockClientBuilder;
427
428    #[async_test]
429    async fn test_load_save_catchup_tokens() {
430        let client = MockClientBuilder::new(None).build().await;
431
432        let tsc = client.thread_subscription_catchup();
433
434        // At first there are no catchup tokens, and we are outdated.
435        let guard = tsc.lock().await.unwrap();
436        assert!(guard.load_catchup_tokens().await.unwrap().is_none());
437        assert!(tsc.is_outdated());
438
439        // When I save a token,
440        let token =
441            ThreadSubscriptionCatchupToken { from: "from".to_owned(), to: Some("to".to_owned()) };
442        guard.save_catchup_tokens(vec![token.clone()]).await.unwrap();
443
444        // Well, it is saved,
445        let tokens = guard.load_catchup_tokens().await.unwrap();
446        assert_eq!(tokens, Some(vec![token]));
447
448        // And we are still outdated.
449        assert!(tsc.is_outdated());
450
451        // When I remove the token,
452        guard.save_catchup_tokens(vec![]).await.unwrap();
453
454        // It is gone,
455        assert!(guard.load_catchup_tokens().await.unwrap().is_none());
456
457        // And we are not outdated anymore!
458        assert!(tsc.is_outdated().not());
459    }
460}