matrix_sdk/authentication/
mod.rs

1// Copyright 2023 Kévin Commaille
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Types and functions related to authentication in Matrix.
16
17use std::{fmt, sync::Arc};
18
19use matrix_sdk_base::{SessionMeta, locks::Mutex};
20use serde::{Deserialize, Serialize};
21use tokio::sync::{Mutex as AsyncMutex, OnceCell, broadcast};
22
23pub mod matrix;
24pub mod oauth;
25
26use self::{
27    matrix::MatrixAuth,
28    oauth::{OAuth, OAuthAuthData, OAuthCtx},
29};
30use crate::{Client, RefreshTokenError, SessionChange};
31
32/// The tokens for a user session.
33#[derive(Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
34#[allow(missing_debug_implementations)]
35pub struct SessionTokens {
36    /// The access token used for this session.
37    pub access_token: String,
38
39    /// The token used for refreshing the access token, if any.
40    #[serde(default, skip_serializing_if = "Option::is_none")]
41    pub refresh_token: Option<String>,
42}
43
44#[cfg(not(tarpaulin_include))]
45impl fmt::Debug for SessionTokens {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        f.debug_struct("SessionTokens").finish_non_exhaustive()
48    }
49}
50
51/// The tokens for a user session and their state.
52pub(crate) struct SessionTokensState {
53    /// The inner tokens.
54    inner: SessionTokens,
55
56    /// Whether the access token is expired.
57    ///
58    /// We keep track of this information here, rather than dropping the access
59    /// token, because we still want to make most requests with the expired
60    /// access token to try to refresh it, or wait for it to be refreshed. If we
61    /// make a request without the access token we will get the wrong error.
62    access_token_expired: bool,
63}
64
65pub(crate) type SessionCallbackError = Box<dyn std::error::Error + Send + Sync>;
66
67#[cfg(not(target_family = "wasm"))]
68pub(crate) type SaveSessionCallback =
69    dyn Fn(Client) -> Result<(), SessionCallbackError> + Send + Sync;
70#[cfg(target_family = "wasm")]
71pub(crate) type SaveSessionCallback = dyn Fn(Client) -> Result<(), SessionCallbackError>;
72
73#[cfg(not(target_family = "wasm"))]
74pub(crate) type ReloadSessionCallback =
75    dyn Fn(Client) -> Result<SessionTokens, SessionCallbackError> + Send + Sync;
76#[cfg(target_family = "wasm")]
77pub(crate) type ReloadSessionCallback =
78    dyn Fn(Client) -> Result<SessionTokens, SessionCallbackError>;
79
80/// All the data relative to authentication, and that must be shared between a
81/// client and all its children.
82pub(crate) struct AuthCtx {
83    oauth: OAuthCtx,
84
85    /// Whether to try to refresh the access token automatically when an
86    /// `M_UNKNOWN_TOKEN` error is encountered.
87    pub(crate) handle_refresh_tokens: bool,
88
89    /// Lock making sure we're only doing one token refresh at a time.
90    refresh_token_lock: Arc<AsyncMutex<Result<(), RefreshTokenError>>>,
91
92    /// Session change publisher. Allows the subscriber to handle changes to the
93    /// session such as logging out when the access token is invalid or
94    /// persisting updates to the access/refresh tokens.
95    pub(crate) session_change_sender: broadcast::Sender<SessionChange>,
96
97    /// Authentication data to keep in memory.
98    pub(crate) auth_data: OnceCell<AuthData>,
99
100    /// The current session tokens and their state.
101    tokens: OnceCell<Mutex<SessionTokensState>>,
102
103    /// A callback called whenever we need an absolute source of truth for the
104    /// current session tokens.
105    ///
106    /// This is required only in multiple processes setups.
107    pub(crate) reload_session_callback: OnceCell<Box<ReloadSessionCallback>>,
108
109    /// A callback to save a session back into the app's secure storage.
110    ///
111    /// This is always called, independently of the presence of a cross-process
112    /// lock.
113    ///
114    /// Internal invariant: this must be called only after `set_session_tokens`
115    /// has been called, not before.
116    pub(crate) save_session_callback: OnceCell<Box<SaveSessionCallback>>,
117}
118
119impl AuthCtx {
120    /// Construct a new `AuthCtx` with the given settings.
121    pub(crate) fn new(handle_refresh_tokens: bool, allow_insecure_oauth: bool) -> Self {
122        Self {
123            handle_refresh_tokens,
124            refresh_token_lock: Arc::new(AsyncMutex::new(Ok(()))),
125            session_change_sender: broadcast::Sender::new(1),
126            auth_data: OnceCell::default(),
127            tokens: OnceCell::default(),
128            reload_session_callback: OnceCell::default(),
129            save_session_callback: OnceCell::default(),
130            oauth: OAuthCtx::new(allow_insecure_oauth),
131        }
132    }
133
134    /// The current session tokens.
135    pub(crate) fn session_tokens(&self) -> Option<SessionTokens> {
136        Some(self.tokens.get()?.lock().inner.clone())
137    }
138
139    /// The current access token.
140    pub(crate) fn access_token(&self) -> Option<String> {
141        Some(self.tokens.get()?.lock().inner.access_token.clone())
142    }
143
144    /// Whether we have a valid session token.
145    pub(crate) fn has_valid_access_token(&self) -> bool {
146        self.tokens.get().is_some_and(|tokens| !tokens.lock().access_token_expired)
147    }
148
149    /// Set the current session tokens.
150    pub(crate) fn set_session_tokens(&self, session_tokens: SessionTokens) {
151        let session_tokens = SessionTokensState {
152            inner: session_tokens,
153            // We just got the tokens, so we assume that they are not expired.
154            access_token_expired: false,
155        };
156
157        if let Some(tokens) = self.tokens.get() {
158            *tokens.lock() = session_tokens;
159        } else {
160            let _ = self.tokens.set(Mutex::new(session_tokens));
161        }
162    }
163
164    /// Set the given access token as expired.
165    ///
166    /// We take the value of the access token to make sure that we don't mark
167    /// the wrong access token as expired.
168    pub(crate) fn set_access_token_expired(&self, access_token: &str) {
169        if let Some(tokens) = self.tokens.get() {
170            let mut tokens = tokens.lock();
171
172            if tokens.inner.access_token == access_token {
173                tokens.access_token_expired = true;
174            }
175        }
176    }
177}
178
179/// An enum over all the possible authentication APIs.
180#[derive(Debug, Clone)]
181#[non_exhaustive]
182pub enum AuthApi {
183    /// The native Matrix authentication API.
184    Matrix(MatrixAuth),
185
186    /// The OAuth 2.0 API.
187    OAuth(OAuth),
188}
189
190/// A user session using one of the available authentication APIs.
191#[derive(Debug, Clone)]
192#[non_exhaustive]
193pub enum AuthSession {
194    /// A session using the native Matrix authentication API.
195    Matrix(matrix::MatrixSession),
196
197    /// A session using the OAuth 2.0 API.
198    OAuth(Box<oauth::OAuthSession>),
199}
200
201impl AuthSession {
202    /// Get the matrix user information of this session.
203    pub fn meta(&self) -> &SessionMeta {
204        match self {
205            AuthSession::Matrix(session) => &session.meta,
206            AuthSession::OAuth(session) => &session.user.meta,
207        }
208    }
209
210    /// Take the matrix user information of this session.
211    pub fn into_meta(self) -> SessionMeta {
212        match self {
213            AuthSession::Matrix(session) => session.meta,
214            AuthSession::OAuth(session) => session.user.meta,
215        }
216    }
217
218    /// Get the access token of this session.
219    pub fn access_token(&self) -> &str {
220        match self {
221            AuthSession::Matrix(session) => &session.tokens.access_token,
222            AuthSession::OAuth(session) => &session.user.tokens.access_token,
223        }
224    }
225
226    /// Get the refresh token of this session.
227    pub fn get_refresh_token(&self) -> Option<&str> {
228        match self {
229            AuthSession::Matrix(session) => session.tokens.refresh_token.as_deref(),
230            AuthSession::OAuth(session) => session.user.tokens.refresh_token.as_deref(),
231        }
232    }
233}
234
235impl From<matrix::MatrixSession> for AuthSession {
236    fn from(session: matrix::MatrixSession) -> Self {
237        Self::Matrix(session)
238    }
239}
240
241impl From<oauth::OAuthSession> for AuthSession {
242    fn from(session: oauth::OAuthSession) -> Self {
243        Self::OAuth(session.into())
244    }
245}
246
247/// Data for an authentication API.
248#[derive(Debug)]
249pub(crate) enum AuthData {
250    /// Data for the native Matrix authentication API.
251    Matrix,
252    /// Data for the OAuth 2.0 API.
253    OAuth(OAuthAuthData),
254}