oauth2/
lib.rs

1//! [<img alt="github" src="https://img.shields.io/badge/github-udoprog/async--oauth2-8da0cb?style=for-the-badge&logo=github" height="20">](https://github.com/udoprog/async-oauth2)
2//! [<img alt="crates.io" src="https://img.shields.io/crates/v/async-oauth2.svg?style=for-the-badge&color=fc8d62&logo=rust" height="20">](https://crates.io/crates/async-oauth2)
3//! [<img alt="docs.rs" src="https://img.shields.io/badge/docs.rs-async--oauth2-66c2a5?style=for-the-badge&logoColor=white&logo=" height="20">](https://docs.rs/async-oauth2)
4//!
5//! An asynchronous OAuth2 flow implementation, trying to adhere as much as
6//! possible to [RFC 6749].
7//!
8//! <br>
9//!
10//! ## Examples
11//!
12//! To see the library in action, you can go to one of our examples:
13//!
14//! - [Google]
15//! - [Spotify]
16//! - [Twitch]
17//!
18//! If you've checked out the project they can be run like this:
19//!
20//! ```sh
21//! cargo run --manifest-path=examples/Cargo.toml --bin spotify --
22//!     --client-id <client-id> --client-secret <client-secret>
23//! cargo run --manifest-path=examples/Cargo.toml --bin google --
24//!     --client-id <client-id> --client-secret <client-secret>
25//! cargo run --manifest-path=examples/Cargo.toml --bin twitch --
26//!     --client-id <client-id> --client-secret <client-secret>
27//! ```
28//!
29//! > Note: You need to configure your client integration to permit redirects to
30//! > `http://localhost:8080/api/auth/redirect` for these to work. How this is
31//! > done depends on the integration used.
32//!
33//! <br>
34//!
35//! ## Authorization Code Grant
36//!
37//! This is the most common OAuth2 flow.
38//!
39//! ```no_run
40//! use oauth2::*;
41//! use url::Url;
42//!
43//! pub struct ReceivedCode {
44//!     pub code: AuthorizationCode,
45//!     pub state: State,
46//! }
47//!
48//! # async fn listen_for_code(port: u32) -> Result<ReceivedCode, Box<dyn std::error::Error>> { todo!() }
49//! # #[tokio::main]
50//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
51//! let reqwest_client = reqwest::Client::new();
52//!
53//! // Create an OAuth2 client by specifying the client ID, client secret,
54//! // authorization URL and token URL.
55//! let mut client = Client::new(
56//!     "client_id",
57//!     Url::parse("http://authorize")?,
58//!     Url::parse("http://token")?
59//! );
60//!
61//! client.set_client_secret("client_secret");
62//! // Set the URL the user will be redirected to after the authorization
63//! // process.
64//! client.set_redirect_url(Url::parse("http://redirect")?);
65//! // Set the desired scopes.
66//! client.add_scope("read");
67//! client.add_scope("write");
68//!
69//! // Generate the full authorization URL.
70//! let state = State::new_random();
71//! let auth_url = client.authorize_url(&state);
72//!
73//! // This is the URL you should redirect the user to, in order to trigger the
74//! // authorization process.
75//! println!("Browse to: {}", auth_url);
76//!
77//! // Once the user has been redirected to the redirect URL, you'll have the
78//! // access code. For security reasons, your code should verify that the
79//! // `state` parameter returned by the server matches `state`.
80//! let received: ReceivedCode = listen_for_code(8080).await?;
81//!
82//! if received.state != state {
83//!    panic!("CSRF token mismatch :(");
84//! }
85//!
86//! // Now you can trade it for an access token.
87//! let token = client.exchange_code(received.code)
88//!     .with_client(&reqwest_client)
89//!     .execute::<StandardToken>()
90//!     .await?;
91//!
92//! # Ok(())
93//! # }
94//! ```
95//!
96//! <br>
97//!
98//! ## Implicit Grant
99//!
100//! This flow fetches an access token directly from the authorization endpoint.
101//!
102//! Be sure to understand the security implications of this flow before using
103//! it. In most cases the Authorization Code Grant flow above is preferred to
104//! the Implicit Grant flow.
105//!
106//! ```no_run
107//! use oauth2::*;
108//! use url::Url;
109//!
110//! pub struct ReceivedCode {
111//!     pub code: AuthorizationCode,
112//!     pub state: State,
113//! }
114//!
115//! # async fn get_code() -> Result<ReceivedCode, Box<dyn std::error::Error>> { todo!() }
116//! # #[tokio::main]
117//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
118//! let mut client = Client::new(
119//!     "client_id",
120//!     Url::parse("http://authorize")?,
121//!     Url::parse("http://token")?
122//! );
123//!
124//! client.set_client_secret("client_secret");
125//!
126//! // Generate the full authorization URL.
127//! let state = State::new_random();
128//! let auth_url = client.authorize_url_implicit(&state);
129//!
130//! // This is the URL you should redirect the user to, in order to trigger the
131//! // authorization process.
132//! println!("Browse to: {}", auth_url);
133//!
134//! // Once the user has been redirected to the redirect URL, you'll have the
135//! // access code. For security reasons, your code should verify that the
136//! // `state` parameter returned by the server matches `state`.
137//! let received: ReceivedCode = get_code().await?;
138//!
139//! if received.state != state {
140//!     panic!("CSRF token mismatch :(");
141//! }
142//!
143//! # Ok(()) }
144//! ```
145//!
146//! <br>
147//!
148//! ## Resource Owner Password Credentials Grant
149//!
150//! You can ask for a *password* access token by calling the
151//! `Client::exchange_password` method, while including the username and
152//! password.
153//!
154//! ```no_run
155//! use oauth2::*;
156//! use url::Url;
157//!
158//! # #[tokio::main]
159//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
160//! let reqwest_client = reqwest::Client::new();
161//!
162//! let mut client = Client::new(
163//!     "client_id",
164//!     Url::parse("http://authorize")?,
165//!     Url::parse("http://token")?
166//! );
167//!
168//! client.set_client_secret("client_secret");
169//! client.add_scope("read");
170//!
171//! let token = client
172//!     .exchange_password("user", "pass")
173//!     .with_client(&reqwest_client)
174//!     .execute::<StandardToken>()
175//!     .await?;
176//!
177//! # Ok(()) }
178//! ```
179//!
180//! <br>
181//!
182//! ## Client Credentials Grant
183//!
184//! You can ask for a *client credentials* access token by calling the
185//! `Client::exchange_client_credentials` method.
186//!
187//! ```no_run
188//! use oauth2::*;
189//! use url::Url;
190//!
191//! # #[tokio::main]
192//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
193//! let reqwest_client = reqwest::Client::new();
194//! let mut client = Client::new(
195//!     "client_id",
196//!     Url::parse("http://authorize")?,
197//!     Url::parse("http://token")?
198//! );
199//!
200//! client.set_client_secret("client_secret");
201//! client.add_scope("read");
202//!
203//! let token_result = client.exchange_client_credentials()
204//!     .with_client(&reqwest_client)
205//!     .execute::<StandardToken>();
206//!
207//! # Ok(()) }
208//! ```
209//!
210//! <br>
211//!
212//! ## Relationship to oauth2-rs
213//!
214//! This is a fork of [oauth2-rs].
215//!
216//! The main differences are:
217//! * Removal of unnecessary type parameters on Client ([see discussion here]).
218//! * Only support one client implementation ([reqwest]).
219//! * Remove most newtypes except `Scope` and the secret ones since they made the API harder to use.
220//!
221//! [RFC 6749]: https://tools.ietf.org/html/rfc6749
222//! [Google]: https://github.com/udoprog/async-oauth2/blob/master/examples/src/bin/google.rs
223//! [oauth2-rs]: https://github.com/ramosbugs/oauth2-rs
224//! [reqwest]: https://docs.rs/reqwest
225//! [see discussion here]: https://github.com/ramosbugs/oauth2-rs/issues/44#issuecomment-50158653
226//! [Spotify]: https://github.com/udoprog/async-oauth2/blob/master/examples/src/bin/spotify.rs
227//! [Twitch]: https://github.com/udoprog/async-oauth2/blob/master/examples/src/bin/twitch.rs
228
229#![deny(missing_docs)]
230
231use std::{borrow::Cow, error, fmt, time::Duration};
232
233use base64::prelude::{Engine as _, BASE64_URL_SAFE_NO_PAD};
234use rand::{thread_rng, Rng};
235use serde::{Deserialize, Serialize};
236use serde_aux::prelude::*;
237use sha2::{Digest, Sha256};
238use thiserror::Error;
239pub use url::Url;
240
241/// Indicates whether requests to the authorization server should use basic authentication or
242/// include the parameters in the request body for requests in which either is valid.
243///
244/// The default AuthType is *BasicAuth*, following the recommendation of
245/// [Section 2.3.1 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-2.3.1).
246#[derive(Clone, Copy, Debug)]
247pub enum AuthType {
248    /// The client_id and client_secret will be included as part of the request body.
249    RequestBody,
250    /// The client_id and client_secret will be included using the basic auth authentication scheme.
251    BasicAuth,
252}
253
254macro_rules! redacted_debug {
255    ($name:ident) => {
256        impl fmt::Debug for $name {
257            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258                write!(f, concat!(stringify!($name), "([redacted])"))
259            }
260        }
261    };
262}
263
264/// borrowed newtype plumbing
265macro_rules! borrowed_newtype {
266    ($name:ident, $borrowed:ty) => {
267        impl std::ops::Deref for $name {
268            type Target = $borrowed;
269
270            #[inline]
271            fn deref(&self) -> &Self::Target {
272                &self.0
273            }
274        }
275
276        impl<'a> From<&'a $name> for Cow<'a, $borrowed> {
277            #[inline]
278            fn from(value: &'a $name) -> Cow<'a, $borrowed> {
279                Cow::Borrowed(&value.0)
280            }
281        }
282
283        impl AsRef<$borrowed> for $name {
284            #[inline]
285            fn as_ref(&self) -> &$borrowed {
286                self
287            }
288        }
289    };
290}
291
292/// newtype plumbing
293macro_rules! newtype {
294    ($name:ident, $owned:ty, $borrowed:ty) => {
295        borrowed_newtype!($name, $borrowed);
296
297        impl<'a> From<&'a $borrowed> for $name {
298            #[inline]
299            fn from(value: &'a $borrowed) -> Self {
300                Self(value.to_owned())
301            }
302        }
303
304        impl From<$owned> for $name {
305            #[inline]
306            fn from(value: $owned) -> Self {
307                Self(value)
308            }
309        }
310
311        impl<'a> From<&'a $owned> for $name {
312            #[inline]
313            fn from(value: &'a $owned) -> Self {
314                Self(value.to_owned())
315            }
316        }
317
318        impl From<$name> for $owned {
319            #[inline]
320            fn from(value: $name) -> $owned {
321                value.0
322            }
323        }
324    };
325}
326
327/// Access token scope, as defined by the authorization server.
328#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
329pub struct Scope(String);
330newtype!(Scope, String, str);
331
332/// Code Challenge used for [PKCE]((https://tools.ietf.org/html/rfc7636)) protection via the
333/// `code_challenge` parameter.
334#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
335pub struct PkceCodeChallengeS256(String);
336newtype!(PkceCodeChallengeS256, String, str);
337
338/// Code Challenge Method used for [PKCE]((https://tools.ietf.org/html/rfc7636)) protection
339/// via the `code_challenge_method` parameter.
340#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
341pub struct PkceCodeChallengeMethod(String);
342newtype!(PkceCodeChallengeMethod, String, str);
343
344/// Client password issued to the client during the registration process described by
345/// [Section 2.2](https://tools.ietf.org/html/rfc6749#section-2.2).
346#[derive(Clone, Deserialize, Serialize)]
347pub struct ClientSecret(String);
348redacted_debug!(ClientSecret);
349newtype!(ClientSecret, String, str);
350
351/// Value used for [CSRF]((https://tools.ietf.org/html/rfc6749#section-10.12)) protection
352/// via the `state` parameter.
353#[must_use]
354#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
355pub struct State([u8; 16]);
356redacted_debug!(State);
357borrowed_newtype!(State, [u8]);
358
359impl State {
360    /// Generate a new random, base64-encoded 128-bit CSRF token.
361    pub fn new_random() -> Self {
362        let mut random_bytes = [0u8; 16];
363        thread_rng().fill(&mut random_bytes);
364        State(random_bytes)
365    }
366
367    /// Convert into base64.
368    pub fn to_base64(&self) -> String {
369        BASE64_URL_SAFE_NO_PAD.encode(self.0)
370    }
371}
372
373impl serde::Serialize for State {
374    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
375    where
376        S: serde::Serializer,
377    {
378        self.to_base64().serialize(serializer)
379    }
380}
381
382impl<'de> serde::Deserialize<'de> for State {
383    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
384    where
385        D: serde::Deserializer<'de>,
386    {
387        let s = String::deserialize(deserializer)?;
388        let bytes = BASE64_URL_SAFE_NO_PAD
389            .decode(s)
390            .map_err(serde::de::Error::custom)?;
391        let mut buf = [0u8; 16];
392        buf.copy_from_slice(&bytes);
393        Ok(Self(buf))
394    }
395}
396
397/// Code Verifier used for [PKCE]((https://tools.ietf.org/html/rfc7636)) protection via the
398/// `code_verifier` parameter. The value must have a minimum length of 43 characters and a
399/// maximum length of 128 characters.  Each character must be ASCII alphanumeric or one of
400/// the characters "-" / "." / "_" / "~".
401#[derive(Deserialize, Serialize)]
402pub struct PkceCodeVerifierS256(String);
403newtype!(PkceCodeVerifierS256, String, str);
404
405impl PkceCodeVerifierS256 {
406    /// Generate a new random, base64-encoded code verifier.
407    pub fn new_random() -> Self {
408        PkceCodeVerifierS256::new_random_len(32)
409    }
410
411    /// Generate a new random, base64-encoded code verifier.
412    ///
413    /// # Arguments
414    ///
415    /// * `num_bytes` - Number of random bytes to generate, prior to base64-encoding.
416    ///   The value must be in the range 32 to 96 inclusive in order to generate a verifier
417    ///   with a suitable length.
418    pub fn new_random_len(num_bytes: u32) -> Self {
419        // The RFC specifies that the code verifier must have "a minimum length of 43
420        // characters and a maximum length of 128 characters".
421        // This implies 32-96 octets of random data to be base64 encoded.
422        assert!((32..=96).contains(&num_bytes));
423        let random_bytes: Vec<u8> = (0..num_bytes).map(|_| thread_rng().gen::<u8>()).collect();
424        let code = BASE64_URL_SAFE_NO_PAD.encode(random_bytes);
425        assert!(code.len() >= 43 && code.len() <= 128);
426        PkceCodeVerifierS256(code)
427    }
428
429    /// Return the code challenge for the code verifier.
430    pub fn code_challenge(&self) -> PkceCodeChallengeS256 {
431        let digest = Sha256::digest(self.as_bytes());
432        PkceCodeChallengeS256::from(BASE64_URL_SAFE_NO_PAD.encode(digest))
433    }
434
435    /// Return the code challenge method for this code verifier.
436    pub fn code_challenge_method() -> PkceCodeChallengeMethod {
437        PkceCodeChallengeMethod::from("S256".to_string())
438    }
439
440    /// Return the extension params used for authorize_url.
441    pub fn authorize_url_params(&self) -> Vec<(&'static str, String)> {
442        vec![
443            (
444                "code_challenge_method",
445                PkceCodeVerifierS256::code_challenge_method().into(),
446            ),
447            ("code_challenge", self.code_challenge().into()),
448        ]
449    }
450}
451
452/// Authorization code returned from the authorization endpoint.
453#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
454pub struct AuthorizationCode(String);
455redacted_debug!(AuthorizationCode);
456newtype!(AuthorizationCode, String, str);
457
458/// Refresh token used to obtain a new access token (if supported by the authorization server).
459#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
460pub struct RefreshToken(String);
461redacted_debug!(RefreshToken);
462newtype!(RefreshToken, String, str);
463
464/// Access token returned by the token endpoint and used to access protected resources.
465#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
466pub struct AccessToken(String);
467redacted_debug!(AccessToken);
468newtype!(AccessToken, String, str);
469
470/// Resource owner's password used directly as an authorization grant to obtain an access
471/// token.
472pub struct ResourceOwnerPassword(String);
473newtype!(ResourceOwnerPassword, String, str);
474
475/// Stores the configuration for an OAuth2 client.
476#[derive(Clone, Debug)]
477pub struct Client {
478    client_id: String,
479    client_secret: Option<ClientSecret>,
480    auth_url: Url,
481    auth_type: AuthType,
482    token_url: Url,
483    scopes: Vec<Scope>,
484    redirect_url: Option<Url>,
485}
486
487impl Client {
488    /// Initializes an OAuth2 client with the fields common to most OAuth2 flows.
489    ///
490    /// # Arguments
491    ///
492    /// * `client_id` -  Client ID
493    /// * `auth_url` -  Authorization endpoint: used by the client to obtain authorization from
494    ///   the resource owner via user-agent redirection. This URL is used in all standard OAuth2
495    ///   flows except the [Resource Owner Password Credentials
496    ///   Grant](https://tools.ietf.org/html/rfc6749#section-4.3) and the
497    ///   [Client Credentials Grant](https://tools.ietf.org/html/rfc6749#section-4.4).
498    /// * `token_url` - Token endpoint: used by the client to exchange an authorization grant
499    ///   (code) for an access token, typically with client authentication. This URL is used in
500    ///   all standard OAuth2 flows except the
501    ///   [Implicit Grant](https://tools.ietf.org/html/rfc6749#section-4.2). If this value is set
502    ///   to `None`, the `exchange_*` methods will return `Err(ExecuteError::Other(_))`.
503    pub fn new(client_id: impl AsRef<str>, auth_url: Url, token_url: Url) -> Self {
504        Client {
505            client_id: client_id.as_ref().to_string(),
506            client_secret: None,
507            auth_url,
508            auth_type: AuthType::BasicAuth,
509            token_url,
510            scopes: Vec::new(),
511            redirect_url: None,
512        }
513    }
514
515    /// Configure the client secret to use.
516    pub fn set_client_secret(&mut self, client_secret: impl Into<ClientSecret>) {
517        self.client_secret = Some(client_secret.into());
518    }
519
520    /// Appends a new scope to the authorization URL.
521    pub fn add_scope(&mut self, scope: impl Into<Scope>) {
522        self.scopes.push(scope.into());
523    }
524
525    /// Configures the type of client authentication used for communicating with the authorization
526    /// server.
527    ///
528    /// The default is to use HTTP Basic authentication, as recommended in
529    /// [Section 2.3.1 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-2.3.1).
530    pub fn set_auth_type(&mut self, auth_type: AuthType) {
531        self.auth_type = auth_type;
532    }
533
534    /// Sets the the redirect URL used by the authorization endpoint.
535    pub fn set_redirect_url(&mut self, redirect_url: Url) {
536        self.redirect_url = Some(redirect_url);
537    }
538
539    /// Produces the full authorization URL used by the
540    /// [Authorization Code Grant](https://tools.ietf.org/html/rfc6749#section-4.1)
541    /// flow, which is the most common OAuth2 flow.
542    ///
543    /// # Arguments
544    ///
545    /// * `state` - A state value to include in the request. The authorization
546    ///   server includes this value when redirecting the user-agent back to the
547    ///   client.
548    ///
549    /// # Security Warning
550    ///
551    /// Callers should use a fresh, unpredictable `state` for each authorization
552    /// request and verify that this value matches the `state` parameter passed
553    /// by the authorization server to the redirect URI. Doing so mitigates
554    /// [Cross-Site Request Forgery](https://tools.ietf.org/html/rfc6749#section-10.12)
555    /// attacks.
556    pub fn authorize_url(&self, state: &State) -> Url {
557        self.authorize_url_impl("code", state)
558    }
559
560    /// Produces the full authorization URL used by the
561    /// [Implicit Grant](https://tools.ietf.org/html/rfc6749#section-4.2) flow.
562    ///
563    /// # Arguments
564    ///
565    /// * `state` - A state value to include in the request. The authorization
566    ///   server includes this value when redirecting the user-agent back to the
567    ///   client.
568    ///
569    /// # Security Warning
570    ///
571    /// Callers should use a fresh, unpredictable `state` for each authorization request and verify
572    /// that this value matches the `state` parameter passed by the authorization server to the
573    /// redirect URI. Doing so mitigates
574    /// [Cross-Site Request Forgery](https://tools.ietf.org/html/rfc6749#section-10.12)
575    ///  attacks.
576    pub fn authorize_url_implicit(&self, state: &State) -> Url {
577        self.authorize_url_impl("token", state)
578    }
579
580    fn authorize_url_impl(&self, response_type: &str, state: &State) -> Url {
581        let scopes = self
582            .scopes
583            .iter()
584            .map(|s| s.to_string())
585            .collect::<Vec<_>>()
586            .join(" ");
587
588        let mut url = self.auth_url.clone();
589
590        {
591            let mut query = url.query_pairs_mut();
592
593            query.append_pair("response_type", response_type);
594            query.append_pair("client_id", &self.client_id);
595
596            if let Some(ref redirect_url) = self.redirect_url {
597                query.append_pair("redirect_uri", redirect_url.as_str());
598            }
599
600            if !scopes.is_empty() {
601                query.append_pair("scope", &scopes);
602            }
603
604            query.append_pair("state", &state.to_base64());
605        }
606
607        url
608    }
609
610    /// Exchanges a code produced by a successful authorization process with an access token.
611    ///
612    /// Acquires ownership of the `code` because authorization codes may only be used to retrieve
613    /// an access token from the authorization server.
614    ///
615    /// See https://tools.ietf.org/html/rfc6749#section-4.1.3
616    pub fn exchange_code(&self, code: impl Into<AuthorizationCode>) -> Request<'_> {
617        let code = code.into();
618
619        self.request_token()
620            .param("grant_type", "authorization_code")
621            .param("code", code.to_string())
622    }
623
624    /// Requests an access token for the *password* grant type.
625    ///
626    /// See https://tools.ietf.org/html/rfc6749#section-4.3.2
627    pub fn exchange_password(
628        &self,
629        username: impl AsRef<str>,
630        password: impl AsRef<str>,
631    ) -> Request<'_> {
632        let username = username.as_ref();
633        let password = password.as_ref();
634
635        let mut builder = self
636            .request_token()
637            .param("grant_type", "password")
638            .param("username", username.to_string())
639            .param("password", password.to_string());
640
641        // Generate the space-delimited scopes String before initializing params so that it has
642        // a long enough lifetime.
643        if !self.scopes.is_empty() {
644            let scopes = self
645                .scopes
646                .iter()
647                .map(|s| s.to_string())
648                .collect::<Vec<_>>()
649                .join(" ");
650
651            builder = builder.param("scope", scopes);
652        }
653
654        builder
655    }
656
657    /// Requests an access token for the *client credentials* grant type.
658    ///
659    /// See https://tools.ietf.org/html/rfc6749#section-4.4.2
660    pub fn exchange_client_credentials(&self) -> Request<'_> {
661        let mut builder = self
662            .request_token()
663            .param("grant_type", "client_credentials");
664
665        // Generate the space-delimited scopes String before initializing params so that it has
666        // a long enough lifetime.
667        if !self.scopes.is_empty() {
668            let scopes = self
669                .scopes
670                .iter()
671                .map(|s| s.to_string())
672                .collect::<Vec<_>>()
673                .join(" ");
674
675            builder = builder.param("scopes", scopes);
676        }
677
678        builder
679    }
680
681    /// Exchanges a refresh token for an access token
682    ///
683    /// See https://tools.ietf.org/html/rfc6749#section-6
684    pub fn exchange_refresh_token(&self, refresh_token: &RefreshToken) -> Request<'_> {
685        self.request_token()
686            .param("grant_type", "refresh_token")
687            .param("refresh_token", refresh_token.to_string())
688    }
689
690    /// Construct a request builder for the token URL.
691    fn request_token(&self) -> Request<'_> {
692        Request {
693            token_url: &self.token_url,
694            auth_type: self.auth_type,
695            client_id: &self.client_id,
696            client_secret: self.client_secret.as_ref(),
697            redirect_url: self.redirect_url.as_ref(),
698            params: vec![],
699        }
700    }
701}
702
703/// A request wrapped in a client, ready to be executed.
704pub struct ClientRequest<'a> {
705    request: Request<'a>,
706    client: &'a reqwest::Client,
707}
708
709impl<'a> ClientRequest<'a> {
710    /// Execute the token request.
711    pub async fn execute<T>(self) -> Result<T, ExecuteError>
712    where
713        T: for<'de> Deserialize<'de>,
714    {
715        use reqwest::{header, Method};
716
717        let mut request = self
718            .client
719            .request(Method::POST, self.request.token_url.clone());
720
721        // Section 5.1 of RFC 6749 (https://tools.ietf.org/html/rfc6749#section-5.1) only permits
722        // JSON responses for this request. Some providers such as GitHub have off-spec behavior
723        // and not only support different response formats, but have non-JSON defaults. Explicitly
724        // request JSON here.
725        request = request.header(
726            header::ACCEPT,
727            header::HeaderValue::from_static(CONTENT_TYPE_JSON),
728        );
729
730        let request = {
731            let mut form = url::form_urlencoded::Serializer::new(String::new());
732
733            // FIXME: add support for auth extensions? e.g., client_secret_jwt and private_key_jwt
734            match self.request.auth_type {
735                AuthType::RequestBody => {
736                    form.append_pair("client_id", self.request.client_id);
737
738                    if let Some(client_secret) = self.request.client_secret {
739                        form.append_pair("client_secret", client_secret);
740                    }
741                }
742                AuthType::BasicAuth => {
743                    // Section 2.3.1 of RFC 6749 requires separately url-encoding the id and secret
744                    // before using them as HTTP Basic auth username and password. Note that this is
745                    // not standard for ordinary Basic auth, so curl won't do it for us.
746                    let username = url_encode(self.request.client_id);
747
748                    let password = self
749                        .request
750                        .client_secret
751                        .map(|client_secret| url_encode(client_secret));
752
753                    request = request.basic_auth(username, password.as_ref());
754                }
755            }
756
757            for (key, value) in self.request.params {
758                form.append_pair(key.as_ref(), value.as_ref());
759            }
760
761            if let Some(redirect_url) = &self.request.redirect_url {
762                form.append_pair("redirect_uri", redirect_url.as_str());
763            }
764
765            request = request.header(
766                header::CONTENT_TYPE,
767                header::HeaderValue::from_static("application/x-www-form-urlencoded"),
768            );
769
770            request.body(form.finish().into_bytes())
771        };
772
773        let res = request
774            .send()
775            .await
776            .map_err(|error| ExecuteError::RequestError { error })?;
777
778        let status = res.status();
779
780        let body = res
781            .bytes()
782            .await
783            .map_err(|error| ExecuteError::RequestError { error })?;
784
785        if body.is_empty() {
786            return Err(ExecuteError::EmptyResponse { status });
787        }
788
789        if !status.is_success() {
790            let error = match serde_json::from_slice::<ErrorResponse>(body.as_ref()) {
791                Ok(error) => error,
792                Err(error) => {
793                    return Err(ExecuteError::BadResponse {
794                        status,
795                        error,
796                        body,
797                    });
798                }
799            };
800
801            return Err(ExecuteError::ErrorResponse { status, error });
802        }
803
804        return serde_json::from_slice(body.as_ref()).map_err(|error| ExecuteError::BadResponse {
805            status,
806            error,
807            body,
808        });
809
810        fn url_encode(s: &str) -> String {
811            url::form_urlencoded::byte_serialize(s.as_bytes()).collect::<String>()
812        }
813
814        const CONTENT_TYPE_JSON: &str = "application/json";
815    }
816}
817
818/// A token request that is in progress.
819pub struct Request<'a> {
820    token_url: &'a Url,
821    auth_type: AuthType,
822    client_id: &'a str,
823    client_secret: Option<&'a ClientSecret>,
824    /// Configured redirect URL.
825    redirect_url: Option<&'a Url>,
826    /// Extra parameters.
827    params: Vec<(Cow<'a, str>, Cow<'a, str>)>,
828}
829
830impl<'a> Request<'a> {
831    /// Set an additional request param.
832    pub fn param(mut self, key: impl Into<Cow<'a, str>>, value: impl Into<Cow<'a, str>>) -> Self {
833        self.params.push((key.into(), value.into()));
834        self
835    }
836
837    /// Wrap the request in a client.
838    pub fn with_client(self, client: &'a reqwest::Client) -> ClientRequest<'a> {
839        ClientRequest {
840            client,
841            request: self,
842        }
843    }
844}
845
846/// Basic OAuth2 authorization token types.
847#[derive(Clone, Debug, PartialEq, Serialize)]
848#[serde(rename_all = "lowercase")]
849pub enum TokenType {
850    /// Bearer token
851    /// ([OAuth 2.0 Bearer Tokens - RFC 6750](https://tools.ietf.org/html/rfc6750)).
852    Bearer,
853    /// MAC ([OAuth 2.0 Message Authentication Code (MAC)
854    /// Tokens](https://tools.ietf.org/html/draft-ietf-oauth-v2-http-mac-05)).
855    Mac,
856}
857
858impl<'de> serde::de::Deserialize<'de> for TokenType {
859    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
860    where
861        D: serde::de::Deserializer<'de>,
862    {
863        let value = String::deserialize(deserializer)?.to_lowercase();
864
865        return match value.as_str() {
866            "bearer" => Ok(TokenType::Bearer),
867            "mac" => Ok(TokenType::Mac),
868            other => Err(serde::de::Error::custom(UnknownVariantError(
869                other.to_string(),
870            ))),
871        };
872
873        #[derive(Debug)]
874        struct UnknownVariantError(String);
875
876        impl error::Error for UnknownVariantError {}
877
878        impl fmt::Display for UnknownVariantError {
879            fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
880                write!(fmt, "unsupported variant: {}", self.0)
881            }
882        }
883    }
884}
885
886/// Common methods shared by all OAuth2 token implementations.
887///
888/// The methods in this trait are defined in
889/// [Section 5.1 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.1). This trait exists
890/// separately from the `StandardToken` struct to support customization by clients,
891/// such as supporting interoperability with non-standards-complaint OAuth2 providers.
892pub trait Token
893where
894    Self: for<'a> serde::de::Deserialize<'a>,
895{
896    /// REQUIRED. The access token issued by the authorization server.
897    fn access_token(&self) -> &AccessToken;
898
899    /// REQUIRED. The type of the token issued as described in
900    /// [Section 7.1](https://tools.ietf.org/html/rfc6749#section-7.1).
901    /// Value is case insensitive and deserialized to the generic `TokenType` parameter.
902    fn token_type(&self) -> &TokenType;
903
904    /// RECOMMENDED. The lifetime in seconds of the access token. For example, the value 3600
905    /// denotes that the access token will expire in one hour from the time the response was
906    /// generated. If omitted, the authorization server SHOULD provide the expiration time via
907    /// other means or document the default value.
908    fn expires_in(&self) -> Option<Duration>;
909
910    /// OPTIONAL. The refresh token, which can be used to obtain new access tokens using the same
911    /// authorization grant as described in
912    /// [Section 6](https://tools.ietf.org/html/rfc6749#section-6).
913    fn refresh_token(&self) -> Option<&RefreshToken>;
914
915    /// OPTIONAL, if identical to the scope requested by the client; otherwise, REQUIRED. The
916    /// scipe of the access token as described by
917    /// [Section 3.3](https://tools.ietf.org/html/rfc6749#section-3.3). If included in the response,
918    /// this space-delimited field is parsed into a `Vec` of individual scopes. If omitted from
919    /// the response, this field is `None`.
920    fn scopes(&self) -> Option<&Vec<Scope>>;
921}
922
923/// Standard OAuth2 token response.
924///
925/// This struct includes the fields defined in
926/// [Section 5.1 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.1), as well as
927/// extensions defined by the `EF` type parameter.
928#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
929pub struct StandardToken {
930    access_token: AccessToken,
931    token_type: TokenType,
932    #[serde(
933        skip_serializing_if = "Option::is_none",
934        deserialize_with = "deserialize_option_number_from_string"
935    )]
936    expires_in: Option<u64>,
937    #[serde(skip_serializing_if = "Option::is_none")]
938    refresh_token: Option<RefreshToken>,
939    #[serde(rename = "scope")]
940    #[serde(deserialize_with = "helpers::deserialize_space_delimited_vec")]
941    #[serde(serialize_with = "helpers::serialize_space_delimited_vec")]
942    #[serde(skip_serializing_if = "Option::is_none")]
943    #[serde(default)]
944    scopes: Option<Vec<Scope>>,
945}
946
947impl Token for StandardToken {
948    /// REQUIRED. The access token issued by the authorization server.
949    fn access_token(&self) -> &AccessToken {
950        &self.access_token
951    }
952
953    /// REQUIRED. The type of the token issued as described in
954    /// [Section 7.1](https://tools.ietf.org/html/rfc6749#section-7.1).
955    /// Value is case insensitive and deserialized to the generic `TokenType` parameter.
956    fn token_type(&self) -> &TokenType {
957        &self.token_type
958    }
959
960    /// RECOMMENDED. The lifetime in seconds of the access token. For example, the value 3600
961    /// denotes that the access token will expire in one hour from the time the response was
962    /// generated. If omitted, the authorization server SHOULD provide the expiration time via
963    /// other means or document the default value.
964    fn expires_in(&self) -> Option<Duration> {
965        self.expires_in.map(Duration::from_secs)
966    }
967
968    /// OPTIONAL. The refresh token, which can be used to obtain new access tokens using the same
969    /// authorization grant as described in
970    /// [Section 6](https://tools.ietf.org/html/rfc6749#section-6).
971    fn refresh_token(&self) -> Option<&RefreshToken> {
972        self.refresh_token.as_ref()
973    }
974
975    /// OPTIONAL, if identical to the scope requested by the client; otherwise, REQUIRED. The
976    /// scipe of the access token as described by
977    /// [Section 3.3](https://tools.ietf.org/html/rfc6749#section-3.3). If included in the response,
978    /// this space-delimited field is parsed into a `Vec` of individual scopes. If omitted from
979    /// the response, this field is `None`.
980    fn scopes(&self) -> Option<&Vec<Scope>> {
981        self.scopes.as_ref()
982    }
983}
984
985/// These error types are defined in
986/// [Section 5.2 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.2).
987#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Serialize)]
988#[serde(rename_all = "snake_case")]
989pub enum ErrorField {
990    /// The request is missing a required parameter, includes an unsupported parameter value
991    /// (other than grant type), repeats a parameter, includes multiple credentials, utilizes
992    /// more than one mechanism for authenticating the client, or is otherwise malformed.
993    InvalidRequest,
994    /// Client authentication failed (e.g., unknown client, no client authentication included,
995    /// or unsupported authentication method).
996    InvalidClient,
997    /// The provided authorization grant (e.g., authorization code, resource owner credentials)
998    /// or refresh token is invalid, expired, revoked, does not match the redirection URI used
999    /// in the authorization request, or was issued to another client.
1000    InvalidGrant,
1001    /// The authenticated client is not authorized to use this authorization grant type.
1002    UnauthorizedClient,
1003    /// The authorization grant type is not supported by the authorization server.
1004    UnsupportedGrantType,
1005    /// The requested scope is invalid, unknown, malformed, or exceeds the scope granted by the
1006    /// resource owner.
1007    InvalidScope,
1008    /// Other error type.
1009    Other(String),
1010}
1011
1012impl fmt::Display for ErrorField {
1013    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1014        use self::ErrorField::*;
1015
1016        match *self {
1017            InvalidRequest => "invalid_request".fmt(fmt),
1018            InvalidClient => "invalid_client".fmt(fmt),
1019            InvalidGrant => "invalid_grant".fmt(fmt),
1020            UnauthorizedClient => "unauthorized_client".fmt(fmt),
1021            UnsupportedGrantType => "unsupported_grant_type".fmt(fmt),
1022            InvalidScope => "invalid_scope".fmt(fmt),
1023            Other(ref value) => value.fmt(fmt),
1024        }
1025    }
1026}
1027
1028/// Error response returned by server after requesting an access token.
1029///
1030/// The fields in this structure are defined in
1031/// [Section 5.2 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.2).
1032#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
1033pub struct ErrorResponse {
1034    /// A single ASCII error code.
1035    pub error: ErrorField,
1036    #[serde(default)]
1037    #[serde(skip_serializing_if = "Option::is_none")]
1038    /// Human-readable ASCII text providing additional information, used to assist
1039    /// the client developer in understanding the error that occurred.
1040    pub error_description: Option<String>,
1041    #[serde(default)]
1042    #[serde(skip_serializing_if = "Option::is_none")]
1043    /// A URI identifying a human-readable web page with information about the error,
1044    /// used to provide the client developer with additional information about the error.
1045    pub error_uri: Option<String>,
1046}
1047
1048impl fmt::Display for ErrorResponse {
1049    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1050        let mut formatted = self.error.to_string();
1051
1052        if let Some(error_description) = self.error_description.as_ref() {
1053            formatted.push_str(": ");
1054            formatted.push_str(error_description);
1055        }
1056
1057        if let Some(error_uri) = self.error_uri.as_ref() {
1058            formatted.push_str(" / See ");
1059            formatted.push_str(error_uri);
1060        }
1061
1062        write!(f, "{}", formatted)
1063    }
1064}
1065
1066impl error::Error for ErrorResponse {}
1067
1068/// Errors when creating new clients.
1069#[derive(Debug, Error)]
1070#[non_exhaustive]
1071pub enum NewClientError {
1072    /// Error creating underlying reqwest client.
1073    #[error("Failed to construct client")]
1074    Reqwest(#[source] reqwest::Error),
1075}
1076
1077impl From<reqwest::Error> for NewClientError {
1078    fn from(error: reqwest::Error) -> Self {
1079        Self::Reqwest(error)
1080    }
1081}
1082
1083/// Error encountered while requesting access token.
1084#[derive(Debug, Error)]
1085#[non_exhaustive]
1086pub enum ExecuteError {
1087    /// A client error that occured.
1088    #[error("reqwest error")]
1089    RequestError {
1090        /// Original request error.
1091        #[source]
1092        error: reqwest::Error,
1093    },
1094    /// Failed to parse server response. Parse errors may occur while parsing either successful
1095    /// or error responses.
1096    #[error("malformed server response: {status}")]
1097    BadResponse {
1098        /// The status code associated with the response.
1099        status: http::status::StatusCode,
1100        /// The body that couldn't be deserialized.
1101        body: bytes::Bytes,
1102        /// Deserialization error.
1103        #[source]
1104        error: serde_json::error::Error,
1105    },
1106    /// Response with non-successful status code and a body that could be
1107    /// successfully deserialized as an [ErrorResponse].
1108    #[error("request resulted in error response: {status}")]
1109    ErrorResponse {
1110        /// The status code associated with the response.
1111        status: http::status::StatusCode,
1112        /// The deserialized response.
1113        #[source]
1114        error: ErrorResponse,
1115    },
1116    /// Server response was empty.
1117    #[error("request resulted in empty response: {status}")]
1118    EmptyResponse {
1119        /// The status code associated with the empty response.
1120        status: http::status::StatusCode,
1121    },
1122}
1123
1124impl ExecuteError {
1125    /// Access the status code of the error if available.
1126    pub fn status(&self) -> Option<http::status::StatusCode> {
1127        match *self {
1128            Self::RequestError { ref error, .. } => error.status(),
1129            Self::BadResponse { status, .. } => Some(status),
1130            Self::ErrorResponse { status, .. } => Some(status),
1131            Self::EmptyResponse { status, .. } => Some(status),
1132        }
1133    }
1134
1135    /// The original response body if available.
1136    pub fn body(&self) -> Option<&bytes::Bytes> {
1137        match *self {
1138            Self::BadResponse { ref body, .. } => Some(body),
1139            _ => None,
1140        }
1141    }
1142}
1143
1144/// Helper methods used by OAuth2 implementations/extensions.
1145pub mod helpers {
1146    use serde::{Deserialize, Deserializer, Serializer};
1147    use url::Url;
1148
1149    /// Serde space-delimited string deserializer for a `Vec<String>`.
1150    ///
1151    /// This function splits a JSON string at each space character into a `Vec<String>` .
1152    ///
1153    /// # Example
1154    ///
1155    /// In example below, the JSON value `{"items": "foo bar baz"}` would deserialize to:
1156    ///
1157    /// ```
1158    /// # struct GroceryBasket {
1159    /// #     items: Vec<String>,
1160    /// # }
1161    /// # fn main() {
1162    /// GroceryBasket {
1163    ///     items: vec!["foo".to_string(), "bar".to_string(), "baz".to_string()]
1164    /// };
1165    /// # }
1166    /// ```
1167    ///
1168    /// Note: this example does not compile automatically due to
1169    /// [Rust issue #29286](https://github.com/rust-lang/rust/issues/29286).
1170    ///
1171    /// ```
1172    /// # /*
1173    /// use serde::Deserialize;
1174    ///
1175    /// #[derive(Deserialize)]
1176    /// struct GroceryBasket {
1177    ///     #[serde(deserialize_with = "helpers::deserialize_space_delimited_vec")]
1178    ///     items: Vec<String>,
1179    /// }
1180    /// # */
1181    /// ```
1182    pub fn deserialize_space_delimited_vec<'de, T, D>(deserializer: D) -> Result<T, D::Error>
1183    where
1184        T: Default + Deserialize<'de>,
1185        D: Deserializer<'de>,
1186    {
1187        use serde::de::Error;
1188        use serde_json::Value;
1189
1190        if let Some(space_delimited) = Option::<String>::deserialize(deserializer)? {
1191            let entries = space_delimited
1192                .split(' ')
1193                .map(|s| Value::String(s.to_string()))
1194                .collect();
1195            return T::deserialize(Value::Array(entries)).map_err(Error::custom);
1196        }
1197
1198        // If the JSON value is null, use the default value.
1199        Ok(T::default())
1200    }
1201
1202    /// Serde space-delimited string serializer for an `Option<Vec<String>>`.
1203    ///
1204    /// This function serializes a string vector into a single space-delimited string.
1205    /// If `string_vec_opt` is `None`, the function serializes it as `None` (e.g., `null`
1206    /// in the case of JSON serialization).
1207    pub fn serialize_space_delimited_vec<T, S>(
1208        vec_opt: &Option<Vec<T>>,
1209        serializer: S,
1210    ) -> Result<S::Ok, S::Error>
1211    where
1212        T: AsRef<str>,
1213        S: Serializer,
1214    {
1215        if let Some(ref vec) = *vec_opt {
1216            let space_delimited = vec.iter().map(|s| s.as_ref()).collect::<Vec<_>>().join(" ");
1217            serializer.serialize_str(&space_delimited)
1218        } else {
1219            serializer.serialize_none()
1220        }
1221    }
1222
1223    /// Serde string deserializer for a `Url`.
1224    pub fn deserialize_url<'de, D>(deserializer: D) -> Result<Url, D::Error>
1225    where
1226        D: Deserializer<'de>,
1227    {
1228        use serde::de::Error;
1229        let url_str = String::deserialize(deserializer)?;
1230        Url::parse(url_str.as_ref()).map_err(Error::custom)
1231    }
1232
1233    /// Serde string serializer for a `Url`.
1234    pub fn serialize_url<S>(url: &Url, serializer: S) -> Result<S::Ok, S::Error>
1235    where
1236        S: Serializer,
1237    {
1238        serializer.serialize_str(url.as_str())
1239    }
1240}