matrix-sdk 0.16.0

A high level Matrix client-server library.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
// Copyright 2025 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::{
    collections::BTreeMap,
    sync::{
        Arc,
        atomic::{self, AtomicBool},
    },
};

use matrix_sdk_base::{
    StateStoreDataKey, StateStoreDataValue, ThreadSubscriptionCatchupToken,
    executor::AbortOnDrop,
    store::{StoredThreadSubscription, ThreadSubscriptionStatus},
};
use matrix_sdk_common::executor::spawn;
use once_cell::sync::OnceCell;
use ruma::{
    OwnedEventId, OwnedRoomId,
    api::client::threads::get_thread_subscriptions_changes::unstable::{
        ThreadSubscription, ThreadUnsubscription,
    },
    assign,
};
use tokio::sync::{
    Mutex, OwnedMutexGuard,
    mpsc::{Receiver, Sender, channel},
};
use tracing::{instrument, trace, warn};

use crate::{Client, Result, client::WeakClient};

struct GuardedStoreAccess {
    _mutex: OwnedMutexGuard<()>,
    client: Client,
    is_outdated: Arc<AtomicBool>,
}

impl GuardedStoreAccess {
    /// Return the current list of catchup tokens, if any.
    ///
    /// It is guaranteed that if the list is set, then it's non-empty.
    async fn load_catchup_tokens(&self) -> Result<Option<Vec<ThreadSubscriptionCatchupToken>>> {
        let loaded = self
            .client
            .state_store()
            .get_kv_data(StateStoreDataKey::ThreadSubscriptionsCatchupTokens)
            .await?;

        match loaded {
            Some(data) => {
                if let Some(tokens) = data.into_thread_subscriptions_catchup_tokens() {
                    // If the tokens list is empty, automatically clean it up.
                    if tokens.is_empty() {
                        self.save_catchup_tokens(tokens).await?;
                        Ok(None)
                    } else {
                        Ok(Some(tokens))
                    }
                } else {
                    warn!(
                        "invalid data in thread subscriptions catchup tokens state store k/v entry"
                    );
                    Ok(None)
                }
            }

            None => Ok(None),
        }
    }

    /// Saves the tokens in the database.
    ///
    /// Returns whether the list of tokens is empty or not.
    #[instrument(skip_all, fields(num_tokens = tokens.len()))]
    async fn save_catchup_tokens(
        &self,
        tokens: Vec<ThreadSubscriptionCatchupToken>,
    ) -> Result<bool> {
        let store = self.client.state_store();
        let is_empty = if tokens.is_empty() {
            store.remove_kv_data(StateStoreDataKey::ThreadSubscriptionsCatchupTokens).await?;

            trace!("Marking thread subscriptions as not outdated \\o/");
            self.is_outdated.store(false, atomic::Ordering::SeqCst);
            true
        } else {
            store
                .set_kv_data(
                    StateStoreDataKey::ThreadSubscriptionsCatchupTokens,
                    StateStoreDataValue::ThreadSubscriptionsCatchupTokens(tokens),
                )
                .await?;

            trace!("Marking thread subscriptions as outdated.");
            self.is_outdated.store(true, atomic::Ordering::SeqCst);
            false
        };
        Ok(is_empty)
    }
}

pub struct ThreadSubscriptionCatchup {
    /// The task catching up thread subscriptions in the background.
    _task: OnceCell<AbortOnDrop<()>>,

    /// Whether the known list of thread subscriptions is outdated or not, i.e.
    /// all thread subscriptions have been caught up
    is_outdated: Arc<AtomicBool>,

    /// A weak reference to the parent [`Client`] instance.
    client: WeakClient,

    /// A sender to wake up the catchup task when new catchup tokens are
    /// available.
    ping_sender: Sender<()>,

    /// A mutex to ensure there's only one writer on the thread subscriptions
    /// catchup tokens at a time.
    uniq_mutex: Arc<Mutex<()>>,
}

impl ThreadSubscriptionCatchup {
    pub fn new(client: Client) -> Arc<Self> {
        let is_outdated = Arc::new(AtomicBool::new(true));

        let weak_client = WeakClient::from_client(&client);

        let (ping_sender, ping_receiver) = channel(8);

        let uniq_mutex = Arc::new(Mutex::new(()));

        let this = Arc::new(Self {
            _task: OnceCell::new(),
            is_outdated,
            client: weak_client,
            ping_sender,
            uniq_mutex,
        });

        // Create the task only if the client is configured to handle thread
        // subscriptions.
        if client.enabled_thread_subscriptions() {
            let _ = this._task.get_or_init(|| {
                AbortOnDrop::new(spawn(Self::thread_subscriptions_catchup_task(
                    this.clone(),
                    ping_receiver,
                )))
            });
        }

        this
    }

