#![forbid(unsafe_code)]
#![deny(rust_2018_idioms)]
#![warn(missing_docs, clippy::pedantic)]
#![allow(
clippy::module_name_repetitions,
clippy::non_ascii_literal,
clippy::items_after_statements,
clippy::filter_map
)]
#![cfg_attr(test, allow(clippy::float_cmp))]
use std::collections::HashMap;
use std::env::{self, VarError};
use std::error::Error as StdError;
use std::ffi::OsStr;
use std::fmt::{self, Display, Formatter};
use std::time::{Duration, Instant};
use reqwest::{header, RequestBuilder, Url};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, MutexGuard};
pub use authorization_url::*;
pub use endpoints::*;
pub use isocountry::CountryCode;
pub use isolanguage_1::LanguageCode;
pub use model::*;
mod authorization_url;
pub mod endpoints;
pub mod model;
mod util;
#[derive(Debug)]
pub struct Client {
pub credentials: ClientCredentials,
client: reqwest::Client,
cache: Mutex<AccessToken>,
debug: bool,
}
impl Client {
#[must_use]
pub fn new(credentials: ClientCredentials) -> Self {
Self {
credentials,
client: reqwest::Client::new(),
cache: Mutex::new(AccessToken::new(None)),
debug: false,
}
}
#[must_use]
pub fn with_refresh(credentials: ClientCredentials, refresh_token: String) -> Self {
Self {
credentials,
client: reqwest::Client::new(),
cache: Mutex::new(AccessToken::new(Some(refresh_token))),
debug: false,
}
}
pub async fn refresh_token(&self) -> Option<String> {
self.cache.lock().await.refresh_token.clone()
}
pub async fn set_refresh_token(&self, refresh_token: Option<String>) {
self.cache.lock().await.refresh_token = refresh_token;
}
pub async fn current_access_token(&self) -> (String, Instant) {
let cache = self.cache.lock().await;
(cache.token.clone(), cache.expires)
}
pub async fn set_current_access_token(&self, token: String, expires: Instant) {
let mut cache = self.cache.lock().await;
cache.token = token;
cache.expires = expires;
}
async fn token_request(&self, params: TokenRequest<'_>) -> Result<AccessToken, Error> {
let request = self
.client
.post("https://accounts.spotify.com/api/token")
.basic_auth(&self.credentials.id, Some(&self.credentials.secret))
.form(¶ms)
.build()?;
if self.debug {
dbg!(&request, body_str(&request));
}
let response = self.client.execute(request).await?;
let status = response.status();
let text = response.text().await?;
if !status.is_success() {
if self.debug {
eprintln!(
"Authentication failed ({}). Response body is '{}'",
status, text
);
}
return Err(Error::Auth(serde_json::from_str(&text)?));
}
if self.debug {
dbg!(status);
eprintln!("Authentication response body is '{}'", text);
}
Ok(serde_json::from_str(&text)?)
}
pub async fn redirected(&self, url: &str, state: &str) -> Result<(), RedirectedError> {
let url = Url::parse(url)?;
let pairs: HashMap<_, _> = url.query_pairs().collect();
if pairs
.get("state")
.map_or(true, |url_state| url_state != state)
{
return Err(RedirectedError::IncorrectState);
}
if let Some(error) = pairs.get("error") {
return Err(RedirectedError::AuthFailed(error.to_string()));
}
let code = pairs
.get("code")
.ok_or_else(|| RedirectedError::AuthFailed(String::new()))?;
let token = self
.token_request(TokenRequest::AuthorizationCode {
code: &*code,
redirect_uri: &url[..url::Position::AfterPath],
})
.await?;
*self.cache.lock().await = token;
Ok(())
}
async fn access_token(&self) -> Result<MutexGuard<'_, AccessToken>, Error> {
let mut cache = self.cache.lock().await;
if Instant::now() >= cache.expires {
*cache = match cache.refresh_token.take() {
Some(refresh_token) => {
let mut token = self
.token_request(TokenRequest::RefreshToken {
refresh_token: &refresh_token,
})
.await?;
token.refresh_token = Some(refresh_token);
token
}
None => self.token_request(TokenRequest::ClientCredentials).await?,
}
}
Ok(cache)
}
async fn send_text(&self, request: RequestBuilder) -> Result<Response<String>, Error> {
let request = request
.bearer_auth(&self.access_token().await?.token)
.build()?;
if self.debug {
dbg!(&request, body_str(&request));
}
let response = loop {
let response = self.client.execute(request.try_clone().unwrap()).await?;
if response.status() != 429 {
break response;
}
let wait = response
.headers()
.get(header::RETRY_AFTER)
.and_then(|val| val.to_str().ok())
.and_then(|secs| secs.parse::<u64>().ok());
let wait = wait.unwrap_or(2);
tokio::time::sleep(std::time::Duration::from_secs(wait)).await;
};
let status = response.status();
let cache_control = Duration::from_secs(
response
.headers()
.get_all(header::CACHE_CONTROL)
.iter()
.filter_map(|value| value.to_str().ok())
.flat_map(|value| value.split(|c| c == ','))
.find_map(|value| {
let mut parts = value.trim().splitn(2, '=');
if parts.next().unwrap().eq_ignore_ascii_case("max-age") {
parts.next().and_then(|max| max.parse::<u64>().ok())
} else {
None
}
})
.unwrap_or_default(),
);
let data = response.text().await?;
if !status.is_success() {
if self.debug {
eprintln!("Failed ({}). Response body is '{}'", status, data);
}
return Err(Error::Endpoint(serde_json::from_str(&data)?));
}
if self.debug {
dbg!(status);
eprintln!("Response body is '{}'", data);
}
Ok(Response {
data,
expires: Instant::now() + cache_control,
})
}
async fn send_empty(&self, request: RequestBuilder) -> Result<(), Error> {
self.send_text(request).await?;
Ok(())
}
async fn send_opt_json<T: DeserializeOwned>(
&self,
request: RequestBuilder,
) -> Result<Response<Option<T>>, Error> {
let res = self.send_text(request).await?;
Ok(Response {
data: if res.data.is_empty() {
None
} else {
serde_json::from_str(&res.data)?
},
expires: res.expires,
})
}
async fn send_json<T: DeserializeOwned>(
&self,
request: RequestBuilder,
) -> Result<Response<T>, Error> {
let res = self.send_text(request).await?;
Ok(Response {
data: serde_json::from_str(&res.data)?,
expires: res.expires,
})
}
async fn send_snapshot_id(&self, request: RequestBuilder) -> Result<String, Error> {
#[derive(Deserialize)]
struct SnapshotId {
snapshot_id: String,
}
Ok(self
.send_json::<SnapshotId>(request)
.await?
.data
.snapshot_id)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Response<T> {
pub data: T,
pub expires: Instant,
}
impl<T> Response<T> {
pub fn map<U>(self, f: impl FnOnce(T) -> U) -> Response<U> {
Response {
data: f(self.data),
expires: self.expires,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ClientCredentials {
pub id: String,
pub secret: String,
}
impl ClientCredentials {
pub fn from_env_vars<I: AsRef<OsStr>, S: AsRef<OsStr>>(
client_id: I,
client_secret: S,
) -> Result<Self, VarError> {
Ok(Self {
id: env::var(client_id)?,
secret: env::var(client_secret)?,
})
}
pub fn from_env() -> Result<Self, VarError> {
Self::from_env_vars("CLIENT_ID", "CLIENT_SECRET")
}
}
#[derive(Debug)]
pub enum RedirectedError {
InvalidUrl(url::ParseError),
IncorrectState,
AuthFailed(String),
Token(Error),
}
impl From<url::ParseError> for RedirectedError {
fn from(error: url::ParseError) -> Self {
Self::InvalidUrl(error)
}
}
impl From<Error> for RedirectedError {
fn from(error: Error) -> Self {
Self::Token(error)
}
}
impl Display for RedirectedError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidUrl(_) => f.write_str("malformed redirect URL"),
Self::IncorrectState => f.write_str("state parameter not found or is incorrect"),
Self::AuthFailed(_) => f.write_str("authorization failed"),
Self::Token(e) => e.fmt(f),
}
}
}
impl StdError for RedirectedError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(match self {
Self::InvalidUrl(e) => e,
Self::Token(e) => e,
_ => return None,
})
}
}
#[derive(Debug, Serialize)]
#[serde(tag = "grant_type", rename_all = "snake_case")]
enum TokenRequest<'a> {
RefreshToken {
refresh_token: &'a String,
},
ClientCredentials,
AuthorizationCode {
code: &'a str,
redirect_uri: &'a str,
},
}
#[derive(Debug, Deserialize)]
struct AccessToken {
#[serde(rename = "access_token")]
token: String,
#[serde(
rename = "expires_in",
deserialize_with = "util::deserialize_instant_seconds"
)]
expires: Instant,
#[serde(default)]
refresh_token: Option<String>,
}
impl AccessToken {
fn new(refresh_token: Option<String>) -> Self {
Self {
token: String::new(),
expires: Instant::now() - Duration::from_secs(1),
refresh_token,
}
}
}
fn body_str(req: &reqwest::Request) -> Option<&str> {
req.body().map(|body| {
body.as_bytes().map_or("stream", |bytes| {
std::str::from_utf8(bytes).unwrap_or("opaque bytes")
})
})
}