use super::{AccessToken, ClientCredentials};
use crate::model::*;
use crate::CLIENT;
use lazy_static::lazy_static;
use rand::Rng;
use reqwest::Url;
use serde::Deserialize;
use std::collections::{HashMap, HashSet};
use std::fmt::{self, Display, Formatter};
use std::time::Instant;
use std::{error, str};
use tokio::sync::Mutex;
lazy_static! {
static ref VALID_STATES: Mutex<HashSet<String>> = Mutex::new(HashSet::new());
}
const STATE_LEN: usize = 16;
const STATE_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~";
fn random_state() -> String {
let mut rng = rand::thread_rng();
let mut state = String::with_capacity(STATE_LEN);
for _ in 0..STATE_LEN {
state.push(STATE_CHARS[rng.gen_range(0, STATE_CHARS.len())].into());
}
state
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Scope {
UgcImageUpload,
UserReadPlaybackState,
UserModifyPlaybackState,
UserReadCurrentlyPlaying,
Streaming,
AppRemoteControl,
UserReadEmail,
UserReadPrivate,
PlaylistReadCollaborative,
PlaylistModifyPublic,
PlaylistReadPrivate,
PlaylistModifyPrivate,
UserLibraryModify,
UserLibraryRead,
UserTopRead,
UserReadRecentlyPlayed,
UserFollowRead,
UserFollowModify,
}
impl Scope {
pub fn as_str(self) -> &'static str {
match self {
Self::UgcImageUpload => "ugc-image-upload",
Self::UserReadPlaybackState => "user-read-playback-state",
Self::UserModifyPlaybackState => "user-modify-playback-state",
Self::UserReadCurrentlyPlaying => "user-read-currently-playing",
Self::Streaming => "streaming",
Self::AppRemoteControl => "app-remote-control",
Self::UserReadEmail => "user-read-email",
Self::UserReadPrivate => "user-read-private",
Self::PlaylistReadCollaborative => "playlist-read-collaborative",
Self::PlaylistModifyPublic => "playlist-modify-public",
Self::PlaylistReadPrivate => "playlist-read-private",
Self::PlaylistModifyPrivate => "playlist-modify-private",
Self::UserLibraryModify => "user-library-modify",
Self::UserLibraryRead => "user-library-read",
Self::UserTopRead => "user-top-read",
Self::UserReadRecentlyPlayed => "user-read-recently-played",
Self::UserFollowRead => "user-follow-read",
Self::UserFollowModify => "user-follow-modify",
}
}
}
pub async fn get_authorization_url(
client_id: &str,
scopes: &[Scope],
force_approve: bool,
redirect_uri: &str,
) -> String {
let mut valid_states = VALID_STATES.lock().await;
let state = loop {
let state = random_state();
if !valid_states.contains(&state) {
break state;
}
};
let url = Url::parse_with_params(
"https://accounts.spotify.com/authorize",
&[
("response_type", "code"),
("state", &state),
("client_id", client_id),
(
"scope",
&scopes
.iter()
.map(|&scope| scope.as_str())
.collect::<Vec<_>>()
.join(" "),
),
("show_dialog", if force_approve { "true" } else { "false" }),
("redirect_uri", redirect_uri),
],
)
.unwrap()
.into_string();
valid_states.insert(state);
url
}
#[derive(Debug)]
pub struct AuthCodeFlow<'cc> {
credentials: &'cc ClientCredentials,
refresh_token: String,
cache: Mutex<AccessToken>,
}
impl<'cc> AuthCodeFlow<'cc> {
pub async fn from_redirect(
credentials: &'cc ClientCredentials,
redirected_to: &str,
) -> Result<AuthCodeFlow<'cc>, FromRedirectError> {
let url = Url::parse(redirected_to).map_err(|_| FromRedirectError::InvalidRedirect)?;
let pairs: HashMap<_, _> = url.query_pairs().collect();
if !VALID_STATES.lock().await.remove(
&pairs
.get("state")
.ok_or(FromRedirectError::InvalidRedirect)?[..],
) {
return Err(FromRedirectError::InvalidRedirect);
}
if let Some(error) = pairs.get("error") {
return Err(FromRedirectError::SpotifyError(SpotifyRedirectError::from(
error.to_string(),
)));
}
let code = pairs
.get("code")
.ok_or(FromRedirectError::InvalidRedirect)?;
let orig_url = &url.as_str()[0..url
.as_str()
.find('?')
.ok_or(FromRedirectError::InvalidRedirect)?];
let response = CLIENT
.post("https://accounts.spotify.com/api/token")
.form(&[
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", orig_url),
])
.basic_auth(&credentials.id, Some(&credentials.secret))
.send()
.await?;
#[derive(Deserialize)]
struct Response {
refresh_token: String,
#[serde(flatten)]
token: AccessToken,
}
let Response {
refresh_token,
token,
} = serde_json::from_str(&response.text().await?)?;
Ok(Self {
credentials,
refresh_token,
cache: Mutex::new(token),
})
}
pub fn from_refresh(credentials: &'cc ClientCredentials, refresh_token: String) -> Self {
Self {
credentials,
refresh_token,
cache: Mutex::default(),
}
}
pub fn get_credentials(&self) -> &ClientCredentials {
self.credentials
}
pub fn get_refresh_token(&self) -> &str {
&self.refresh_token
}
pub fn into_refresh_token(self) -> String {
self.refresh_token
}
pub async fn send(&self) -> Result<AccessToken, EndpointError<AuthenticationError>> {
let cache = self.cache.lock().await;
if Instant::now() < cache.expires {
return Ok(cache.clone());
}
let request = CLIENT
.post("https://accounts.spotify.com/api/token")
.form(&[
("grant_type", "refresh_token"),
("refresh_token", &self.refresh_token),
])
.basic_auth(&self.credentials.id, Some(&self.credentials.secret));
drop(cache);
let response = request.send().await?;
let status = response.status();
let text = response.text().await?;
if !status.is_success() {
return Err(EndpointError::SpotifyError(serde_json::from_str(&text)?));
}
let token = serde_json::from_str::<AccessToken>(&text)?;
*self.cache.lock().await = token.clone();
Ok(token)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SpotifyRedirectError(pub String);
impl Display for SpotifyRedirectError {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.write_str(&self.0)
}
}
impl error::Error for SpotifyRedirectError {}
impl From<String> for SpotifyRedirectError {
fn from(s: String) -> Self {
Self(s)
}
}
#[derive(Debug)]
pub enum FromRedirectError {
InvalidRedirect,
ParseError(serde_json::error::Error),
SpotifyError(SpotifyRedirectError),
HttpError(reqwest::Error),
}
impl From<serde_json::error::Error> for FromRedirectError {
fn from(e: serde_json::error::Error) -> Self {
Self::ParseError(e)
}
}
impl From<reqwest::Error> for FromRedirectError {
fn from(e: reqwest::Error) -> Self {
Self::HttpError(e)
}
}
impl Display for FromRedirectError {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
Self::InvalidRedirect => f.write_str("Invalid redirect URL"),
Self::ParseError(e) => write!(f, "{}", e),
Self::SpotifyError(e) => write!(f, "{}", e),
Self::HttpError(e) => write!(f, "{}", e),
}
}
}
impl error::Error for FromRedirectError {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match self {
Self::HttpError(e) => Some(e),
Self::SpotifyError(e) => Some(e),
_ => None,
}
}
}