    /// Returns whether the known list of thread subscriptions is outdated or
    /// more thread subscriptions need to be caught up.
    pub(crate) fn is_outdated(&self) -> bool {
        self.is_outdated.load(atomic::Ordering::SeqCst)
    }

    /// Store the new subscriptions changes, received via the sync response or
    /// from the msc4308 companion endpoint.
    #[instrument(skip_all)]
    pub(crate) async fn sync_subscriptions(
        &self,
        subscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadSubscription>>,
        unsubscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadUnsubscription>>,
        token: Option<ThreadSubscriptionCatchupToken>,
    ) -> Result<()> {
        let Some(guard) = self.lock().await else {
            // Client is shutting down.
            return Ok(());
        };
        self.save_catchup_token(&guard, token).await?;
        self.store_subscriptions(&guard, subscribed, unsubscribed).await?;
        Ok(())
    }

    async fn store_subscriptions(
        &self,
        guard: &GuardedStoreAccess,
        subscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadSubscription>>,
        unsubscribed: BTreeMap<OwnedRoomId, BTreeMap<OwnedEventId, ThreadUnsubscription>>,
    ) -> Result<()> {
        if subscribed.is_empty() && unsubscribed.is_empty() {
            // Nothing to do.
            return Ok(());
        }

        trace!(
            "saving {} new subscriptions and {} unsubscriptions",
            subscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
            unsubscribed.values().map(|by_room| by_room.len()).sum::<usize>(),
        );

        // Take into account the new unsubscriptions.
        for (room_id, room_map) in unsubscribed {
            for (event_id, thread_sub) in room_map {
                guard
                    .client
                    .state_store()
                    .upsert_thread_subscription(
                        &room_id,
                        &event_id,
                        StoredThreadSubscription {
                            status: ThreadSubscriptionStatus::Unsubscribed,
                            bump_stamp: Some(thread_sub.bump_stamp.into()),
                        },
                    )
                    .await?;
            }
        }

        // Take into account the new subscriptions.
        for (room_id, room_map) in subscribed {
            for (event_id, thread_sub) in room_map {
                guard
                    .client
                    .state_store()
                    .upsert_thread_subscription(
                        &room_id,
                        &event_id,
                        StoredThreadSubscription {
                            status: ThreadSubscriptionStatus::Subscribed {
                                automatic: thread_sub.automatic,
                            },
                            bump_stamp: Some(thread_sub.bump_stamp.into()),
                        },
                    )
                    .await?;
            }
        }

        Ok(())
    }

    /// Internal helper to lock writes to the thread subscriptions catchup
    /// tokens list.
    async fn lock(&self) -> Option<GuardedStoreAccess> {
        let client = self.client.get()?;
        let mutex_guard = self.uniq_mutex.clone().lock_owned().await;
        Some(GuardedStoreAccess {
            _mutex: mutex_guard,
            client,
            is_outdated: self.is_outdated.clone(),
        })
    }

    /// Save a new catchup token (or absence thereof) in the state store.
    async fn save_catchup_token(
        &self,
        guard: &GuardedStoreAccess,
        token: Option<ThreadSubscriptionCatchupToken>,
    ) -> Result<()> {
        // Note: saving an empty tokens list will mark the thread subscriptions list as
        // not outdated.
        let mut tokens = guard.load_catchup_tokens().await?.unwrap_or_default();

        if let Some(token) = token {
            trace!(?token, "Saving catchup token");
            tokens.push(token);
        } else {
            trace!("No catchup token to save");
        }

        let is_token_list_empty = guard.save_catchup_tokens(tokens).await?;

        // Wake up the catchup task, in case it's waiting.
        if !is_token_list_empty {
            let _ = self.ping_sender.send(()).await;
        }

        Ok(())
    }

