1#![forbid(unsafe_code)]
28#![deny(rust_2018_idioms)]
29#![warn(missing_docs, clippy::pedantic)]
30#![allow(
31 clippy::module_name_repetitions,
32 clippy::non_ascii_literal,
33 clippy::items_after_statements,
34 clippy::filter_map
35)]
36#![cfg_attr(test, allow(clippy::float_cmp))]
37
38use std::collections::HashMap;
39use std::env::{self, VarError};
40use std::error::Error as StdError;
41use std::ffi::OsStr;
42use std::fmt::{self, Display, Formatter};
43use std::time::{Duration, Instant};
44
45use reqwest::{header, RequestBuilder, Url};
46use serde::de::DeserializeOwned;
47use serde::{Deserialize, Serialize};
48use tokio::sync::{Mutex, MutexGuard};
49
50pub use authorization_url::*;
51pub use endpoints::*;
52pub use isocountry::CountryCode;
54pub use isolanguage_1::LanguageCode;
56pub use model::*;
57
58mod authorization_url;
59pub mod endpoints;
60pub mod model;
61mod util;
62
63#[derive(Debug)]
72pub struct Client {
73 pub credentials: ClientCredentials,
75 client: reqwest::Client,
76 cache: Mutex<AccessToken>,
77 debug: bool,
78}
79
80impl Client {
81 #[must_use]
83 pub fn new(credentials: ClientCredentials) -> Self {
84 Self {
85 credentials,
86 client: reqwest::Client::new(),
87 cache: Mutex::new(AccessToken::new(None)),
88 debug: false,
89 }
90 }
91 #[must_use]
93 pub fn with_refresh(credentials: ClientCredentials, refresh_token: String) -> Self {
94 Self {
95 credentials,
96 client: reqwest::Client::new(),
97 cache: Mutex::new(AccessToken::new(Some(refresh_token))),
98 debug: false,
99 }
100 }
101 pub async fn refresh_token(&self) -> Option<String> {
103 self.cache.lock().await.refresh_token.clone()
104 }
105 pub async fn set_refresh_token(&self, refresh_token: Option<String>) {
107 self.cache.lock().await.refresh_token = refresh_token;
108 }
109 pub async fn current_access_token(&self) -> (String, Instant) {
111 let cache = self.cache.lock().await;
112 (cache.token.clone(), cache.expires)
113 }
114 pub async fn set_current_access_token(&self, token: String, expires: Instant) {
117 let mut cache = self.cache.lock().await;
118 cache.token = token;
119 cache.expires = expires;
120 }
121
122 async fn token_request(&self, params: TokenRequest<'_>) -> Result<AccessToken, Error> {
123 let request = self
124 .client
125 .post("https://accounts.spotify.com/api/token")
126 .basic_auth(&self.credentials.id, Some(&self.credentials.secret))
127 .form(¶ms)
128 .build()?;
129
130 if self.debug {
131 dbg!(&request, body_str(&request));
132 }
133
134 let response = self.client.execute(request).await?;
135 let status = response.status();
136 let text = response.text().await?;
137 if !status.is_success() {
138 if self.debug {
139 eprintln!(
140 "Authentication failed ({}). Response body is '{}'",
141 status, text
142 );
143 }
144 return Err(Error::Auth(serde_json::from_str(&text)?));
145 }
146
147 if self.debug {
148 dbg!(status);
149 eprintln!("Authentication response body is '{}'", text);
150 }
151
152 Ok(serde_json::from_str(&text)?)
153 }
154
155 pub async fn redirected(&self, url: &str, state: &str) -> Result<(), RedirectedError> {
166 let url = Url::parse(url)?;
167
168 let pairs: HashMap<_, _> = url.query_pairs().collect();
169
170 if pairs
171 .get("state")
172 .map_or(true, |url_state| url_state != state)
173 {
174 return Err(RedirectedError::IncorrectState);
175 }
176
177 if let Some(error) = pairs.get("error") {
178 return Err(RedirectedError::AuthFailed(error.to_string()));
179 }
180
181 let code = pairs
182 .get("code")
183 .ok_or_else(|| RedirectedError::AuthFailed(String::new()))?;
184
185 let token = self
186 .token_request(TokenRequest::AuthorizationCode {
187 code: &*code,
188 redirect_uri: &url[..url::Position::AfterPath],
189 })
190 .await?;
191 *self.cache.lock().await = token;
192
193 Ok(())
194 }
195
196 async fn access_token(&self) -> Result<MutexGuard<'_, AccessToken>, Error> {
197 let mut cache = self.cache.lock().await;
198 if Instant::now() >= cache.expires {
199 *cache = match cache.refresh_token.take() {
200 Some(refresh_token) => {
202 let mut token = self
203 .token_request(TokenRequest::RefreshToken {
204 refresh_token: &refresh_token,
205 })
206 .await?;
207 token.refresh_token = Some(refresh_token);
208 token
209 }
210 None => self.token_request(TokenRequest::ClientCredentials).await?,
212 }
213 }
214 Ok(cache)
215 }
216
217 async fn send_text(&self, request: RequestBuilder) -> Result<Response<String>, Error> {
218 let request = request
219 .bearer_auth(&self.access_token().await?.token)
220 .build()?;
221
222 if self.debug {
223 dbg!(&request, body_str(&request));
224 }
225
226 let response = loop {
227 let response = self.client.execute(request.try_clone().unwrap()).await?;
228 if response.status() != 429 {
229 break response;
230 }
231 let wait = response
232 .headers()
233 .get(header::RETRY_AFTER)
234 .and_then(|val| val.to_str().ok())
235 .and_then(|secs| secs.parse::<u64>().ok());
236 let wait = wait.unwrap_or(2);
239 tokio::time::sleep(std::time::Duration::from_secs(wait)).await;
240 };
241 let status = response.status();
242 let cache_control = Duration::from_secs(
243 response
244 .headers()
245 .get_all(header::CACHE_CONTROL)
246 .iter()
247 .filter_map(|value| value.to_str().ok())
248 .flat_map(|value| value.split(|c| c == ','))
249 .find_map(|value| {
250 let mut parts = value.trim().splitn(2, '=');
251 if parts.next().unwrap().eq_ignore_ascii_case("max-age") {
252 parts.next().and_then(|max| max.parse::<u64>().ok())
253 } else {
254 None
255 }
256 })
257 .unwrap_or_default(),
258 );
259
260 let data = response.text().await?;
261 if !status.is_success() {
262 if self.debug {
263 eprintln!("Failed ({}). Response body is '{}'", status, data);
264 }
265 return Err(Error::Endpoint(serde_json::from_str(&data)?));
266 }
267
268 if self.debug {
269 dbg!(status);
270 eprintln!("Response body is '{}'", data);
271 }
272
273 Ok(Response {
274 data,
275 expires: Instant::now() + cache_control,
276 })
277 }
278
279 async fn send_empty(&self, request: RequestBuilder) -> Result<(), Error> {
280 self.send_text(request).await?;
281 Ok(())
282 }
283
284 async fn send_opt_json<T: DeserializeOwned>(
285 &self,
286 request: RequestBuilder,
287 ) -> Result<Response<Option<T>>, Error> {
288 let res = self.send_text(request).await?;
289 Ok(Response {
290 data: if res.data.is_empty() {
291 None
292 } else {
293 serde_json::from_str(&res.data)?
294 },
295 expires: res.expires,
296 })
297 }
298
299 async fn send_json<T: DeserializeOwned>(
300 &self,
301 request: RequestBuilder,
302 ) -> Result<Response<T>, Error> {
303 let res = self.send_text(request).await?;
304 Ok(Response {
305 data: serde_json::from_str(&res.data)?,
306 expires: res.expires,
307 })
308 }
309
310 async fn send_snapshot_id(&self, request: RequestBuilder) -> Result<String, Error> {
311 #[derive(Deserialize)]
312 struct SnapshotId {
313 snapshot_id: String,
314 }
315 Ok(self
316 .send_json::<SnapshotId>(request)
317 .await?
318 .data
319 .snapshot_id)
320 }
321}
322
323#[derive(Debug, Clone, Copy, PartialEq, Eq)]
325pub struct Response<T> {
326 pub data: T,
328 pub expires: Instant,
330}
331
332impl<T> Response<T> {
333 pub fn map<U>(self, f: impl FnOnce(T) -> U) -> Response<U> {
335 Response {
336 data: f(self.data),
337 expires: self.expires,
338 }
339 }
340}
341
342#[derive(Debug, Clone, PartialEq, Eq)]
368pub struct ClientCredentials {
369 pub id: String,
371 pub secret: String,
373}
374
375impl ClientCredentials {
376 pub fn from_env_vars<I: AsRef<OsStr>, S: AsRef<OsStr>>(
382 client_id: I,
383 client_secret: S,
384 ) -> Result<Self, VarError> {
385 Ok(Self {
386 id: env::var(client_id)?,
387 secret: env::var(client_secret)?,
388 })
389 }
390 pub fn from_env() -> Result<Self, VarError> {
399 Self::from_env_vars("CLIENT_ID", "CLIENT_SECRET")
400 }
401}
402
403#[derive(Debug)]
405pub enum RedirectedError {
406 InvalidUrl(url::ParseError),
408 IncorrectState,
410 AuthFailed(String),
414 Token(Error),
416}
417
418impl From<url::ParseError> for RedirectedError {
419 fn from(error: url::ParseError) -> Self {
420 Self::InvalidUrl(error)
421 }
422}
423impl From<Error> for RedirectedError {
424 fn from(error: Error) -> Self {
425 Self::Token(error)
426 }
427}
428
429impl Display for RedirectedError {
430 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
431 match self {
432 Self::InvalidUrl(_) => f.write_str("malformed redirect URL"),
433 Self::IncorrectState => f.write_str("state parameter not found or is incorrect"),
434 Self::AuthFailed(_) => f.write_str("authorization failed"),
435 Self::Token(e) => e.fmt(f),
436 }
437 }
438}
439
440impl StdError for RedirectedError {
441 fn source(&self) -> Option<&(dyn StdError + 'static)> {
442 Some(match self {
443 Self::InvalidUrl(e) => e,
444 Self::Token(e) => e,
445 _ => return None,
446 })
447 }
448}
449
450#[derive(Debug, Serialize)]
451#[serde(tag = "grant_type", rename_all = "snake_case")]
452enum TokenRequest<'a> {
453 RefreshToken {
454 refresh_token: &'a String,
455 },
456 ClientCredentials,
457 AuthorizationCode {
458 code: &'a str,
459 redirect_uri: &'a str,
460 },
461}
462
463#[derive(Debug, Deserialize)]
464struct AccessToken {
465 #[serde(rename = "access_token")]
466 token: String,
467 #[serde(
468 rename = "expires_in",
469 deserialize_with = "util::deserialize_instant_seconds"
470 )]
471 expires: Instant,
472 #[serde(default)]
473 refresh_token: Option<String>,
474}
475
476impl AccessToken {
477 fn new(refresh_token: Option<String>) -> Self {
478 Self {
479 token: String::new(),
480 expires: Instant::now() - Duration::from_secs(1),
481 refresh_token,
482 }
483 }
484}
485
486fn body_str(req: &reqwest::Request) -> Option<&str> {
488 req.body().map(|body| {
489 body.as_bytes().map_or("stream", |bytes| {
490 std::str::from_utf8(bytes).unwrap_or("opaque bytes")
491 })
492 })
493}