mod client_credentials;
mod pkce;
pub mod scopes;
use crate::{
RestError,
api::{ApiError, FormParams, query},
model::Token,
};
use bytes::Bytes;
pub(crate) use client_credentials::ClientCredentials;
use http::{HeaderMap, HeaderValue, Request, Response as HttpResponse, header, request::Builder};
pub(crate) use pkce::AuthCodePKCE;
use reqwest::blocking::Client;
use thiserror::Error;
use url::Url;
pub type AuthResult<T> = Result<T, AuthError>;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum AuthError {
#[error("header value error: {0}")]
HeaderValue(#[from] header::InvalidHeaderValue),
#[error("failed to parse url: {0}")]
UrlParse(#[from] url::ParseError),
#[error("authorization code not found")]
CodeNotFound,
#[error("invalid state parameter: expected {expected} got {got}")]
InvalidState { expected: String, got: String },
#[error(
"AuthCodePKCE's state is None. Make sure to generate a user authorization URL by calling user_authorization_url()"
)]
NoState,
#[error(
"AuthCodePKCE's code_verifier is None. Make sure to generate a code verifier by calling user_authorization_url()"
)]
NoCodeVerifier,
#[error("access token is empty")]
EmptyAccessToken,
#[error("refresh token is empty")]
EmptyRefreshToken,
}
pub(crate) mod private {
use super::AuthError;
use crate::{RestError, api::ApiError, model::Token};
use async_trait::async_trait;
use reqwest::blocking::Client;
pub trait AuthFlow {
fn refresh_token(
&self,
client: &Client,
refresh_token: &str,
) -> Result<Token, ApiError<RestError>> {
let _ = client;
let _ = refresh_token;
Err(AuthError::EmptyRefreshToken.into())
}
}
#[async_trait]
pub trait AsyncAuthFlow {
async fn refresh_token_async(
&self,
client: &reqwest::Client,
refresh_token: &str,
) -> Result<Token, ApiError<RestError>> {
let _ = client;
let _ = refresh_token;
Err(AuthError::EmptyRefreshToken.into())
}
}
}
fn request_token(
client: &Client,
authorization_header: Option<String>,
params: FormParams<'_>,
) -> Result<Token, ApiError<RestError>> {
let (request, data) = init_http_request_and_data(authorization_header, params)?;
let response = send_http_request(client, request, data).map_err(ApiError::client)?;
parse_http_response(&response)
}
async fn request_token_async(
client: &reqwest::Client,
authorization_header: Option<String>,
params: FormParams<'_>,
) -> Result<Token, ApiError<RestError>> {
let (request, data) = init_http_request_and_data(authorization_header, params)?;
let response = send_http_request_async(client, request, data)
.await
.map_err(ApiError::client)?;
parse_http_response(&response)
}
fn set_authorization_header<'a>(
headers: &'a mut HeaderMap<HeaderValue>,
value: &str,
) -> AuthResult<&'a mut HeaderMap<HeaderValue>> {
let mut header_value = HeaderValue::from_str(value)?;
header_value.set_sensitive(true);
headers.insert(header::AUTHORIZATION, header_value);
Ok(headers)
}
fn init_http_request_and_data(
authorization_header: Option<String>,
params: FormParams<'_>,
) -> Result<(Builder, Vec<u8>), ApiError<RestError>> {
let url = Url::parse("https://accounts.spotify.com/api/token")?;
let mut req = Request::builder()
.method(http::Method::POST)
.uri(query::url_to_http_uri(&url));
if let Some(value) = authorization_header {
set_authorization_header(
req.headers_mut()
.expect("failed to get headers on the request builder"),
&value,
)?;
}
let (mime, data) = params
.into_body()?
.map_or((None, Vec::new()), |(mime, data)| {
(Some(mime), data.clone())
});
if let Some(mime) = mime {
req = req.header(header::CONTENT_TYPE, mime);
}
req = req.header(header::CONTENT_LENGTH, data.len().to_string());
Ok((req, data))
}
fn send_http_request(
client: &Client,
request: Builder,
data: Vec<u8>,
) -> Result<http::Response<Bytes>, RestError> {
let http_request = request.body(data)?;
let request = http_request.try_into()?;
let response = client.execute(request)?;
let mut http_response = HttpResponse::builder()
.status(response.status())
.version(response.version());
let headers = http_response
.headers_mut()
.expect("failed to get headers on the request builder");
for (key, value) in response.headers() {
headers.insert(key, value.clone());
}
Ok(http_response.body(response.bytes()?)?)
}
async fn send_http_request_async(
client: &reqwest::Client,
request: Builder,
data: Vec<u8>,
) -> Result<http::Response<Bytes>, RestError> {
let http_request = request.body(data)?;
let request = http_request.try_into()?;
let response = client.execute(request).await?;
let mut http_response = HttpResponse::builder()
.status(response.status())
.version(response.version());
let headers = http_response
.headers_mut()
.expect("failed to get headers on the request builder");
for (key, value) in response.headers() {
headers.insert(key, value.clone());
}
Ok(http_response.body(response.bytes().await?)?)
}
fn parse_http_response<T>(response: &http::Response<Bytes>) -> Result<T, ApiError<RestError>>
where
T: serde::de::DeserializeOwned,
{
let status = response.status();
let v = serde_json::from_slice(response.body())
.map_err(|_e| ApiError::server_error(status, response.body()))?;
if !status.is_success() {
return Err(ApiError::from_spotify_with_status(status, v));
} else if status == http::StatusCode::MOVED_PERMANENTLY {
return Err(ApiError::moved_permanently(
response.headers().get(header::LOCATION),
));
}
serde_json::from_value::<_>(v).map_err(ApiError::data_type::<T>)
}