eternaltwin_core/
oauth.rs

1use crate::auth::EtwinOauthAccessTokenKey;
2use crate::core::Instant;
3use crate::password::{Password, PasswordHash};
4use crate::twinoid::TwinoidUserId;
5use crate::types::WeakError;
6use crate::user::{ShortUser, UserId, UserIdRef};
7use async_trait::async_trait;
8#[cfg(feature = "serde")]
9use eternaltwin_serde_tools::{Deserialize, Serialize};
10use once_cell::sync::Lazy;
11use regex::Regex;
12use std::error::Error;
13use std::ops::Deref;
14use std::str::FromStr;
15use thiserror::Error;
16use url::Url;
17
18declare_new_uuid! {
19  pub struct OauthClientId(Uuid);
20  pub type ParseError = OauthClientIdParseError;
21  const SQL_NAME = "oauth_client_id";
22}
23
24declare_new_string! {
25  pub struct OauthClientKey(String);
26  pub type ParseError = OauthClientKeyParseError;
27  const PATTERN = r"^[a-z_][a-z0-9_]{1,31}@clients$";
28  const SQL_NAME = "oauth_client_key";
29}
30
31impl OauthClientKey {
32  /// Extract the app name and channel from the conventional oauth client key format
33  /// TODO: avoid this kind of key parsing and have dedicated fields
34  ///
35  /// ```
36  /// use eternaltwin_core::oauth::OauthClientKey;
37  ///
38  /// let key: OauthClientKey = "neoparc_production@clients".parse().unwrap();
39  /// assert_eq!(("neoparc", Some("production")), key.parts());
40  /// let key: OauthClientKey = "myhordes@clients".parse().unwrap();
41  /// assert_eq!(("myhordes", None), key.parts());
42  /// ```
43  pub fn parts(&self) -> (&str, Option<&str>) {
44    static CONVENTIONAL_PARTS: Lazy<Regex> =
45      Lazy::new(|| Regex::new(r"^([a-z0-9]+)(?:_([a-z0-9_]+))?@clients$").unwrap());
46    let parts = CONVENTIONAL_PARTS
47      .captures(self.as_str())
48      .expect("pattern always matches");
49    let app = parts.get(1).expect("group 1 always exists").as_str();
50    let channel = parts.get(2).map(|m| m.as_str());
51    (app, channel)
52  }
53}
54
55declare_new_string! {
56  pub struct RfcOauthAccessTokenKey(String);
57  pub type ParseError = RfcOauthAccessTokenKeyParseError;
58  const PATTERN = r"^.+$";
59  const SQL_NAME = "rfc_oauth_access_token_key";
60}
61
62declare_new_string! {
63  pub struct RfcOauthRefreshTokenKey(String);
64  pub type ParseError = RfcOauthRefreshTokenKeyParseError;
65  const PATTERN = r"^.+$";
66  const SQL_NAME = "rfc_oauth_refresh_token_key";
67}
68
69declare_new_enum!(
70  pub enum RfcOauthResponseType {
71    #[str("code")]
72    Code,
73    #[str("token")]
74    Token,
75  }
76  pub type ParseError = RfcOauthResponseTypeParseError;
77);
78
79declare_new_enum!(
80  pub enum RfcOauthGrantType {
81    #[str("authorization_code")]
82    AuthorizationCode,
83  }
84  pub type ParseError = RfcOauthGrantTypeParseError;
85);
86
87declare_new_enum!(
88  pub enum RfcOauthTokenType {
89    // TODO: Case-insensitive deserialization
90    #[str("Bearer")]
91    Bearer,
92  }
93  pub type ParseError = RfcOauthTokenTypeParseError;
94);
95
96#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
97#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
98pub struct TwinoidAccessToken {
99  pub key: RfcOauthAccessTokenKey,
100  #[cfg_attr(feature = "serde", serde(rename = "ctime"))]
101  pub created_at: Instant,
102  #[cfg_attr(feature = "serde", serde(rename = "atime"))]
103  pub accessed_at: Instant,
104  #[cfg_attr(feature = "serde", serde(rename = "expiration_time"))]
105  pub expires_at: Instant,
106  pub twinoid_user_id: TwinoidUserId,
107}
108
109#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
110#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
111pub struct TwinoidRefreshToken {
112  pub key: RfcOauthRefreshTokenKey,
113  #[cfg_attr(feature = "serde", serde(rename = "ctime"))]
114  pub created_at: Instant,
115  #[cfg_attr(feature = "serde", serde(rename = "atime"))]
116  pub accessed_at: Instant,
117  pub twinoid_user_id: TwinoidUserId,
118}
119
120#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
121#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
122pub struct StoredOauthAccessToken {
123  pub key: EtwinOauthAccessTokenKey,
124  #[cfg_attr(feature = "serde", serde(rename = "ctime"))]
125  pub created_at: Instant,
126  #[cfg_attr(feature = "serde", serde(rename = "atime"))]
127  pub accessed_at: Instant,
128  #[cfg_attr(feature = "serde", serde(rename = "expiration_time"))]
129  pub expires_at: Instant,
130  pub user: UserIdRef,
131  pub client: OauthClientIdRef,
132}
133
134// TODO: Fix this to not be a mix of `Etwin*` and `Rfc*` types.
135#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
136#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
137pub struct OauthAccessToken {
138  pub token_type: RfcOauthTokenType,
139  pub access_token: EtwinOauthAccessTokenKey,
140  pub expires_in: i64,
141  #[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))]
142  pub refresh_token: Option<RfcOauthRefreshTokenKey>,
143}
144
145#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
146#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
147pub struct RfcOauthAccessToken {
148  pub token_type: RfcOauthTokenType,
149  pub access_token: RfcOauthAccessTokenKey,
150  pub expires_in: i64,
151  #[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))]
152  pub refresh_token: Option<RfcOauthRefreshTokenKey>,
153}
154
155#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
156#[cfg_attr(feature = "serde", serde(tag = "type", rename = "OauthClient"))]
157#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
158pub struct ShortOauthClient {
159  pub id: OauthClientId,
160  pub key: Option<OauthClientKey>,
161  pub display_name: OauthClientDisplayName,
162}
163
164impl From<SimpleOauthClient> for ShortOauthClient {
165  fn from(client: SimpleOauthClient) -> Self {
166    Self {
167      id: client.id,
168      key: client.key,
169      display_name: client.display_name,
170    }
171  }
172}
173
174declare_new_string! {
175  pub struct OauthClientDisplayName(String);
176  pub type ParseError = OauthClientDisplayNameParseError;
177  const PATTERN = r"^[A-Za-z_ ()-]{2,32}$";
178  const SQL_NAME = "oauth_client_display_name";
179}
180
181#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
182#[cfg_attr(feature = "serde", serde(tag = "type", rename = "OauthClient"))]
183#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
184pub struct OauthClient {
185  pub id: OauthClientId,
186  pub key: Option<OauthClientKey>,
187  pub display_name: OauthClientDisplayName,
188  pub app_uri: Url,
189  pub callback_uri: Url,
190  pub owner: Option<ShortUser>,
191}
192
193#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
194#[cfg_attr(feature = "serde", serde(tag = "type", rename = "OauthClient"))]
195#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
196pub struct SimpleOauthClient {
197  pub id: OauthClientId,
198  pub key: Option<OauthClientKey>,
199  pub display_name: OauthClientDisplayName,
200  pub app_uri: Url,
201  pub callback_uri: Url,
202  pub owner: Option<UserIdRef>,
203}
204
205#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
206#[cfg_attr(feature = "serde", serde(tag = "type", rename = "OauthClient"))]
207#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
208pub struct SimpleOauthClientWithSecret {
209  pub id: OauthClientId,
210  pub key: Option<OauthClientKey>,
211  pub display_name: OauthClientDisplayName,
212  pub app_uri: Url,
213  pub callback_uri: Url,
214  pub owner: Option<UserIdRef>,
215  pub secret: PasswordHash,
216}
217
218#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
219#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
220pub struct UpsertSystemClientOptions {
221  pub key: OauthClientKey,
222  pub display_name: OauthClientDisplayName,
223  pub app_uri: Url,
224  pub callback_uri: Url,
225  pub secret: Password,
226}
227
228#[derive(Error, Debug)]
229pub enum RawUpsertSystemOauthClientError {
230  #[error(transparent)]
231  Other(WeakError),
232}
233
234impl RawUpsertSystemOauthClientError {
235  pub fn other<E: Error>(e: E) -> Self {
236    Self::Other(WeakError::wrap(e))
237  }
238}
239
240#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
241#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
242pub struct OauthClientIdRef {
243  pub id: OauthClientId,
244}
245
246impl From<OauthClientId> for OauthClientIdRef {
247  fn from(id: OauthClientId) -> Self {
248    Self { id }
249  }
250}
251
252#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
253#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
254pub struct OauthClientKeyRef {
255  pub key: OauthClientKey,
256}
257
258impl From<OauthClientKey> for OauthClientKeyRef {
259  fn from(key: OauthClientKey) -> Self {
260    Self { key }
261  }
262}
263
264#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
265#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
266pub enum OauthClientRef {
267  Id(OauthClientIdRef),
268  Key(OauthClientKeyRef),
269}
270
271impl FromStr for OauthClientRef {
272  type Err = ();
273
274  fn from_str(input: &str) -> Result<Self, Self::Err> {
275    if let Ok(client_key) = OauthClientKey::from_str(input) {
276      Ok(OauthClientRef::Key(client_key.into()))
277    } else if let Ok(id) = OauthClientId::from_str(input) {
278      Ok(OauthClientRef::Id(id.into()))
279    } else {
280      Err(())
281    }
282  }
283}
284
285declare_new_string! {
286  pub struct EtwinOauthScopesString(String);
287  pub type ParseError = EtwinOauthScopesStringParseError;
288  const PATTERN = r"^.{0,100}$";
289}
290
291#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
292#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
293pub struct EtwinOauthScopes {
294  pub base: bool,
295}
296
297impl EtwinOauthScopes {
298  pub fn strings(&self) -> Vec<String> {
299    if self.base {
300      vec!["base".to_string()]
301    } else {
302      Vec::new()
303    }
304  }
305}
306
307impl Default for EtwinOauthScopes {
308  fn default() -> Self {
309    Self { base: true }
310  }
311}
312
313impl FromStr for EtwinOauthScopes {
314  type Err = ();
315
316  fn from_str(input: &str) -> Result<Self, Self::Err> {
317    let scopes = input.split(' ').map(str::trim).filter(|s| !s.is_empty());
318    let parsed = EtwinOauthScopes::default();
319    for scope in scopes {
320      match scope {
321        "base" => debug_assert!(parsed.base),
322        _ => return Err(()), // Unknown scope
323      }
324    }
325    Ok(parsed)
326  }
327}
328
329#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
330#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
331pub struct GetOauthClientOptions {
332  pub r#ref: OauthClientRef,
333}
334
335#[derive(Error, Debug)]
336pub enum RawGetOauthClientError {
337  #[error("oauth client not found: {0:?}")]
338  NotFound(OauthClientRef),
339  #[error(transparent)]
340  Other(WeakError),
341}
342
343impl RawGetOauthClientError {
344  pub fn other<E: Error>(e: E) -> Self {
345    Self::Other(WeakError::wrap(e))
346  }
347}
348
349#[derive(Error, Debug)]
350pub enum RawGetOauthClientWithSecretError {
351  #[error("oauth client not found: {0:?}")]
352  NotFound(OauthClientRef),
353  #[error(transparent)]
354  Other(WeakError),
355}
356
357impl RawGetOauthClientWithSecretError {
358  pub fn other<E: Error>(e: E) -> Self {
359    Self::Other(WeakError::wrap(e))
360  }
361}
362
363#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
364#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
365pub struct VerifyClientSecretOptions {
366  pub r#ref: OauthClientRef,
367  pub secret: Password,
368}
369
370#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
371#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
372pub struct CreateStoredAccessTokenOptions {
373  pub key: EtwinOauthAccessTokenKey,
374  pub ctime: Instant,
375  pub expiration_time: Instant,
376  pub user: UserIdRef,
377  pub client: OauthClientIdRef,
378}
379
380#[derive(Error, Debug)]
381pub enum RawCreateAccessTokenError {
382  #[error(transparent)]
383  Other(WeakError),
384}
385
386impl RawCreateAccessTokenError {
387  pub fn other<E: Error>(e: E) -> Self {
388    Self::Other(WeakError::wrap(e))
389  }
390}
391
392#[derive(Error, Debug)]
393pub enum RawGetAccessTokenError {
394  #[error("oauth access token not found")]
395  NotFound,
396  #[error(transparent)]
397  Other(WeakError),
398}
399
400impl RawGetAccessTokenError {
401  pub fn other<E: Error>(e: E) -> Self {
402    Self::Other(WeakError::wrap(e))
403  }
404}
405
406#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
407#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
408pub struct GetOauthAccessTokenOptions {
409  pub key: EtwinOauthAccessTokenKey,
410  pub touch_accessed_at: bool,
411}
412
413#[async_trait]
414pub trait OauthProviderStore: Send + Sync {
415  async fn upsert_system_client(
416    &self,
417    options: &UpsertSystemClientOptions,
418  ) -> Result<SimpleOauthClient, RawUpsertSystemOauthClientError>;
419
420  async fn get_client(&self, options: &GetOauthClientOptions) -> Result<SimpleOauthClient, RawGetOauthClientError>;
421
422  async fn get_client_with_secret(
423    &self,
424    options: &GetOauthClientOptions,
425  ) -> Result<SimpleOauthClientWithSecret, RawGetOauthClientWithSecretError>;
426
427  async fn create_access_token(
428    &self,
429    options: &CreateStoredAccessTokenOptions,
430  ) -> Result<StoredOauthAccessToken, RawCreateAccessTokenError>;
431
432  async fn get_access_token(
433    &self,
434    options: &GetOauthAccessTokenOptions,
435  ) -> Result<StoredOauthAccessToken, RawGetAccessTokenError>;
436}
437
438/// Like [`Deref`], but the target has the bound [`OauthProviderStore`]
439pub trait OauthProviderStoreRef: Send + Sync {
440  type OauthProviderStore: OauthProviderStore + ?Sized;
441
442  fn oauth_provider_store(&self) -> &Self::OauthProviderStore;
443}
444
445impl<TyRef> OauthProviderStoreRef for TyRef
446where
447  TyRef: Deref + Send + Sync,
448  TyRef::Target: OauthProviderStore,
449{
450  type OauthProviderStore = TyRef::Target;
451
452  fn oauth_provider_store(&self) -> &Self::OauthProviderStore {
453    self.deref()
454  }
455}
456
457#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
458#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
459pub struct EtwinOauthStateClaims {
460  /**
461   * String used for XSRF protection.
462   */
463  #[cfg_attr(feature = "serde", serde(rename = "rfp"))]
464  pub request_forgery_protection: String,
465  #[cfg_attr(feature = "serde", serde(rename = "a"))]
466  pub action: EtwinOauthStateAction,
467  #[cfg_attr(feature = "serde", serde(rename = "iat", with = "serde_posix_timestamp"))]
468  pub issued_at: Instant,
469  #[cfg_attr(feature = "serde", serde(rename = "as"))]
470  pub authorization_server: String,
471  #[cfg_attr(feature = "serde", serde(rename = "exp", with = "serde_posix_timestamp"))]
472  pub expiration_time: Instant,
473}
474
475#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
476#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
477#[cfg_attr(feature = "serde", serde(tag = "type"))]
478pub enum EtwinOauthStateAction {
479  #[cfg_attr(feature = "serde", serde(rename = "Login"))]
480  Login,
481  #[cfg_attr(feature = "serde", serde(rename = "Link"))]
482  Link { user_id: UserId },
483}
484
485#[cfg(feature = "serde")]
486pub mod serde_posix_timestamp {
487  use crate::core::Instant;
488  use serde::{Deserialize, Deserializer, Serialize, Serializer};
489
490  pub fn serialize<S>(value: &Instant, serializer: S) -> Result<S::Ok, S::Error>
491  where
492    S: Serializer,
493  {
494    value.into_posix_timestamp().serialize(serializer)
495  }
496
497  pub fn deserialize<'de, D>(deserializer: D) -> Result<Instant, D::Error>
498  where
499    D: Deserializer<'de>,
500  {
501    i64::deserialize(deserializer).map(Instant::from_posix_timestamp)
502  }
503}