use openauth_oauth::oauth2::{
authorization_code_request, create_authorization_url, refresh_access_token,
refresh_access_token_request, validate_authorization_code, AuthorizationCodeRequest,
AuthorizationUrlRequest, ClientTokenRequest, OAuth2Tokens, OAuth2UserInfo, OAuthError,
OAuthFormRequest, ProviderOptions, RefreshAccessTokenRequest, SocialAuthorizationCodeRequest,
SocialAuthorizationUrlRequest, SocialIdTokenRequest, SocialOAuthProvider, SocialProviderFuture,
};
use url::Url;
use super::config::{GenericOAuthConfig, GenericOAuthTokenRequest};
use super::discovery::DiscoveryCache;
use super::user_info;
#[derive(Debug, Clone)]
pub struct GenericOAuthProvider {
config: GenericOAuthConfig,
discovery_cache: Option<DiscoveryCache>,
}
impl GenericOAuthProvider {
pub fn new(config: GenericOAuthConfig) -> Self {
Self {
config,
discovery_cache: None,
}
}
pub(crate) fn with_discovery_cache(
config: GenericOAuthConfig,
discovery_cache: DiscoveryCache,
) -> Self {
Self {
config,
discovery_cache: Some(discovery_cache),
}
}
pub fn config(&self) -> &GenericOAuthConfig {
&self.config
}
pub fn authorization_code_request(
&self,
input: SocialAuthorizationCodeRequest,
) -> Result<OAuthFormRequest, OAuthError> {
authorization_code_request(self.authorization_code_input(input)?)
}
pub fn refresh_access_token_request(
&self,
refresh_token: impl Into<String>,
) -> Result<OAuthFormRequest, OAuthError> {
refresh_access_token_request(RefreshAccessTokenRequest {
refresh_token: refresh_token.into(),
options: self.config.provider_options(),
authentication: self.config.authentication,
extra_params: self.config.token_url_params.clone(),
..RefreshAccessTokenRequest::default()
})
}
fn authorization_code_input(
&self,
input: SocialAuthorizationCodeRequest,
) -> Result<AuthorizationCodeRequest, OAuthError> {
Ok(AuthorizationCodeRequest {
code: input.code,
redirect_uri: input.redirect_uri,
options: self.config.provider_options(),
code_verifier: self.config.pkce.then_some(input.code_verifier).flatten(),
device_id: input.device_id,
authentication: self.config.authentication,
headers: super::discovery::headers(&self.config.authorization_headers),
additional_params: self.config.token_url_params.clone(),
..AuthorizationCodeRequest::default()
})
}
async fn token_endpoint(&self) -> Result<String, OAuthError> {
if let Some(token_url) = &self.config.token_url {
return Ok(token_url.clone());
}
let Some(discovery_cache) = &self.discovery_cache else {
return Err(OAuthError::InvalidResponse(
"Invalid OAuth configuration. Token URL not found.".to_owned(),
));
};
let discovery = discovery_cache
.fetch(&self.config)
.await
.map_err(|error| OAuthError::InvalidResponse(error.to_string()))?
.ok_or_else(|| {
OAuthError::InvalidResponse(
"Invalid OAuth configuration. Token URL not found.".to_owned(),
)
})?;
discovery.token_endpoint.ok_or_else(|| {
OAuthError::InvalidResponse(
"Invalid OAuth configuration. Token URL not found.".to_owned(),
)
})
}
}
impl SocialOAuthProvider for GenericOAuthProvider {
fn id(&self) -> &str {
&self.config.provider_id
}
fn name(&self) -> &str {
&self.config.provider_id
}
fn provider_options(&self) -> ProviderOptions {
self.config.provider_options()
}
fn create_authorization_url(
&self,
input: SocialAuthorizationUrlRequest,
) -> Result<Url, OAuthError> {
let Some(authorization_endpoint) = self.config.authorization_url.clone() else {
return Err(OAuthError::InvalidResponse(
"Invalid OAuth configuration".to_owned(),
));
};
create_authorization_url(AuthorizationUrlRequest {
id: self.config.provider_id.clone(),
options: self.config.provider_options(),
authorization_endpoint,
redirect_uri: input.redirect_uri,
state: input.state,
code_verifier: self.config.pkce.then_some(input.code_verifier).flatten(),
scopes: self.config.scopes(input.scopes),
prompt: self.config.prompt.clone(),
access_type: self.config.access_type.clone(),
response_type: self.config.response_type.clone(),
response_mode: self.config.response_mode.clone(),
login_hint: input.login_hint,
additional_params: self.config.authorization_url_params.clone(),
..AuthorizationUrlRequest::default()
})
}
fn validate_authorization_code(
&self,
input: SocialAuthorizationCodeRequest,
) -> SocialProviderFuture<'_, OAuth2Tokens> {
Box::pin(async move {
if let Some(get_token) = &self.config.get_token {
return get_token(GenericOAuthTokenRequest {
code: input.code,
redirect_uri: self
.config
.redirect_uri
.clone()
.unwrap_or(input.redirect_uri),
code_verifier: self.config.pkce.then_some(input.code_verifier).flatten(),
device_id: input.device_id,
})
.await;
}
let token_endpoint = self.token_endpoint().await?;
validate_authorization_code(ClientTokenRequest {
token_endpoint,
request: self.authorization_code_input(input)?,
})
.await
})
}
fn get_user_info(
&self,
tokens: OAuth2Tokens,
_provider_user: Option<serde_json::Value>,
) -> SocialProviderFuture<'_, Option<OAuth2UserInfo>> {
Box::pin(async move {
let user = if let Some(get_user_info) = &self.config.get_user_info {
get_user_info(tokens).await?
} else {
user_info::get_user_info(&tokens, self.config.user_info_url.as_deref()).await?
};
if let Some(map_profile) = &self.config.map_profile_to_user {
if let Some(user) = user {
return map_profile(user).await.map(Some);
}
return Ok(None);
}
Ok(user)
})
}
fn verify_id_token(&self, input: SocialIdTokenRequest) -> SocialProviderFuture<'_, bool> {
Box::pin(async move { Ok(user_info::decode_id_token_claims(&input.token).is_some()) })
}
fn refresh_access_token(
&self,
refresh_token_value: String,
) -> SocialProviderFuture<'_, OAuth2Tokens> {
Box::pin(async move {
let token_endpoint = self.token_endpoint().await?;
refresh_access_token(ClientTokenRequest {
token_endpoint,
request: RefreshAccessTokenRequest {
refresh_token: refresh_token_value,
options: self.config.provider_options(),
authentication: self.config.authentication,
extra_params: self.config.token_url_params.clone(),
..RefreshAccessTokenRequest::default()
},
})
.await
})
}
}