use super::{
AuthError, AuthResult,
private::{AsyncAuthFlow, AuthFlow},
};
use crate::{
RestError,
api::{ApiError, FormParams, QueryParams},
auth::scopes::{self, Scope},
model::Token,
};
use async_trait::async_trait;
use reqwest::blocking::Client;
use std::collections::HashSet;
use url::Url;
#[derive(Debug, Clone)]
pub struct AuthCodePKCE {
client_id: String,
redirect_uri: String,
scopes: Option<HashSet<Scope>>,
state: Option<String>,
code_verifier: Option<String>,
}
impl AuthCodePKCE {
pub fn new(
client_id: impl Into<String>,
redirect_uri: impl Into<String>,
scopes: impl Into<Option<HashSet<Scope>>>,
) -> Self {
Self {
client_id: client_id.into(),
redirect_uri: redirect_uri.into(),
scopes: scopes.into(),
state: None,
code_verifier: None,
}
}
pub fn set_scopes(&mut self, scopes: Option<HashSet<Scope>>) {
self.scopes = scopes;
}
pub fn user_authorization_url(&mut self) -> String {
let code_verifier = crypto::generate_code_verifier(128);
let code_challenge = crypto::generate_code_challenge(&code_verifier);
let state = crypto::random_string(16);
let mut params = QueryParams::default();
params
.push("client_id", &self.client_id)
.push("response_type", &"code")
.push("redirect_uri", &self.redirect_uri)
.push("state", &state)
.push_opt("scope", self.scopes.as_ref().map(scopes::to_string))
.push("code_challenge_method", &"S256")
.push("code_challenge", &code_challenge);
let mut url =
Url::parse("https://accounts.spotify.com/authorize").expect("This URL is always valid");
params.add_to_url(&mut url);
self.state = Some(state);
self.code_verifier = Some(code_verifier);
url.as_str().to_owned()
}
pub fn verify_authorization_code(&self, url: &str) -> AuthResult<String> {
let self_state = self.state.as_ref().ok_or(AuthError::NoState)?;
let url = Url::parse(url)?;
let mut code = None;
let mut state = None;
for (key, value) in url.query_pairs() {
match key.as_ref() {
"code" => code = Some(value),
"state" => state = Some(value),
_ => {}
}
}
let code = code.ok_or(AuthError::CodeNotFound)?;
let state = state.ok_or(AuthError::InvalidState {
expected: self_state.to_owned(),
got: "None".to_owned(),
})?;
if self_state.eq(&state) {
Ok(code.to_string())
} else {
Err(AuthError::InvalidState {
expected: self_state.to_owned(),
got: state.to_string(),
})
}
}
pub fn request_token(&self, code: &str, client: &Client) -> Result<Token, ApiError<RestError>> {
let code_verifier = self
.code_verifier
.as_ref()
.ok_or(AuthError::NoCodeVerifier)?;
let params = self.token_request_params(code, code_verifier);
super::request_token(client, None, params)
}
pub async fn request_token_async(
&self,
code: &str,
client: &reqwest::Client,
) -> Result<Token, ApiError<RestError>> {
let code_verifier = self
.code_verifier
.as_ref()
.ok_or(AuthError::NoCodeVerifier)?;
let params = self.token_request_params(code, code_verifier);
super::request_token_async(client, None, params).await
}
pub fn request_token_from_redirect_url(
&self,
url: &str,
client: &Client,
) -> Result<Token, ApiError<RestError>> {
let code = self.verify_authorization_code(url)?;
let code_verifier = self
.code_verifier
.as_ref()
.ok_or(AuthError::NoCodeVerifier)?;
let params = self.token_request_params(&code, code_verifier);
super::request_token(client, None, params)
}
pub async fn request_token_from_redirect_url_async(
&self,
url: &str,
client: &reqwest::Client,
) -> Result<Token, ApiError<RestError>> {
let code = self.verify_authorization_code(url)?;
let code_verifier = self
.code_verifier
.as_ref()
.ok_or(AuthError::NoCodeVerifier)?;
let params = self.token_request_params(&code, code_verifier);
super::request_token_async(client, None, params).await
}
fn token_request_params<'a>(&self, code: &'a str, code_verifier: &'a str) -> FormParams<'a> {
let mut params = FormParams::default();
params.push("grant_type", &"authorization_code");
params.push("code", &code);
params.push("redirect_uri", &self.redirect_uri);
params.push("client_id", &self.client_id);
params.push("code_verifier", &code_verifier);
params
}
fn refresh_token_request_params<'a>(&self, refresh_token: &'a str) -> FormParams<'a> {
let mut params = FormParams::default();
params.push("grant_type", &"refresh_token");
params.push("refresh_token", &refresh_token);
params.push("client_id", &self.client_id);
params
}
}
impl AuthFlow for AuthCodePKCE {
fn refresh_token(
&self,
client: &Client,
refresh_token: &str,
) -> Result<Token, ApiError<RestError>> {
let params = self.refresh_token_request_params(refresh_token);
let (request, data) = super::init_http_request_and_data(None, params)?;
let response = super::send_http_request(client, request, data).map_err(ApiError::client)?;
super::parse_http_response(&response)
}
}
#[async_trait]
impl AsyncAuthFlow for AuthCodePKCE {
async fn refresh_token_async(
&self,
client: &reqwest::Client,
refresh_token: &str,
) -> Result<Token, ApiError<RestError>> {
let params = self.refresh_token_request_params(refresh_token);
let (request, data) = super::init_http_request_and_data(None, params)?;
let response = super::send_http_request_async(client, request, data)
.await
.map_err(ApiError::client)?;
super::parse_http_response(&response)
}
}
mod crypto {
use base64::{Engine as _, engine::general_purpose};
use rand::Rng as _;
use sha2::{Digest, Sha256};
const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~";
pub fn generate_code_verifier(length: usize) -> String {
let length = length.clamp(43, 128);
random_string(length)
}
pub fn generate_code_challenge(code_verifier: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(code_verifier.as_bytes());
let result = hasher.finalize();
general_purpose::URL_SAFE_NO_PAD.encode(result)
}
pub fn random_string(length: usize) -> String {
let mut rng = rand::rng();
let s: String = (0..length)
.map(|_| {
let idx = rng.random_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect();
s
}
}
#[cfg(test)]
mod tests {
#[test]
fn random_string() {
let length = 16;
let random_string = super::crypto::random_string(length);
assert_eq!(random_string.len(), length);
}
}