Skip to main content

openauth_plugins/generic_oauth/
provider.rs

1use openauth_oauth::oauth2::{
2    authorization_code_request, create_authorization_url, refresh_access_token,
3    refresh_access_token_request, validate_authorization_code, AuthorizationCodeRequest,
4    AuthorizationUrlRequest, ClientTokenRequest, OAuth2Tokens, OAuth2UserInfo, OAuthError,
5    OAuthFormRequest, ProviderOptions, RefreshAccessTokenRequest, SocialAuthorizationCodeRequest,
6    SocialAuthorizationUrlRequest, SocialIdTokenRequest, SocialOAuthProvider, SocialProviderFuture,
7};
8use url::Url;
9
10use super::config::{GenericOAuthConfig, GenericOAuthTokenRequest};
11use super::discovery::DiscoveryCache;
12use super::user_info;
13
14/// Social provider implementation used by the generic OAuth plugin.
15///
16/// `SocialOAuthProvider::create_authorization_url` is synchronous, so providers that only
17/// define `discovery_url` cannot resolve their authorization endpoint through this trait method.
18/// Use the plugin routes (`/sign-in/oauth2`, `/oauth2/callback/:providerId`, `/oauth2/link`) as
19/// the canonical flow for discovery-only generic providers.
20#[derive(Debug, Clone)]
21pub struct GenericOAuthProvider {
22    config: GenericOAuthConfig,
23    discovery_cache: Option<DiscoveryCache>,
24}
25
26impl GenericOAuthProvider {
27    pub fn new(config: GenericOAuthConfig) -> Self {
28        Self {
29            config,
30            discovery_cache: None,
31        }
32    }
33
34    pub(crate) fn with_discovery_cache(
35        config: GenericOAuthConfig,
36        discovery_cache: DiscoveryCache,
37    ) -> Self {
38        Self {
39            config,
40            discovery_cache: Some(discovery_cache),
41        }
42    }
43
44    pub fn config(&self) -> &GenericOAuthConfig {
45        &self.config
46    }
47
48    pub fn authorization_code_request(
49        &self,
50        input: SocialAuthorizationCodeRequest,
51    ) -> Result<OAuthFormRequest, OAuthError> {
52        authorization_code_request(self.authorization_code_input(input)?)
53    }
54
55    pub fn refresh_access_token_request(
56        &self,
57        refresh_token: impl Into<String>,
58    ) -> Result<OAuthFormRequest, OAuthError> {
59        refresh_access_token_request(RefreshAccessTokenRequest {
60            refresh_token: refresh_token.into(),
61            options: self.config.provider_options(),
62            authentication: self.config.authentication,
63            extra_params: self.config.token_url_params.clone(),
64            ..RefreshAccessTokenRequest::default()
65        })
66    }
67
68    fn authorization_code_input(
69        &self,
70        input: SocialAuthorizationCodeRequest,
71    ) -> Result<AuthorizationCodeRequest, OAuthError> {
72        Ok(AuthorizationCodeRequest {
73            code: input.code,
74            redirect_uri: input.redirect_uri,
75            options: self.config.provider_options(),
76            code_verifier: self.config.pkce.then_some(input.code_verifier).flatten(),
77            device_id: input.device_id,
78            authentication: self.config.authentication,
79            headers: super::discovery::headers(&self.config.authorization_headers),
80            additional_params: self.config.token_url_params.clone(),
81            ..AuthorizationCodeRequest::default()
82        })
83    }
84
85    async fn token_endpoint(&self) -> Result<String, OAuthError> {
86        if let Some(token_url) = &self.config.token_url {
87            return Ok(token_url.clone());
88        }
89        let Some(discovery_cache) = &self.discovery_cache else {
90            return Err(OAuthError::InvalidResponse(
91                "Invalid OAuth configuration. Token URL not found.".to_owned(),
92            ));
93        };
94        let discovery = discovery_cache
95            .fetch(&self.config)
96            .await
97            .map_err(|error| OAuthError::InvalidResponse(error.to_string()))?
98            .ok_or_else(|| {
99                OAuthError::InvalidResponse(
100                    "Invalid OAuth configuration. Token URL not found.".to_owned(),
101                )
102            })?;
103        discovery.token_endpoint.ok_or_else(|| {
104            OAuthError::InvalidResponse(
105                "Invalid OAuth configuration. Token URL not found.".to_owned(),
106            )
107        })
108    }
109}
110
111impl SocialOAuthProvider for GenericOAuthProvider {
112    fn id(&self) -> &str {
113        &self.config.provider_id
114    }
115
116    fn name(&self) -> &str {
117        &self.config.provider_id
118    }
119
120    fn provider_options(&self) -> ProviderOptions {
121        self.config.provider_options()
122    }
123
124    fn create_authorization_url(
125        &self,
126        input: SocialAuthorizationUrlRequest,
127    ) -> Result<Url, OAuthError> {
128        let Some(authorization_endpoint) = self.config.authorization_url.clone() else {
129            return Err(OAuthError::InvalidResponse(
130                "Invalid OAuth configuration".to_owned(),
131            ));
132        };
133        create_authorization_url(AuthorizationUrlRequest {
134            id: self.config.provider_id.clone(),
135            options: self.config.provider_options(),
136            authorization_endpoint,
137            redirect_uri: input.redirect_uri,
138            state: input.state,
139            code_verifier: self.config.pkce.then_some(input.code_verifier).flatten(),
140            scopes: self.config.scopes(input.scopes),
141            prompt: self.config.prompt.clone(),
142            access_type: self.config.access_type.clone(),
143            response_type: self.config.response_type.clone(),
144            response_mode: self.config.response_mode.clone(),
145            login_hint: input.login_hint,
146            additional_params: self.config.authorization_url_params.clone(),
147            ..AuthorizationUrlRequest::default()
148        })
149    }
150
151    fn validate_authorization_code(
152        &self,
153        input: SocialAuthorizationCodeRequest,
154    ) -> SocialProviderFuture<'_, OAuth2Tokens> {
155        Box::pin(async move {
156            if let Some(get_token) = &self.config.get_token {
157                return get_token(GenericOAuthTokenRequest {
158                    code: input.code,
159                    redirect_uri: self
160                        .config
161                        .redirect_uri
162                        .clone()
163                        .unwrap_or(input.redirect_uri),
164                    code_verifier: self.config.pkce.then_some(input.code_verifier).flatten(),
165                    device_id: input.device_id,
166                })
167                .await;
168            }
169            let token_endpoint = self.token_endpoint().await?;
170            validate_authorization_code(ClientTokenRequest {
171                token_endpoint,
172                request: self.authorization_code_input(input)?,
173            })
174            .await
175        })
176    }
177
178    fn get_user_info(
179        &self,
180        tokens: OAuth2Tokens,
181        _provider_user: Option<serde_json::Value>,
182    ) -> SocialProviderFuture<'_, Option<OAuth2UserInfo>> {
183        Box::pin(async move {
184            let user = if let Some(get_user_info) = &self.config.get_user_info {
185                get_user_info(tokens).await?
186            } else {
187                user_info::get_user_info(&tokens, self.config.user_info_url.as_deref()).await?
188            };
189            if let Some(map_profile) = &self.config.map_profile_to_user {
190                if let Some(user) = user {
191                    return map_profile(user).await.map(Some);
192                }
193                return Ok(None);
194            }
195            Ok(user)
196        })
197    }
198
199    fn verify_id_token(&self, input: SocialIdTokenRequest) -> SocialProviderFuture<'_, bool> {
200        Box::pin(async move {
201            if let Some(verify_id_token) = &self.config.verify_id_token {
202                return verify_id_token(input).await;
203            }
204            Ok(false)
205        })
206    }
207
208    fn refresh_access_token(
209        &self,
210        refresh_token_value: String,
211    ) -> SocialProviderFuture<'_, OAuth2Tokens> {
212        Box::pin(async move {
213            if let Some(refresh_access_token) = &self.config.refresh_access_token {
214                return refresh_access_token(refresh_token_value).await;
215            }
216            let token_endpoint = self.token_endpoint().await?;
217            refresh_access_token(ClientTokenRequest {
218                token_endpoint,
219                request: RefreshAccessTokenRequest {
220                    refresh_token: refresh_token_value,
221                    options: self.config.provider_options(),
222                    authentication: self.config.authentication,
223                    extra_params: self.config.token_url_params.clone(),
224                    ..RefreshAccessTokenRequest::default()
225                },
226            })
227            .await
228        })
229    }
230
231    fn revoke_token(&self, token: String) -> SocialProviderFuture<'_, ()> {
232        Box::pin(async move {
233            if let Some(revoke_token) = &self.config.revoke_token {
234                return revoke_token(token).await;
235            }
236            Err(OAuthError::InvalidResponse(format!(
237                "provider does not support token revocation for token `{token}`"
238            )))
239        })
240    }
241}