etwin_core/
oauth.rs

1use crate::auth::EtwinOauthAccessTokenKey;
2use crate::core::Instant;
3use crate::password::{Password, PasswordHash};
4use crate::twinoid::TwinoidUserId;
5use crate::types::AnyError;
6use crate::user::{ShortUser, UserId, UserIdRef};
7use async_trait::async_trait;
8use auto_impl::auto_impl;
9#[cfg(feature = "serde")]
10use etwin_serde_tools::{Deserialize, Serialize};
11use std::str::FromStr;
12use thiserror::Error;
13use url::Url;
14
15declare_new_uuid! {
16  pub struct OauthClientId(Uuid);
17  pub type ParseError = OauthClientIdParseError;
18  const SQL_NAME = "oauth_client_id";
19}
20
21declare_new_string! {
22  pub struct OauthClientKey(String);
23  pub type ParseError = OauthClientKeyParseError;
24  const PATTERN = r"^[a-z_][a-z0-9_]{1,31}@clients$";
25  const SQL_NAME = "oauth_client_key";
26}
27
28declare_new_string! {
29  pub struct RfcOauthAccessTokenKey(String);
30  pub type ParseError = RfcOauthAccessTokenKeyParseError;
31  const PATTERN = r"^.+$";
32  const SQL_NAME = "rfc_oauth_access_token_key";
33}
34
35declare_new_string! {
36  pub struct RfcOauthRefreshTokenKey(String);
37  pub type ParseError = RfcOauthRefreshTokenKeyParseError;
38  const PATTERN = r"^.+$";
39  const SQL_NAME = "rfc_oauth_refresh_token_key";
40}
41
42declare_new_enum!(
43  pub enum RfcOauthResponseType {
44    #[str("code")]
45    Code,
46    #[str("token")]
47    Token,
48  }
49  pub type ParseError = RfcOauthResponseTypeParseError;
50);
51
52declare_new_enum!(
53  pub enum RfcOauthGrantType {
54    #[str("authorization_code")]
55    AuthorizationCode,
56  }
57  pub type ParseError = RfcOauthGrantTypeParseError;
58);
59
60declare_new_enum!(
61  pub enum RfcOauthTokenType {
62    // TODO: Case-insensitive deserialization
63    #[str("Bearer")]
64    Bearer,
65  }
66  pub type ParseError = RfcOauthTokenTypeParseError;
67);
68
69#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
70#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
71pub struct TwinoidAccessToken {
72  pub key: RfcOauthAccessTokenKey,
73  #[cfg_attr(feature = "serde", serde(rename = "ctime"))]
74  pub created_at: Instant,
75  #[cfg_attr(feature = "serde", serde(rename = "atime"))]
76  pub accessed_at: Instant,
77  #[cfg_attr(feature = "serde", serde(rename = "expiration_time"))]
78  pub expires_at: Instant,
79  pub twinoid_user_id: TwinoidUserId,
80}
81
82#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
83#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
84pub struct TwinoidRefreshToken {
85  pub key: RfcOauthRefreshTokenKey,
86  #[cfg_attr(feature = "serde", serde(rename = "ctime"))]
87  pub created_at: Instant,
88  #[cfg_attr(feature = "serde", serde(rename = "atime"))]
89  pub accessed_at: Instant,
90  pub twinoid_user_id: TwinoidUserId,
91}
92
93#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
94#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
95pub struct StoredOauthAccessToken {
96  pub key: EtwinOauthAccessTokenKey,
97  #[cfg_attr(feature = "serde", serde(rename = "ctime"))]
98  pub created_at: Instant,
99  #[cfg_attr(feature = "serde", serde(rename = "atime"))]
100  pub accessed_at: Instant,
101  #[cfg_attr(feature = "serde", serde(rename = "expiration_time"))]
102  pub expires_at: Instant,
103  pub user: UserIdRef,
104  pub client: OauthClientIdRef,
105}
106
107// TODO: Fix this to not be a mix of `Etwin*` and `Rfc*` types.
108#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
109#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
110pub struct OauthAccessToken {
111  pub token_type: RfcOauthTokenType,
112  pub access_token: EtwinOauthAccessTokenKey,
113  pub expires_in: i64,
114  #[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))]
115  pub refresh_token: Option<RfcOauthRefreshTokenKey>,
116}
117
118#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
119#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
120pub struct RfcOauthAccessToken {
121  pub token_type: RfcOauthTokenType,
122  pub access_token: RfcOauthAccessTokenKey,
123  pub expires_in: i64,
124  #[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))]
125  pub refresh_token: Option<RfcOauthRefreshTokenKey>,
126}
127
128#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
129#[cfg_attr(feature = "serde", serde(tag = "type", rename = "OauthClient"))]
130#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
131pub struct ShortOauthClient {
132  pub id: OauthClientId,
133  pub key: Option<OauthClientKey>,
134  pub display_name: OauthClientDisplayName,
135}
136
137impl From<SimpleOauthClient> for ShortOauthClient {
138  fn from(client: SimpleOauthClient) -> Self {
139    Self {
140      id: client.id,
141      key: client.key,
142      display_name: client.display_name,
143    }
144  }
145}
146
147declare_new_string! {
148  pub struct OauthClientDisplayName(String);
149  pub type ParseError = OauthClientDisplayNameParseError;
150  const PATTERN = r"^[A-Za-z_ ()-]{2,32}$";
151  const SQL_NAME = "oauth_client_display_name";
152}
153
154#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
155#[cfg_attr(feature = "serde", serde(tag = "type", rename = "OauthClient"))]
156#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
157pub struct OauthClient {
158  pub id: OauthClientId,
159  pub key: Option<OauthClientKey>,
160  pub display_name: OauthClientDisplayName,
161  pub app_uri: Url,
162  pub callback_uri: Url,
163  pub owner: Option<ShortUser>,
164}
165
166#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
167#[cfg_attr(feature = "serde", serde(tag = "type", rename = "OauthClient"))]
168#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
169pub struct SimpleOauthClient {
170  pub id: OauthClientId,
171  pub key: Option<OauthClientKey>,
172  pub display_name: OauthClientDisplayName,
173  pub app_uri: Url,
174  pub callback_uri: Url,
175  pub owner: Option<UserIdRef>,
176}
177
178#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
179#[cfg_attr(feature = "serde", serde(tag = "type", rename = "OauthClient"))]
180#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
181pub struct SimpleOauthClientWithSecret {
182  pub id: OauthClientId,
183  pub key: Option<OauthClientKey>,
184  pub display_name: OauthClientDisplayName,
185  pub app_uri: Url,
186  pub callback_uri: Url,
187  pub owner: Option<UserIdRef>,
188  pub secret: PasswordHash,
189}
190
191#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
192#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
193pub struct UpsertSystemClientOptions {
194  pub key: OauthClientKey,
195  pub display_name: OauthClientDisplayName,
196  pub app_uri: Url,
197  pub callback_uri: Url,
198  pub secret: Password,
199}
200
201#[derive(Error, Debug)]
202pub enum UpsertSystemClientError {
203  #[error(transparent)]
204  Other(AnyError),
205}
206
207#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
208#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
209pub struct OauthClientIdRef {
210  pub id: OauthClientId,
211}
212
213impl From<OauthClientId> for OauthClientIdRef {
214  fn from(id: OauthClientId) -> Self {
215    Self { id }
216  }
217}
218
219#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
220#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
221pub struct OauthClientKeyRef {
222  pub key: OauthClientKey,
223}
224
225impl From<OauthClientKey> for OauthClientKeyRef {
226  fn from(key: OauthClientKey) -> Self {
227    Self { key }
228  }
229}
230
231#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
232#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
233pub enum OauthClientRef {
234  Id(OauthClientIdRef),
235  Key(OauthClientKeyRef),
236}
237
238impl FromStr for OauthClientRef {
239  type Err = ();
240
241  fn from_str(input: &str) -> Result<Self, Self::Err> {
242    if let Ok(client_key) = OauthClientKey::from_str(input) {
243      Ok(OauthClientRef::Key(client_key.into()))
244    } else if let Ok(id) = OauthClientId::from_str(input) {
245      Ok(OauthClientRef::Id(id.into()))
246    } else {
247      Err(())
248    }
249  }
250}
251
252declare_new_string! {
253  pub struct EtwinOauthScopesString(String);
254  pub type ParseError = EtwinOauthScopesStringParseError;
255  const PATTERN = r"^.{0,100}$";
256}
257
258#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
259#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
260pub struct EtwinOauthScopes {
261  pub base: bool,
262}
263
264impl EtwinOauthScopes {
265  pub fn strings(&self) -> Vec<String> {
266    if self.base {
267      vec!["base".to_string()]
268    } else {
269      Vec::new()
270    }
271  }
272}
273
274impl Default for EtwinOauthScopes {
275  fn default() -> Self {
276    Self { base: true }
277  }
278}
279
280impl FromStr for EtwinOauthScopes {
281  type Err = ();
282
283  fn from_str(input: &str) -> Result<Self, Self::Err> {
284    let scopes = input.split(' ').map(str::trim).filter(|s| !s.is_empty());
285    let parsed = EtwinOauthScopes::default();
286    for scope in scopes {
287      match scope {
288        "base" => debug_assert!(parsed.base),
289        _ => return Err(()), // Unknown scope
290      }
291    }
292    Ok(parsed)
293  }
294}
295
296#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
297#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
298pub struct GetOauthClientOptions {
299  pub r#ref: OauthClientRef,
300}
301
302#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
303#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
304pub struct VerifyClientSecretOptions {
305  pub r#ref: OauthClientRef,
306  pub secret: Password,
307}
308
309#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
310#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
311pub struct CreateStoredAccessTokenOptions {
312  pub key: EtwinOauthAccessTokenKey,
313  pub ctime: Instant,
314  pub expiration_time: Instant,
315  pub user: UserIdRef,
316  pub client: OauthClientIdRef,
317}
318
319#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
320#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
321pub struct GetOauthAccessTokenOptions {
322  pub key: EtwinOauthAccessTokenKey,
323  pub touch_accessed_at: bool,
324}
325
326#[derive(Error, Debug)]
327pub enum GetOauthClientError {
328  #[error("oauth client not found: {0:?}")]
329  NotFound(OauthClientRef),
330  #[error(transparent)]
331  Other(AnyError),
332}
333
334#[async_trait]
335#[auto_impl(&, Arc)]
336pub trait OauthProviderStore: Send + Sync {
337  async fn upsert_system_client(&self, options: &UpsertSystemClientOptions) -> Result<SimpleOauthClient, AnyError>;
338
339  async fn get_client(&self, options: &GetOauthClientOptions) -> Result<SimpleOauthClient, GetOauthClientError>;
340
341  async fn get_client_with_secret(
342    &self,
343    options: &GetOauthClientOptions,
344  ) -> Result<SimpleOauthClientWithSecret, AnyError>;
345
346  async fn create_access_token(
347    &self,
348    options: &CreateStoredAccessTokenOptions,
349  ) -> Result<StoredOauthAccessToken, AnyError>;
350
351  async fn get_access_token(&self, options: &GetOauthAccessTokenOptions) -> Result<StoredOauthAccessToken, AnyError>;
352}
353
354#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
355#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
356pub struct EtwinOauthStateClaims {
357  /**
358   * String used for XSRF protection.
359   */
360  #[cfg_attr(feature = "serde", serde(rename = "rfp"))]
361  pub request_forgery_protection: String,
362  #[cfg_attr(feature = "serde", serde(rename = "a"))]
363  pub action: EtwinOauthStateAction,
364  #[cfg_attr(feature = "serde", serde(rename = "iat", with = "serde_posix_timestamp"))]
365  pub issued_at: Instant,
366  #[cfg_attr(feature = "serde", serde(rename = "as"))]
367  pub authorization_server: String,
368  #[cfg_attr(feature = "serde", serde(rename = "exp", with = "serde_posix_timestamp"))]
369  pub expiration_time: Instant,
370}
371
372#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
373#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
374#[cfg_attr(feature = "serde", serde(tag = "type"))]
375pub enum EtwinOauthStateAction {
376  #[cfg_attr(feature = "serde", serde(rename = "Login"))]
377  Login,
378  #[cfg_attr(feature = "serde", serde(rename = "Link"))]
379  Link { user_id: UserId },
380}
381
382#[cfg(feature = "serde")]
383pub mod serde_posix_timestamp {
384  use crate::core::Instant;
385  use serde::{Deserialize, Deserializer, Serialize, Serializer};
386
387  pub fn serialize<S>(value: &Instant, serializer: S) -> Result<S::Ok, S::Error>
388  where
389    S: Serializer,
390  {
391    value.into_posix_timestamp().serialize(serializer)
392  }
393
394  pub fn deserialize<'de, D>(deserializer: D) -> Result<Instant, D::Error>
395  where
396    D: Deserializer<'de>,
397  {
398    i64::deserialize(deserializer).map(Instant::from_posix_timestamp)
399  }
400}