use std::sync::{Arc, RwLock};
use base64::Engine;
use log::debug;
use rand::{distributions::Alphanumeric, Rng};
use reqwest::{IntoUrl, Method, Url};
use serde::Deserialize;
use sha2::Digest;
use super::{
private, ACCOUNTS_API_TOKEN_ENDPOINT, ACCOUNTS_AUTHORIZE_ENDPOINT, PKCE_VERIFIER_LENGTH, RANDOM_STATE_LENGTH,
};
#[cfg(feature = "async")]
use super::{private::AsyncClient, AccessTokenRefreshAsync};
#[cfg(feature = "sync")]
use super::{private::SyncClient, AccessTokenRefreshSync};
use crate::{
error::{Error, Result},
model::error::AuthenticationErrorKind,
scope::ToScopesString,
};
#[cfg(feature = "async")]
pub type AsyncAuthorizationCodeUserClient = AuthorizationCodeUserClient<AsyncClient>;
#[cfg(feature = "sync")]
pub type SyncAuthorizationCodeUserClient = AuthorizationCodeUserClient<SyncClient>;
#[cfg(feature = "async")]
pub type AsyncIncompleteAuthorizationCodeUserClient = IncompleteAuthorizationCodeUserClient<AsyncClient>;
#[cfg(feature = "sync")]
pub type SyncIncompleteAuthorizationCodeUserClient = IncompleteAuthorizationCodeUserClient<SyncClient>;
#[cfg(feature = "async")]
pub type AsyncAuthorizationCodeUserClientBuilder = AuthorizationCodeUserClientBuilder<AsyncClient>;
#[cfg(feature = "sync")]
pub type SyncAuthorizationCodeUserClientBuilder = AuthorizationCodeUserClientBuilder<SyncClient>;
#[derive(Debug, Clone)]
pub struct AuthorizationCodeUserClient<C>
where
C: private::HttpClient + Clone,
{
inner: Arc<AuthorizationCodeUserClientRef>,
http_client: C,
}
#[derive(Debug)]
struct AuthorizationCodeUserClientRef {
access_token: RwLock<String>,
refresh_token: RwLock<String>,
client_id: Option<String>,
}
#[derive(Debug)]
pub struct IncompleteAuthorizationCodeUserClient<C>
where
C: private::HttpClient + Clone,
{
client_id: String,
redirect_uri: String,
state: String,
scopes: Option<String>,
show_dialog: bool,
pkce_verifier: Option<String>,
http_client: C,
}
#[derive(Debug)]
pub struct AuthorizationCodeUserClientBuilder<C>
where
C: private::HttpClient + Clone,
{
client_id: String,
redirect_uri: String,
scopes: Option<String>,
show_dialog: bool,
pkce_verifier: Option<String>,
http_client: C,
}
#[derive(Debug, Deserialize)]
struct AuthorizeUserTokenResponse {
access_token: String,
refresh_token: String,
#[allow(dead_code)]
scope: Option<String>,
#[allow(dead_code)]
expires_in: u32,
#[allow(dead_code)]
token_type: String,
}
#[derive(Debug, Deserialize)]
struct RefreshUserTokenResponse {
access_token: String,
refresh_token: Option<String>,
#[allow(dead_code)]
scope: Option<String>,
#[allow(dead_code)]
expires_in: u32,
#[allow(dead_code)]
token_type: String,
}
impl<C> AuthorizationCodeUserClient<C>
where
C: private::HttpClient + Clone,
{
fn new_from_refresh_token(
token_response: RefreshUserTokenResponse,
refresh_token: String,
client_id: Option<String>,
http_client: C,
) -> Self {
debug!(
"Got token response for refreshing authorization code flow tokens: {:?}",
token_response
);
let refresh_token = token_response.refresh_token.unwrap_or(refresh_token);
Self {
inner: Arc::new(AuthorizationCodeUserClientRef {
access_token: RwLock::new(token_response.access_token),
refresh_token: RwLock::new(refresh_token),
client_id,
}),
http_client,
}
}
pub fn get_refresh_token(&self) -> String {
self.inner
.refresh_token
.read()
.expect("refresh token rwlock poisoned")
.to_owned()
}
fn update_access_and_refresh_tokens(&self, token_response: RefreshUserTokenResponse) {
debug!(
"Got token response for refreshing authorization code flow tokens: {:?}",
token_response
);
*self.inner.access_token.write().expect("access token rwlock poisoned") = token_response.access_token;
if let Some(refresh_token) = token_response.refresh_token {
*self.inner.refresh_token.write().expect("refresh token rwlock poisoned") = refresh_token;
}
}
}
#[cfg(feature = "async")]
impl AsyncAuthorizationCodeUserClient {
pub(crate) async fn new_with_refresh_token(
http_client: AsyncClient,
refresh_token: String,
client_id: Option<String>,
) -> Result<Self> {
debug!(
"Attempting to create new authorization code flow client with existng refresh token: {} and client ID \
(for PKCE): {:?}",
refresh_token, client_id
);
let response = http_client
.post(ACCOUNTS_API_TOKEN_ENDPOINT)
.form(&build_refresh_token_request_form(&refresh_token, client_id.as_deref()))
.send()
.await?;
let response = super::extract_authentication_error_async(response)
.await
.map_err(map_refresh_token_error)?;
let token_response = response.json().await?;
Ok(Self::new_from_refresh_token(
token_response,
refresh_token,
client_id,
http_client,
))
}
}
#[cfg(feature = "sync")]
impl SyncAuthorizationCodeUserClient {
pub(crate) fn new_with_refresh_token(
http_client: SyncClient,
refresh_token: String,
client_id: Option<String>,
) -> Result<Self> {
debug!(
"Attempting to create new authorization code flow client with existng refresh token: {} and client ID \
(for PKCE): {:?}",
refresh_token, client_id
);
let response = http_client
.post(ACCOUNTS_API_TOKEN_ENDPOINT)
.form(&build_refresh_token_request_form(&refresh_token, client_id.as_deref()))
.send()?;
let response = super::extract_authentication_error_sync(response).map_err(map_refresh_token_error)?;
let token_response = response.json()?;
Ok(Self::new_from_refresh_token(
token_response,
refresh_token,
client_id,
http_client,
))
}
}
impl<C> IncompleteAuthorizationCodeUserClient<C>
where
C: private::HttpClient + Clone,
{
pub fn get_authorize_url(&self) -> String {
let mut query_params = vec![
("response_type", "code"),
("redirect_uri", self.redirect_uri.as_str()),
("client_id", self.client_id.as_str()),
("state", self.state.as_str()),
("show_dialog", if self.show_dialog { "true" } else { "false" }),
];
if let Some(scopes) = &self.scopes {
query_params.push(("scope", scopes.as_str()));
}
let authorize_url = if let Some(pkce_verifier) = self.pkce_verifier.as_deref() {
let mut hasher = sha2::Sha256::new();
hasher.update(pkce_verifier);
let pkce_challenge = hasher.finalize();
let pkce_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(pkce_challenge);
debug!(
"Using PKCE extension with verifier: {} and challenge: {}",
pkce_verifier, pkce_challenge
);
query_params.extend([("code_challenge_method", "S256"), ("code_challenge", &pkce_challenge)]);
Url::parse_with_params(ACCOUNTS_AUTHORIZE_ENDPOINT, &query_params)
.expect("failed to build authorize URL: invalid base URL (this is likely a bug)")
} else {
Url::parse_with_params(ACCOUNTS_AUTHORIZE_ENDPOINT, &query_params)
.expect("failed to build authorize URL: invalid base URL (this is likely a bug)")
};
authorize_url.into()
}
fn build_authorization_code_token_request_form<'a>(
&'a self,
code: &'a str,
state: &str,
) -> Result<Vec<(&'a str, &'a str)>> {
debug!(
"Attempting to finalize authorization code flow user client with code: {} and state: {}",
code, state
);
if state != self.state {
return Err(Error::AuthorizationCodeStateMismatch);
}
let mut token_request_form = vec![
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", self.redirect_uri.as_str()),
];
if let Some(pkce_verifier) = self.pkce_verifier.as_deref() {
debug!("Requesting access and refresh tokens for authorization code flow with PKCE");
token_request_form.extend([("client_id", self.client_id.as_str()), ("code_verifier", pkce_verifier)]);
} else {
debug!("Requesting access and refresh tokens for authorization code flow");
}
Ok(token_request_form)
}
fn build_client(self, token_response: AuthorizeUserTokenResponse) -> AuthorizationCodeUserClient<C> {
debug!("Got token response for authorization code flow: {:?}", token_response);
AuthorizationCodeUserClient {
http_client: self.http_client,
inner: Arc::new(AuthorizationCodeUserClientRef {
access_token: RwLock::new(token_response.access_token),
refresh_token: RwLock::new(token_response.refresh_token),
client_id: self.pkce_verifier.and(Some(self.client_id)),
}),
}
}
}
#[cfg(feature = "async")]
impl AsyncIncompleteAuthorizationCodeUserClient {
pub async fn finalize(self, code: &str, state: &str) -> Result<AsyncAuthorizationCodeUserClient> {
let token_request_form = self.build_authorization_code_token_request_form(code, state)?;
let response = self
.http_client
.post(ACCOUNTS_API_TOKEN_ENDPOINT)
.form(&token_request_form)
.send()
.await?;
let response = super::extract_authentication_error_async(response)
.await
.map_err(map_authentication_error)?;
let token_response = response.json().await?;
Ok(self.build_client(token_response))
}
}
#[cfg(feature = "sync")]
impl SyncIncompleteAuthorizationCodeUserClient {
pub fn finalize(self, code: &str, state: &str) -> Result<SyncAuthorizationCodeUserClient> {
let token_request_form = self.build_authorization_code_token_request_form(code, state)?;
let response = self
.http_client
.post(ACCOUNTS_API_TOKEN_ENDPOINT)
.form(&token_request_form)
.send()?;
let response = super::extract_authentication_error_sync(response).map_err(map_authentication_error)?;
let token_response = response.json()?;
Ok(self.build_client(token_response))
}
}
#[cfg(feature = "async")]
impl AsyncAuthorizationCodeUserClientBuilder {
pub(super) fn new(redirect_uri: String, client_id: String, http_client: AsyncClient) -> Self {
Self {
client_id,
redirect_uri,
scopes: None,
show_dialog: false,
pkce_verifier: None,
http_client,
}
}
}
#[cfg(feature = "sync")]
impl SyncAuthorizationCodeUserClientBuilder {
pub(super) fn new(redirect_uri: String, client_id: String, http_client: SyncClient) -> Self {
Self {
client_id,
redirect_uri,
scopes: None,
show_dialog: false,
pkce_verifier: None,
http_client,
}
}
}
impl<C> AuthorizationCodeUserClientBuilder<C>
where
C: private::HttpClient + Clone,
{
pub(super) fn with_pkce(self) -> Self {
let code_verifier = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(PKCE_VERIFIER_LENGTH)
.map(char::from)
.collect();
Self {
pkce_verifier: Some(code_verifier),
..self
}
}
pub fn scopes<T>(self, scopes: T) -> Self
where
T: ToScopesString,
{
Self {
scopes: Some(scopes.to_scopes_string()),
..self
}
}
pub fn show_dialog(self, show_dialog: bool) -> Self {
Self { show_dialog, ..self }
}
pub fn build(self) -> IncompleteAuthorizationCodeUserClient<C> {
let state = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(RANDOM_STATE_LENGTH)
.map(char::from)
.collect();
IncompleteAuthorizationCodeUserClient {
redirect_uri: self.redirect_uri,
state,
scopes: self.scopes,
show_dialog: self.show_dialog,
client_id: self.client_id,
pkce_verifier: self.pkce_verifier,
http_client: self.http_client,
}
}
}
impl<C> crate::private::Sealed for AuthorizationCodeUserClient<C> where C: private::HttpClient + Clone {}
#[cfg(feature = "async")]
impl private::BuildHttpRequestAsync for AsyncAuthorizationCodeUserClient {
fn build_http_request<U>(&self, method: Method, url: U) -> reqwest::RequestBuilder
where
U: IntoUrl,
{
let access_token = self.inner.access_token.read().expect("access token rwlock poisoned");
self.http_client.request(method, url).bearer_auth(access_token.as_str())
}
}
#[cfg(feature = "sync")]
impl private::BuildHttpRequestSync for SyncAuthorizationCodeUserClient {
fn build_http_request<U>(&self, method: Method, url: U) -> reqwest::blocking::RequestBuilder
where
U: IntoUrl,
{
let access_token = self.inner.access_token.read().expect("access token rwlock poisoned");
self.http_client.request(method, url).bearer_auth(access_token.as_str())
}
}
#[cfg(feature = "async")]
impl super::ScopedClient for AsyncAuthorizationCodeUserClient {}
#[cfg(feature = "sync")]
impl super::ScopedClient for SyncAuthorizationCodeUserClient {}
#[cfg(feature = "async")]
impl super::UnscopedClient for AsyncAuthorizationCodeUserClient {}
#[cfg(feature = "sync")]
impl super::UnscopedClient for SyncAuthorizationCodeUserClient {}
#[cfg(feature = "async")]
#[async_trait::async_trait]
impl super::AccessTokenRefreshAsync for AsyncAuthorizationCodeUserClient {
async fn refresh_access_token(&self) -> Result<()> {
let response = {
let refresh_token = self.inner.refresh_token.read().expect("refresh token rwlock poisoned");
debug!(
"Attempting to refresh authorization code flow access token with refresh token: {}",
refresh_token
);
let request = self
.http_client
.post(ACCOUNTS_API_TOKEN_ENDPOINT)
.form(&build_refresh_token_request_form(
&refresh_token,
self.inner.client_id.as_deref(),
))
.send();
drop(refresh_token);
request
}
.await?;
let response = super::extract_authentication_error_async(response)
.await
.map_err(map_refresh_token_error)?;
let token_response = response.json().await?;
self.update_access_and_refresh_tokens(token_response);
Ok(())
}
}
#[cfg(feature = "sync")]
impl super::AccessTokenRefreshSync for SyncAuthorizationCodeUserClient {
fn refresh_access_token(&self) -> Result<()> {
let refresh_token = self.inner.refresh_token.read().expect("refresh token rwlock poisoned");
debug!(
"Attempting to refresh authorization code flow access token with refresh token: {}",
refresh_token
);
let response = self
.http_client
.post(ACCOUNTS_API_TOKEN_ENDPOINT)
.form(&build_refresh_token_request_form(
&refresh_token,
self.inner.client_id.as_deref(),
))
.send()?;
drop(refresh_token);
let response = super::extract_authentication_error_sync(response).map_err(map_refresh_token_error)?;
let token_response = response.json()?;
self.update_access_and_refresh_tokens(token_response);
Ok(())
}
}
#[cfg(feature = "async")]
#[async_trait::async_trait]
impl private::AccessTokenExpiryAsync for AsyncAuthorizationCodeUserClient {
async fn handle_access_token_expired(&self) -> Result<private::AccessTokenExpiryResult> {
self.refresh_access_token().await?;
Ok(private::AccessTokenExpiryResult::Ok)
}
}
#[cfg(feature = "sync")]
impl private::AccessTokenExpirySync for SyncAuthorizationCodeUserClient {
fn handle_access_token_expired(&self) -> Result<private::AccessTokenExpiryResult> {
self.refresh_access_token()?;
Ok(private::AccessTokenExpiryResult::Ok)
}
}
fn build_refresh_token_request_form<'a>(refresh_token: &'a str, client_id: Option<&'a str>) -> Vec<(&'a str, &'a str)> {
let mut token_request_form = vec![("grant_type", "refresh_token"), ("refresh_token", refresh_token)];
if let Some(client_id) = client_id {
token_request_form.push(("client_id", client_id));
}
token_request_form
}
fn map_authentication_error(err: Error) -> Error {
if let Error::UnhandledAuthenticationError(AuthenticationErrorKind::InvalidGrant, _) = err {
Error::InvalidAuthorizationCode
} else {
err
}
}
fn map_refresh_token_error(err: Error) -> Error {
if let Error::UnhandledAuthenticationError(AuthenticationErrorKind::InvalidGrant, description) = err {
Error::InvalidRefreshToken(description)
} else {
err
}
}