    /// The background task listening to new catchup tokens, and using them to
    /// catch up the thread subscriptions via the [MSC4308] companion
    /// endpoint.
    ///
    /// It will continue to process catchup tokens until there are none, and
    /// then wait for a new one to be available and inserted in the
    /// database.
    ///
    /// It always processes catch up tokens from the newest to the oldest, since
    /// newest tokens are more interesting than older ones. Indeed, they're
    /// more likely to include entries with higher bump-stamps, i.e. to include
    /// more recent thread subscriptions statuses for each thread, so more
    /// relevant information.
    ///
    /// [MSC4308]: https://github.com/matrix-org/matrix-spec-proposals/pull/4308
    #[instrument(skip_all)]
    async fn thread_subscriptions_catchup_task(this: Arc<Self>, mut ping_receiver: Receiver<()>) {
        loop {
            // Load the current catchup token.
            let Some(guard) = this.lock().await else {
                // Client is shutting down.
                return;
            };

            let store_tokens = match guard.load_catchup_tokens().await {
                Ok(tokens) => tokens,
                Err(err) => {
                    warn!("Failed to load thread subscriptions catchup tokens: {err}");
                    continue;
                }
            };

            let Some(mut tokens) = store_tokens else {
                // Release the mutex.
                drop(guard);

                // Wait for a wake up.
                trace!("Waiting for an explicit wake up to process future thread subscriptions");

                if let Some(()) = ping_receiver.recv().await {
                    trace!("Woke up!");
                    continue;
                }

                // Channel closed, the client is shutting down.
                break;
            };

            // We do have a tokens. Pop the last value, and use it to catch up!
            let last = tokens.pop().expect("must be set per `load_catchup_tokens` contract");

            // Release the mutex before running the network request.
            let client = guard.client.clone();
            drop(guard);

            // Start the actual catchup!
            let req = assign!(ruma::api::client::threads::get_thread_subscriptions_changes::unstable::Request::new(), {
                from: Some(last.from.clone()),
                to: last.to.clone(),
            });

            match client.send(req).await {
                Ok(resp) => {
                    let guard = this
                        .lock()
                        .await
                        .expect("a client instance is alive, so the locking should not fail");

                    if let Err(err) =
                        this.store_subscriptions(&guard, resp.subscribed, resp.unsubscribed).await
                    {
                        warn!("Failed to store caught up thread subscriptions: {err}");
                        continue;
                    }

                    // Refresh the tokens, as the list might have changed while we sent the
                    // request.
                    let mut tokens = match guard.load_catchup_tokens().await {
                        Ok(tokens) => tokens.unwrap_or_default(),
                        Err(err) => {
                            warn!("Failed to load thread subscriptions catchup tokens: {err}");
                            continue;
                        }
                    };

                    let Some(index) = tokens.iter().position(|t| *t == last) else {
                        warn!("Thread subscriptions catchup token disappeared while processing it");
                        continue;
                    };

                    if let Some(next_batch) = resp.end {
                        // If the response contained a next batch token, reuse the same catchup
                        // token entry, so the `to` value remains the same.
                        tokens[index] =
                            ThreadSubscriptionCatchupToken { from: next_batch, to: last.to };
                    } else {
                        // No next batch, we can remove this token from the list.
                        tokens.remove(index);
                    }

                    if let Err(err) = guard.save_catchup_tokens(tokens).await {
                        warn!("Failed to save updated thread subscriptions catchup tokens: {err}");
                    }
                }

                Err(err) => {
                    warn!("Failed to catch up thread subscriptions: {err}");
                }
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use std::ops::Not as _;

    use matrix_sdk_base::ThreadSubscriptionCatchupToken;
    use matrix_sdk_test::async_test;

    use crate::test_utils::client::MockClientBuilder;

    #[async_test]
    async fn test_load_save_catchup_tokens() {
        let client = MockClientBuilder::new(None).build().await;

        let tsc = client.thread_subscription_catchup();

        // At first there are no catchup tokens, and we are outdated.
        let guard = tsc.lock().await.unwrap();
        assert!(guard.load_catchup_tokens().await.unwrap().is_none());
        assert!(tsc.is_outdated());

        // When I save a token,
        let token =
            ThreadSubscriptionCatchupToken { from: "from".to_owned(), to: Some("to".to_owned()) };
        guard.save_catchup_tokens(vec![token.clone()]).await.unwrap();

        // Well, it is saved,
        let tokens = guard.load_catchup_tokens().await.unwrap();
        assert_eq!(tokens, Some(vec![token]));

        // And we are still outdated.
        assert!(tsc.is_outdated());

        // When I remove the token,
        guard.save_catchup_tokens(vec![]).await.unwrap();

        // It is gone,
        assert!(guard.load_catchup_tokens().await.unwrap().is_none());

        // And we are not outdated anymore!
        assert!(tsc.is_outdated().not());
    }
}