oauth/v2_0/
authorization_code_grant.rs

1//! Authorization Grant Code flow helper, as defined in the
2//! [RFC6749](https://datatracker.ietf.org/doc/html/rfc6749#section-1.3.1)
3
4#[cfg(feature = "async-std")]
5use async_std::{
6    io::{BufReadExt, BufReader, WriteExt},
7    net::TcpListener,
8};
9use oauth2::{
10    url::Url, AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, RequestTokenError,
11    Scope, TokenResponse,
12};
13#[cfg(feature = "tokio")]
14use tokio::{
15    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
16    net::TcpListener,
17};
18
19use super::{Client, Error, Result};
20
21/// OAuth 2.0 Authorization Code Grant flow builder.
22///
23/// The first step (once the builder is configured) is to build a
24/// [`crate::Client`].
25///
26/// The second step is to get the redirect URL by calling
27/// [`AuthorizationCodeGrant::get_redirect_url`].
28///
29/// The last step is to spawn a redirect server and wait for the user
30/// to click on the redirect URL in order to extract the access token
31/// and the refresh token by calling
32/// [`AuthorizationCodeGrant::wait_for_redirection`].
33#[derive(Debug, Default)]
34pub struct AuthorizationCodeGrant {
35    pub scopes: Vec<Scope>,
36    pub pkce: Option<(PkceCodeChallenge, PkceCodeVerifier)>,
37}
38
39impl AuthorizationCodeGrant {
40    pub fn new() -> Self {
41        Self::default()
42    }
43
44    pub fn with_scope<T>(mut self, scope: T) -> Self
45    where
46        T: ToString,
47    {
48        self.scopes.push(Scope::new(scope.to_string()));
49        self
50    }
51
52    pub fn with_pkce(mut self) -> Self {
53        self.pkce = Some(PkceCodeChallenge::new_random_sha256());
54        self
55    }
56
57    /// Generate the redirect URL used to complete the OAuth 2.0
58    /// Authorization Code Grant flow.
59    pub fn get_redirect_url(&self, client: &Client) -> (Url, CsrfToken) {
60        let mut redirect = client
61            .authorize_url(CsrfToken::new_random)
62            .add_scopes(self.scopes.clone());
63
64        if let Some((pkce_challenge, _)) = &self.pkce {
65            redirect = redirect.set_pkce_challenge(pkce_challenge.clone());
66        }
67
68        redirect.url()
69    }
70
71    /// Wait for the user to click on the redirect URL generated by
72    /// [`AuthorizationCodeGrant::get_redirect_url`], then exchange
73    /// the received code with an access token and maybe a refresh
74    /// token.
75    pub async fn wait_for_redirection(
76        self,
77        client: &Client,
78        csrf_state: CsrfToken,
79    ) -> Result<(String, Option<String>)> {
80        // listen for one single connection
81        let (mut stream, _) =
82            TcpListener::bind((client.redirect_host.as_str(), client.redirect_port))
83                .await
84                .map_err(|err| {
85                    Error::BindRedirectServerError(
86                        client.redirect_host.clone(),
87                        client.redirect_port,
88                        err,
89                    )
90                })?
91                .accept()
92                .await
93                .map_err(Error::AcceptRedirectServerError)?;
94
95        // extract the code from the url
96        let code = {
97            let mut reader = BufReader::new(&mut stream);
98
99            let mut request_line = String::new();
100            reader.read_line(&mut request_line).await?;
101
102            let redirect_url = request_line
103                .split_whitespace()
104                .nth(1)
105                .ok_or_else(|| Error::MissingRedirectUrlError(request_line.clone()))?;
106            let redirect_url = format!("http://localhost{redirect_url}");
107            let redirect_url = Url::parse(&redirect_url)
108                .map_err(|err| Error::ParseRedirectUrlError(err, redirect_url.clone()))?;
109
110            let (_, state) = redirect_url
111                .query_pairs()
112                .find(|(key, _)| key == "state")
113                .ok_or_else(|| Error::FindStateInRedirectUrlError(redirect_url.clone()))?;
114            let state = CsrfToken::new(state.into_owned());
115
116            if state.secret() != csrf_state.secret() {
117                return Err(Error::InvalidStateError(
118                    state.secret().to_owned(),
119                    csrf_state.secret().to_owned(),
120                ));
121            }
122
123            let (_, code) = redirect_url
124                .query_pairs()
125                .find(|(key, _)| key == "code")
126                .ok_or_else(|| Error::FindCodeInRedirectUrlError(redirect_url.clone()))?;
127
128            AuthorizationCode::new(code.into_owned())
129        };
130
131        // write a basic http response in plain text
132        let res = "Authentication successful!";
133        let res = format!(
134            "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
135            res.len(),
136            res
137        );
138        stream.write_all(res.as_bytes()).await?;
139
140        // exchange the code for an access token and a refresh token
141        let mut res = client.exchange_code(code);
142
143        if let Some((_, pkce_verifier)) = self.pkce {
144            res = res.set_pkce_verifier(pkce_verifier);
145        }
146
147        let res = res
148            .request_async(&Client::send_oauth2_request)
149            .await
150            .map_err(|err| match err {
151                RequestTokenError::Request(req) => Error::ExchangeCodeError(req.to_string()),
152                RequestTokenError::ServerResponse(res) => Error::ExchangeCodeError(res.to_string()),
153                RequestTokenError::Parse(err, _) => Error::ExchangeCodeError(err.to_string()),
154                RequestTokenError::Other(err) => Error::ExchangeCodeError(err),
155            })?;
156
157        let access_token = res.access_token().secret().to_owned();
158        let refresh_token = res.refresh_token().map(|t| t.secret().clone());
159
160        Ok((access_token, refresh_token))
161    }
162}