use crate::{
alphabets, auth_urls,
clients::{BaseClient, OAuthClient},
generate_random_string,
http::{Form, HttpClient},
join_scopes, params,
sync::Mutex,
ClientResult, Config, Credentials, OAuth, Token,
};
use base64::{engine::general_purpose, Engine as _};
use std::collections::HashMap;
use std::sync::Arc;
use maybe_async::maybe_async;
use sha2::{Digest, Sha256};
use url::Url;
#[derive(Clone, Debug, Default)]
pub struct AuthCodePkceSpotify {
pub creds: Credentials,
pub oauth: OAuth,
pub config: Config,
pub token: Arc<Mutex<Option<Token>>>,
pub verifier: Option<String>,
pub(crate) http: HttpClient,
}
#[cfg_attr(target_arch = "wasm32", maybe_async(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), maybe_async)]
impl BaseClient for AuthCodePkceSpotify {
fn get_http(&self) -> &HttpClient {
&self.http
}
fn get_token(&self) -> Arc<Mutex<Option<Token>>> {
Arc::clone(&self.token)
}
fn get_creds(&self) -> &Credentials {
&self.creds
}
fn get_config(&self) -> &Config {
&self.config
}
async fn refetch_token(&self) -> ClientResult<Option<Token>> {
match self.token.lock().await.unwrap().as_ref() {
Some(Token {
refresh_token: Some(refresh_token),
..
}) => {
let mut data = Form::new();
data.insert(params::GRANT_TYPE, params::GRANT_TYPE_REFRESH_TOKEN);
data.insert(params::REFRESH_TOKEN, refresh_token);
data.insert(params::CLIENT_ID, &self.creds.id);
let token = self.fetch_access_token(&data, None).await?;
if let Some(callback_fn) = &*self.get_config().token_callback_fn.clone() {
callback_fn.0(token.clone())?;
}
Ok(Some(token))
}
_ => Ok(None),
}
}
}
#[cfg_attr(target_arch = "wasm32", maybe_async(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), maybe_async)]
impl OAuthClient for AuthCodePkceSpotify {
fn get_oauth(&self) -> &OAuth {
&self.oauth
}
async fn request_token(&self, code: &str) -> ClientResult<()> {
log::info!("Requesting PKCE Auth Code token");
let verifier = self.verifier.as_ref().expect(
"Unknown code verifier. Try calling \
`AuthCodePkceSpotify::get_authorize_url` first or setting it \
yourself.",
);
let mut data = Form::new();
data.insert(params::CLIENT_ID, &self.creds.id);
data.insert(params::GRANT_TYPE, params::GRANT_TYPE_AUTH_CODE);
data.insert(params::CODE, code);
data.insert(params::REDIRECT_URI, &self.oauth.redirect_uri);
data.insert(params::CODE_VERIFIER, verifier);
let token = self.fetch_access_token(&data, None).await?;
if let Some(callback_fn) = &*self.get_config().token_callback_fn.clone() {
callback_fn.0(token.clone())?;
}
*self.token.lock().await.unwrap() = Some(token);
self.write_token_cache().await
}
}
impl AuthCodePkceSpotify {
#[must_use]
pub fn new(creds: Credentials, oauth: OAuth) -> Self {
Self {
creds,
oauth,
..Default::default()
}
}
#[must_use]
pub fn from_token(token: Token) -> Self {
Self {
token: Arc::new(Mutex::new(Some(token))),
..Default::default()
}
}
#[must_use]
pub fn with_config(creds: Credentials, oauth: OAuth, config: Config) -> Self {
Self {
creds,
oauth,
config,
..Default::default()
}
}
fn generate_codes(verifier_bytes: usize) -> (String, String) {
log::info!("Generating PKCE codes");
debug_assert!(verifier_bytes >= 43);
debug_assert!(verifier_bytes <= 128);
let verifier = generate_random_string(verifier_bytes, alphabets::PKCE_CODE_VERIFIER);
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let challenge = hasher.finalize();
let challenge = general_purpose::URL_SAFE_NO_PAD.encode(challenge);
(verifier, challenge)
}
pub fn get_authorize_url(&mut self, verifier_bytes: Option<usize>) -> ClientResult<String> {
log::info!("Building auth URL");
let scopes = join_scopes(&self.oauth.scopes);
let verifier_bytes = verifier_bytes.unwrap_or(43);
let (verifier, challenge) = Self::generate_codes(verifier_bytes);
self.verifier = Some(verifier);
let mut payload: HashMap<&str, &str> = HashMap::new();
payload.insert(params::CLIENT_ID, &self.creds.id);
payload.insert(params::RESPONSE_TYPE, params::RESPONSE_TYPE_CODE);
payload.insert(params::REDIRECT_URI, &self.oauth.redirect_uri);
payload.insert(
params::CODE_CHALLENGE_METHOD,
params::CODE_CHALLENGE_METHOD_S256,
);
payload.insert(params::CODE_CHALLENGE, &challenge);
payload.insert(params::STATE, &self.oauth.state);
payload.insert(params::SCOPE, &scopes);
let request_url = self.auth_url(auth_urls::AUTHORIZE);
let parsed = Url::parse_with_params(&request_url, payload)?;
Ok(parsed.into())
}
}