openauth_plugins/generic_oauth/
provider.rs1use openauth_oauth::oauth2::{
2 authorization_code_request, create_authorization_url, refresh_access_token,
3 refresh_access_token_request, validate_authorization_code, AuthorizationCodeRequest,
4 AuthorizationUrlRequest, ClientTokenRequest, OAuth2Tokens, OAuth2UserInfo, OAuthError,
5 OAuthFormRequest, ProviderOptions, RefreshAccessTokenRequest, SocialAuthorizationCodeRequest,
6 SocialAuthorizationUrlRequest, SocialIdTokenRequest, SocialOAuthProvider, SocialProviderFuture,
7};
8use url::Url;
9
10use super::config::{GenericOAuthConfig, GenericOAuthTokenRequest};
11use super::discovery::DiscoveryCache;
12use super::user_info;
13
14#[derive(Debug, Clone)]
21pub struct GenericOAuthProvider {
22 config: GenericOAuthConfig,
23 discovery_cache: Option<DiscoveryCache>,
24}
25
26impl GenericOAuthProvider {
27 pub fn new(config: GenericOAuthConfig) -> Self {
28 Self {
29 config,
30 discovery_cache: None,
31 }
32 }
33
34 pub(crate) fn with_discovery_cache(
35 config: GenericOAuthConfig,
36 discovery_cache: DiscoveryCache,
37 ) -> Self {
38 Self {
39 config,
40 discovery_cache: Some(discovery_cache),
41 }
42 }
43
44 pub fn config(&self) -> &GenericOAuthConfig {
45 &self.config
46 }
47
48 pub fn authorization_code_request(
49 &self,
50 input: SocialAuthorizationCodeRequest,
51 ) -> Result<OAuthFormRequest, OAuthError> {
52 authorization_code_request(self.authorization_code_input(input)?)
53 }
54
55 pub fn refresh_access_token_request(
56 &self,
57 refresh_token: impl Into<String>,
58 ) -> Result<OAuthFormRequest, OAuthError> {
59 refresh_access_token_request(RefreshAccessTokenRequest {
60 refresh_token: refresh_token.into(),
61 options: self.config.provider_options(),
62 authentication: self.config.authentication,
63 extra_params: self.config.token_url_params.clone(),
64 ..RefreshAccessTokenRequest::default()
65 })
66 }
67
68 fn authorization_code_input(
69 &self,
70 input: SocialAuthorizationCodeRequest,
71 ) -> Result<AuthorizationCodeRequest, OAuthError> {
72 Ok(AuthorizationCodeRequest {
73 code: input.code,
74 redirect_uri: input.redirect_uri,
75 options: self.config.provider_options(),
76 code_verifier: self.config.pkce.then_some(input.code_verifier).flatten(),
77 device_id: input.device_id,
78 authentication: self.config.authentication,
79 headers: super::discovery::headers(&self.config.authorization_headers),
80 additional_params: self.config.token_url_params.clone(),
81 ..AuthorizationCodeRequest::default()
82 })
83 }
84
85 async fn token_endpoint(&self) -> Result<String, OAuthError> {
86 if let Some(token_url) = &self.config.token_url {
87 return Ok(token_url.clone());
88 }
89 let Some(discovery_cache) = &self.discovery_cache else {
90 return Err(OAuthError::InvalidResponse(
91 "Invalid OAuth configuration. Token URL not found.".to_owned(),
92 ));
93 };
94 let discovery = discovery_cache
95 .fetch(&self.config)
96 .await
97 .map_err(|error| OAuthError::InvalidResponse(error.to_string()))?
98 .ok_or_else(|| {
99 OAuthError::InvalidResponse(
100 "Invalid OAuth configuration. Token URL not found.".to_owned(),
101 )
102 })?;
103 discovery.token_endpoint.ok_or_else(|| {
104 OAuthError::InvalidResponse(
105 "Invalid OAuth configuration. Token URL not found.".to_owned(),
106 )
107 })
108 }
109}
110
111impl SocialOAuthProvider for GenericOAuthProvider {
112 fn id(&self) -> &str {
113 &self.config.provider_id
114 }
115
116 fn name(&self) -> &str {
117 &self.config.provider_id
118 }
119
120 fn provider_options(&self) -> ProviderOptions {
121 self.config.provider_options()
122 }
123
124 fn create_authorization_url(
125 &self,
126 input: SocialAuthorizationUrlRequest,
127 ) -> Result<Url, OAuthError> {
128 let Some(authorization_endpoint) = self.config.authorization_url.clone() else {
129 return Err(OAuthError::InvalidResponse(
130 "Invalid OAuth configuration".to_owned(),
131 ));
132 };
133 create_authorization_url(AuthorizationUrlRequest {
134 id: self.config.provider_id.clone(),
135 options: self.config.provider_options(),
136 authorization_endpoint,
137 redirect_uri: input.redirect_uri,
138 state: input.state,
139 code_verifier: self.config.pkce.then_some(input.code_verifier).flatten(),
140 scopes: self.config.scopes(input.scopes),
141 prompt: self.config.prompt.clone(),
142 access_type: self.config.access_type.clone(),
143 response_type: self.config.response_type.clone(),
144 response_mode: self.config.response_mode.clone(),
145 login_hint: input.login_hint,
146 additional_params: self.config.authorization_url_params.clone(),
147 ..AuthorizationUrlRequest::default()
148 })
149 }
150
151 fn validate_authorization_code(
152 &self,
153 input: SocialAuthorizationCodeRequest,
154 ) -> SocialProviderFuture<'_, OAuth2Tokens> {
155 Box::pin(async move {
156 if let Some(get_token) = &self.config.get_token {
157 return get_token(GenericOAuthTokenRequest {
158 code: input.code,
159 redirect_uri: self
160 .config
161 .redirect_uri
162 .clone()
163 .unwrap_or(input.redirect_uri),
164 code_verifier: self.config.pkce.then_some(input.code_verifier).flatten(),
165 device_id: input.device_id,
166 })
167 .await;
168 }
169 let token_endpoint = self.token_endpoint().await?;
170 validate_authorization_code(ClientTokenRequest {
171 token_endpoint,
172 request: self.authorization_code_input(input)?,
173 })
174 .await
175 })
176 }
177
178 fn get_user_info(
179 &self,
180 tokens: OAuth2Tokens,
181 _provider_user: Option<serde_json::Value>,
182 ) -> SocialProviderFuture<'_, Option<OAuth2UserInfo>> {
183 Box::pin(async move {
184 let user = if let Some(get_user_info) = &self.config.get_user_info {
185 get_user_info(tokens).await?
186 } else {
187 user_info::get_user_info(&tokens, self.config.user_info_url.as_deref()).await?
188 };
189 if let Some(map_profile) = &self.config.map_profile_to_user {
190 if let Some(user) = user {
191 return map_profile(user).await.map(Some);
192 }
193 return Ok(None);
194 }
195 Ok(user)
196 })
197 }
198
199 fn verify_id_token(&self, input: SocialIdTokenRequest) -> SocialProviderFuture<'_, bool> {
200 Box::pin(async move {
201 if let Some(verify_id_token) = &self.config.verify_id_token {
202 return verify_id_token(input).await;
203 }
204 Ok(false)
205 })
206 }
207
208 fn refresh_access_token(
209 &self,
210 refresh_token_value: String,
211 ) -> SocialProviderFuture<'_, OAuth2Tokens> {
212 Box::pin(async move {
213 if let Some(refresh_access_token) = &self.config.refresh_access_token {
214 return refresh_access_token(refresh_token_value).await;
215 }
216 let token_endpoint = self.token_endpoint().await?;
217 refresh_access_token(ClientTokenRequest {
218 token_endpoint,
219 request: RefreshAccessTokenRequest {
220 refresh_token: refresh_token_value,
221 options: self.config.provider_options(),
222 authentication: self.config.authentication,
223 extra_params: self.config.token_url_params.clone(),
224 ..RefreshAccessTokenRequest::default()
225 },
226 })
227 .await
228 })
229 }
230
231 fn revoke_token(&self, token: String) -> SocialProviderFuture<'_, ()> {
232 Box::pin(async move {
233 if let Some(revoke_token) = &self.config.revoke_token {
234 return revoke_token(token).await;
235 }
236 Err(OAuthError::InvalidResponse(format!(
237 "provider does not support token revocation for token `{token}`"
238 )))
239 })
240 }
241}