matrix_sdk/authentication/oauth/
cross_process.rs1use std::sync::Arc;
2
3#[cfg(feature = "e2e-encryption")]
4use matrix_sdk_base::crypto::{
5 store::{LockableCryptoStore, Store},
6 CryptoStoreError,
7};
8use matrix_sdk_common::store_locks::{
9 CrossProcessStoreLock, CrossProcessStoreLockGuard, LockStoreError,
10};
11use sha2::{Digest as _, Sha256};
12use thiserror::Error;
13use tokio::sync::{Mutex, OwnedMutexGuard};
14use tracing::trace;
15
16use crate::SessionTokens;
17
18const OIDC_SESSION_HASH_KEY: &str = "oidc_session_hash";
21
22#[derive(Clone, PartialEq, Eq)]
24struct SessionHash(Vec<u8>);
25
26impl SessionHash {
27 fn to_hex(&self) -> String {
28 const CHARS: &[char; 16] =
29 &['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'];
30 let mut res = String::with_capacity(2 * self.0.len() + 2);
31 if !self.0.is_empty() {
32 res.push('0');
33 res.push('x');
34 }
35 for &c in &self.0 {
36 res.push(CHARS[(c >> 4) as usize]);
40 res.push(CHARS[(c & 0b1111) as usize]);
41 }
42 res
43 }
44}
45
46impl std::fmt::Debug for SessionHash {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_tuple("SessionHash").field(&self.to_hex()).finish()
49 }
50}
51
52fn compute_session_hash(tokens: &SessionTokens) -> SessionHash {
54 let mut hash = Sha256::new().chain_update(tokens.access_token.as_bytes());
55 if let Some(refresh_token) = &tokens.refresh_token {
56 hash = hash.chain_update(refresh_token.as_bytes());
57 }
58 SessionHash(hash.finalize().to_vec())
59}
60
61#[derive(Clone)]
62pub(super) struct CrossProcessRefreshManager {
63 store: Store,
64 store_lock: CrossProcessStoreLock<LockableCryptoStore>,
65 known_session_hash: Arc<Mutex<Option<SessionHash>>>,
66}
67
68impl CrossProcessRefreshManager {
69 pub fn new(store: Store, lock: CrossProcessStoreLock<LockableCryptoStore>) -> Self {
71 Self { store, store_lock: lock, known_session_hash: Arc::new(Mutex::new(None)) }
72 }
73
74 pub async fn spin_lock(
80 &self,
81 ) -> Result<CrossProcessRefreshLockGuard, CrossProcessRefreshLockError> {
82 trace!("Waiting for intra-process lock...");
85 let prev_hash = self.known_session_hash.clone().lock_owned().await;
86
87 trace!("Waiting for inter-process lock...");
90 let store_guard = self.store_lock.spin_lock(Some(60000)).await?;
91
92 let current_db_session_bytes = self.store.get_custom_value(OIDC_SESSION_HASH_KEY).await?;
94
95 let db_hash = current_db_session_bytes.map(SessionHash);
96
97 let hash_mismatch = match (&db_hash, &*prev_hash) {
98 (None, _) => false,
99 (Some(_), None) => true,
100 (Some(db), Some(known)) => db != known,
101 };
102
103 trace!(hash_mismatch, ?prev_hash, ?db_hash);
104
105 let guard = CrossProcessRefreshLockGuard {
106 hash_guard: prev_hash,
107 _store_guard: store_guard,
108 hash_mismatch,
109 db_hash,
110 store: self.store.clone(),
111 };
112
113 Ok(guard)
114 }
115
116 pub async fn restore_session(&self, tokens: &SessionTokens) {
117 let prev_tokens_hash = compute_session_hash(tokens);
118 *self.known_session_hash.lock().await = Some(prev_tokens_hash);
119 }
120
121 pub async fn on_logout(&self) -> Result<(), CrossProcessRefreshLockError> {
122 self.store
123 .remove_custom_value(OIDC_SESSION_HASH_KEY)
124 .await
125 .map_err(CrossProcessRefreshLockError::StoreError)?;
126 *self.known_session_hash.lock().await = None;
127 Ok(())
128 }
129}
130
131pub(super) struct CrossProcessRefreshLockGuard {
132 hash_guard: OwnedMutexGuard<Option<SessionHash>>,
135
136 _store_guard: CrossProcessStoreLockGuard,
138
139 store: Store,
142
143 pub hash_mismatch: bool,
152
153 db_hash: Option<SessionHash>,
157}
158
159impl CrossProcessRefreshLockGuard {
160 fn save_in_memory(&mut self, hash: SessionHash) {
162 *self.hash_guard = Some(hash);
163 }
164
165 async fn save_in_database(
167 &self,
168 hash: &SessionHash,
169 ) -> Result<(), CrossProcessRefreshLockError> {
170 self.store.set_custom_value(OIDC_SESSION_HASH_KEY, hash.0.clone()).await?;
171 Ok(())
172 }
173
174 pub async fn save_in_memory_and_db(
178 &mut self,
179 tokens: &SessionTokens,
180 ) -> Result<(), CrossProcessRefreshLockError> {
181 let hash = compute_session_hash(tokens);
182 self.save_in_database(&hash).await?;
183 self.save_in_memory(hash);
184 Ok(())
185 }
186
187 pub async fn handle_mismatch(
190 &mut self,
191 trusted_tokens: &SessionTokens,
192 ) -> Result<(), CrossProcessRefreshLockError> {
193 let new_hash = compute_session_hash(trusted_tokens);
194 trace!("Trusted OAuth 2.0 tokens have hash {new_hash:?}; db had {:?}", self.db_hash);
195
196 if let Some(db_hash) = &self.db_hash {
197 if new_hash != *db_hash {
198 tracing::error!("error: DB and trusted disagree. Overriding in DB.");
202 self.save_in_database(&new_hash).await?;
203 }
204 }
205
206 self.save_in_memory(new_hash);
207 Ok(())
208 }
209}
210
211#[derive(Debug, Error)]
214pub enum CrossProcessRefreshLockError {
215 #[error(transparent)]
217 StoreError(#[from] CryptoStoreError),
218
219 #[error(transparent)]
221 LockError(#[from] LockStoreError),
222
223 #[error("the previous stored hash isn't a valid integer")]
225 InvalidPreviousHash,
226
227 #[error("the cross-process lock hasn't been set up with `enable_cross_process_refresh_lock")]
229 MissingLock,
230
231 #[error("reload session callback must be set with Client::set_session_callbacks() for the cross-process lock to work")]
233 MissingReloadSession,
234
235 #[error(
237 "the cross-process lock has been set up twice with `enable_cross_process_refresh_lock`"
238 )]
239 DuplicatedLock,
240}
241
242#[cfg(all(test, feature = "e2e-encryption", feature = "sqlite", not(target_family = "wasm")))]
243mod tests {
244
245 use anyhow::Context as _;
246 use futures_util::future::join_all;
247 use matrix_sdk_base::{store::RoomLoadSettings, SessionMeta};
248 use matrix_sdk_test::async_test;
249 use ruma::{owned_device_id, owned_user_id};
250
251 use super::compute_session_hash;
252 use crate::{
253 authentication::oauth::cross_process::SessionHash,
254 test_utils::{
255 client::{
256 mock_prev_session_tokens_with_refresh, mock_session_tokens_with_refresh,
257 oauth::mock_session, MockClientBuilder,
258 },
259 mocks::MatrixMockServer,
260 },
261 Error,
262 };
263
264 #[async_test]
265 async fn test_restore_session_lock() -> Result<(), Error> {
266 let tmp_dir = tempfile::tempdir()?;
269 let client = MockClientBuilder::new("https://example.org".to_owned())
270 .sqlite_store(&tmp_dir)
271 .unlogged()
272 .build()
273 .await;
274
275 let tokens = mock_session_tokens_with_refresh();
276
277 client.oauth().enable_cross_process_refresh_lock("test".to_owned()).await?;
278
279 client.set_session_callbacks(
280 Box::new({
281 let tokens = tokens.clone();
283 move |_| Ok(tokens.clone())
284 }),
285 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
286 )?;
287
288 let session_hash = compute_session_hash(&tokens);
289 client
290 .oauth()
291 .restore_session(mock_session(tokens.clone()), RoomLoadSettings::default())
292 .await?;
293
294 assert_eq!(client.session_tokens().unwrap(), tokens);
295
296 let oauth = client.oauth();
297 let xp_manager = oauth.ctx().cross_process_token_refresh_manager.get().unwrap();
298
299 {
300 let known_session = xp_manager.known_session_hash.lock().await;
301 assert_eq!(known_session.as_ref().unwrap(), &session_hash);
302 }
303
304 {
305 let lock = xp_manager.spin_lock().await.unwrap();
306 assert!(!lock.hash_mismatch);
307 assert_eq!(lock.db_hash.unwrap(), session_hash);
308 }
309
310 Ok(())
311 }
312
313 #[async_test]
314 async fn test_finish_login() -> anyhow::Result<()> {
315 let server = MatrixMockServer::new().await;
316 server.mock_who_am_i().ok().expect(1).named("whoami").mount().await;
317
318 let tmp_dir = tempfile::tempdir()?;
319 let client =
320 server.client_builder().sqlite_store(&tmp_dir).registered_with_oauth().build().await;
321 let oauth = client.oauth();
322
323 oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
325
326 let session_tokens = mock_session_tokens_with_refresh();
328 client.auth_ctx().set_session_tokens(session_tokens.clone());
329
330 oauth.load_session(owned_device_id!("D3V1C31D")).await?;
332
333 let session_meta = client.session_meta().context("should have session meta now")?;
334 assert_eq!(
335 *session_meta,
336 SessionMeta {
337 user_id: owned_user_id!("@joe:example.org"),
338 device_id: owned_device_id!("D3V1C31D")
339 }
340 );
341
342 {
343 let xp_manager =
346 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
347 let guard = xp_manager.spin_lock().await?;
348 let actual_hash = compute_session_hash(&session_tokens);
349 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
350 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
351 assert!(!guard.hash_mismatch);
352 }
353
354 Ok(())
355 }
356
357 #[async_test]
358 async fn test_refresh_access_token_twice() -> anyhow::Result<()> {
359 let server = MatrixMockServer::new().await;
363
364 let oauth_server = server.oauth();
365 oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
366 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
367
368 let tmp_dir = tempfile::tempdir()?;
369 let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
370 let oauth = client.oauth();
371
372 let next_tokens = mock_session_tokens_with_refresh();
373
374 oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
376
377 oauth
379 .restore_session(
380 mock_session(mock_prev_session_tokens_with_refresh()),
381 RoomLoadSettings::default(),
382 )
383 .await?;
384
385 for result in join_all([oauth.refresh_access_token(), oauth.refresh_access_token()]).await {
387 result?;
388 }
389
390 {
391 let xp_manager =
394 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
395 let guard = xp_manager.spin_lock().await?;
396 let actual_hash = compute_session_hash(&next_tokens);
397 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
398 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
399 assert!(!guard.hash_mismatch);
400 }
401
402 Ok(())
403 }
404
405 #[async_test]
406 async fn test_cross_process_concurrent_refresh() -> anyhow::Result<()> {
407 let server = MatrixMockServer::new().await;
408
409 let oauth_server = server.oauth();
410 oauth_server.mock_server_metadata().ok().expect(1..).named("server_metadata").mount().await;
411 oauth_server.mock_token().ok().expect(1).named("token").mount().await;
412
413 let prev_tokens = mock_prev_session_tokens_with_refresh();
414 let next_tokens = mock_session_tokens_with_refresh();
415
416 let tmp_dir = tempfile::tempdir()?;
418 let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
419
420 let oauth = client.oauth();
421 oauth.enable_cross_process_refresh_lock("client1".to_owned()).await?;
422
423 oauth
424 .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
425 .await?;
426
427 let unrestored_client =
430 server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
431 let unrestored_oauth = unrestored_client.oauth();
432 unrestored_oauth.enable_cross_process_refresh_lock("unrestored_client".to_owned()).await?;
433
434 {
435 let client3 = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
438
439 let oauth3 = client3.oauth();
440 oauth3.enable_cross_process_refresh_lock("client3".to_owned()).await?;
441 oauth3
442 .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
443 .await?;
444
445 oauth3.refresh_access_token().await?;
448
449 assert_eq!(client3.session_tokens(), Some(next_tokens.clone()));
450
451 let xp_manager =
454 oauth3.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
455 let guard = xp_manager.spin_lock().await?;
456 let actual_hash = compute_session_hash(&next_tokens);
457 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
458 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
459 assert!(!guard.hash_mismatch);
460 }
461
462 {
463 let oauth = unrestored_oauth;
465
466 unrestored_client.set_session_callbacks(
467 Box::new({
468 let tokens = next_tokens.clone();
470 move |_| Ok(tokens.clone())
471 }),
472 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
473 )?;
474
475 oauth
476 .restore_session(mock_session(prev_tokens.clone()), RoomLoadSettings::default())
477 .await?;
478
479 let xp_manager =
481 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
482 let guard = xp_manager.spin_lock().await?;
483 let next_hash = compute_session_hash(&next_tokens);
484 assert_eq!(guard.db_hash.as_ref(), Some(&next_hash));
485 assert_eq!(guard.hash_guard.as_ref(), Some(&next_hash));
486 assert!(!guard.hash_mismatch);
487
488 drop(oauth);
489 drop(unrestored_client);
490 }
491
492 {
493 let xp_manager =
496 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
497 let guard = xp_manager.spin_lock().await?;
498 let previous_hash = compute_session_hash(&prev_tokens);
499 let next_hash = compute_session_hash(&next_tokens);
500 assert_eq!(guard.db_hash, Some(next_hash));
501 assert_eq!(guard.hash_guard.as_ref(), Some(&previous_hash));
502 assert!(guard.hash_mismatch);
503 }
504
505 client.set_session_callbacks(
506 Box::new({
507 let tokens = next_tokens.clone();
509 move |_| Ok(tokens.clone())
510 }),
511 Box::new(|_| panic!("save_session_callback shouldn't be called here")),
512 )?;
513
514 oauth.refresh_access_token().await?;
515
516 {
517 let xp_manager =
519 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
520 let guard = xp_manager.spin_lock().await?;
521 let actual_hash = compute_session_hash(&next_tokens);
522 assert_eq!(guard.db_hash.as_ref(), Some(&actual_hash));
523 assert_eq!(guard.hash_guard.as_ref(), Some(&actual_hash));
524 assert!(!guard.hash_mismatch);
525 }
526
527 Ok(())
528 }
529
530 #[async_test]
531 async fn test_logout() -> anyhow::Result<()> {
532 let server = MatrixMockServer::new().await;
533
534 let oauth_server = server.oauth();
535 oauth_server
536 .mock_server_metadata()
537 .ok_https()
538 .expect(1..)
539 .named("server_metadata")
540 .mount()
541 .await;
542 oauth_server.mock_revocation().ok().expect(1).named("revocation").mount().await;
543
544 let tmp_dir = tempfile::tempdir()?;
545 let client = server.client_builder().sqlite_store(&tmp_dir).unlogged().build().await;
546 let oauth = client.oauth().insecure_rewrite_https_to_http();
547
548 oauth.enable_cross_process_refresh_lock("lock".to_owned()).await?;
550
551 let tokens = mock_session_tokens_with_refresh();
553 oauth.restore_session(mock_session(tokens.clone()), RoomLoadSettings::default()).await?;
554
555 oauth.logout().await.unwrap();
556
557 {
558 let xp_manager =
561 oauth.ctx().cross_process_token_refresh_manager.get().context("must have lock")?;
562 let guard = xp_manager.spin_lock().await?;
563 assert!(guard.db_hash.is_none());
564 assert!(guard.hash_guard.is_none());
565 assert!(!guard.hash_mismatch);
566 }
567
568 Ok(())
569 }
570
571 #[test]
572 fn test_session_hash_to_hex() {
573 let hash = SessionHash(vec![]);
574 assert_eq!(hash.to_hex(), "");
575
576 let hash = SessionHash(vec![0x13, 0x37, 0x42, 0xde, 0xad, 0xca, 0xfe]);
577 assert_eq!(hash.to_hex(), "0x133742deadcafe");
578 }
579}