conrad_oauth/
lib.rs

1use async_trait::async_trait;
2use conrad_core::{
3    auth::Authenticator, database::DatabaseAdapter, errors::AuthError, Key, NaiveKeyType, User,
4    UserData, UserId,
5};
6use errors::OAuthError;
7use oauth2::url::Url;
8
9pub mod errors;
10pub mod providers;
11mod utils;
12
13#[derive(Clone)]
14pub struct OAuthConfig {
15    client_id: String,
16    client_secret: String,
17    scope: Vec<String>,
18}
19
20impl OAuthConfig {
21    pub fn new(client_id: String, client_secret: String, scope: Vec<String>) -> Self {
22        Self {
23            client_id,
24            client_secret,
25            scope,
26        }
27    }
28}
29
30#[async_trait]
31pub trait OAuthProvider {
32    type Config;
33    type UserInfo;
34
35    fn new(config: Self::Config) -> Self;
36    fn get_authorization_url(&self) -> RedirectInfo;
37    async fn validate_callback(
38        &self,
39        code: String,
40    ) -> Result<ValidationResult<Self::UserInfo>, OAuthError>;
41}
42
43#[derive(Clone, Debug)]
44pub struct RedirectInfo {
45    pub url: Url,
46    pub csrf_token: String,
47}
48
49#[derive(Clone)]
50pub struct ValidationResult<T> {
51    pub tokens: Tokens,
52    pub provider_user: T,
53    pub auth_info: AuthInfo,
54}
55
56#[derive(Clone)]
57pub struct AuthInfo {
58    provider_id: &'static str,
59    provider_user_id: String,
60}
61
62impl AuthInfo {
63    pub fn into_auth_connector<D, U>(self, auth: &Authenticator<D, U>) -> AuthConnector<D, U> {
64        AuthConnector {
65            auth_info: self,
66            auth,
67        }
68    }
69}
70
71#[derive(Clone)]
72pub struct AuthConnector<'a, D, U> {
73    auth_info: AuthInfo,
74    auth: &'a Authenticator<D, U>,
75}
76
77impl<'a, D, U> AuthConnector<'a, D, U>
78where
79    D: DatabaseAdapter<U>,
80{
81    pub async fn get_existing_user(&self) -> Result<Option<U>, AuthError> {
82        let res = {
83            let key = self
84                .auth
85                .use_key(
86                    self.auth_info.provider_id,
87                    &self.auth_info.provider_user_id,
88                    None,
89                )
90                .await?;
91            self.auth.get_user(&key.user_id).await
92        };
93        match res {
94            Ok(e) => Ok(Some(e)),
95            Err(AuthError::InvalidKeyId) => Ok(None),
96            Err(err) => Err(err),
97        }
98    }
99
100    pub async fn create_persistent_key(&self, user_id: UserId) -> Result<Key, AuthError> {
101        let user_data = UserData::new(
102            self.auth_info.provider_id.to_string(),
103            self.auth_info.provider_user_id.clone(),
104            None,
105        );
106        self.auth
107            .create_key(user_id, user_data, &NaiveKeyType::Persistent)
108            .await
109    }
110
111    pub async fn create_user(&self, attributes: U) -> Result<User<U>, AuthError> {
112        let user_data = UserData::new(
113            self.auth_info.provider_id.to_string(),
114            self.auth_info.provider_user_id.clone(),
115            None,
116        );
117        self.auth.create_user(user_data, attributes).await
118    }
119}
120
121#[derive(Debug, Clone)]
122pub struct Tokens {
123    pub access_token: String,
124    pub expiration_info: Option<ExpirationInfo>,
125    pub scope: Option<Vec<String>>,
126}
127
128#[derive(Debug, Clone)]
129pub struct ExpirationInfo {
130    pub refresh_token: String,
131    pub expires_in: i64,
132}