oauth2/lib.rs
1//! [<img alt="github" src="https://img.shields.io/badge/github-udoprog/async--oauth2-8da0cb?style=for-the-badge&logo=github" height="20">](https://github.com/udoprog/async-oauth2)
2//! [<img alt="crates.io" src="https://img.shields.io/crates/v/async-oauth2.svg?style=for-the-badge&color=fc8d62&logo=rust" height="20">](https://crates.io/crates/async-oauth2)
3//! [<img alt="docs.rs" src="https://img.shields.io/badge/docs.rs-async--oauth2-66c2a5?style=for-the-badge&logoColor=white&logo=data:image/svg+xml;base64,PHN2ZyByb2xlPSJpbWciIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIgdmlld0JveD0iMCAwIDUxMiA1MTIiPjxwYXRoIGZpbGw9IiNmNWY1ZjUiIGQ9Ik00ODguNiAyNTAuMkwzOTIgMjE0VjEwNS41YzAtMTUtOS4zLTI4LjQtMjMuNC0zMy43bC0xMDAtMzcuNWMtOC4xLTMuMS0xNy4xLTMuMS0yNS4zIDBsLTEwMCAzNy41Yy0xNC4xIDUuMy0yMy40IDE4LjctMjMuNCAzMy43VjIxNGwtOTYuNiAzNi4yQzkuMyAyNTUuNSAwIDI2OC45IDAgMjgzLjlWMzk0YzAgMTMuNiA3LjcgMjYuMSAxOS45IDMyLjJsMTAwIDUwYzEwLjEgNS4xIDIyLjEgNS4xIDMyLjIgMGwxMDMuOS01MiAxMDMuOSA1MmMxMC4xIDUuMSAyMi4xIDUuMSAzMi4yIDBsMTAwLTUwYzEyLjItNi4xIDE5LjktMTguNiAxOS45LTMyLjJWMjgzLjljMC0xNS05LjMtMjguNC0yMy40LTMzLjd6TTM1OCAyMTQuOGwtODUgMzEuOXYtNjguMmw4NS0zN3Y3My4zek0xNTQgMTA0LjFsMTAyLTM4LjIgMTAyIDM4LjJ2LjZsLTEwMiA0MS40LTEwMi00MS40di0uNnptODQgMjkxLjFsLTg1IDQyLjV2LTc5LjFsODUtMzguOHY3NS40em0wLTExMmwtMTAyIDQxLjQtMTAyLTQxLjR2LS42bDEwMi0zOC4yIDEwMiAzOC4ydi42em0yNDAgMTEybC04NSA0Mi41di03OS4xbDg1LTM4Ljh2NzUuNHptMC0xMTJsLTEwMiA0MS40LTEwMi00MS40di0uNmwxMDItMzguMiAxMDIgMzguMnYuNnoiPjwvcGF0aD48L3N2Zz4K" height="20">](https://docs.rs/async-oauth2)
4//!
5//! An asynchronous OAuth2 flow implementation, trying to adhere as much as
6//! possible to [RFC 6749].
7//!
8//! <br>
9//!
10//! ## Examples
11//!
12//! To see the library in action, you can go to one of our examples:
13//!
14//! - [Google]
15//! - [Spotify]
16//! - [Twitch]
17//!
18//! If you've checked out the project they can be run like this:
19//!
20//! ```sh
21//! cargo run --manifest-path=examples/Cargo.toml --bin spotify --
22//! --client-id <client-id> --client-secret <client-secret>
23//! cargo run --manifest-path=examples/Cargo.toml --bin google --
24//! --client-id <client-id> --client-secret <client-secret>
25//! cargo run --manifest-path=examples/Cargo.toml --bin twitch --
26//! --client-id <client-id> --client-secret <client-secret>
27//! ```
28//!
29//! > Note: You need to configure your client integration to permit redirects to
30//! > `http://localhost:8080/api/auth/redirect` for these to work. How this is
31//! > done depends on the integration used.
32//!
33//! <br>
34//!
35//! ## Authorization Code Grant
36//!
37//! This is the most common OAuth2 flow.
38//!
39//! ```no_run
40//! use oauth2::*;
41//! use url::Url;
42//!
43//! pub struct ReceivedCode {
44//! pub code: AuthorizationCode,
45//! pub state: State,
46//! }
47//!
48//! # async fn listen_for_code(port: u32) -> Result<ReceivedCode, Box<dyn std::error::Error>> { todo!() }
49//! # #[tokio::main]
50//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
51//! let reqwest_client = reqwest::Client::new();
52//!
53//! // Create an OAuth2 client by specifying the client ID, client secret,
54//! // authorization URL and token URL.
55//! let mut client = Client::new(
56//! "client_id",
57//! Url::parse("http://authorize")?,
58//! Url::parse("http://token")?
59//! );
60//!
61//! client.set_client_secret("client_secret");
62//! // Set the URL the user will be redirected to after the authorization
63//! // process.
64//! client.set_redirect_url(Url::parse("http://redirect")?);
65//! // Set the desired scopes.
66//! client.add_scope("read");
67//! client.add_scope("write");
68//!
69//! // Generate the full authorization URL.
70//! let state = State::new_random();
71//! let auth_url = client.authorize_url(&state);
72//!
73//! // This is the URL you should redirect the user to, in order to trigger the
74//! // authorization process.
75//! println!("Browse to: {}", auth_url);
76//!
77//! // Once the user has been redirected to the redirect URL, you'll have the
78//! // access code. For security reasons, your code should verify that the
79//! // `state` parameter returned by the server matches `state`.
80//! let received: ReceivedCode = listen_for_code(8080).await?;
81//!
82//! if received.state != state {
83//! panic!("CSRF token mismatch :(");
84//! }
85//!
86//! // Now you can trade it for an access token.
87//! let token = client.exchange_code(received.code)
88//! .with_client(&reqwest_client)
89//! .execute::<StandardToken>()
90//! .await?;
91//!
92//! # Ok(())
93//! # }
94//! ```
95//!
96//! <br>
97//!
98//! ## Implicit Grant
99//!
100//! This flow fetches an access token directly from the authorization endpoint.
101//!
102//! Be sure to understand the security implications of this flow before using
103//! it. In most cases the Authorization Code Grant flow above is preferred to
104//! the Implicit Grant flow.
105//!
106//! ```no_run
107//! use oauth2::*;
108//! use url::Url;
109//!
110//! pub struct ReceivedCode {
111//! pub code: AuthorizationCode,
112//! pub state: State,
113//! }
114//!
115//! # async fn get_code() -> Result<ReceivedCode, Box<dyn std::error::Error>> { todo!() }
116//! # #[tokio::main]
117//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
118//! let mut client = Client::new(
119//! "client_id",
120//! Url::parse("http://authorize")?,
121//! Url::parse("http://token")?
122//! );
123//!
124//! client.set_client_secret("client_secret");
125//!
126//! // Generate the full authorization URL.
127//! let state = State::new_random();
128//! let auth_url = client.authorize_url_implicit(&state);
129//!
130//! // This is the URL you should redirect the user to, in order to trigger the
131//! // authorization process.
132//! println!("Browse to: {}", auth_url);
133//!
134//! // Once the user has been redirected to the redirect URL, you'll have the
135//! // access code. For security reasons, your code should verify that the
136//! // `state` parameter returned by the server matches `state`.
137//! let received: ReceivedCode = get_code().await?;
138//!
139//! if received.state != state {
140//! panic!("CSRF token mismatch :(");
141//! }
142//!
143//! # Ok(()) }
144//! ```
145//!
146//! <br>
147//!
148//! ## Resource Owner Password Credentials Grant
149//!
150//! You can ask for a *password* access token by calling the
151//! `Client::exchange_password` method, while including the username and
152//! password.
153//!
154//! ```no_run
155//! use oauth2::*;
156//! use url::Url;
157//!
158//! # #[tokio::main]
159//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
160//! let reqwest_client = reqwest::Client::new();
161//!
162//! let mut client = Client::new(
163//! "client_id",
164//! Url::parse("http://authorize")?,
165//! Url::parse("http://token")?
166//! );
167//!
168//! client.set_client_secret("client_secret");
169//! client.add_scope("read");
170//!
171//! let token = client
172//! .exchange_password("user", "pass")
173//! .with_client(&reqwest_client)
174//! .execute::<StandardToken>()
175//! .await?;
176//!
177//! # Ok(()) }
178//! ```
179//!
180//! <br>
181//!
182//! ## Client Credentials Grant
183//!
184//! You can ask for a *client credentials* access token by calling the
185//! `Client::exchange_client_credentials` method.
186//!
187//! ```no_run
188//! use oauth2::*;
189//! use url::Url;
190//!
191//! # #[tokio::main]
192//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
193//! let reqwest_client = reqwest::Client::new();
194//! let mut client = Client::new(
195//! "client_id",
196//! Url::parse("http://authorize")?,
197//! Url::parse("http://token")?
198//! );
199//!
200//! client.set_client_secret("client_secret");
201//! client.add_scope("read");
202//!
203//! let token_result = client.exchange_client_credentials()
204//! .with_client(&reqwest_client)
205//! .execute::<StandardToken>();
206//!
207//! # Ok(()) }
208//! ```
209//!
210//! <br>
211//!
212//! ## Relationship to oauth2-rs
213//!
214//! This is a fork of [oauth2-rs].
215//!
216//! The main differences are:
217//! * Removal of unnecessary type parameters on Client ([see discussion here]).
218//! * Only support one client implementation ([reqwest]).
219//! * Remove most newtypes except `Scope` and the secret ones since they made the API harder to use.
220//!
221//! [RFC 6749]: https://tools.ietf.org/html/rfc6749
222//! [Google]: https://github.com/udoprog/async-oauth2/blob/master/examples/src/bin/google.rs
223//! [oauth2-rs]: https://github.com/ramosbugs/oauth2-rs
224//! [reqwest]: https://docs.rs/reqwest
225//! [see discussion here]: https://github.com/ramosbugs/oauth2-rs/issues/44#issuecomment-50158653
226//! [Spotify]: https://github.com/udoprog/async-oauth2/blob/master/examples/src/bin/spotify.rs
227//! [Twitch]: https://github.com/udoprog/async-oauth2/blob/master/examples/src/bin/twitch.rs
228
229#![deny(missing_docs)]
230
231use std::{borrow::Cow, error, fmt, time::Duration};
232
233use base64::prelude::{Engine as _, BASE64_URL_SAFE_NO_PAD};
234use rand::{thread_rng, Rng};
235use serde::{Deserialize, Serialize};
236use serde_aux::prelude::*;
237use sha2::{Digest, Sha256};
238use thiserror::Error;
239pub use url::Url;
240
241/// Indicates whether requests to the authorization server should use basic authentication or
242/// include the parameters in the request body for requests in which either is valid.
243///
244/// The default AuthType is *BasicAuth*, following the recommendation of
245/// [Section 2.3.1 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-2.3.1).
246#[derive(Clone, Copy, Debug)]
247pub enum AuthType {
248 /// The client_id and client_secret will be included as part of the request body.
249 RequestBody,
250 /// The client_id and client_secret will be included using the basic auth authentication scheme.
251 BasicAuth,
252}
253
254macro_rules! redacted_debug {
255 ($name:ident) => {
256 impl fmt::Debug for $name {
257 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258 write!(f, concat!(stringify!($name), "([redacted])"))
259 }
260 }
261 };
262}
263
264/// borrowed newtype plumbing
265macro_rules! borrowed_newtype {
266 ($name:ident, $borrowed:ty) => {
267 impl std::ops::Deref for $name {
268 type Target = $borrowed;
269
270 #[inline]
271 fn deref(&self) -> &Self::Target {
272 &self.0
273 }
274 }
275
276 impl<'a> From<&'a $name> for Cow<'a, $borrowed> {
277 #[inline]
278 fn from(value: &'a $name) -> Cow<'a, $borrowed> {
279 Cow::Borrowed(&value.0)
280 }
281 }
282
283 impl AsRef<$borrowed> for $name {
284 #[inline]
285 fn as_ref(&self) -> &$borrowed {
286 self
287 }
288 }
289 };
290}
291
292/// newtype plumbing
293macro_rules! newtype {
294 ($name:ident, $owned:ty, $borrowed:ty) => {
295 borrowed_newtype!($name, $borrowed);
296
297 impl<'a> From<&'a $borrowed> for $name {
298 #[inline]
299 fn from(value: &'a $borrowed) -> Self {
300 Self(value.to_owned())
301 }
302 }
303
304 impl From<$owned> for $name {
305 #[inline]
306 fn from(value: $owned) -> Self {
307 Self(value)
308 }
309 }
310
311 impl<'a> From<&'a $owned> for $name {
312 #[inline]
313 fn from(value: &'a $owned) -> Self {
314 Self(value.to_owned())
315 }
316 }
317
318 impl From<$name> for $owned {
319 #[inline]
320 fn from(value: $name) -> $owned {
321 value.0
322 }
323 }
324 };
325}
326
327/// Access token scope, as defined by the authorization server.
328#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
329pub struct Scope(String);
330newtype!(Scope, String, str);
331
332/// Code Challenge used for [PKCE]((https://tools.ietf.org/html/rfc7636)) protection via the
333/// `code_challenge` parameter.
334#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
335pub struct PkceCodeChallengeS256(String);
336newtype!(PkceCodeChallengeS256, String, str);
337
338/// Code Challenge Method used for [PKCE]((https://tools.ietf.org/html/rfc7636)) protection
339/// via the `code_challenge_method` parameter.
340#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)]
341pub struct PkceCodeChallengeMethod(String);
342newtype!(PkceCodeChallengeMethod, String, str);
343
344/// Client password issued to the client during the registration process described by
345/// [Section 2.2](https://tools.ietf.org/html/rfc6749#section-2.2).
346#[derive(Clone, Deserialize, Serialize)]
347pub struct ClientSecret(String);
348redacted_debug!(ClientSecret);
349newtype!(ClientSecret, String, str);
350
351/// Value used for [CSRF]((https://tools.ietf.org/html/rfc6749#section-10.12)) protection
352/// via the `state` parameter.
353#[must_use]
354#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
355pub struct State([u8; 16]);
356redacted_debug!(State);
357borrowed_newtype!(State, [u8]);
358
359impl State {
360 /// Generate a new random, base64-encoded 128-bit CSRF token.
361 pub fn new_random() -> Self {
362 let mut random_bytes = [0u8; 16];
363 thread_rng().fill(&mut random_bytes);
364 State(random_bytes)
365 }
366
367 /// Convert into base64.
368 pub fn to_base64(&self) -> String {
369 BASE64_URL_SAFE_NO_PAD.encode(self.0)
370 }
371}
372
373impl serde::Serialize for State {
374 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
375 where
376 S: serde::Serializer,
377 {
378 self.to_base64().serialize(serializer)
379 }
380}
381
382impl<'de> serde::Deserialize<'de> for State {
383 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
384 where
385 D: serde::Deserializer<'de>,
386 {
387 let s = String::deserialize(deserializer)?;
388 let bytes = BASE64_URL_SAFE_NO_PAD
389 .decode(s)
390 .map_err(serde::de::Error::custom)?;
391 let mut buf = [0u8; 16];
392 buf.copy_from_slice(&bytes);
393 Ok(Self(buf))
394 }
395}
396
397/// Code Verifier used for [PKCE]((https://tools.ietf.org/html/rfc7636)) protection via the
398/// `code_verifier` parameter. The value must have a minimum length of 43 characters and a
399/// maximum length of 128 characters. Each character must be ASCII alphanumeric or one of
400/// the characters "-" / "." / "_" / "~".
401#[derive(Deserialize, Serialize)]
402pub struct PkceCodeVerifierS256(String);
403newtype!(PkceCodeVerifierS256, String, str);
404
405impl PkceCodeVerifierS256 {
406 /// Generate a new random, base64-encoded code verifier.
407 pub fn new_random() -> Self {
408 PkceCodeVerifierS256::new_random_len(32)
409 }
410
411 /// Generate a new random, base64-encoded code verifier.
412 ///
413 /// # Arguments
414 ///
415 /// * `num_bytes` - Number of random bytes to generate, prior to base64-encoding.
416 /// The value must be in the range 32 to 96 inclusive in order to generate a verifier
417 /// with a suitable length.
418 pub fn new_random_len(num_bytes: u32) -> Self {
419 // The RFC specifies that the code verifier must have "a minimum length of 43
420 // characters and a maximum length of 128 characters".
421 // This implies 32-96 octets of random data to be base64 encoded.
422 assert!((32..=96).contains(&num_bytes));
423 let random_bytes: Vec<u8> = (0..num_bytes).map(|_| thread_rng().gen::<u8>()).collect();
424 let code = BASE64_URL_SAFE_NO_PAD.encode(random_bytes);
425 assert!(code.len() >= 43 && code.len() <= 128);
426 PkceCodeVerifierS256(code)
427 }
428
429 /// Return the code challenge for the code verifier.
430 pub fn code_challenge(&self) -> PkceCodeChallengeS256 {
431 let digest = Sha256::digest(self.as_bytes());
432 PkceCodeChallengeS256::from(BASE64_URL_SAFE_NO_PAD.encode(digest))
433 }
434
435 /// Return the code challenge method for this code verifier.
436 pub fn code_challenge_method() -> PkceCodeChallengeMethod {
437 PkceCodeChallengeMethod::from("S256".to_string())
438 }
439
440 /// Return the extension params used for authorize_url.
441 pub fn authorize_url_params(&self) -> Vec<(&'static str, String)> {
442 vec![
443 (
444 "code_challenge_method",
445 PkceCodeVerifierS256::code_challenge_method().into(),
446 ),
447 ("code_challenge", self.code_challenge().into()),
448 ]
449 }
450}
451
452/// Authorization code returned from the authorization endpoint.
453#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
454pub struct AuthorizationCode(String);
455redacted_debug!(AuthorizationCode);
456newtype!(AuthorizationCode, String, str);
457
458/// Refresh token used to obtain a new access token (if supported by the authorization server).
459#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
460pub struct RefreshToken(String);
461redacted_debug!(RefreshToken);
462newtype!(RefreshToken, String, str);
463
464/// Access token returned by the token endpoint and used to access protected resources.
465#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
466pub struct AccessToken(String);
467redacted_debug!(AccessToken);
468newtype!(AccessToken, String, str);
469
470/// Resource owner's password used directly as an authorization grant to obtain an access
471/// token.
472pub struct ResourceOwnerPassword(String);
473newtype!(ResourceOwnerPassword, String, str);
474
475/// Stores the configuration for an OAuth2 client.
476#[derive(Clone, Debug)]
477pub struct Client {
478 client_id: String,
479 client_secret: Option<ClientSecret>,
480 auth_url: Url,
481 auth_type: AuthType,
482 token_url: Url,
483 scopes: Vec<Scope>,
484 redirect_url: Option<Url>,
485}
486
487impl Client {
488 /// Initializes an OAuth2 client with the fields common to most OAuth2 flows.
489 ///
490 /// # Arguments
491 ///
492 /// * `client_id` - Client ID
493 /// * `auth_url` - Authorization endpoint: used by the client to obtain authorization from
494 /// the resource owner via user-agent redirection. This URL is used in all standard OAuth2
495 /// flows except the [Resource Owner Password Credentials
496 /// Grant](https://tools.ietf.org/html/rfc6749#section-4.3) and the
497 /// [Client Credentials Grant](https://tools.ietf.org/html/rfc6749#section-4.4).
498 /// * `token_url` - Token endpoint: used by the client to exchange an authorization grant
499 /// (code) for an access token, typically with client authentication. This URL is used in
500 /// all standard OAuth2 flows except the
501 /// [Implicit Grant](https://tools.ietf.org/html/rfc6749#section-4.2). If this value is set
502 /// to `None`, the `exchange_*` methods will return `Err(ExecuteError::Other(_))`.
503 pub fn new(client_id: impl AsRef<str>, auth_url: Url, token_url: Url) -> Self {
504 Client {
505 client_id: client_id.as_ref().to_string(),
506 client_secret: None,
507 auth_url,
508 auth_type: AuthType::BasicAuth,
509 token_url,
510 scopes: Vec::new(),
511 redirect_url: None,
512 }
513 }
514
515 /// Configure the client secret to use.
516 pub fn set_client_secret(&mut self, client_secret: impl Into<ClientSecret>) {
517 self.client_secret = Some(client_secret.into());
518 }
519
520 /// Appends a new scope to the authorization URL.
521 pub fn add_scope(&mut self, scope: impl Into<Scope>) {
522 self.scopes.push(scope.into());
523 }
524
525 /// Configures the type of client authentication used for communicating with the authorization
526 /// server.
527 ///
528 /// The default is to use HTTP Basic authentication, as recommended in
529 /// [Section 2.3.1 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-2.3.1).
530 pub fn set_auth_type(&mut self, auth_type: AuthType) {
531 self.auth_type = auth_type;
532 }
533
534 /// Sets the the redirect URL used by the authorization endpoint.
535 pub fn set_redirect_url(&mut self, redirect_url: Url) {
536 self.redirect_url = Some(redirect_url);
537 }
538
539 /// Produces the full authorization URL used by the
540 /// [Authorization Code Grant](https://tools.ietf.org/html/rfc6749#section-4.1)
541 /// flow, which is the most common OAuth2 flow.
542 ///
543 /// # Arguments
544 ///
545 /// * `state` - A state value to include in the request. The authorization
546 /// server includes this value when redirecting the user-agent back to the
547 /// client.
548 ///
549 /// # Security Warning
550 ///
551 /// Callers should use a fresh, unpredictable `state` for each authorization
552 /// request and verify that this value matches the `state` parameter passed
553 /// by the authorization server to the redirect URI. Doing so mitigates
554 /// [Cross-Site Request Forgery](https://tools.ietf.org/html/rfc6749#section-10.12)
555 /// attacks.
556 pub fn authorize_url(&self, state: &State) -> Url {
557 self.authorize_url_impl("code", state)
558 }
559
560 /// Produces the full authorization URL used by the
561 /// [Implicit Grant](https://tools.ietf.org/html/rfc6749#section-4.2) flow.
562 ///
563 /// # Arguments
564 ///
565 /// * `state` - A state value to include in the request. The authorization
566 /// server includes this value when redirecting the user-agent back to the
567 /// client.
568 ///
569 /// # Security Warning
570 ///
571 /// Callers should use a fresh, unpredictable `state` for each authorization request and verify
572 /// that this value matches the `state` parameter passed by the authorization server to the
573 /// redirect URI. Doing so mitigates
574 /// [Cross-Site Request Forgery](https://tools.ietf.org/html/rfc6749#section-10.12)
575 /// attacks.
576 pub fn authorize_url_implicit(&self, state: &State) -> Url {
577 self.authorize_url_impl("token", state)
578 }
579
580 fn authorize_url_impl(&self, response_type: &str, state: &State) -> Url {
581 let scopes = self
582 .scopes
583 .iter()
584 .map(|s| s.to_string())
585 .collect::<Vec<_>>()
586 .join(" ");
587
588 let mut url = self.auth_url.clone();
589
590 {
591 let mut query = url.query_pairs_mut();
592
593 query.append_pair("response_type", response_type);
594 query.append_pair("client_id", &self.client_id);
595
596 if let Some(ref redirect_url) = self.redirect_url {
597 query.append_pair("redirect_uri", redirect_url.as_str());
598 }
599
600 if !scopes.is_empty() {
601 query.append_pair("scope", &scopes);
602 }
603
604 query.append_pair("state", &state.to_base64());
605 }
606
607 url
608 }
609
610 /// Exchanges a code produced by a successful authorization process with an access token.
611 ///
612 /// Acquires ownership of the `code` because authorization codes may only be used to retrieve
613 /// an access token from the authorization server.
614 ///
615 /// See https://tools.ietf.org/html/rfc6749#section-4.1.3
616 pub fn exchange_code(&self, code: impl Into<AuthorizationCode>) -> Request<'_> {
617 let code = code.into();
618
619 self.request_token()
620 .param("grant_type", "authorization_code")
621 .param("code", code.to_string())
622 }
623
624 /// Requests an access token for the *password* grant type.
625 ///
626 /// See https://tools.ietf.org/html/rfc6749#section-4.3.2
627 pub fn exchange_password(
628 &self,
629 username: impl AsRef<str>,
630 password: impl AsRef<str>,
631 ) -> Request<'_> {
632 let username = username.as_ref();
633 let password = password.as_ref();
634
635 let mut builder = self
636 .request_token()
637 .param("grant_type", "password")
638 .param("username", username.to_string())
639 .param("password", password.to_string());
640
641 // Generate the space-delimited scopes String before initializing params so that it has
642 // a long enough lifetime.
643 if !self.scopes.is_empty() {
644 let scopes = self
645 .scopes
646 .iter()
647 .map(|s| s.to_string())
648 .collect::<Vec<_>>()
649 .join(" ");
650
651 builder = builder.param("scope", scopes);
652 }
653
654 builder
655 }
656
657 /// Requests an access token for the *client credentials* grant type.
658 ///
659 /// See https://tools.ietf.org/html/rfc6749#section-4.4.2
660 pub fn exchange_client_credentials(&self) -> Request<'_> {
661 let mut builder = self
662 .request_token()
663 .param("grant_type", "client_credentials");
664
665 // Generate the space-delimited scopes String before initializing params so that it has
666 // a long enough lifetime.
667 if !self.scopes.is_empty() {
668 let scopes = self
669 .scopes
670 .iter()
671 .map(|s| s.to_string())
672 .collect::<Vec<_>>()
673 .join(" ");
674
675 builder = builder.param("scopes", scopes);
676 }
677
678 builder
679 }
680
681 /// Exchanges a refresh token for an access token
682 ///
683 /// See https://tools.ietf.org/html/rfc6749#section-6
684 pub fn exchange_refresh_token(&self, refresh_token: &RefreshToken) -> Request<'_> {
685 self.request_token()
686 .param("grant_type", "refresh_token")
687 .param("refresh_token", refresh_token.to_string())
688 }
689
690 /// Construct a request builder for the token URL.
691 fn request_token(&self) -> Request<'_> {
692 Request {
693 token_url: &self.token_url,
694 auth_type: self.auth_type,
695 client_id: &self.client_id,
696 client_secret: self.client_secret.as_ref(),
697 redirect_url: self.redirect_url.as_ref(),
698 params: vec![],
699 }
700 }
701}
702
703/// A request wrapped in a client, ready to be executed.
704pub struct ClientRequest<'a> {
705 request: Request<'a>,
706 client: &'a reqwest::Client,
707}
708
709impl<'a> ClientRequest<'a> {
710 /// Execute the token request.
711 pub async fn execute<T>(self) -> Result<T, ExecuteError>
712 where
713 T: for<'de> Deserialize<'de>,
714 {
715 use reqwest::{header, Method};
716
717 let mut request = self
718 .client
719 .request(Method::POST, self.request.token_url.clone());
720
721 // Section 5.1 of RFC 6749 (https://tools.ietf.org/html/rfc6749#section-5.1) only permits
722 // JSON responses for this request. Some providers such as GitHub have off-spec behavior
723 // and not only support different response formats, but have non-JSON defaults. Explicitly
724 // request JSON here.
725 request = request.header(
726 header::ACCEPT,
727 header::HeaderValue::from_static(CONTENT_TYPE_JSON),
728 );
729
730 let request = {
731 let mut form = url::form_urlencoded::Serializer::new(String::new());
732
733 // FIXME: add support for auth extensions? e.g., client_secret_jwt and private_key_jwt
734 match self.request.auth_type {
735 AuthType::RequestBody => {
736 form.append_pair("client_id", self.request.client_id);
737
738 if let Some(client_secret) = self.request.client_secret {
739 form.append_pair("client_secret", client_secret);
740 }
741 }
742 AuthType::BasicAuth => {
743 // Section 2.3.1 of RFC 6749 requires separately url-encoding the id and secret
744 // before using them as HTTP Basic auth username and password. Note that this is
745 // not standard for ordinary Basic auth, so curl won't do it for us.
746 let username = url_encode(self.request.client_id);
747
748 let password = self
749 .request
750 .client_secret
751 .map(|client_secret| url_encode(client_secret));
752
753 request = request.basic_auth(username, password.as_ref());
754 }
755 }
756
757 for (key, value) in self.request.params {
758 form.append_pair(key.as_ref(), value.as_ref());
759 }
760
761 if let Some(redirect_url) = &self.request.redirect_url {
762 form.append_pair("redirect_uri", redirect_url.as_str());
763 }
764
765 request = request.header(
766 header::CONTENT_TYPE,
767 header::HeaderValue::from_static("application/x-www-form-urlencoded"),
768 );
769
770 request.body(form.finish().into_bytes())
771 };
772
773 let res = request
774 .send()
775 .await
776 .map_err(|error| ExecuteError::RequestError { error })?;
777
778 let status = res.status();
779
780 let body = res
781 .bytes()
782 .await
783 .map_err(|error| ExecuteError::RequestError { error })?;
784
785 if body.is_empty() {
786 return Err(ExecuteError::EmptyResponse { status });
787 }
788
789 if !status.is_success() {
790 let error = match serde_json::from_slice::<ErrorResponse>(body.as_ref()) {
791 Ok(error) => error,
792 Err(error) => {
793 return Err(ExecuteError::BadResponse {
794 status,
795 error,
796 body,
797 });
798 }
799 };
800
801 return Err(ExecuteError::ErrorResponse { status, error });
802 }
803
804 return serde_json::from_slice(body.as_ref()).map_err(|error| ExecuteError::BadResponse {
805 status,
806 error,
807 body,
808 });
809
810 fn url_encode(s: &str) -> String {
811 url::form_urlencoded::byte_serialize(s.as_bytes()).collect::<String>()
812 }
813
814 const CONTENT_TYPE_JSON: &str = "application/json";
815 }
816}
817
818/// A token request that is in progress.
819pub struct Request<'a> {
820 token_url: &'a Url,
821 auth_type: AuthType,
822 client_id: &'a str,
823 client_secret: Option<&'a ClientSecret>,
824 /// Configured redirect URL.
825 redirect_url: Option<&'a Url>,
826 /// Extra parameters.
827 params: Vec<(Cow<'a, str>, Cow<'a, str>)>,
828}
829
830impl<'a> Request<'a> {
831 /// Set an additional request param.
832 pub fn param(mut self, key: impl Into<Cow<'a, str>>, value: impl Into<Cow<'a, str>>) -> Self {
833 self.params.push((key.into(), value.into()));
834 self
835 }
836
837 /// Wrap the request in a client.
838 pub fn with_client(self, client: &'a reqwest::Client) -> ClientRequest<'a> {
839 ClientRequest {
840 client,
841 request: self,
842 }
843 }
844}
845
846/// Basic OAuth2 authorization token types.
847#[derive(Clone, Debug, PartialEq, Serialize)]
848#[serde(rename_all = "lowercase")]
849pub enum TokenType {
850 /// Bearer token
851 /// ([OAuth 2.0 Bearer Tokens - RFC 6750](https://tools.ietf.org/html/rfc6750)).
852 Bearer,
853 /// MAC ([OAuth 2.0 Message Authentication Code (MAC)
854 /// Tokens](https://tools.ietf.org/html/draft-ietf-oauth-v2-http-mac-05)).
855 Mac,
856}
857
858impl<'de> serde::de::Deserialize<'de> for TokenType {
859 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
860 where
861 D: serde::de::Deserializer<'de>,
862 {
863 let value = String::deserialize(deserializer)?.to_lowercase();
864
865 return match value.as_str() {
866 "bearer" => Ok(TokenType::Bearer),
867 "mac" => Ok(TokenType::Mac),
868 other => Err(serde::de::Error::custom(UnknownVariantError(
869 other.to_string(),
870 ))),
871 };
872
873 #[derive(Debug)]
874 struct UnknownVariantError(String);
875
876 impl error::Error for UnknownVariantError {}
877
878 impl fmt::Display for UnknownVariantError {
879 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
880 write!(fmt, "unsupported variant: {}", self.0)
881 }
882 }
883 }
884}
885
886/// Common methods shared by all OAuth2 token implementations.
887///
888/// The methods in this trait are defined in
889/// [Section 5.1 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.1). This trait exists
890/// separately from the `StandardToken` struct to support customization by clients,
891/// such as supporting interoperability with non-standards-complaint OAuth2 providers.
892pub trait Token
893where
894 Self: for<'a> serde::de::Deserialize<'a>,
895{
896 /// REQUIRED. The access token issued by the authorization server.
897 fn access_token(&self) -> &AccessToken;
898
899 /// REQUIRED. The type of the token issued as described in
900 /// [Section 7.1](https://tools.ietf.org/html/rfc6749#section-7.1).
901 /// Value is case insensitive and deserialized to the generic `TokenType` parameter.
902 fn token_type(&self) -> &TokenType;
903
904 /// RECOMMENDED. The lifetime in seconds of the access token. For example, the value 3600
905 /// denotes that the access token will expire in one hour from the time the response was
906 /// generated. If omitted, the authorization server SHOULD provide the expiration time via
907 /// other means or document the default value.
908 fn expires_in(&self) -> Option<Duration>;
909
910 /// OPTIONAL. The refresh token, which can be used to obtain new access tokens using the same
911 /// authorization grant as described in
912 /// [Section 6](https://tools.ietf.org/html/rfc6749#section-6).
913 fn refresh_token(&self) -> Option<&RefreshToken>;
914
915 /// OPTIONAL, if identical to the scope requested by the client; otherwise, REQUIRED. The
916 /// scipe of the access token as described by
917 /// [Section 3.3](https://tools.ietf.org/html/rfc6749#section-3.3). If included in the response,
918 /// this space-delimited field is parsed into a `Vec` of individual scopes. If omitted from
919 /// the response, this field is `None`.
920 fn scopes(&self) -> Option<&Vec<Scope>>;
921}
922
923/// Standard OAuth2 token response.
924///
925/// This struct includes the fields defined in
926/// [Section 5.1 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.1), as well as
927/// extensions defined by the `EF` type parameter.
928#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
929pub struct StandardToken {
930 access_token: AccessToken,
931 token_type: TokenType,
932 #[serde(
933 skip_serializing_if = "Option::is_none",
934 deserialize_with = "deserialize_option_number_from_string"
935 )]
936 expires_in: Option<u64>,
937 #[serde(skip_serializing_if = "Option::is_none")]
938 refresh_token: Option<RefreshToken>,
939 #[serde(rename = "scope")]
940 #[serde(deserialize_with = "helpers::deserialize_space_delimited_vec")]
941 #[serde(serialize_with = "helpers::serialize_space_delimited_vec")]
942 #[serde(skip_serializing_if = "Option::is_none")]
943 #[serde(default)]
944 scopes: Option<Vec<Scope>>,
945}
946
947impl Token for StandardToken {
948 /// REQUIRED. The access token issued by the authorization server.
949 fn access_token(&self) -> &AccessToken {
950 &self.access_token
951 }
952
953 /// REQUIRED. The type of the token issued as described in
954 /// [Section 7.1](https://tools.ietf.org/html/rfc6749#section-7.1).
955 /// Value is case insensitive and deserialized to the generic `TokenType` parameter.
956 fn token_type(&self) -> &TokenType {
957 &self.token_type
958 }
959
960 /// RECOMMENDED. The lifetime in seconds of the access token. For example, the value 3600
961 /// denotes that the access token will expire in one hour from the time the response was
962 /// generated. If omitted, the authorization server SHOULD provide the expiration time via
963 /// other means or document the default value.
964 fn expires_in(&self) -> Option<Duration> {
965 self.expires_in.map(Duration::from_secs)
966 }
967
968 /// OPTIONAL. The refresh token, which can be used to obtain new access tokens using the same
969 /// authorization grant as described in
970 /// [Section 6](https://tools.ietf.org/html/rfc6749#section-6).
971 fn refresh_token(&self) -> Option<&RefreshToken> {
972 self.refresh_token.as_ref()
973 }
974
975 /// OPTIONAL, if identical to the scope requested by the client; otherwise, REQUIRED. The
976 /// scipe of the access token as described by
977 /// [Section 3.3](https://tools.ietf.org/html/rfc6749#section-3.3). If included in the response,
978 /// this space-delimited field is parsed into a `Vec` of individual scopes. If omitted from
979 /// the response, this field is `None`.
980 fn scopes(&self) -> Option<&Vec<Scope>> {
981 self.scopes.as_ref()
982 }
983}
984
985/// These error types are defined in
986/// [Section 5.2 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.2).
987#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Serialize)]
988#[serde(rename_all = "snake_case")]
989pub enum ErrorField {
990 /// The request is missing a required parameter, includes an unsupported parameter value
991 /// (other than grant type), repeats a parameter, includes multiple credentials, utilizes
992 /// more than one mechanism for authenticating the client, or is otherwise malformed.
993 InvalidRequest,
994 /// Client authentication failed (e.g., unknown client, no client authentication included,
995 /// or unsupported authentication method).
996 InvalidClient,
997 /// The provided authorization grant (e.g., authorization code, resource owner credentials)
998 /// or refresh token is invalid, expired, revoked, does not match the redirection URI used
999 /// in the authorization request, or was issued to another client.
1000 InvalidGrant,
1001 /// The authenticated client is not authorized to use this authorization grant type.
1002 UnauthorizedClient,
1003 /// The authorization grant type is not supported by the authorization server.
1004 UnsupportedGrantType,
1005 /// The requested scope is invalid, unknown, malformed, or exceeds the scope granted by the
1006 /// resource owner.
1007 InvalidScope,
1008 /// Other error type.
1009 Other(String),
1010}
1011
1012impl fmt::Display for ErrorField {
1013 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1014 use self::ErrorField::*;
1015
1016 match *self {
1017 InvalidRequest => "invalid_request".fmt(fmt),
1018 InvalidClient => "invalid_client".fmt(fmt),
1019 InvalidGrant => "invalid_grant".fmt(fmt),
1020 UnauthorizedClient => "unauthorized_client".fmt(fmt),
1021 UnsupportedGrantType => "unsupported_grant_type".fmt(fmt),
1022 InvalidScope => "invalid_scope".fmt(fmt),
1023 Other(ref value) => value.fmt(fmt),
1024 }
1025 }
1026}
1027
1028/// Error response returned by server after requesting an access token.
1029///
1030/// The fields in this structure are defined in
1031/// [Section 5.2 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.2).
1032#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
1033pub struct ErrorResponse {
1034 /// A single ASCII error code.
1035 pub error: ErrorField,
1036 #[serde(default)]
1037 #[serde(skip_serializing_if = "Option::is_none")]
1038 /// Human-readable ASCII text providing additional information, used to assist
1039 /// the client developer in understanding the error that occurred.
1040 pub error_description: Option<String>,
1041 #[serde(default)]
1042 #[serde(skip_serializing_if = "Option::is_none")]
1043 /// A URI identifying a human-readable web page with information about the error,
1044 /// used to provide the client developer with additional information about the error.
1045 pub error_uri: Option<String>,
1046}
1047
1048impl fmt::Display for ErrorResponse {
1049 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1050 let mut formatted = self.error.to_string();
1051
1052 if let Some(error_description) = self.error_description.as_ref() {
1053 formatted.push_str(": ");
1054 formatted.push_str(error_description);
1055 }
1056
1057 if let Some(error_uri) = self.error_uri.as_ref() {
1058 formatted.push_str(" / See ");
1059 formatted.push_str(error_uri);
1060 }
1061
1062 write!(f, "{}", formatted)
1063 }
1064}
1065
1066impl error::Error for ErrorResponse {}
1067
1068/// Errors when creating new clients.
1069#[derive(Debug, Error)]
1070#[non_exhaustive]
1071pub enum NewClientError {
1072 /// Error creating underlying reqwest client.
1073 #[error("Failed to construct client")]
1074 Reqwest(#[source] reqwest::Error),
1075}
1076
1077impl From<reqwest::Error> for NewClientError {
1078 fn from(error: reqwest::Error) -> Self {
1079 Self::Reqwest(error)
1080 }
1081}
1082
1083/// Error encountered while requesting access token.
1084#[derive(Debug, Error)]
1085#[non_exhaustive]
1086pub enum ExecuteError {
1087 /// A client error that occured.
1088 #[error("reqwest error")]
1089 RequestError {
1090 /// Original request error.
1091 #[source]
1092 error: reqwest::Error,
1093 },
1094 /// Failed to parse server response. Parse errors may occur while parsing either successful
1095 /// or error responses.
1096 #[error("malformed server response: {status}")]
1097 BadResponse {
1098 /// The status code associated with the response.
1099 status: http::status::StatusCode,
1100 /// The body that couldn't be deserialized.
1101 body: bytes::Bytes,
1102 /// Deserialization error.
1103 #[source]
1104 error: serde_json::error::Error,
1105 },
1106 /// Response with non-successful status code and a body that could be
1107 /// successfully deserialized as an [ErrorResponse].
1108 #[error("request resulted in error response: {status}")]
1109 ErrorResponse {
1110 /// The status code associated with the response.
1111 status: http::status::StatusCode,
1112 /// The deserialized response.
1113 #[source]
1114 error: ErrorResponse,
1115 },
1116 /// Server response was empty.
1117 #[error("request resulted in empty response: {status}")]
1118 EmptyResponse {
1119 /// The status code associated with the empty response.
1120 status: http::status::StatusCode,
1121 },
1122}
1123
1124impl ExecuteError {
1125 /// Access the status code of the error if available.
1126 pub fn status(&self) -> Option<http::status::StatusCode> {
1127 match *self {
1128 Self::RequestError { ref error, .. } => error.status(),
1129 Self::BadResponse { status, .. } => Some(status),
1130 Self::ErrorResponse { status, .. } => Some(status),
1131 Self::EmptyResponse { status, .. } => Some(status),
1132 }
1133 }
1134
1135 /// The original response body if available.
1136 pub fn body(&self) -> Option<&bytes::Bytes> {
1137 match *self {
1138 Self::BadResponse { ref body, .. } => Some(body),
1139 _ => None,
1140 }
1141 }
1142}
1143
1144/// Helper methods used by OAuth2 implementations/extensions.
1145pub mod helpers {
1146 use serde::{Deserialize, Deserializer, Serializer};
1147 use url::Url;
1148
1149 /// Serde space-delimited string deserializer for a `Vec<String>`.
1150 ///
1151 /// This function splits a JSON string at each space character into a `Vec<String>` .
1152 ///
1153 /// # Example
1154 ///
1155 /// In example below, the JSON value `{"items": "foo bar baz"}` would deserialize to:
1156 ///
1157 /// ```
1158 /// # struct GroceryBasket {
1159 /// # items: Vec<String>,
1160 /// # }
1161 /// # fn main() {
1162 /// GroceryBasket {
1163 /// items: vec!["foo".to_string(), "bar".to_string(), "baz".to_string()]
1164 /// };
1165 /// # }
1166 /// ```
1167 ///
1168 /// Note: this example does not compile automatically due to
1169 /// [Rust issue #29286](https://github.com/rust-lang/rust/issues/29286).
1170 ///
1171 /// ```
1172 /// # /*
1173 /// use serde::Deserialize;
1174 ///
1175 /// #[derive(Deserialize)]
1176 /// struct GroceryBasket {
1177 /// #[serde(deserialize_with = "helpers::deserialize_space_delimited_vec")]
1178 /// items: Vec<String>,
1179 /// }
1180 /// # */
1181 /// ```
1182 pub fn deserialize_space_delimited_vec<'de, T, D>(deserializer: D) -> Result<T, D::Error>
1183 where
1184 T: Default + Deserialize<'de>,
1185 D: Deserializer<'de>,
1186 {
1187 use serde::de::Error;
1188 use serde_json::Value;
1189
1190 if let Some(space_delimited) = Option::<String>::deserialize(deserializer)? {
1191 let entries = space_delimited
1192 .split(' ')
1193 .map(|s| Value::String(s.to_string()))
1194 .collect();
1195 return T::deserialize(Value::Array(entries)).map_err(Error::custom);
1196 }
1197
1198 // If the JSON value is null, use the default value.
1199 Ok(T::default())
1200 }
1201
1202 /// Serde space-delimited string serializer for an `Option<Vec<String>>`.
1203 ///
1204 /// This function serializes a string vector into a single space-delimited string.
1205 /// If `string_vec_opt` is `None`, the function serializes it as `None` (e.g., `null`
1206 /// in the case of JSON serialization).
1207 pub fn serialize_space_delimited_vec<T, S>(
1208 vec_opt: &Option<Vec<T>>,
1209 serializer: S,
1210 ) -> Result<S::Ok, S::Error>
1211 where
1212 T: AsRef<str>,
1213 S: Serializer,
1214 {
1215 if let Some(ref vec) = *vec_opt {
1216 let space_delimited = vec.iter().map(|s| s.as_ref()).collect::<Vec<_>>().join(" ");
1217 serializer.serialize_str(&space_delimited)
1218 } else {
1219 serializer.serialize_none()
1220 }
1221 }
1222
1223 /// Serde string deserializer for a `Url`.
1224 pub fn deserialize_url<'de, D>(deserializer: D) -> Result<Url, D::Error>
1225 where
1226 D: Deserializer<'de>,
1227 {
1228 use serde::de::Error;
1229 let url_str = String::deserialize(deserializer)?;
1230 Url::parse(url_str.as_ref()).map_err(Error::custom)
1231 }
1232
1233 /// Serde string serializer for a `Url`.
1234 pub fn serialize_url<S>(url: &Url, serializer: S) -> Result<S::Ok, S::Error>
1235 where
1236 S: Serializer,
1237 {
1238 serializer.serialize_str(url.as_str())
1239 }
1240}