matrix_sdk/authentication/oauth/
cross_process.rs

1use std::sync::Arc;
2
3#[cfg(feature = "e2e-encryption")]
4use matrix_sdk_base::crypto::{
5    store::{LockableCryptoStore, Store},
6    CryptoStoreError,
7};
8use matrix_sdk_common::store_locks::{
9    CrossProcessStoreLock, CrossProcessStoreLockGuard, LockStoreError,
10};
11use sha2::{Digest as _, Sha256};
12use thiserror::Error;
13use tokio::sync::{Mutex, OwnedMutexGuard};
14use tracing::trace;
15
16use crate::SessionTokens;
17
18/// Key in the database for the custom value holding the current session tokens
19/// hash.
20const OIDC_SESSION_HASH_KEY: &str = "oidc_session_hash";
21
22/// Newtype to identify that a value is a session tokens' hash.
23#[derive(Clone, PartialEq, Eq)]
24struct SessionHash(Vec<u8>);
25
26impl SessionHash {
27    fn to_hex(&self) -> String {
28        const CHARS: &[char; 16] =
29            &['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'];
30        let mut res = String::with_capacity(2 * self.0.len() + 2);
31        if !self.0.is_empty() {
32            res.push('0');
33            res.push('x');
34        }
35        for &c in &self.0 {
36            // We don't really care about little vs big endianness, since we only need a
37            // stable format, so we pick one: little endian (print high bits
38            // first).
39            res.push(CHARS[(c >> 4) as usize]);
40            res.push(CHARS[(c & 0b1111) as usize]);
41        }
42        res
43    }
44}
45
46impl std::fmt::Debug for SessionHash {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_tuple("SessionHash").field(&self.to_hex()).finish()
49    }
50}
51
52/// Compute a hash uniquely identifying the OAuth 2.0 session tokens.
53fn compute_session_hash(tokens: &SessionTokens) -> SessionHash {
54    let mut hash = Sha256::new().chain_update(tokens.access_token.as_bytes());
55    if let Some(refresh_token) = &tokens.refresh_token {
56        hash = hash.chain_update(refresh_token.as_bytes());
57    }
58    SessionHash(hash.finalize().to_vec())
59}
60
61#[derive(Clone)]
62pub(super) struct CrossProcessRefreshManager {
63    store: Store,
64    store_lock: CrossProcessStoreLock<LockableCryptoStore>,
65    known_session_hash: Arc<Mutex<Option<SessionHash>>>,
66}
67
68impl CrossProcessRefreshManager {
69    /// Create a new `CrossProcessRefreshManager`.
70    pub fn new(store: Store, lock: CrossProcessStoreLock<LockableCryptoStore>) -> Self {
71        Self { store, store_lock: lock, known_session_hash: Arc::new(Mutex::new(None)) }
72    }
73
74    /// Wait for up to 60 seconds to get a cross-process store lock, then either
75    /// timeout (as an error) or return a lock guard.
76    ///
77    /// The guard also contains information useful to react upon another
78    /// background refresh having happened in the database already.
79    pub async fn spin_lock(
80        &self,
81    ) -> Result<CrossProcessRefreshLockGuard, CrossProcessRefreshLockError> {
82        // Acquire the intra-process mutex, to avoid multiple requests across threads in
83        // the current process.
84        trace!("Waiting for intra-process lock...");
85        let prev_hash = self.known_session_hash.clone().lock_owned().await;
86
87        // Acquire the cross-process mutex, to avoid multiple requests across different
88        // processus.
89        trace!("Waiting for inter-process lock...");
90        let store_guard = self.store_lock.spin_lock(Some(60000)).await?;
91
92        // Read the previous session hash in the database.
93        let current_db_session_bytes = self.store.get_custom_value(OIDC_SESSION_HASH_KEY).await?;
94
95        let db_hash = current_db_session_bytes.map(SessionHash);
96
97        let hash_mismatch = match (&db_hash, &*prev_hash) {
98            (None, _) => false,
99            (Some(_), None) => true,
100            (Some(db), Some(known)) => db != known,
101        };
102
103        trace!(
104            "Hash mismatch? {:?} (prev. known={:?}, db={:?})",
105            hash_mismatch,
106            *prev_hash,
107            db_hash
108        );
109
110        let guard = CrossProcessRefreshLockGuard {
111            hash_guard: prev_hash,
112            _store_guard: store_guard,
113            hash_mismatch,
114            db_hash,
115            store: self.store.clone(),
116        };
117
118        Ok(guard)
119    }
120
121    pub async fn restore_session(&self, tokens: &SessionTokens) {
122        let prev_tokens_hash = compute_session_hash(tokens);
123        *self.known_session_hash.lock().await = Some(prev_tokens_hash);
124    }
125
126    pub async fn on_logout(&self) -> Result<(), CrossProcessRefreshLockError> {
127        self.store
128            .remove_custom_value(OIDC_SESSION_HASH_KEY)
129            .await
130            .map_err(CrossProcessRefreshLockError::StoreError)?;
131        *self.known_session_hash.lock().await = None;
132        Ok(())
133    }
134}
135
136pub(super) struct CrossProcessRefreshLockGuard {
137    /// The hash for the latest session, either the one we knew, or the latest
138    /// one read from the database, if it was more up to date.
139    hash_guard: OwnedMutexGuard<Option<SessionHash>>,
140
141    /// Cross-process lock being hold.
142    _store_guard: CrossProcessStoreLockGuard,
143
144    /// Reference to the underlying store, for storing the hash of the latest
145    /// known session (as a custom value).
146    store: Store,
147
148    /// Do the in-memory hash and database hash mismatch?
149    ///
150    /// If so, this indicates that another process may have refreshed the token
151    /// in the background.
152    ///
153    /// We don't consider it a mismatch if there was no previous value in the
154    /// database. We do consider it a mismatch if there was no in-memory
155    /// value known, but one was known in the database.
156    pub hash_mismatch: bool,
157
158    /// Session hash previously stored in the DB.
159    ///
160    /// Used for debugging and testing purposes.
161    db_hash: Option<SessionHash>,
162}
163
164impl CrossProcessRefreshLockGuard {
165    /// Updates the `SessionTokens` hash in-memory only.
166    fn save_in_memory(&mut self, hash: SessionHash) {
167        *self.hash_guard = Some(hash);
168    }
169
170    /// Updates the `SessionTokens` hash in the database only.
171    async fn save_in_database(
172        &self,
173        hash: &SessionHash,
174    ) -> Result<(), CrossProcessRefreshLockError> {
175        self.store.set_custom_value(OIDC_SESSION_HASH_KEY, hash.0.clone()).await?;
176        Ok(())
177    }
178
179    /// Updates the `SessionTokens` hash in both memory and database.
180    ///
181    /// Must be called after a successful refresh.
182    pub async fn save_in_memory_and_db(
183        &mut self,
184        tokens: &SessionTokens,
185    ) -> Result<(), CrossProcessRefreshLockError> {
186        let hash = compute_session_hash(tokens);
187        self.save_in_database(&hash).await?;
188        self.save_in_memory(hash);
189        Ok(())
190    }
191
192    /// Handle a mismatch by making sure values in the database and memory match
193    /// tokens we trust.
194    pub async fn handle_mismatch(
195        &mut self,
196        trusted_tokens: &SessionTokens,
197    ) -> Result<(), CrossProcessRefreshLockError> {
198        let new_hash = compute_session_hash(trusted_tokens);
199        trace!("Trusted OAuth 2.0 tokens have hash {new_hash:?}; db had {:?}", self.db_hash);
200
201        if let Some(db_hash) = &self.db_hash {
202            if new_hash != *db_hash {
203                // That should never happen, unless we got into an impossible situation!
204                // In this case, we assume the value returned by the callback is always
205                // correct, so override that in the database too.
206                tracing::error!("error: DB and trusted disagree. Overriding in DB.");
207                self.save_in_database(&new_hash).await?;
208            }
209        }
210
211        self.save_in_memory(new_hash);
212        Ok(())
213    }
214}
215
216/// An error that happened when interacting with the cross-process store lock
217/// during a token refresh.
218#[derive(Debug, Error)]
219pub enum CrossProcessRefreshLockError {
220    /// Underlying error caused by the store.
221    #[error(transparent)]
222    StoreError(#[from] CryptoStoreError),
223
224    /// The locking itself failed.
225    #[error(transparent)]
226    LockError(#[from] LockStoreError),
227
228    /// The previous hash isn't valid.
229    #[error("the previous stored hash isn't a valid integer")]
230    InvalidPreviousHash,
231
232    /// The lock hasn't been set up.
233    #[error("the cross-process lock hasn't been set up with `enable_cross_process_refresh_lock")]
234    MissingLock,
235
236    /// Cross-process lock was set, but without session callbacks.
237    #[error("reload session callback must be set with Client::set_session_callbacks() for the cross-process lock to work")]
238    MissingReloadSession,
239
240    /// The store has been created twice.
241    #[error(
242        "the cross-process lock has been set up twice with `enable_cross_process_refresh_lock`"
243    )]
244    DuplicatedLock,
245}
246
247#[cfg(all(test, feature = "e2e-encryption", feature = "sqlite", not(target_arch = "wasm32")))]
248mod tests {
249
250    use anyhow::Context as _;
251    use futures_util::future::join_all;
252    use matrix_sdk_base::{store::RoomLoadSettings, SessionMeta};
253    use matrix_sdk_test::async_test;
254    use ruma::{owned_device_id, owned_user_id};
255
256    use super::compute_session_hash;
257    use crate::{
258        authentication::oauth::cross_process::SessionHash,
259        test_utils::{
260            client::{
261                mock_prev_session_tokens_with_refresh, mock_session_tokens_with_refresh,
262                oauth::mock_session, MockClientBuilder,
263            },
264            mocks::MatrixMockServer,
265        },
266        Error,
267    };
268
269    #[async_test]
270    async fn test_restore_session_lock() -> Result<(), Error> {
271        // Create a client that will use sqlite databases.
272
273        let tmp_dir = tempfile::tempdir()?;
274        let client = MockClientBuilder::new("https://example.org".to_owned())
275            .sqlite_store(&tmp_dir)
276            .unlogged()
277            .build()
278            .await;
279
280        let tokens = mock_session_tokens_with_refresh();
281
282        client.oauth().enable_cross_process_refresh_lock("test".to_owned()).await?;
283
284        client.set_session_callbacks(
285            Box::new({
286                // This is only called because of extra checks in the code.
287                let tokens = tokens.clone();
288                move |_| Ok(tokens.clone())
289            }),
290            Box::new(|_| panic!("save_session_callback shouldn't be called here")),
291        )?;
292
293        let session_hash = compute_session_hash(&tokens);
294        client
295            .oauth()
296            .restore_session(mock_session(tokens.clone()), RoomLoadSettings::default())
297            .await?;
298
299        assert_eq!(client.session_tokens().unwrap(), tokens);
300
301        let oauth = client.oauth();
302        let xp_manager = oauth.ctx().cross_process_token_refresh_manager.get().unwrap();
303
304        {
305            let known_session = xp_manager.known_session_hash.lock().await;
306            assert_eq!(known_session.as_ref().unwrap(), &session_hash);
307        }
308
309        {
310            let lock = xp_manager.spin_lock().await.unwrap();
311            assert!(!lock.hash_mismatch);
312            assert_eq!(lock.db_hash.unwrap(), session_hash);
313        }
314
315        Ok(())
316    }
317
318    #[async_test]
319    async fn test_finish_login() -> anyhow::Result<()> {
320        let server = MatrixMockServer::new().await;
321        server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
322
323        let tmp_dir = tempfile::tempdir()?;
324        let client =
325            server.client_builder().sqlite_store(&tmp_dir).registered_with_oauth().build().await;
326        let oauth = client.oauth();
327
328        // Enable cross-process lock.
329        oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
330
331        // Simulate we've done finalize_authorization / restore_session before.
332        let session_tokens = mock_session_tokens_with_refresh();
333        client.auth_ctx().set_session_tokens(session_tokens.clone());
334
335        // Now, finishing logging will get the user ID.
336        oauth.load_session(owned_device_id!("D3V1C31D")).await?;
337
338        let session_meta = client.session_meta().context("should have session meta now")?;
339        assert_eq!(
340            *session_meta,
341            SessionMeta {
342                user_id: owned_user_id!("@joe:example.org"),
343                device_id: owned_device_id!("D3V1C31D")
344            }
345        );
346
347        {
348            // The cross process lock has been correctly updated, and the next attempt to
349            // take it won't result in a mismatch.
350            let xp_manager =
351                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
352            let guard = xp_manager.spin_lock().await?;
353            let actual_hash = compute_session_hash(&session_tokens);
354            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
355            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
356            assert!(!guard.hash_mismatch);
357        }
358
359        Ok(())
360    }
361
362    #[async_test]
363    async fn test_refresh_access_token_twice() -> anyhow::Result<()> {
364        // This tests that refresh token works, and that it doesn't cause multiple token
365        // refreshes whenever one spawns two refreshes around the same time.
366
367        let server = MatrixMockServer::new().await;
368
369        let oauth_server = server.oauth();
370        oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
371        oauth_server.mock_token().ok().expect(1).named("token").mount().await;
372
373        let tmp_dir = tempfile::tempdir()?;
374        let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
375        let oauth = client.oauth();
376
377        let next_tokens = mock_session_tokens_with_refresh();
378
379        // Enable cross-process lock.
380        oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
381
382        // Restore the session.
383        oauth
384            .restore_session(
385                mock_session(mock_prev_session_tokens_with_refresh()),
386                RoomLoadSettings::default(),
387            )
388            .await?;
389
390        // Immediately try to refresh the access token twice in parallel.
391        for result in join_all([oauth.refresh_access_token(), oauth.refresh_access_token()]).await {
392            result?;
393        }
394
395        {
396            // The cross process lock has been correctly updated, and the next attempt to
397            // take it won't result in a mismatch.
398            let xp_manager =
399                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
400            let guard = xp_manager.spin_lock().await?;
401            let actual_hash = compute_session_hash(&next_tokens);
402            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
403            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
404            assert!(!guard.hash_mismatch);
405        }
406
407        Ok(())
408    }
409
410    #[async_test]
411    async fn test_cross_process_concurrent_refresh() -> anyhow::Result<()> {
412        let server = MatrixMockServer::new().await;
413
414        let oauth_server = server.oauth();
415        oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
416        oauth_server.mock_token().ok().expect(1).named("token").mount().await;
417
418        let prev_tokens = mock_prev_session_tokens_with_refresh();
419        let next_tokens = mock_session_tokens_with_refresh();
420
421        // Create the first client.
422        let tmp_dir = tempfile::tempdir()?;
423        let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
424
425        let oauth = client.oauth();
426        oauth.enable_cross_process_refresh_lock("client1".to_owned()).await?;
427
428        oauth
429            .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
430            .await?;
431
432        // Create a second client, without restoring it, to test that a token update
433        // before restoration doesn't cause new issues.
434        let unrestored_client =
435            server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
436        let unrestored_oauth = unrestored_client.oauth();
437        unrestored_oauth.enable_cross_process_refresh_lock("unrestored_client".to_owned()).await?;
438
439        {
440            // Create a third client that will run a refresh while the others two are doing
441            // nothing.
442            let client3 = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
443
444            let oauth3 = client3.oauth();
445            oauth3.enable_cross_process_refresh_lock("client3".to_owned()).await?;
446            oauth3
447                .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
448                .await?;
449
450            // Run a refresh in the second client; this will invalidate the tokens from the
451            // first token.
452            oauth3.refresh_access_token().await?;
453
454            assert_eq!(client3.session_tokens(), Some(next_tokens.clone()));
455
456            // Reading from the cross-process lock for the second client only shows the new
457            // tokens.
458            let xp_manager =
459                oauth3.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
460            let guard = xp_manager.spin_lock().await?;
461            let actual_hash = compute_session_hash(&next_tokens);
462            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
463            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
464            assert!(!guard.hash_mismatch);
465        }
466
467        {
468            // Restoring the client that was not restored yet will work Just Fine.
469            let oauth = unrestored_oauth;
470
471            unrestored_client.set_session_callbacks(
472                Box::new({
473                    // This is only called because of extra checks in the code.
474                    let tokens = next_tokens.clone();
475                    move |_| Ok(tokens.clone())
476                }),
477                Box::new(|_| panic!("save_session_callback shouldn't be called here")),
478            )?;
479
480            oauth
481                .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
482                .await?;
483
484            // And this client is now aware of the latest tokens.
485            let xp_manager =
486                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
487            let guard = xp_manager.spin_lock().await?;
488            let next_hash = compute_session_hash(&next_tokens);
489            assert_eq!(guard.db_hash.as_ref(), Some(&next_hash));
490            assert_eq!(guard.hash_guard.as_ref(), Some(&next_hash));
491            assert!(!guard.hash_mismatch);
492
493            drop(oauth);
494            drop(unrestored_client);
495        }
496
497        {
498            // The cross process lock has been correctly updated, and the next attempt to
499            // take it will result in a mismatch.
500            let xp_manager =
501                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
502            let guard = xp_manager.spin_lock().await?;
503            let previous_hash = compute_session_hash(&prev_tokens);
504            let next_hash = compute_session_hash(&next_tokens);
505            assert_eq!(guard.db_hash, Some(next_hash));
506            assert_eq!(guard.hash_guard.as_ref(), Some(&previous_hash));
507            assert!(guard.hash_mismatch);
508        }
509
510        client.set_session_callbacks(
511            Box::new({
512                // This is only called because of extra checks in the code.
513                let tokens = next_tokens.clone();
514                move |_| Ok(tokens.clone())
515            }),
516            Box::new(|_| panic!("save_session_callback shouldn't be called here")),
517        )?;
518
519        oauth.refresh_access_token().await?;
520
521        {
522            // The next attempt to take the lock isn't a mismatch.
523            let xp_manager =
524                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
525            let guard = xp_manager.spin_lock().await?;
526            let actual_hash = compute_session_hash(&next_tokens);
527            assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
528            assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
529            assert!(!guard.hash_mismatch);
530        }
531
532        Ok(())
533    }
534
535    #[async_test]
536    async fn test_logout() -> anyhow::Result<()> {
537        let server = MatrixMockServer::new().await;
538
539        let oauth_server = server.oauth();
540        oauth_server
541            .mock_server_metadata()
542            .ok_https()
543            .expect(1..)
544            .named("server_metadata")
545            .mount()
546            .await;
547        oauth_server.mock_revocation().ok().expect(1).named("revocation").mount().await;
548
549        let tmp_dir = tempfile::tempdir()?;
550        let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
551        let oauth = client.oauth().insecure_rewrite_https_to_http();
552
553        // Enable cross-process lock.
554        oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
555
556        // Restore the session.
557        let tokens = mock_session_tokens_with_refresh();
558        oauth.restore_session(mock_session(tokens.clone()), RoomLoadSettings::default()).await?;
559
560        oauth.logout().await.unwrap();
561
562        {
563            // The cross process lock has been correctly updated, and all the hashes are
564            // empty after a logout.
565            let xp_manager =
566                oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
567            let guard = xp_manager.spin_lock().await?;
568            assert!(guard.db_hash.is_none());
569            assert!(guard.hash_guard.is_none());
570            assert!(!guard.hash_mismatch);
571        }
572
573        Ok(())
574    }
575
576    #[test]
577    fn test_session_hash_to_hex() {
578        let hash = SessionHash(vec![]);
579        assert_eq!(hash.to_hex(), "");
580
581        let hash = SessionHash(vec![0x13, 0x37, 0x42, 0xde, 0xad, 0xca, 0xfe]);
582        assert_eq!(hash.to_hex(), "0x133742deadcafe");
583    }
584}