matrix_sdk/client/
thread_subscriptions.rs

1// Copyright 2025 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    collections::BTreeMap,
17    sync::{
18        Arc,
19        atomic::{self, AtomicBool},
20    },
21};
22
23use matrix_sdk_base::{
24    StateStoreDataKey, StateStoreDataValue, ThreadSubscriptionCatchupToken,
25    executor::AbortOnDrop,
26    store::{StoredThreadSubscription, ThreadSubscriptionStatus},
27};
28use matrix_sdk_common::executor::spawn;
29use once_cell::sync::OnceCell;
30use ruma::{
31    OwnedEventId, OwnedRoomId,
32    api::client::threads::get_thread_subscriptions_changes::unstable::{
33        ThreadSubscription, ThreadUnsubscription,
34    },
35    assign,
36};
37use tokio::sync::{
38    Mutex, OwnedMutexGuard,
39    mpsc::{Receiver, Sender, channel},
40};
41use tracing::{instrument, trace, warn};
42
43use crate::{Client, Result, client::WeakClient};
44
45struct GuardedStoreAccess {
46    _mutex: OwnedMutexGuard<()>,
47    client: Client,
48    is_outdated: Arc<AtomicBool>,
49}
50
51impl GuardedStoreAccess {
52    /// Return the current list of catchup tokens, if any.
53    ///
54    /// It is guaranteed that if the list is set, then it's non-empty.
55    async fn load_catchup_tokens(&self) -> Result<Option<Vec<ThreadSubscriptionCatchupToken>>> {
56        let loaded = self
57            .client
58            .state_store()
59            .get_kv_data(StateStoreDataKey::ThreadSubscriptionsCatchupTokens)
60            .await?;
61
62        match loaded {
63            Some(data) => {
64                if let Some(tokens) = data.into_thread_subscriptions_catchup_tokens() {
65                    // If the tokens list is empty, automatically clean it up.
66                    if tokens.is_empty() {
67                        self.save_catchup_tokens(tokens).await?;
68                        Ok(None)
69                    } else {
70                        Ok(Some(tokens))
71                    }
72                } else {
73                    warn!(
74                        "invalid data in thread subscriptions catchup tokens state store k/v entry"
75                    );
76                    Ok(None)
77                }
78            }
79
80            None => Ok(None),
81        }
82    }
83
84    /// Saves the tokens in the database.
85    ///
86    /// Returns whether the list of tokens is empty or not.
87    #[instrument(skip_all, fields(num_tokens = tokens.len()))]
88    async fn save_catchup_tokens(
89        &self,
90        tokens: Vec<ThreadSubscriptionCatchupToken>,
91    ) -> Result<bool> {
92        let store = self.client.state_store();
93        let is_empty = if tokens.is_empty() {
94            store.remove_kv_data(StateStoreDataKey::ThreadSubscriptionsCatchupTokens).await?;
95
96            trace!("Marking thread subscriptions as not outdated \\o/");
97            self.is_outdated.store(false, atomic::Ordering::SeqCst);
98            true
99        } else {
100            store
101                .set_kv_data(
102                    StateStoreDataKey::ThreadSubscriptionsCatchupTokens,
103                    StateStoreDataValue::ThreadSubscriptionsCatchupTokens(tokens),
104                )
105                .await?;
106
107            trace!("Marking thread subscriptions as outdated.");
108            self.is_outdated.store(true, atomic::Ordering::SeqCst);
109            false
110        };
111        Ok(is_empty)
112    }
113}
114
115pub struct ThreadSubscriptionCatchup {
116    /// The task catching up thread subscriptions in the background.
117    _task: OnceCell<AbortOnDrop<()>>,
118
119    /// Whether the known list of thread subscriptions is outdated or not, i.e.
120    /// all thread subscriptions have been caught up
121    is_outdated: Arc<AtomicBool>,
122
123    /// A weak reference to the parent [`Client`] instance.
124    client: WeakClient,
125
126    /// A sender to wake up the catchup task when new catchup tokens are
127    /// available.
128    ping_sender: Sender<()>,
129
130    /// A mutex to ensure there's only one writer on the thread subscriptions
131    /// catchup tokens at a time.
132    uniq_mutex: Arc<Mutex<()>>,
133}
134
135impl ThreadSubscriptionCatchup {
136    pub fn new(client: Client) -> Arc<Self> {
137        let is_outdated = Arc::new(AtomicBool::new(true));
138
139        let weak_client = WeakClient::from_client(&client);
140
141        let (ping_sender, ping_receiver) = channel(8);
142
143        let uniq_mutex = Arc::new(Mutex::new(()));
144
145        let this = Arc::new(Self {
146            _task: OnceCell::new(),
147            is_outdated,
148            client: weak_client,
149            ping_sender,
150            uniq_mutex,
151        });
152
153        // Create the task only if the client is configured to handle thread
154        // subscriptions.
155        if client.enabled_thread_subscriptions() {
156            let _ = this._task.get_or_init(|| {
157                AbortOnDrop::new(spawn(Self::thread_subscriptions_catchup_task(
158                    this.clone(),
159                    ping_receiver,
160                )))
161            });
162        }
163
164        this
165    }
166
167    /// Returns whether the known list of thread subscriptions is outdated or
168    /// more thread subscriptions need to be caught up.
169    pub(crate) fn is_outdated(&self) -> bool {
170        self.is_outdated.load(atomic::Ordering::SeqCst)
171    }
172
173    /// Store the new subscriptions changes, received via the sync response or
174    /// from the msc4308 companion endpoint.
175    #[instrument(skip_all)]
176    pub(crate) async fn sync_subscriptions(
177        &self,
178        subscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadSubscription>>,
179        unsubscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadUnsubscription>>,
180        token: Option<ThreadSubscriptionCatchupToken>,
181    ) -> Result<()> {
182        let Some(guard) = self.lock().await else {
183            // Client is shutting down.
184            return Ok(());
185        };
186        self.save_catchup_token(&guard, token).await?;
187        self.store_subscriptions(&guard, subscribed, unsubscribed).await?;
188        Ok(())
189    }
190
191    async fn store_subscriptions(
192        &self,
193        guard: &GuardedStoreAccess,
194        subscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadSubscription>>,
195        unsubscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadUnsubscription>>,
196    ) -> Result<()> {
197        if subscribed.is_empty() && unsubscribed.is_empty() {
198            // Nothing to do.
199            return Ok(());
200        }
201
202        trace!(
203            "saving {} new subscriptions and {} unsubscriptions",
204            subscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
205            unsubscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
206        );
207
208        // Take into account the new unsubscriptions.
209        for (room_id, room_map) in unsubscribed {
210            for (event_id, thread_sub) in room_map {
211                guard
212                    .client
213                    .state_store()
214                    .upsert_thread_subscription(
215                        &room_id,
216                        &event_id,
217                        StoredThreadSubscription {
218                            status: ThreadSubscriptionStatus::Unsubscribed,
219                            bump_stamp: Some(thread_sub.bump_stamp.into()),
220                        },
221                    )
222                    .await?;
223            }
224        }
225
226        // Take into account the new subscriptions.
227        for (room_id, room_map) in subscribed {
228            for (event_id, thread_sub) in room_map {
229                guard
230                    .client
231                    .state_store()
232                    .upsert_thread_subscription(
233                        &room_id,
234                        &event_id,
235                        StoredThreadSubscription {
236                            status: ThreadSubscriptionStatus::Subscribed {
237                                automatic: thread_sub.automatic,
238                            },
239                            bump_stamp: Some(thread_sub.bump_stamp.into()),
240                        },
241                    )
242                    .await?;
243            }
244        }
245
246        Ok(())
247    }
248
249    /// Internal helper to lock writes to the thread subscriptions catchup
250    /// tokens list.
251    async fn lock(&self) -> Option<GuardedStoreAccess> {
252        let client = self.client.get()?;
253        let mutex_guard = self.uniq_mutex.clone().lock_owned().await;
254        Some(GuardedStoreAccess {
255            _mutex: mutex_guard,
256            client,
257            is_outdated: self.is_outdated.clone(),
258        })
259    }
260
261    /// Save a new catchup token (or absence thereof) in the state store.
262    async fn save_catchup_token(
263        &self,
264        guard: &GuardedStoreAccess,
265        token: Option<ThreadSubscriptionCatchupToken>,
266    ) -> Result<()> {
267        // Note: saving an empty tokens list will mark the thread subscriptions list as
268        // not outdated.
269        let mut tokens = guard.load_catchup_tokens().await?.unwrap_or_default();
270
271        if let Some(token) = token {
272            trace!(?token, "Saving catchup token");
273            tokens.push(token);
274        } else {
275            trace!("No catchup token to save");
276        }
277
278        let is_token_list_empty = guard.save_catchup_tokens(tokens).await?;
279
280        // Wake up the catchup task, in case it's waiting.
281        if !is_token_list_empty {
282            let _ = self.ping_sender.send(()).await;
283        }
284
285        Ok(())
286    }
287
288    /// The background task listening to new catchup tokens, and using them to
289    /// catch up the thread subscriptions via the [MSC4308] companion
290    /// endpoint.
291    ///
292    /// It will continue to process catchup tokens until there are none, and
293    /// then wait for a new one to be available and inserted in the
294    /// database.
295    ///
296    /// It always processes catch up tokens from the newest to the oldest, since
297    /// newest tokens are more interesting than older ones. Indeed, they're
298    /// more likely to include entries with higher bump-stamps, i.e. to include
299    /// more recent thread subscriptions statuses for each thread, so more
300    /// relevant information.
301    ///
302    /// [MSC4308]: https://github.com/matrix-org/matrix-spec-proposals/pull/4308
303    #[instrument(skip_all)]
304    async fn thread_subscriptions_catchup_task(this: Arc<Self>, mut ping_receiver: Receiver<()>) {
305        loop {
306            // Load the current catchup token.
307            let Some(guard) = this.lock().await else {
308                // Client is shutting down.
309                return;
310            };
311
312            let store_tokens = match guard.load_catchup_tokens().await {
313                Ok(tokens) => tokens,
314                Err(err) => {
315                    warn!("Failed to load thread subscriptions catchup tokens: {err}");
316                    continue;
317                }
318            };
319
320            let Some(mut tokens) = store_tokens else {
321                // Release the mutex.
322                drop(guard);
323
324                // Wait for a wake up.
325                trace!("Waiting for an explicit wake up to process future thread subscriptions");
326
327                if let Some(()) = ping_receiver.recv().await {
328                    trace!("Woke up!");
329                    continue;
330                }
331
332                // Channel closed, the client is shutting down.
333                break;
334            };
335
336            // We do have a tokens. Pop the last value, and use it to catch up!
337            let last = tokens.pop().expect("must be set per `load_catchup_tokens` contract");
338
339            // Release the mutex before running the network request.
340            let client = guard.client.clone();
341            drop(guard);
342
343            // Start the actual catchup!
344            let req = assign!(ruma::api::client::threads::get_thread_subscriptions_changes::unstable::Request::new(), {
345                from: Some(last.from.clone()),
346                to: last.to.clone(),
347            });
348
349            match client.send(req).await {
350                Ok(resp) => {
351                    let guard = this
352                        .lock()
353                        .await
354                        .expect("a client instance is alive, so the locking should not fail");
355
356                    if let Err(err) =
357                        this.store_subscriptions(&guard, resp.subscribed, resp.unsubscribed).await
358                    {
359                        warn!("Failed to store caught up thread subscriptions: {err}");
360                        continue;
361                    }
362
363                    // Refresh the tokens, as the list might have changed while we sent the
364                    // request.
365                    let mut tokens = match guard.load_catchup_tokens().await {
366                        Ok(tokens) => tokens.unwrap_or_default(),
367                        Err(err) => {
368                            warn!("Failed to load thread subscriptions catchup tokens: {err}");
369                            continue;
370                        }
371                    };
372
373                    let Some(index) = tokens.iter().position(|t| *t == last) else {
374                        warn!("Thread subscriptions catchup token disappeared while processing it");
375                        continue;
376                    };
377
378                    if let Some(next_batch) = resp.end {
379                        // If the response contained a next batch token, reuse the same catchup
380                        // token entry, so the `to` value remains the same.
381                        tokens[index] =
382                            ThreadSubscriptionCatchupToken { from: next_batch, to: last.to };
383                    } else {
384                        // No next batch, we can remove this token from the list.
385                        tokens.remove(index);
386                    }
387
388                    if let Err(err) = guard.save_catchup_tokens(tokens).await {
389                        warn!("Failed to save updated thread subscriptions catchup tokens: {err}");
390                    }
391                }
392
393                Err(err) => {
394                    warn!("Failed to catch up thread subscriptions: {err}");
395                }
396            }
397        }
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use std::ops::Not as _;
404
405    use matrix_sdk_base::ThreadSubscriptionCatchupToken;
406    use matrix_sdk_test::async_test;
407
408    use crate::test_utils::client::MockClientBuilder;
409
410    #[async_test]
411    async fn test_load_save_catchup_tokens() {
412        let client = MockClientBuilder::new(None).build().await;
413
414        let tsc = client.thread_subscription_catchup();
415
416        // At first there are no catchup tokens, and we are outdated.
417        let guard = tsc.lock().await.unwrap();
418        assert!(guard.load_catchup_tokens().await.unwrap().is_none());
419        assert!(tsc.is_outdated());
420
421        // When I save a token,
422        let token =
423            ThreadSubscriptionCatchupToken { from: "from".to_owned(), to: Some("to".to_owned()) };
424        guard.save_catchup_tokens(vec![token.clone()]).await.unwrap();
425
426        // Well, it is saved,
427        let tokens = guard.load_catchup_tokens().await.unwrap();
428        assert_eq!(tokens, Some(vec![token]));
429
430        // And we are still outdated.
431        assert!(tsc.is_outdated());
432
433        // When I remove the token,
434        guard.save_catchup_tokens(vec![]).await.unwrap();
435
436        // It is gone,
437        assert!(guard.load_catchup_tokens().await.unwrap().is_none());
438
439        // And we are not outdated anymore!
440        assert!(tsc.is_outdated().not());
441    }
442}