1use async_trait::async_trait;
2use conrad_core::{
3 auth::Authenticator, database::DatabaseAdapter, errors::AuthError, Key, NaiveKeyType, User,
4 UserData, UserId,
5};
6use errors::OAuthError;
7use oauth2::url::Url;
8
9pub mod errors;
10pub mod providers;
11mod utils;
12
13#[derive(Clone)]
14pub struct OAuthConfig {
15 client_id: String,
16 client_secret: String,
17 scope: Vec<String>,
18}
19
20impl OAuthConfig {
21 pub fn new(client_id: String, client_secret: String, scope: Vec<String>) -> Self {
22 Self {
23 client_id,
24 client_secret,
25 scope,
26 }
27 }
28}
29
30#[async_trait]
31pub trait OAuthProvider {
32 type Config;
33 type UserInfo;
34
35 fn new(config: Self::Config) -> Self;
36 fn get_authorization_url(&self) -> RedirectInfo;
37 async fn validate_callback(
38 &self,
39 code: String,
40 ) -> Result<ValidationResult<Self::UserInfo>, OAuthError>;
41}
42
43#[derive(Clone, Debug)]
44pub struct RedirectInfo {
45 pub url: Url,
46 pub csrf_token: String,
47}
48
49#[derive(Clone)]
50pub struct ValidationResult<T> {
51 pub tokens: Tokens,
52 pub provider_user: T,
53 pub auth_info: AuthInfo,
54}
55
56#[derive(Clone)]
57pub struct AuthInfo {
58 provider_id: &'static str,
59 provider_user_id: String,
60}
61
62impl AuthInfo {
63 pub fn into_auth_connector<D, U>(self, auth: &Authenticator<D, U>) -> AuthConnector<D, U> {
64 AuthConnector {
65 auth_info: self,
66 auth,
67 }
68 }
69}
70
71#[derive(Clone)]
72pub struct AuthConnector<'a, D, U> {
73 auth_info: AuthInfo,
74 auth: &'a Authenticator<D, U>,
75}
76
77impl<'a, D, U> AuthConnector<'a, D, U>
78where
79 D: DatabaseAdapter<U>,
80{
81 pub async fn get_existing_user(&self) -> Result<Option<U>, AuthError> {
82 let res = {
83 let key = self
84 .auth
85 .use_key(
86 self.auth_info.provider_id,
87 &self.auth_info.provider_user_id,
88 None,
89 )
90 .await?;
91 self.auth.get_user(&key.user_id).await
92 };
93 match res {
94 Ok(e) => Ok(Some(e)),
95 Err(AuthError::InvalidKeyId) => Ok(None),
96 Err(err) => Err(err),
97 }
98 }
99
100 pub async fn create_persistent_key(&self, user_id: UserId) -> Result<Key, AuthError> {
101 let user_data = UserData::new(
102 self.auth_info.provider_id.to_string(),
103 self.auth_info.provider_user_id.clone(),
104 None,
105 );
106 self.auth
107 .create_key(user_id, user_data, &NaiveKeyType::Persistent)
108 .await
109 }
110
111 pub async fn create_user(&self, attributes: U) -> Result<User<U>, AuthError> {
112 let user_data = UserData::new(
113 self.auth_info.provider_id.to_string(),
114 self.auth_info.provider_user_id.clone(),
115 None,
116 );
117 self.auth.create_user(user_data, attributes).await
118 }
119}
120
121#[derive(Debug, Clone)]
122pub struct Tokens {
123 pub access_token: String,
124 pub expiration_info: Option<ExpirationInfo>,
125 pub scope: Option<Vec<String>>,
126}
127
128#[derive(Debug, Clone)]
129pub struct ExpirationInfo {
130 pub refresh_token: String,
131 pub expires_in: i64,
132}