use oauth2::{
AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EndpointNotSet, EndpointSet,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope, TokenResponse, TokenUrl,
basic::BasicClient,
};
use crate::error::{AuthError, OAuthError};
use crate::oauth::{OAuthProviderConfig, OAuthTokens};
type ConfiguredClient = BasicClient<
EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet, >;
pub struct OAuthAuthorization {
pub authorize_url: String,
pub csrf_state: String,
pub pkce_verifier: String,
}
pub fn build_authorization(config: &OAuthProviderConfig) -> Result<OAuthAuthorization, AuthError> {
let client = build_client(config)?;
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let mut auth_request = client.authorize_url(CsrfToken::new_random);
auth_request = auth_request.set_pkce_challenge(pkce_challenge);
for scope in &config.scopes {
auth_request = auth_request.add_scope(Scope::new(scope.clone()));
}
let (auth_url, csrf_token) = auth_request.url();
Ok(OAuthAuthorization {
authorize_url: auth_url.to_string(),
csrf_state: csrf_token.secret().clone(),
pkce_verifier: pkce_verifier.secret().clone(),
})
}
pub async fn exchange_code(
config: &OAuthProviderConfig,
code: &str,
pkce_verifier: &str,
) -> Result<OAuthTokens, AuthError> {
let client = build_client(config)?;
let http_client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|_| AuthError::OAuth(OAuthError::ExchangeFailed))?;
let token_response = client
.exchange_code(AuthorizationCode::new(code.to_string()))
.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier.to_string()))
.request_async(&http_client)
.await
.map_err(|_| AuthError::OAuth(OAuthError::ExchangeFailed))?;
let access_token = token_response.access_token().secret().clone();
let refresh_token = token_response.refresh_token().map(|t| t.secret().clone());
let expires_in = token_response
.expires_in()
.and_then(|d| time::Duration::try_from(d).ok());
let scope = token_response.scopes().map(|scopes| {
scopes
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(" ")
});
Ok(OAuthTokens {
access_token: Some(access_token),
refresh_token,
expires_in,
scope,
})
}
pub async fn refresh_access_token(
config: &OAuthProviderConfig,
refresh_token_str: &str,
) -> Result<OAuthTokens, AuthError> {
let client = build_client(config)?;
let http_client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.map_err(|_| AuthError::OAuth(OAuthError::RefreshFailed))?;
let token_response = client
.exchange_refresh_token(&RefreshToken::new(refresh_token_str.to_string()))
.request_async(&http_client)
.await
.map_err(|_| AuthError::OAuth(OAuthError::RefreshFailed))?;
let access_token = token_response.access_token().secret().clone();
let refresh_token = token_response.refresh_token().map(|t| t.secret().clone());
let expires_in = token_response
.expires_in()
.and_then(|d| time::Duration::try_from(d).ok());
let scope = token_response.scopes().map(|scopes| {
scopes
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(" ")
});
Ok(OAuthTokens {
access_token: Some(access_token),
refresh_token,
expires_in,
scope,
})
}
fn build_client(config: &OAuthProviderConfig) -> Result<ConfiguredClient, AuthError> {
let client = BasicClient::new(ClientId::new(config.client_id.clone()))
.set_client_secret(ClientSecret::new(config.client_secret.clone()))
.set_auth_uri(AuthUrl::new(config.auth_url.clone()).map_err(|e| {
AuthError::OAuth(OAuthError::Misconfigured {
message: format!("invalid auth_url: {e}"),
})
})?)
.set_token_uri(TokenUrl::new(config.token_url.clone()).map_err(|e| {
AuthError::OAuth(OAuthError::Misconfigured {
message: format!("invalid token_url: {e}"),
})
})?)
.set_redirect_uri(RedirectUrl::new(config.redirect_url.clone()).map_err(|e| {
AuthError::OAuth(OAuthError::Misconfigured {
message: format!("invalid redirect_url: {e}"),
})
})?);
Ok(client)
}