1use std::{
2 fmt::Debug,
3 sync::{Arc, RwLock},
4};
5
6use oauth2::{
7 basic::{
8 BasicErrorResponse, BasicRevocationErrorResponse, BasicTokenIntrospectionResponse,
9 BasicTokenType,
10 },
11 reqwest::async_http_client,
12 AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl,
13 RefreshToken, StandardRevocableToken, TokenUrl,
14};
15use reqwest::{header::CONTENT_LENGTH, Method, Url};
16use serde::{
17 de::{value::BytesDeserializer, DeserializeOwned, IntoDeserializer},
18 Serialize,
19};
20use tracing::info;
21
22use crate::{
23 auth::{
24 AuthCodeFlow, AuthCodePkceFlow, AuthFlow, AuthenticationState, ClientCredsFlow, Scopes,
25 Token, Unauthenticated, UnknownFlow,
26 },
27 error::{Error, Result, SpotifyError},
28};
29
30const AUTHORISATION_URL: &str = "https://accounts.spotify.com/authorize";
31const TOKEN_URL: &str = "https://accounts.spotify.com/api/token";
32pub(crate) const API_URL: &str = "https://api.spotify.com/v1";
33
34pub(crate) type OAuthClient = oauth2::Client<
35 BasicErrorResponse,
36 Token,
37 BasicTokenType,
38 BasicTokenIntrospectionResponse,
39 StandardRevocableToken,
40 BasicRevocationErrorResponse,
41>;
42
43pub type AuthCodeClient<A> = Client<A, AuthCodeFlow>;
45
46pub type AuthCodePkceClient<A> = Client<A, AuthCodePkceFlow>;
48
49pub type ClientCredsClient<A> = Client<A, ClientCredsFlow>;
51
52#[doc(hidden)]
53#[derive(Debug)]
54pub(crate) enum Body<P: Serialize = ()> {
55 Json(P),
56 File(Vec<u8>),
57}
58
59#[derive(Clone, Debug)]
64pub struct Client<A: AuthenticationState, F: AuthFlow> {
65 pub auto_refresh: bool,
70 pub(crate) auth_state: Arc<RwLock<A>>,
73 pub(crate) auth_flow: F,
76 pub(crate) oauth: OAuthClient,
78 pub(crate) http: reqwest::Client,
80}
81
82impl Client<Token, UnknownFlow> {
83 pub async fn from_refresh_token(
87 client_id: impl Into<String>,
88 client_secret: Option<&str>,
89 scopes: Option<Scopes>,
90 auto_refresh: bool,
91 refresh_token: String,
92 ) -> Result<Self> {
93 let client_id = ClientId::new(client_id.into());
94 let client_secret = client_secret.map(|s| ClientSecret::new(s.to_owned()));
95
96 let oauth_client = OAuthClient::new(
97 client_id,
98 client_secret,
99 AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
100 Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
101 );
102
103 let refresh_token = RefreshToken::new(refresh_token);
104 let mut req = oauth_client.exchange_refresh_token(&refresh_token);
105
106 if let Some(scopes) = scopes {
107 req = req.add_scopes(scopes.0);
108 }
109
110 let token = req.request_async(async_http_client).await?.set_timestamps();
111
112 Ok(Self {
113 auto_refresh,
114 auth_state: Arc::new(RwLock::new(token)),
115 auth_flow: UnknownFlow,
116 oauth: oauth_client,
117 http: reqwest::Client::new(),
118 })
119 }
120}
121
122impl<F: AuthFlow> Client<Token, F> {
123 pub fn token(&self) -> Arc<RwLock<Token>> {
128 self.auth_state.clone()
129 }
130
131 pub fn access_token(&self) -> Result<String> {
138 let token = self
139 .auth_state
140 .read()
141 .expect("The lock holding the token has been poisoned.");
142
143 Ok(token.access_token.secret().clone())
144 }
145
146 pub fn refresh_token(&self) -> Result<Option<String>> {
153 let token = self
154 .auth_state
155 .read()
156 .expect("The lock holding the token has been poisoned.");
157
158 let refresh_token = token.refresh_token.as_ref().map(|t| t.secret().clone());
159
160 Ok(refresh_token)
161 }
162
163 pub async fn exchange_refresh_token(&self) -> Result<()> {
166 let refresh_token = {
167 let lock = self.auth_state.read().unwrap_or_else(|e| e.into_inner());
168
169 let Some(refresh_token) = &lock.refresh_token else {
170 return Err(Error::RefreshUnavailable);
171 };
172
173 refresh_token.clone()
174 };
175
176 let token = self
177 .oauth
178 .exchange_refresh_token(&refresh_token)
179 .request_async(async_http_client)
180 .await?
181 .set_timestamps();
182
183 let mut lock = self
184 .auth_state
185 .write()
186 .expect("The lock holding the token has been poisoned.");
187 *lock = token;
188 Ok(())
189 }
190
191 pub(crate) async fn request<P: Serialize + Debug, T: DeserializeOwned>(
192 &self,
193 method: Method,
194 endpoint: String,
195 query: Option<P>,
196 body: Option<Body<P>>,
197 ) -> Result<T> {
198 let (token_expired, secret) = {
199 let lock = self
200 .auth_state
201 .read()
202 .expect("The lock holding the token has been poisoned.");
203
204 (lock.is_expired(), lock.access_token.secret().to_owned())
205 };
206
207 if token_expired {
208 if self.auto_refresh {
209 info!("The token has expired, attempting to refresh...");
210
211 self.exchange_refresh_token().await?;
212
213 let lock = self
214 .auth_state
215 .read()
216 .expect("The lock holding the token has been poisoned.");
217
218 info!("The token has been successfully refreshed. The new token will expire in {} seconds", lock.expires_in);
219 } else {
220 info!("The token has expired, automatic refresh is disabled.");
221 return Err(Error::ExpiredToken);
222 }
223 }
224
225 let mut req = {
226 self.http
227 .request(method, format!("{API_URL}{endpoint}"))
228 .bearer_auth(secret)
229 };
230
231 if let Some(q) = query {
232 req = req.query(&q);
233 }
234
235 if let Some(b) = body {
236 match b {
237 Body::Json(j) => req = req.json(&j),
238 Body::File(f) => req = req.body(f),
239 }
240 } else {
241 req = req.header(CONTENT_LENGTH, 0);
245 }
246
247 let req = req.build()?;
248 info!(headers = ?req.headers(), "{} request sent to {}", req.method(), req.url());
249
250 let res = self.http.execute(req).await?;
251
252 if res.status().is_success() {
253 let bytes = res.bytes().await?;
254
255 let deserialized = serde_json::from_slice::<T>(&bytes).or_else(|e| {
257 let de: BytesDeserializer<'_, serde::de::value::Error> =
260 bytes.as_ref().into_deserializer();
261
262 T::deserialize(de).map_err(|_| e)
265 });
266 match deserialized {
269 Ok(content) => Ok(content),
270 Err(err) => {
271 let body = std::str::from_utf8(&bytes).map_err(|_| Error::InvalidResponse)?;
272
273 tracing::error!(
274 %body,
275 "Failed to deserialize the response body into an object or Nil."
276 );
277
278 Err(Error::Deserialization {
279 source: err,
280 body: body.to_owned(),
281 })
282 }
283 }
284 } else {
285 Err(res.json::<SpotifyError>().await?.into())
286 }
287 }
288
289 pub(crate) async fn get<P: Serialize + Debug, T: DeserializeOwned>(
290 &self,
291 endpoint: String,
292 query: impl Into<Option<P>>,
293 ) -> Result<T> {
294 self.request(Method::GET, endpoint, query.into(), None)
295 .await
296 }
297
298 pub(crate) async fn post<P: Serialize + Debug, T: DeserializeOwned>(
299 &self,
300 endpoint: String,
301 body: impl Into<Option<Body<P>>>,
302 ) -> Result<T> {
303 self.request(Method::POST, endpoint, None, body.into())
304 .await
305 }
306
307 pub(crate) async fn put<P: Serialize + Debug, T: DeserializeOwned>(
308 &self,
309 endpoint: String,
310 body: impl Into<Option<Body<P>>>,
311 ) -> Result<T> {
312 self.request(Method::PUT, endpoint, None, body.into()).await
313 }
314
315 pub(crate) async fn delete<P: Serialize + Debug, T: DeserializeOwned>(
316 &self,
317 endpoint: String,
318 body: impl Into<Option<Body<P>>>,
319 ) -> Result<T> {
320 self.request(Method::DELETE, endpoint, None, body.into())
321 .await
322 }
323}
324
325impl AuthCodeClient<Unauthenticated> {
326 pub fn new<S>(
333 client_id: impl Into<String>,
334 client_secret: impl Into<String>,
335 scopes: S,
336 redirect_uri: RedirectUrl,
337 auto_refresh: bool,
338 ) -> (Self, Url)
339 where
340 S: Into<Scopes>,
341 {
342 let client_id = ClientId::new(client_id.into());
343 let client_secret = Some(ClientSecret::new(client_secret.into()));
344
345 let oauth = OAuthClient::new(
346 client_id,
347 client_secret,
348 AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
349 Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
350 )
351 .set_redirect_uri(redirect_uri);
352
353 let (auth_url, csrf_token) = oauth
354 .authorize_url(CsrfToken::new_random)
355 .add_scopes(scopes.into().0)
356 .url();
357
358 (
359 Client {
360 auto_refresh,
361 auth_state: Arc::new(RwLock::new(Unauthenticated)),
362 auth_flow: AuthCodeFlow { csrf_token },
363 oauth,
364 http: reqwest::Client::new(),
365 },
366 auth_url,
367 )
368 }
369
370 pub async fn authenticate(
375 self,
376 auth_code: impl Into<String>,
377 csrf_state: impl AsRef<str>,
378 ) -> Result<Client<Token, AuthCodeFlow>> {
379 let auth_code = auth_code.into().trim().to_owned();
380 let csrf_state = csrf_state.as_ref().trim();
381
382 if csrf_state != self.auth_flow.csrf_token.secret() {
383 return Err(Error::InvalidStateParameter);
384 }
385
386 let token = self
387 .oauth
388 .exchange_code(AuthorizationCode::new(auth_code))
389 .request_async(async_http_client)
390 .await?
391 .set_timestamps();
392
393 Ok(Client {
394 auto_refresh: self.auto_refresh,
395 auth_state: Arc::new(RwLock::new(token)),
396 auth_flow: self.auth_flow,
397 oauth: self.oauth,
398 http: self.http,
399 })
400 }
401}
402
403impl AuthCodePkceClient<Unauthenticated> {
404 pub fn new<T, S>(
411 client_id: T,
412 scopes: S,
413 redirect_uri: RedirectUrl,
414 auto_refresh: bool,
415 ) -> (Self, Url)
416 where
417 T: Into<String>,
418 S: Into<Scopes>,
419 {
420 let client_id = ClientId::new(client_id.into());
421
422 let oauth = OAuthClient::new(
423 client_id,
424 None,
425 AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
426 Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
427 )
428 .set_redirect_uri(redirect_uri);
429
430 let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
431
432 let (auth_url, csrf_token) = oauth
433 .authorize_url(CsrfToken::new_random)
434 .add_scopes(scopes.into().0)
435 .set_pkce_challenge(pkce_challenge)
436 .url();
437
438 (
439 Client {
440 auto_refresh,
441 auth_state: Arc::new(RwLock::new(Unauthenticated)),
442 auth_flow: AuthCodePkceFlow {
443 csrf_token,
444 pkce_verifier: Some(pkce_verifier),
445 },
446 oauth,
447 http: reqwest::Client::new(),
448 },
449 auth_url,
450 )
451 }
452
453 pub async fn authenticate(
458 mut self,
459 auth_code: impl Into<String>,
460 csrf_state: impl AsRef<str>,
461 ) -> Result<Client<Token, AuthCodePkceFlow>> {
462 let auth_code = auth_code.into().trim().to_owned();
463 let csrf_state = csrf_state.as_ref().trim();
464
465 if csrf_state != self.auth_flow.csrf_token.secret() {
466 return Err(Error::InvalidStateParameter);
467 }
468
469 let Some(pkce_verifier) = self.auth_flow.pkce_verifier.take() else {
470 tracing::error!(client = ?self, "No PKCE code verifier present when authenticating the client.");
473 return Err(Error::InvalidClientState);
474 };
475
476 let token = self
477 .oauth
478 .exchange_code(AuthorizationCode::new(auth_code))
479 .set_pkce_verifier(pkce_verifier)
480 .request_async(async_http_client)
481 .await?
482 .set_timestamps();
483
484 Ok(Client {
485 auto_refresh: self.auto_refresh,
486 auth_state: Arc::new(RwLock::new(token)),
487 auth_flow: self.auth_flow,
488 oauth: self.oauth,
489 http: self.http,
490 })
491 }
492}
493
494impl ClientCredsClient<Unauthenticated> {
495 pub async fn authenticate(
501 client_id: impl Into<String>,
502 client_secret: impl Into<String>,
503 ) -> Result<ClientCredsClient<Token>> {
504 let client_id = ClientId::new(client_id.into());
505 let client_secret = Some(ClientSecret::new(client_secret.into()));
506
507 let oauth = OAuthClient::new(
508 client_id,
509 client_secret,
510 AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
511 Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
512 );
513
514 let token = oauth
515 .exchange_client_credentials()
516 .request_async(async_http_client)
517 .await?
518 .set_timestamps();
519
520 Ok(Client {
521 auto_refresh: false,
522 auth_state: Arc::new(RwLock::new(token)),
523 auth_flow: ClientCredsFlow,
524 oauth,
525 http: reqwest::Client::new(),
526 })
527 }
528}
529
530impl AuthCodeClient<Token> {
531 pub async fn from_access_token(
537 client_id: impl Into<String>,
538 client_secret: impl Into<String>,
539 auto_refresh: bool,
540 token: Token,
541 ) -> Result<Self> {
542 let client_id = ClientId::new(client_id.into());
543 let client_secret = Some(ClientSecret::new(client_secret.into()));
545
546 let oauth_client = OAuthClient::new(
547 client_id,
548 client_secret,
549 AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
550 Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
551 );
552
553 let http = reqwest::Client::new();
554
555 let res = http
557 .get(format!("{API_URL}/recommendations/available-genre-seeds"))
558 .bearer_auth(token.secret())
559 .header(CONTENT_LENGTH, 0)
560 .send()
561 .await?;
562
563 if !res.status().is_success() {
564 return Err(res.json::<SpotifyError>().await?.into());
565 }
566
567 let auth_flow = AuthCodeFlow {
568 csrf_token: CsrfToken::new("not needed".to_owned()),
569 };
570
571 let auto_refresh = auto_refresh && token.refresh_token.is_some();
572
573 Ok(Self {
574 auto_refresh,
575 auth_state: Arc::new(RwLock::new(token)),
576 auth_flow,
577 oauth: oauth_client,
578 http,
579 })
580 }
581}
582
583impl AuthCodePkceClient<Token> {
584 pub async fn from_access_token(
590 client_id: impl Into<String>,
591 auto_refresh: bool,
592 token: Token,
593 ) -> Result<Self> {
594 let client_id = ClientId::new(client_id.into());
595
596 let oauth_client = OAuthClient::new(
597 client_id,
598 None,
599 AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
600 Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
601 );
602
603 let http = reqwest::Client::new();
604
605 let res = http
607 .get(format!("{API_URL}/recommendations/available-genre-seeds"))
608 .bearer_auth(token.secret())
609 .header(CONTENT_LENGTH, 0)
610 .send()
611 .await?;
612
613 if !res.status().is_success() {
614 return Err(res.json::<SpotifyError>().await?.into());
615 }
616
617 let auth_flow = AuthCodePkceFlow {
618 csrf_token: CsrfToken::new("not needed".to_owned()),
619 pkce_verifier: None,
620 };
621
622 let auto_refresh = auto_refresh && token.refresh_token.is_some();
623
624 Ok(Self {
625 auto_refresh,
626 auth_state: Arc::new(RwLock::new(token)),
627 auth_flow,
628 oauth: oauth_client,
629 http,
630 })
631 }
632}
633
634impl ClientCredsClient<Token> {
635 pub async fn from_access_token(
641 client_id: impl Into<String>,
642 client_secret: impl Into<String>,
643 token: Token,
644 ) -> Result<Self> {
645 let client_id = ClientId::new(client_id.into());
646 let client_secret = Some(ClientSecret::new(client_secret.into()));
647
648 let oauth_client = OAuthClient::new(
649 client_id,
650 client_secret,
651 AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
652 Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
653 );
654
655 let http = reqwest::Client::new();
656
657 let res = http
659 .get(format!("{API_URL}/recommendations/available-genre-seeds"))
660 .bearer_auth(token.secret())
661 .header(CONTENT_LENGTH, 0)
662 .send()
663 .await?;
664
665 if !res.status().is_success() {
666 return Err(res.json::<SpotifyError>().await?.into());
667 }
668
669 Ok(Self {
670 auto_refresh: false,
671 auth_state: Arc::new(RwLock::new(token)),
672 auth_flow: ClientCredsFlow,
673 oauth: oauth_client,
674 http,
675 })
676 }
677}