matrix_sdk/authentication/oauth/
cross_process.rs

1use std::sync::Arc;
2
3#[cfg(feature = "e2e-encryption")]
4use matrix_sdk_base::crypto::{
5    store::{LockableCryptoStore, Store},
6    CryptoStoreError,
7};
8use matrix_sdk_common::store_locks::{
9    CrossProcessStoreLock, CrossProcessStoreLockGuard, LockStoreError,
10};
11use sha2::{Digest as _, Sha256};
12use thiserror::Error;
13use tokio::sync::{Mutex, OwnedMutexGuard};
14use tracing::trace;
15
16use crate::SessionTokens;
17
18/// Key in the database for the custom value holding the current session tokens
19/// hash.
20const OIDC_SESSION_HASH_KEY: &str = "oidc_session_hash";
21
22/// Newtype to identify that a value is a session tokens' hash.
23#[derive(Clone, PartialEq, Eq)]
24struct SessionHash(Vec<u8>);
25
26impl SessionHash {
27    fn to_hex(&self) -> String {
28        const CHARS: &[char; 16] =
29            &['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'];
30        let mut res = String::with_capacity(2 * self.0.len() + 2);
31        if !self.0.is_empty() {
32            res.push('0');
33            res.push('x');
34        }
35        for &c in &self.0 {
36            // We don't really care about little vs big endianness, since we only need a
37            // stable format, so we pick one: little endian (print high bits
38            // first).
39            res.push(CHARS[(c >> 4) as usize]);
40            res.push(CHARS[(c & 0b1111) as usize]);
41        }
42        res
43    }
44}
45
46impl std::fmt::Debug for SessionHash {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_tuple("SessionHash").field(&self.to_hex()).finish()
49    }
50}
51
52/// Compute a hash uniquely identifying the OAuth 2.0 session tokens.
53fn compute_session_hash(tokens: &SessionTokens) -> SessionHash {
54    let mut hash = Sha256::new().chain_update(tokens.access_token.as_bytes());
55    if let Some(refresh_token) = &tokens.refresh_token {
56        hash = hash.chain_update(refresh_token.as_bytes());
57    }
58    SessionHash(hash.finalize().to_vec())
59}
60
61#[derive(Clone)]
62pub(super) struct CrossProcessRefreshManager {
63    store: Store,
64    store_lock: CrossProcessStoreLock<LockableCryptoStore>,
65    known_session_hash: Arc<Mutex<Option<SessionHash>>>,
66}
67
68impl CrossProcessRefreshManager {
69    /// Create a new `CrossProcessRefreshManager`.
70    pub fn new(store: Store, lock: CrossProcessStoreLock<LockableCryptoStore>) -> Self {
71        Self { store, store_lock: lock, known_session_hash: Arc::new(Mutex::new(None)) }
72    }
73
74    /// Wait for up to 60 seconds to get a cross-process store lock, then either
75    /// timeout (as an error) or return a lock guard.
76    ///
77    /// The guard also contains information useful to react upon another
78    /// background refresh having happened in the database already.
79    pub async fn spin_lock(
80        &self,
81    ) -> Result<CrossProcessRefreshLockGuard, CrossProcessRefreshLockError> {
82        // Acquire the intra-process mutex, to avoid multiple requests across threads in
83        // the current process.
84        trace!("Waiting for intra-process lock...");
85        let prev_hash = self.known_session_hash.clone().lock_owned().await;
86
87        // Acquire the cross-process mutex, to avoid multiple requests across different
88        // processus.
89        trace!("Waiting for inter-process lock...");
90        let store_guard = self.store_lock.spin_lock(Some(60000)).await?;
91
92        // Read the previous session hash in the database.
93        let current_db_session_bytes = self.store.get_custom_value(OIDC_SESSION_HASH_KEY).await?;
94
95        let db_hash = current_db_session_bytes.map(SessionHash);
96
97        let hash_mismatch = match (&db_hash, &*prev_hash) {
98            (None, _) => false,
99            (Some(_), None) => true,
100            (Some(db), Some(known)) => db != known,
101        };
102
103        trace!(hash_mismatch, ?prev_hash, ?db_hash);
104
105        let guard = CrossProcessRefreshLockGuard {
106            hash_guard: prev_hash,
107            _store_guard: store_guard,
108            hash_mismatch,
109            db_hash,
110            store: self.store.clone(),
111        };
112
113        Ok(guard)
114    }
115
116    pub async fn restore_session(&self, tokens: &SessionTokens) {
117        let prev_tokens_hash = compute_session_hash(tokens);
118        *self.known_session_hash.lock().await = Some(prev_tokens_hash);
119    }
120
121    pub async fn on_logout(&self) -> Result<(), CrossProcessRefreshLockError> {
122        self.store
123            .remove_custom_value(OIDC_SESSION_HASH_KEY)
124            .await
125            .map_err(CrossProcessRefreshLockError::StoreError)?;
126        *self.known_session_hash.lock().await = None;
127        Ok(())
128    }
129}
130
131pub(super) struct CrossProcessRefreshLockGuard {
132    /// The hash for the latest session, either the one we knew, or the latest
133    /// one read from the database, if it was more up to date.
134    hash_guard: OwnedMutexGuard<Option<SessionHash>>,
135
136    /// Cross-process lock being hold.
137    _store_guard: CrossProcessStoreLockGuard,
138
139    /// Reference to the underlying store, for storing the hash of the latest
140    /// known session (as a custom value).
141    store: Store,
142
143    /// Do the in-memory hash and database hash mismatch?
144    ///
145    /// If so, this indicates that another process may have refreshed the token
146    /// in the background.
147    ///
148    /// We don't consider it a mismatch if there was no previous value in the
149    /// database. We do consider it a mismatch if there was no in-memory
150    /// value known, but one was known in the database.
151    pub hash_mismatch: bool,
152
153    /// Session hash previously stored in the DB.
154    ///
155    /// Used for debugging and testing purposes.
156    db_hash: Option<SessionHash>,
157}
158
159impl CrossProcessRefreshLockGuard {
160    /// Updates the `SessionTokens` hash in-memory only.
161    fn save_in_memory(&mut self, hash: SessionHash) {
162        *self.hash_guard = Some(hash);
163    }
164
165    /// Updates the `SessionTokens` hash in the database only.
166    async fn save_in_database(
167        &self,
168        hash: &SessionHash,
169    ) -> Result<(), CrossProcessRefreshLockError> {
170        self.store.set_custom_value(OIDC_SESSION_HASH_KEY, hash.0.clone()).await?;
171        Ok(())
172    }
173
174    /// Updates the `SessionTokens` hash in both memory and database.
175    ///
176    /// Must be called after a successful refresh.
177    pub async fn save_in_memory_and_db(
178        &mut self,
179        tokens: &SessionTokens,
180    ) -> Result<(), CrossProcessRefreshLockError> {
181        let hash = compute_session_hash(tokens);
182        self.save_in_database(&hash).await?;
183        self.save_in_memory(hash);
184        Ok(())
185    }
186
187    /// Handle a mismatch by making sure values in the database and memory match
188    /// tokens we trust.
189    pub async fn handle_mismatch(
190        &mut self,
191        trusted_tokens: &SessionTokens,
192    ) -> Result<(), CrossProcessRefreshLockError> {
193        let new_hash = compute_session_hash(trusted_tokens);
194        trace!("Trusted OAuth 2.0 tokens have hash {new_hash:?}; db had {:?}", self.db_hash);
195
196        if let Some(db_hash) = &self.db_hash {
197            if new_hash != *db_hash {
198                // That should never happen, unless we got into an impossible situation!
199                // In this case, we assume the value returned by the callback is always
200                // correct, so override that in the database too.
201                tracing::error!("error: DB and trusted disagree. Overriding in DB.");
202                self.save_in_database(&new_hash).await?;
203            }
204        }
205
206        self.save_in_memory(new_hash);
207        Ok(())
208    }
209}
210
211/// An error that happened when interacting with the cross-process store lock
212/// during a token refresh.
213#[derive(Debug, Error)]
214pub enum CrossProcessRefreshLockError {
215    /// Underlying error caused by the store.
216    #[error(transparent)]
217    StoreError(#[from] CryptoStoreError),
218
219    /// The locking itself failed.
220    #[error(transparent)]
221    LockError(#[from] LockStoreError),
222
223    /// The previous hash isn't valid.
224    #[error("the previous stored hash isn't a valid integer")]
225    InvalidPreviousHash,
226
227    /// The lock hasn't been set up.
228    #[error("the cross-process lock hasn't been set up with `enable_cross_process_refresh_lock")]
229    MissingLock,
230
231    /// Cross-process lock was set, but without session callbacks.
232    #[error("reload session callback must be set with Client::set_session_callbacks() for the cross-process lock to work")]
233    MissingReloadSession,
234
235    /// The store has been created twice.
236    #[error(
237        "the cross-process lock has been set up twice with `enable_cross_process_refresh_lock`"
238    )]
239    DuplicatedLock,
240}
241
242#[cfg(all(test, feature = "e2e-encryption", feature = "sqlite", not(target_family = "wasm")))]
243mod tests {
244
245    use anyhow::Context as _;
246    use futures_util::future::join_all;
247    use matrix_sdk_base::{store::RoomLoadSettings, SessionMeta};
248    use matrix_sdk_test::async_test;
249    use ruma::{owned_device_id, owned_user_id};
250
251    use super::compute_session_hash;
252    use crate::{
253        authentication::oauth::cross_process::SessionHash,
254        test_utils::{
255            client::{
256                mock_prev_session_tokens_with_refresh, mock_session_tokens_with_refresh,
257                oauth::mock_session, MockClientBuilder,
258            },
259            mocks::MatrixMockServer,
260        },
261        Error,
262    };
263
264    #[async_test]
265    async fn test_restore_session_lock() -> Result<(), Error> {
266        // Create a client that will use sqlite databases.
267
268        let tmp_dir = tempfile::tempdir()?;
269        let client = MockClientBuilder::new("https://example.org".to_owned())
270            .sqlite_store(&tmp_dir)
271            .unlogged()
272            .build()
273            .await;
274
275        let tokens = mock_session_tokens_with_refresh();
276
277        client.oauth().enable_cross_process_refresh_lock("test".to_owned()).await?;
278
279        client.set_session_callbacks(
280            Box::new({
281                // This is only called because of extra checks in the code.
282                let tokens = tokens.clone();
283                move |_| Ok(tokens.clone())
284            }),
285            Box::new(|_| panic!("save_session_callback shouldn't be called here")),
286        )?;
287
288        let session_hash = compute_session_hash(&tokens);
289        client
290            .oauth()
291            .restore_session(mock_session(tokens.clone()), RoomLoadSettings::default())
292            .await?;
293
294        assert_eq!(client.session_tokens().unwrap(), tokens);
295
296        let oauth = client.oauth();
297        let xp_manager = oauth.ctx().cross_process_token_refresh_manager.get().unwrap();
298
299        {
300            let known_session = xp_manager.known_session_hash.lock().await;
301            assert_eq!(known_session.as_ref().unwrap(), &session_hash);
302        }
303
304        {
305            let lock = xp_manager.spin_lock().await.unwrap();
306            assert!(!lock.hash_mismatch);
307            assert_eq!(lock.db_hash.unwrap(), session_hash);
308        }
309
310        Ok(())
311    }
312
313    #[async_test]
314    async fn test_finish_login() -> anyhow::Result<()> {
315        let server = MatrixMockServer::new().await;
316        server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
317
318        let tmp_dir = tempfile::tempdir()?;
319        let client =
320            server.client_builder().sqlite_store(&tmp_dir).registered_with_oauth().build().await;
321        let oauth = client.oauth();
322
323        // Enable cross-process lock.
324        oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
325
326        // Simulate we've done finalize_authorization / restore_session before.
327        let session_tokens = mock_session_tokens_with_refresh();
328        client.auth_ctx().set_session_tokens(session_tokens.clone());
329
330        // Now, finishing logging will get the user ID.
331        oauth.load_session(owned_device_id!("D3V1C31D")).await?;
332
333        let session_meta = client.session_meta().context("should have session meta now")?;
334        assert_eq!(
335            *session_meta,
336            SessionMeta {
337                user_id: owned_user_id!("@joe:example.org"),
338                device_id: owned_device_id!("D3V1C31D")
339            }
340        );
341
342        {
343            // The cross process lock has been correctly updated, and the next attempt to
344            // take it won't result in a mismatch.
345            let xp_manager =
346                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
347            let guard = xp_manager.spin_lock().await?;
348            let actual_hash = compute_session_hash(&session_tokens);
349            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
350            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
351            assert!(!guard.hash_mismatch);
352        }
353
354        Ok(())
355    }
356
357    #[async_test]
358    async fn test_refresh_access_token_twice() -> anyhow::Result<()> {
359        // This tests that refresh token works, and that it doesn't cause multiple token
360        // refreshes whenever one spawns two refreshes around the same time.
361
362        let server = MatrixMockServer::new().await;
363
364        let oauth_server = server.oauth();
365        oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
366        oauth_server.mock_token().ok().expect(1).named("token").mount().await;
367
368        let tmp_dir = tempfile::tempdir()?;
369        let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
370        let oauth = client.oauth();
371
372        let next_tokens = mock_session_tokens_with_refresh();
373
374        // Enable cross-process lock.
375        oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
376
377        // Restore the session.
378        oauth
379            .restore_session(
380                mock_session(mock_prev_session_tokens_with_refresh()),
381                RoomLoadSettings::default(),
382            )
383            .await?;
384
385        // Immediately try to refresh the access token twice in parallel.
386        for result in join_all([oauth.refresh_access_token(), oauth.refresh_access_token()]).await {
387            result?;
388        }
389
390        {
391            // The cross process lock has been correctly updated, and the next attempt to
392            // take it won't result in a mismatch.
393            let xp_manager =
394                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
395            let guard = xp_manager.spin_lock().await?;
396            let actual_hash = compute_session_hash(&next_tokens);
397            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
398            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
399            assert!(!guard.hash_mismatch);
400        }
401
402        Ok(())
403    }
404
405    #[async_test]
406    async fn test_cross_process_concurrent_refresh() -> anyhow::Result<()> {
407        let server = MatrixMockServer::new().await;
408
409        let oauth_server = server.oauth();
410        oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
411        oauth_server.mock_token().ok().expect(1).named("token").mount().await;
412
413        let prev_tokens = mock_prev_session_tokens_with_refresh();
414        let next_tokens = mock_session_tokens_with_refresh();
415
416        // Create the first client.
417        let tmp_dir = tempfile::tempdir()?;
418        let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
419
420        let oauth = client.oauth();
421        oauth.enable_cross_process_refresh_lock("client1".to_owned()).await?;
422
423        oauth
424            .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
425            .await?;
426
427        // Create a second client, without restoring it, to test that a token update
428        // before restoration doesn't cause new issues.
429        let unrestored_client =
430            server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
431        let unrestored_oauth = unrestored_client.oauth();
432        unrestored_oauth.enable_cross_process_refresh_lock("unrestored_client".to_owned()).await?;
433
434        {
435            // Create a third client that will run a refresh while the others two are doing
436            // nothing.
437            let client3 = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
438
439            let oauth3 = client3.oauth();
440            oauth3.enable_cross_process_refresh_lock("client3".to_owned()).await?;
441            oauth3
442                .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
443                .await?;
444
445            // Run a refresh in the second client; this will invalidate the tokens from the
446            // first token.
447            oauth3.refresh_access_token().await?;
448
449            assert_eq!(client3.session_tokens(), Some(next_tokens.clone()));
450
451            // Reading from the cross-process lock for the second client only shows the new
452            // tokens.
453            let xp_manager =
454                oauth3.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
455            let guard = xp_manager.spin_lock().await?;
456            let actual_hash = compute_session_hash(&next_tokens);
457            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
458            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
459            assert!(!guard.hash_mismatch);
460        }
461
462        {
463            // Restoring the client that was not restored yet will work Just Fine.
464            let oauth = unrestored_oauth;
465
466            unrestored_client.set_session_callbacks(
467                Box::new({
468                    // This is only called because of extra checks in the code.
469                    let tokens = next_tokens.clone();
470                    move |_| Ok(tokens.clone())
471                }),
472                Box::new(|_| panic!("save_session_callback shouldn't be called here")),
473            )?;
474
475            oauth
476                .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
477                .await?;
478
479            // And this client is now aware of the latest tokens.
480            let xp_manager =
481                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
482            let guard = xp_manager.spin_lock().await?;
483            let next_hash = compute_session_hash(&next_tokens);
484            assert_eq!(guard.db_hash.as_ref(), Some(&next_hash));
485            assert_eq!(guard.hash_guard.as_ref(), Some(&next_hash));
486            assert!(!guard.hash_mismatch);
487
488            drop(oauth);
489            drop(unrestored_client);
490        }
491
492        {
493            // The cross process lock has been correctly updated, and the next attempt to
494            // take it will result in a mismatch.
495            let xp_manager =
496                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
497            let guard = xp_manager.spin_lock().await?;
498            let previous_hash = compute_session_hash(&prev_tokens);
499            let next_hash = compute_session_hash(&next_tokens);
500            assert_eq!(guard.db_hash, Some(next_hash));
501            assert_eq!(guard.hash_guard.as_ref(), Some(&previous_hash));
502            assert!(guard.hash_mismatch);
503        }
504
505        client.set_session_callbacks(
506            Box::new({
507                // This is only called because of extra checks in the code.
508                let tokens = next_tokens.clone();
509                move |_| Ok(tokens.clone())
510            }),
511            Box::new(|_| panic!("save_session_callback shouldn't be called here")),
512        )?;
513
514        oauth.refresh_access_token().await?;
515
516        {
517            // The next attempt to take the lock isn't a mismatch.
518            let xp_manager =
519                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
520            let guard = xp_manager.spin_lock().await?;
521            let actual_hash = compute_session_hash(&next_tokens);
522            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
523            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
524            assert!(!guard.hash_mismatch);
525        }
526
527        Ok(())
528    }
529
530    #[async_test]
531    async fn test_logout() -> anyhow::Result<()> {
532        let server = MatrixMockServer::new().await;
533
534        let oauth_server = server.oauth();
535        oauth_server
536            .mock_server_metadata()
537            .ok_https()
538            .expect(1..)
539            .named("server_metadata")
540            .mount()
541            .await;
542        oauth_server.mock_revocation().ok().expect(1).named("revocation").mount().await;
543
544        let tmp_dir = tempfile::tempdir()?;
545        let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
546        let oauth = client.oauth().insecure_rewrite_https_to_http();
547
548        // Enable cross-process lock.
549        oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
550
551        // Restore the session.
552        let tokens = mock_session_tokens_with_refresh();
553        oauth.restore_session(mock_session(tokens.clone()), RoomLoadSettings::default()).await?;
554
555        oauth.logout().await.unwrap();
556
557        {
558            // The cross process lock has been correctly updated, and all the hashes are
559            // empty after a logout.
560            let xp_manager =
561                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
562            let guard = xp_manager.spin_lock().await?;
563            assert!(guard.db_hash.is_none());
564            assert!(guard.hash_guard.is_none());
565            assert!(!guard.hash_mismatch);
566        }
567
568        Ok(())
569    }
570
571    #[test]
572    fn test_session_hash_to_hex() {
573        let hash = SessionHash(vec![]);
574        assert_eq!(hash.to_hex(), "");
575
576        let hash = SessionHash(vec![0x13, 0x37, 0x42, 0xde, 0xad, 0xca, 0xfe]);
577        assert_eq!(hash.to_hex(), "0x133742deadcafe");
578    }
579}