matrix_sdk/authentication/oauth/
cross_process.rs1use std::sync::Arc;
2
3#[cfg(feature = "e2e-encryption")]
4use matrix_sdk_base::crypto::{
5 store::{LockableCryptoStore, Store},
6 CryptoStoreError,
7};
8use matrix_sdk_common::store_locks::{
9 CrossProcessStoreLock, CrossProcessStoreLockGuard, LockStoreError,
10};
11use sha2::{Digest as _, Sha256};
12use thiserror::Error;
13use tokio::sync::{Mutex, OwnedMutexGuard};
14use tracing::trace;
15
16use crate::SessionTokens;
17
18const OIDC_SESSION_HASH_KEY: &str = "oidc_session_hash";
21
22#[derive(Clone, PartialEq, Eq)]
24struct SessionHash(Vec<u8>);
25
26impl SessionHash {
27 fn to_hex(&self) -> String {
28 const CHARS: &[char; 16] =
29 &['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'];
30 let mut res = String::with_capacity(2 * self.0.len() + 2);
31 if !self.0.is_empty() {
32 res.push('0');
33 res.push('x');
34 }
35 for &c in &self.0 {
36 res.push(CHARS[(c >> 4) as usize]);
40 res.push(CHARS[(c & 0b1111) as usize]);
41 }
42 res
43 }
44}
45
46impl std::fmt::Debug for SessionHash {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_tuple("SessionHash").field(&self.to_hex()).finish()
49 }
50}
51
52fn compute_session_hash(tokens: &SessionTokens) -> SessionHash {
54 let mut hash = Sha256::new().chain_update(tokens.access_token.as_bytes());
55 if let Some(refresh_token) = &tokens.refresh_token {
56 hash = hash.chain_update(refresh_token.as_bytes());
57 }
58 SessionHash(hash.finalize().to_vec())
59}
60
61#[derive(Clone)]
62pub(super) struct CrossProcessRefreshManager {
63 store: Store,
64 store_lock: CrossProcessStoreLock<LockableCryptoStore>,
65 known_session_hash: Arc<Mutex<Option<SessionHash>>>,
66}
67
68impl CrossProcessRefreshManager {
69 pub fn new(store: Store, lock: CrossProcessStoreLock<LockableCryptoStore>) -> Self {
71 Self { store, store_lock: lock, known_session_hash: Arc::new(Mutex::new(None)) }
72 }
73
74 pub async fn spin_lock(
80 &self,
81 ) -> Result<CrossProcessRefreshLockGuard, CrossProcessRefreshLockError> {
82 trace!("Waiting for intra-process lock...");
85 let prev_hash = self.known_session_hash.clone().lock_owned().await;
86
87 trace!("Waiting for inter-process lock...");
90 let store_guard = self.store_lock.spin_lock(Some(60000)).await?;
91
92 let current_db_session_bytes = self.store.get_custom_value(OIDC_SESSION_HASH_KEY).await?;
94
95 let db_hash = current_db_session_bytes.map(SessionHash);
96
97 let hash_mismatch = match (&db_hash, &*prev_hash) {
98 (None, _) => false,
99 (Some(_), None) => true,
100 (Some(db), Some(known)) => db != known,
101 };
102
103 trace!(
104 "Hash mismatch? {:?} (prev. known={:?}, db={:?})",
105 hash_mismatch,
106 *prev_hash,
107 db_hash
108 );
109
110 let guard = CrossProcessRefreshLockGuard {
111 hash_guard: prev_hash,
112 _store_guard: store_guard,
113 hash_mismatch,
114 db_hash,
115 store: self.store.clone(),
116 };
117
118 Ok(guard)
119 }
120
121 pub async fn restore_session(&self, tokens: &SessionTokens) {
122 let prev_tokens_hash = compute_session_hash(tokens);
123 *self.known_session_hash.lock().await = Some(prev_tokens_hash);
124 }
125
126 pub async fn on_logout(&self) -> Result<(), CrossProcessRefreshLockError> {
127 self.store
128 .remove_custom_value(OIDC_SESSION_HASH_KEY)
129 .await
130 .map_err(CrossProcessRefreshLockError::StoreError)?;
131 *self.known_session_hash.lock().await = None;
132 Ok(())
133 }
134}
135
136pub(super) struct CrossProcessRefreshLockGuard {
137 hash_guard: OwnedMutexGuard<Option<SessionHash>>,
140
141 _store_guard: CrossProcessStoreLockGuard,
143
144 store: Store,
147
148 pub hash_mismatch: bool,
157
158 db_hash: Option<SessionHash>,
162}
163
164impl CrossProcessRefreshLockGuard {
165 fn save_in_memory(&mut self, hash: SessionHash) {
167 *self.hash_guard = Some(hash);
168 }
169
170 async fn save_in_database(
172 &self,
173 hash: &SessionHash,
174 ) -> Result<(), CrossProcessRefreshLockError> {
175 self.store.set_custom_value(OIDC_SESSION_HASH_KEY, hash.0.clone()).await?;
176 Ok(())
177 }
178
179 pub async fn save_in_memory_and_db(
183 &mut self,
184 tokens: &SessionTokens,
185 ) -> Result<(), CrossProcessRefreshLockError> {
186 let hash = compute_session_hash(tokens);
187 self.save_in_database(&hash).await?;
188 self.save_in_memory(hash);
189 Ok(())
190 }
191
192 pub async fn handle_mismatch(
195 &mut self,
196 trusted_tokens: &SessionTokens,
197 ) -> Result<(), CrossProcessRefreshLockError> {
198 let new_hash = compute_session_hash(trusted_tokens);
199 trace!("Trusted OAuth 2.0 tokens have hash {new_hash:?}; db had {:?}", self.db_hash);
200
201 if let Some(db_hash) = &self.db_hash {
202 if new_hash != *db_hash {
203 tracing::error!("error: DB and trusted disagree. Overriding in DB.");
207 self.save_in_database(&new_hash).await?;
208 }
209 }
210
211 self.save_in_memory(new_hash);
212 Ok(())
213 }
214}
215
216#[derive(Debug, Error)]
219pub enum CrossProcessRefreshLockError {
220 #[error(transparent)]
222 StoreError(#[from] CryptoStoreError),
223
224 #[error(transparent)]
226 LockError(#[from] LockStoreError),
227
228 #[error("the previous stored hash isn't a valid integer")]
230 InvalidPreviousHash,
231
232 #[error("the cross-process lock hasn't been set up with `enable_cross_process_refresh_lock")]
234 MissingLock,
235
236 #[error("reload session callback must be set with Client::set_session_callbacks() for the cross-process lock to work")]
238 MissingReloadSession,
239
240 #[error(
242 "the cross-process lock has been set up twice with `enable_cross_process_refresh_lock`"
243 )]
244 DuplicatedLock,
245}
246
247#[cfg(all(test, feature = "e2e-encryption", feature = "sqlite", not(target_arch = "wasm32")))]
248mod tests {
249
250 use anyhow::Context as _;
251 use futures_util::future::join_all;
252 use matrix_sdk_base::{store::RoomLoadSettings, SessionMeta};
253 use matrix_sdk_test::async_test;
254 use ruma::{owned_device_id, owned_user_id};
255
256 use super::compute_session_hash;
257 use crate::{
258 authentication::oauth::cross_process::SessionHash,
259 test_utils::{
260 client::{
261 mock_prev_session_tokens_with_refresh, mock_session_tokens_with_refresh,
262 oauth::mock_session, MockClientBuilder,
263 },
264 mocks::MatrixMockServer,
265 },
266 Error,
267 };
268
269 #[async_test]
270 async fn test_restore_session_lock() -> Result<(), Error> {
271 let tmp_dir = tempfile::tempdir()?;
274 let client = MockClientBuilder::new("https://example.org".to_owned())
275 .sqlite_store(&tmp_dir)
276 .unlogged()
277 .build()
278 .await;
279
280 let tokens = mock_session_tokens_with_refresh();
281
282 client.oauth().enable_cross_process_refresh_lock("test".to_owned()).await?;
283
284 client.set_session_callbacks(
285 Box::new({
286 let tokens = tokens.clone();
288 move |_| Ok(tokens.clone())
289 }),
290 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
291 )?;
292
293 let session_hash = compute_session_hash(&tokens);
294 client
295 .oauth()
296 .restore_session(mock_session(tokens.clone()), RoomLoadSettings::default())
297 .await?;
298
299 assert_eq!(client.session_tokens().unwrap(), tokens);
300
301 let oauth = client.oauth();
302 let xp_manager = oauth.ctx().cross_process_token_refresh_manager.get().unwrap();
303
304 {
305 let known_session = xp_manager.known_session_hash.lock().await;
306 assert_eq!(known_session.as_ref().unwrap(), &session_hash);
307 }
308
309 {
310 let lock = xp_manager.spin_lock().await.unwrap();
311 assert!(!lock.hash_mismatch);
312 assert_eq!(lock.db_hash.unwrap(), session_hash);
313 }
314
315 Ok(())
316 }
317
318 #[async_test]
319 async fn test_finish_login() -> anyhow::Result<()> {
320 let server = MatrixMockServer::new().await;
321 server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
322
323 let tmp_dir = tempfile::tempdir()?;
324 let client =
325 server.client_builder().sqlite_store(&tmp_dir).registered_with_oauth().build().await;
326 let oauth = client.oauth();
327
328 oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
330
331 let session_tokens = mock_session_tokens_with_refresh();
333 client.auth_ctx().set_session_tokens(session_tokens.clone());
334
335 oauth.load_session(owned_device_id!("D3V1C31D")).await?;
337
338 let session_meta = client.session_meta().context("should have session meta now")?;
339 assert_eq!(
340 *session_meta,
341 SessionMeta {
342 user_id: owned_user_id!("@joe:example.org"),
343 device_id: owned_device_id!("D3V1C31D")
344 }
345 );
346
347 {
348 let xp_manager =
351 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
352 let guard = xp_manager.spin_lock().await?;
353 let actual_hash = compute_session_hash(&session_tokens);
354 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
355 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
356 assert!(!guard.hash_mismatch);
357 }
358
359 Ok(())
360 }
361
362 #[async_test]
363 async fn test_refresh_access_token_twice() -> anyhow::Result<()> {
364 let server = MatrixMockServer::new().await;
368
369 let oauth_server = server.oauth();
370 oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
371 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
372
373 let tmp_dir = tempfile::tempdir()?;
374 let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
375 let oauth = client.oauth();
376
377 let next_tokens = mock_session_tokens_with_refresh();
378
379 oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
381
382 oauth
384 .restore_session(
385 mock_session(mock_prev_session_tokens_with_refresh()),
386 RoomLoadSettings::default(),
387 )
388 .await?;
389
390 for result in join_all([oauth.refresh_access_token(), oauth.refresh_access_token()]).await {
392 result?;
393 }
394
395 {
396 let xp_manager =
399 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
400 let guard = xp_manager.spin_lock().await?;
401 let actual_hash = compute_session_hash(&next_tokens);
402 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
403 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
404 assert!(!guard.hash_mismatch);
405 }
406
407 Ok(())
408 }
409
410 #[async_test]
411 async fn test_cross_process_concurrent_refresh() -> anyhow::Result<()> {
412 let server = MatrixMockServer::new().await;
413
414 let oauth_server = server.oauth();
415 oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
416 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
417
418 let prev_tokens = mock_prev_session_tokens_with_refresh();
419 let next_tokens = mock_session_tokens_with_refresh();
420
421 let tmp_dir = tempfile::tempdir()?;
423 let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
424
425 let oauth = client.oauth();
426 oauth.enable_cross_process_refresh_lock("client1".to_owned()).await?;
427
428 oauth
429 .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
430 .await?;
431
432 let unrestored_client =
435 server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
436 let unrestored_oauth = unrestored_client.oauth();
437 unrestored_oauth.enable_cross_process_refresh_lock("unrestored_client".to_owned()).await?;
438
439 {
440 let client3 = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
443
444 let oauth3 = client3.oauth();
445 oauth3.enable_cross_process_refresh_lock("client3".to_owned()).await?;
446 oauth3
447 .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
448 .await?;
449
450 oauth3.refresh_access_token().await?;
453
454 assert_eq!(client3.session_tokens(), Some(next_tokens.clone()));
455
456 let xp_manager =
459 oauth3.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
460 let guard = xp_manager.spin_lock().await?;
461 let actual_hash = compute_session_hash(&next_tokens);
462 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
463 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
464 assert!(!guard.hash_mismatch);
465 }
466
467 {
468 let oauth = unrestored_oauth;
470
471 unrestored_client.set_session_callbacks(
472 Box::new({
473 let tokens = next_tokens.clone();
475 move |_| Ok(tokens.clone())
476 }),
477 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
478 )?;
479
480 oauth
481 .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
482 .await?;
483
484 let xp_manager =
486 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
487 let guard = xp_manager.spin_lock().await?;
488 let next_hash = compute_session_hash(&next_tokens);
489 assert_eq!(guard.db_hash.as_ref(), Some(&next_hash));
490 assert_eq!(guard.hash_guard.as_ref(), Some(&next_hash));
491 assert!(!guard.hash_mismatch);
492
493 drop(oauth);
494 drop(unrestored_client);
495 }
496
497 {
498 let xp_manager =
501 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
502 let guard = xp_manager.spin_lock().await?;
503 let previous_hash = compute_session_hash(&prev_tokens);
504 let next_hash = compute_session_hash(&next_tokens);
505 assert_eq!(guard.db_hash, Some(next_hash));
506 assert_eq!(guard.hash_guard.as_ref(), Some(&previous_hash));
507 assert!(guard.hash_mismatch);
508 }
509
510 client.set_session_callbacks(
511 Box::new({
512 let tokens = next_tokens.clone();
514 move |_| Ok(tokens.clone())
515 }),
516 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
517 )?;
518
519 oauth.refresh_access_token().await?;
520
521 {
522 let xp_manager =
524 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
525 let guard = xp_manager.spin_lock().await?;
526 let actual_hash = compute_session_hash(&next_tokens);
527 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
528 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
529 assert!(!guard.hash_mismatch);
530 }
531
532 Ok(())
533 }
534
535 #[async_test]
536 async fn test_logout() -> anyhow::Result<()> {
537 let server = MatrixMockServer::new().await;
538
539 let oauth_server = server.oauth();
540 oauth_server
541 .mock_server_metadata()
542 .ok_https()
543 .expect(1..)
544 .named("server_metadata")
545 .mount()
546 .await;
547 oauth_server.mock_revocation().ok().expect(1).named("revocation").mount().await;
548
549 let tmp_dir = tempfile::tempdir()?;
550 let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
551 let oauth = client.oauth().insecure_rewrite_https_to_http();
552
553 oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
555
556 let tokens = mock_session_tokens_with_refresh();
558 oauth.restore_session(mock_session(tokens.clone()), RoomLoadSettings::default()).await?;
559
560 oauth.logout().await.unwrap();
561
562 {
563 let xp_manager =
566 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
567 let guard = xp_manager.spin_lock().await?;
568 assert!(guard.db_hash.is_none());
569 assert!(guard.hash_guard.is_none());
570 assert!(!guard.hash_mismatch);
571 }
572
573 Ok(())
574 }
575
576 #[test]
577 fn test_session_hash_to_hex() {
578 let hash = SessionHash(vec![]);
579 assert_eq!(hash.to_hex(), "");
580
581 let hash = SessionHash(vec![0x13, 0x37, 0x42, 0xde, 0xad, 0xca, 0xfe]);
582 assert_eq!(hash.to_hex(), "0x133742deadcafe");
583 }
584}