spotify_rs/
client.rs

1use std::{
2    fmt::Debug,
3    sync::{Arc, RwLock},
4};
5
6use oauth2::{
7    basic::{
8        BasicErrorResponse, BasicRevocationErrorResponse, BasicTokenIntrospectionResponse,
9        BasicTokenType,
10    },
11    reqwest::async_http_client,
12    AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl,
13    RefreshToken, StandardRevocableToken, TokenUrl,
14};
15use reqwest::{header::CONTENT_LENGTH, Method, Url};
16use serde::{
17    de::{value::BytesDeserializer, DeserializeOwned, IntoDeserializer},
18    Serialize,
19};
20use tracing::info;
21
22use crate::{
23    auth::{
24        AuthCodeFlow, AuthCodePkceFlow, AuthFlow, AuthenticationState, ClientCredsFlow, Scopes,
25        Token, Unauthenticated, UnknownFlow,
26    },
27    error::{Error, Result, SpotifyError},
28};
29
30const AUTHORISATION_URL: &str = "https://accounts.spotify.com/authorize";
31const TOKEN_URL: &str = "https://accounts.spotify.com/api/token";
32pub(crate) const API_URL: &str = "https://api.spotify.com/v1";
33
34pub(crate) type OAuthClient = oauth2::Client<
35    BasicErrorResponse,
36    Token,
37    BasicTokenType,
38    BasicTokenIntrospectionResponse,
39    StandardRevocableToken,
40    BasicRevocationErrorResponse,
41>;
42
43/// A client created using the Authorisation Code Flow.
44pub type AuthCodeClient<A> = Client<A, AuthCodeFlow>;
45
46/// A client created using the Authorisation Code with PKCE Flow.
47pub type AuthCodePkceClient<A> = Client<A, AuthCodePkceFlow>;
48
49/// A client created using the Client Credentials Flow.
50pub type ClientCredsClient<A> = Client<A, ClientCredsFlow>;
51
52#[doc(hidden)]
53#[derive(Debug)]
54pub(crate) enum Body<P: Serialize = ()> {
55    Json(P),
56    File(Vec<u8>),
57}
58
59/// The client which handles the authentication and all the Spotify API requests.
60///
61/// It is recommended to use one of the following: [`AuthCodeClient`], [`AuthCodePkceClient`]
62/// or [`ClientCredsClient`], depending on the chosen authentication flow.
63#[derive(Clone, Debug)]
64pub struct Client<A: AuthenticationState, F: AuthFlow> {
65    /// Dictates whether or not the client will request a new token when the
66    /// current one is about the expire.
67    ///
68    /// It will check if the token has expired in every request.
69    pub auto_refresh: bool,
70    // This is used for the typestate pattern, to differentiate an authenticated
71    // client from an unauthenticated one, but it also holds the Token.
72    pub(crate) auth_state: Arc<RwLock<A>>,
73    // This is used for the typestate pattern to differentiate between different
74    // authorisation flows, as well as hold the CSRF/PKCE verifiers.
75    pub(crate) auth_flow: F,
76    // The OAuth2 client.
77    pub(crate) oauth: OAuthClient,
78    // The HTTP client.
79    pub(crate) http: reqwest::Client,
80}
81
82impl Client<Token, UnknownFlow> {
83    /// Create a new authenticated and authorised client from a refresh token.
84    ///
85    /// This method will fail if the refresh token is invalid or a new one cannot be obtained.
86    pub async fn from_refresh_token(
87        client_id: impl Into<String>,
88        client_secret: Option<&str>,
89        scopes: Option<Scopes>,
90        auto_refresh: bool,
91        refresh_token: String,
92    ) -> Result<Self> {
93        let client_id = ClientId::new(client_id.into());
94        let client_secret = client_secret.map(|s| ClientSecret::new(s.to_owned()));
95
96        let oauth_client = OAuthClient::new(
97            client_id,
98            client_secret,
99            AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
100            Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
101        );
102
103        let refresh_token = RefreshToken::new(refresh_token);
104        let mut req = oauth_client.exchange_refresh_token(&refresh_token);
105
106        if let Some(scopes) = scopes {
107            req = req.add_scopes(scopes.0);
108        }
109
110        let mut token = req.request_async(async_http_client).await?.set_timestamps();
111        if token.refresh_token.is_none() {
112            // "When a refresh token is not returned, continue using the existing token."
113            // https://developer.spotify.com/documentation/web-api/tutorials/refreshing-tokens
114            token.refresh_token = Some(refresh_token);
115        }
116
117        Ok(Self {
118            auto_refresh,
119            auth_state: Arc::new(RwLock::new(token)),
120            auth_flow: UnknownFlow,
121            oauth: oauth_client,
122            http: reqwest::Client::new(),
123        })
124    }
125}
126
127impl<F: AuthFlow> Client<Token, F> {
128    /// Get a reference to the client's token.
129    ///
130    /// Please note that the [RwLock] used here is **not** async-aware, and thus
131    /// the read/write guard should not be held across await points.
132    pub fn token(&self) -> Arc<RwLock<Token>> {
133        self.auth_state.clone()
134    }
135
136    /// Get the access token secret as an owned (cloned) string.
137    /// If you only need a reference, you can use [`token`](Self::token)
138    /// yourself and get a reference from the returned [RwLock].
139    ///
140    /// This method will fail if the `RwLock` that holds the token has
141    /// been poisoned.
142    pub fn access_token(&self) -> Result<String> {
143        let token = self
144            .auth_state
145            .read()
146            .expect("The lock holding the token has been poisoned.");
147
148        Ok(token.access_token.secret().clone())
149    }
150
151    /// Get the refresh token secret as an owned (cloned) string.
152    /// If you only need a reference, you can use [`token`](Self::token)
153    /// yourself and get a reference from the returned [RwLock].
154    ///
155    /// This method will fail if the `RwLock` that holds the token has
156    /// been poisoned.
157    pub fn refresh_token(&self) -> Result<Option<String>> {
158        let token = self
159            .auth_state
160            .read()
161            .expect("The lock holding the token has been poisoned.");
162
163        let refresh_token = token.refresh_token.as_ref().map(|t| t.secret().clone());
164
165        Ok(refresh_token)
166    }
167
168    /// Exchange the refresh token for a new access token and updates it in the client.
169    /// Only some auth flows allow for token refreshing.
170    pub async fn exchange_refresh_token(&self) -> Result<()> {
171        let refresh_token = {
172            let lock = self.auth_state.read().unwrap_or_else(|e| e.into_inner());
173
174            let Some(refresh_token) = &lock.refresh_token else {
175                return Err(Error::RefreshUnavailable);
176            };
177
178            refresh_token.clone()
179        };
180
181        let token = self
182            .oauth
183            .exchange_refresh_token(&refresh_token)
184            .request_async(async_http_client)
185            .await?
186            .set_timestamps();
187
188        let mut lock = self
189            .auth_state
190            .write()
191            .expect("The lock holding the token has been poisoned.");
192        *lock = token;
193        Ok(())
194    }
195
196    pub(crate) async fn request<P: Serialize + Debug, T: DeserializeOwned>(
197        &self,
198        method: Method,
199        endpoint: String,
200        query: Option<P>,
201        body: Option<Body<P>>,
202    ) -> Result<T> {
203        let (token_expired, secret) = {
204            let lock = self
205                .auth_state
206                .read()
207                .expect("The lock holding the token has been poisoned.");
208
209            (lock.is_expired(), lock.access_token.secret().to_owned())
210        };
211
212        if token_expired {
213            if self.auto_refresh {
214                info!("The token has expired, attempting to refresh...");
215
216                self.exchange_refresh_token().await?;
217
218                let lock = self
219                    .auth_state
220                    .read()
221                    .expect("The lock holding the token has been poisoned.");
222
223                info!("The token has been successfully refreshed. The new token will expire in {} seconds", lock.expires_in);
224            } else {
225                info!("The token has expired, automatic refresh is disabled.");
226                return Err(Error::ExpiredToken);
227            }
228        }
229
230        let mut req = {
231            self.http
232                .request(method, format!("{API_URL}{endpoint}"))
233                .bearer_auth(secret)
234        };
235
236        if let Some(q) = query {
237            req = req.query(&q);
238        }
239
240        if let Some(b) = body {
241            match b {
242                Body::Json(j) => req = req.json(&j),
243                Body::File(f) => req = req.body(f),
244            }
245        } else {
246            // Used because Spotify wants a Content-Length header for the PUT /audiobooks/me endpoint even though there is no body
247            // If not supplied, it will return an error in the form of HTML (not JSON), which I believe to be an issue on their end.
248            // No other endpoints so far behave this way.
249            req = req.header(CONTENT_LENGTH, 0);
250        }
251
252        let req = req.build()?;
253        info!(headers = ?req.headers(), "{} request sent to {}", req.method(), req.url());
254
255        let res = self.http.execute(req).await?;
256
257        if res.status().is_success() {
258            let bytes = res.bytes().await?;
259
260            // Try to deserialize from bytes of JSON text;
261            let deserialized = serde_json::from_slice::<T>(&bytes).or_else(|e| {
262                // if the previous operation fails, try deserializing straight
263                // from the bytes, which works for Nil.
264                let de: BytesDeserializer<'_, serde::de::value::Error> =
265                    bytes.as_ref().into_deserializer();
266
267                // This line also converts the serde::de::value::Error to a serde_json::Error
268                // to make it clearer to the end user that deserialization failed.
269                T::deserialize(de).map_err(|_| e)
270            });
271            // .context(DeserializationSnafu { body });
272
273            match deserialized {
274                Ok(content) => Ok(content),
275                Err(err) => {
276                    let body = std::str::from_utf8(&bytes).map_err(|_| Error::InvalidResponse)?;
277
278                    tracing::error!(
279                        %body,
280                        "Failed to deserialize the response body into an object or Nil."
281                    );
282
283                    Err(Error::Deserialization {
284                        source: err,
285                        body: body.to_owned(),
286                    })
287                }
288            }
289        } else {
290            Err(res.json::<SpotifyError>().await?.into())
291        }
292    }
293
294    pub(crate) async fn get<P: Serialize + Debug, T: DeserializeOwned>(
295        &self,
296        endpoint: String,
297        query: impl Into<Option<P>>,
298    ) -> Result<T> {
299        self.request(Method::GET, endpoint, query.into(), None)
300            .await
301    }
302
303    pub(crate) async fn post<P: Serialize + Debug, T: DeserializeOwned>(
304        &self,
305        endpoint: String,
306        body: impl Into<Option<Body<P>>>,
307    ) -> Result<T> {
308        self.request(Method::POST, endpoint, None, body.into())
309            .await
310    }
311
312    pub(crate) async fn put<P: Serialize + Debug, T: DeserializeOwned>(
313        &self,
314        endpoint: String,
315        body: impl Into<Option<Body<P>>>,
316    ) -> Result<T> {
317        self.request(Method::PUT, endpoint, None, body.into()).await
318    }
319
320    pub(crate) async fn delete<P: Serialize + Debug, T: DeserializeOwned>(
321        &self,
322        endpoint: String,
323        body: impl Into<Option<Body<P>>>,
324    ) -> Result<T> {
325        self.request(Method::DELETE, endpoint, None, body.into())
326            .await
327    }
328}
329
330impl AuthCodeClient<Unauthenticated> {
331    /// Create a new client and generate an authorisation URL
332    ///
333    /// You must redirect the user to the returned URL, which in turn redirects them to
334    /// the `redirect_uri` you provided, along with a `code` and `state` parameter in the URl.
335    ///
336    /// They are required for the next step in the auth process.
337    pub fn new<S>(
338        client_id: impl Into<String>,
339        client_secret: impl Into<String>,
340        scopes: S,
341        redirect_uri: RedirectUrl,
342        auto_refresh: bool,
343    ) -> (Self, Url)
344    where
345        S: Into<Scopes>,
346    {
347        let client_id = ClientId::new(client_id.into());
348        let client_secret = Some(ClientSecret::new(client_secret.into()));
349
350        let oauth = OAuthClient::new(
351            client_id,
352            client_secret,
353            AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
354            Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
355        )
356        .set_redirect_uri(redirect_uri);
357
358        let (auth_url, csrf_token) = oauth
359            .authorize_url(CsrfToken::new_random)
360            .add_scopes(scopes.into().0)
361            .url();
362
363        (
364            Client {
365                auto_refresh,
366                auth_state: Arc::new(RwLock::new(Unauthenticated)),
367                auth_flow: AuthCodeFlow { csrf_token },
368                oauth,
369                http: reqwest::Client::new(),
370            },
371            auth_url,
372        )
373    }
374
375    /// This will exchange the `auth_code` for a token which will allow the client
376    /// to make requests.
377    ///
378    /// `csrf_state` is used for CSRF protection.
379    pub async fn authenticate(
380        self,
381        auth_code: impl Into<String>,
382        csrf_state: impl AsRef<str>,
383    ) -> Result<Client<Token, AuthCodeFlow>> {
384        let auth_code = auth_code.into().trim().to_owned();
385        let csrf_state = csrf_state.as_ref().trim();
386
387        if csrf_state != self.auth_flow.csrf_token.secret() {
388            return Err(Error::InvalidStateParameter);
389        }
390
391        let token = self
392            .oauth
393            .exchange_code(AuthorizationCode::new(auth_code))
394            .request_async(async_http_client)
395            .await?
396            .set_timestamps();
397
398        Ok(Client {
399            auto_refresh: self.auto_refresh,
400            auth_state: Arc::new(RwLock::new(token)),
401            auth_flow: self.auth_flow,
402            oauth: self.oauth,
403            http: self.http,
404        })
405    }
406}
407
408impl AuthCodePkceClient<Unauthenticated> {
409    /// Create a new client and generate an authorisation URL
410    ///
411    /// You must redirect the user to the received URL, which in turn redirects them to
412    /// the redirect URI you provided, along with a `code` and `state` parameter in the URl.
413    ///
414    /// They are required for the next step in the auth process.
415    pub fn new<T, S>(
416        client_id: T,
417        scopes: S,
418        redirect_uri: RedirectUrl,
419        auto_refresh: bool,
420    ) -> (Self, Url)
421    where
422        T: Into<String>,
423        S: Into<Scopes>,
424    {
425        let client_id = ClientId::new(client_id.into());
426
427        let oauth = OAuthClient::new(
428            client_id,
429            None,
430            AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
431            Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
432        )
433        .set_redirect_uri(redirect_uri);
434
435        let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
436
437        let (auth_url, csrf_token) = oauth
438            .authorize_url(CsrfToken::new_random)
439            .add_scopes(scopes.into().0)
440            .set_pkce_challenge(pkce_challenge)
441            .url();
442
443        (
444            Client {
445                auto_refresh,
446                auth_state: Arc::new(RwLock::new(Unauthenticated)),
447                auth_flow: AuthCodePkceFlow {
448                    csrf_token,
449                    pkce_verifier: Some(pkce_verifier),
450                },
451                oauth,
452                http: reqwest::Client::new(),
453            },
454            auth_url,
455        )
456    }
457
458    /// This will exchange the `auth_code` for a token which will allow the client
459    /// to make requests.
460    ///
461    /// `csrf_state` is used for CSRF protection.
462    pub async fn authenticate(
463        mut self,
464        auth_code: impl Into<String>,
465        csrf_state: impl AsRef<str>,
466    ) -> Result<Client<Token, AuthCodePkceFlow>> {
467        let auth_code = auth_code.into().trim().to_owned();
468        let csrf_state = csrf_state.as_ref().trim();
469
470        if csrf_state != self.auth_flow.csrf_token.secret() {
471            return Err(Error::InvalidStateParameter);
472        }
473
474        let Some(pkce_verifier) = self.auth_flow.pkce_verifier.take() else {
475            // This should never be reached realistically, but an error
476            // will be thrown and log issued just in case.
477            tracing::error!(client = ?self, "No PKCE code verifier present when authenticating the client.");
478            return Err(Error::InvalidClientState);
479        };
480
481        let token = self
482            .oauth
483            .exchange_code(AuthorizationCode::new(auth_code))
484            .set_pkce_verifier(pkce_verifier)
485            .request_async(async_http_client)
486            .await?
487            .set_timestamps();
488
489        Ok(Client {
490            auto_refresh: self.auto_refresh,
491            auth_state: Arc::new(RwLock::new(token)),
492            auth_flow: self.auth_flow,
493            oauth: self.oauth,
494            http: self.http,
495        })
496    }
497}
498
499impl ClientCredsClient<Unauthenticated> {
500    /// This will exchange the client credentials for an access token used
501    /// to make requests.
502    ///
503    /// This authentication method doesn't allow for token refreshing or to access
504    /// user resources.
505    pub async fn authenticate(
506        client_id: impl Into<String>,
507        client_secret: impl Into<String>,
508    ) -> Result<ClientCredsClient<Token>> {
509        let client_id = ClientId::new(client_id.into());
510        let client_secret = Some(ClientSecret::new(client_secret.into()));
511
512        let oauth = OAuthClient::new(
513            client_id,
514            client_secret,
515            AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
516            Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
517        );
518
519        let token = oauth
520            .exchange_client_credentials()
521            .request_async(async_http_client)
522            .await?
523            .set_timestamps();
524
525        Ok(Client {
526            auto_refresh: false,
527            auth_state: Arc::new(RwLock::new(token)),
528            auth_flow: ClientCredsFlow,
529            oauth,
530            http: reqwest::Client::new(),
531        })
532    }
533}
534
535impl AuthCodeClient<Token> {
536    /// Create a new authenticated client from an access token.
537    /// This client will be able to access user data.
538    ///
539    /// This method will fail if the access token is invalid (a request will
540    /// be sent to check the token).
541    pub async fn from_access_token(
542        client_id: impl Into<String>,
543        client_secret: impl Into<String>,
544        auto_refresh: bool,
545        token: Token,
546    ) -> Result<Self> {
547        let client_id = ClientId::new(client_id.into());
548        // client_secret.map(|s| ClientSecret::new(s.to_owned()));
549        let client_secret = Some(ClientSecret::new(client_secret.into()));
550
551        let oauth_client = OAuthClient::new(
552            client_id,
553            client_secret,
554            AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
555            Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
556        );
557
558        let http = reqwest::Client::new();
559
560        // This is just a bogus request to check if the token is valid.
561        let res = http
562            .get(format!("{API_URL}/markets"))
563            .bearer_auth(token.secret())
564            .header(CONTENT_LENGTH, 0)
565            .send()
566            .await?;
567
568        if !res.status().is_success() {
569            return Err(res.json::<SpotifyError>().await?.into());
570        }
571
572        let auth_flow = AuthCodeFlow {
573            csrf_token: CsrfToken::new("not needed".to_owned()),
574        };
575
576        let auto_refresh = auto_refresh && token.refresh_token.is_some();
577
578        Ok(Self {
579            auto_refresh,
580            auth_state: Arc::new(RwLock::new(token)),
581            auth_flow,
582            oauth: oauth_client,
583            http,
584        })
585    }
586}
587
588impl AuthCodePkceClient<Token> {
589    /// Create a new authenticated client from an access token.
590    /// This client will be able to access user data.
591    ///
592    /// This method will fail if the access token is invalid (a request will
593    /// be sent to check the token).
594    pub async fn from_access_token(
595        client_id: impl Into<String>,
596        auto_refresh: bool,
597        token: Token,
598    ) -> Result<Self> {
599        let client_id = ClientId::new(client_id.into());
600
601        let oauth_client = OAuthClient::new(
602            client_id,
603            None,
604            AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
605            Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
606        );
607
608        let http = reqwest::Client::new();
609
610        // This is just a bogus request to check if the token is valid.
611        let res = http
612            .get(format!("{API_URL}/recommendations/available-genre-seeds"))
613            .bearer_auth(token.secret())
614            .header(CONTENT_LENGTH, 0)
615            .send()
616            .await?;
617
618        if !res.status().is_success() {
619            return Err(res.json::<SpotifyError>().await?.into());
620        }
621
622        let auth_flow = AuthCodePkceFlow {
623            csrf_token: CsrfToken::new("not needed".to_owned()),
624            pkce_verifier: None,
625        };
626
627        let auto_refresh = auto_refresh && token.refresh_token.is_some();
628
629        Ok(Self {
630            auto_refresh,
631            auth_state: Arc::new(RwLock::new(token)),
632            auth_flow,
633            oauth: oauth_client,
634            http,
635        })
636    }
637}
638
639impl ClientCredsClient<Token> {
640    /// Create a new authenticated client from an access token.
641    /// This client will not be able to access user data.
642    ///
643    /// This method will fail if the access token is invalid (a request will
644    /// be sent to check the token).
645    pub async fn from_access_token(
646        client_id: impl Into<String>,
647        client_secret: impl Into<String>,
648        token: Token,
649    ) -> Result<Self> {
650        let client_id = ClientId::new(client_id.into());
651        let client_secret = Some(ClientSecret::new(client_secret.into()));
652
653        let oauth_client = OAuthClient::new(
654            client_id,
655            client_secret,
656            AuthUrl::new(AUTHORISATION_URL.to_owned()).unwrap(),
657            Some(TokenUrl::new(TOKEN_URL.to_owned()).unwrap()),
658        );
659
660        let http = reqwest::Client::new();
661
662        // This is just a bogus request to check if the token is valid.
663        let res = http
664            .get(format!("{API_URL}/recommendations/available-genre-seeds"))
665            .bearer_auth(token.secret())
666            .header(CONTENT_LENGTH, 0)
667            .send()
668            .await?;
669
670        if !res.status().is_success() {
671            return Err(res.json::<SpotifyError>().await?.into());
672        }
673
674        Ok(Self {
675            auto_refresh: false,
676            auth_state: Arc::new(RwLock::new(token)),
677            auth_flow: ClientCredsFlow,
678            oauth: oauth_client,
679            http,
680        })
681    }
682}