yup_oauth2/
installed.rs

1// Copyright (c) 2016 Google Inc (lewinb@google.com).
2//
3// Refer to the project root for licensing information.
4//
5use crate::authenticator_delegate::{DefaultInstalledFlowDelegate, InstalledFlowDelegate};
6use crate::client::SendRequest;
7use crate::error::Error;
8use crate::types::{ApplicationSecret, TokenInfo};
9
10use http_body_util::BodyExt;
11use std::convert::AsRef;
12use std::net::SocketAddr;
13use std::sync::Arc;
14
15use http::header;
16use percent_encoding::{percent_encode, AsciiSet, CONTROLS};
17use tokio::sync::{oneshot, Mutex};
18use url::form_urlencoded;
19
20const QUERY_SET: AsciiSet = CONTROLS.add(b' ').add(b'"').add(b'#').add(b'<').add(b'>');
21
22const OOB_REDIRECT_URI: &str = "urn:ietf:wg:oauth:2.0:oob";
23
24/// Assembles a URL to request an authorization token (with user interaction).
25/// Note that the redirect_uri here has to be either None or some variation of
26/// http://localhost:{port}, or the authorization won't work (error "redirect_uri_mismatch")
27fn build_authentication_request_url<T>(
28    auth_uri: &str,
29    client_id: &str,
30    scopes: &[T],
31    redirect_uri: Option<&str>,
32    force_account_selection: bool,
33) -> String
34where
35    T: AsRef<str>,
36{
37    let mut url = String::new();
38    let scopes_string = crate::helper::join(scopes, " ");
39
40    url.push_str(auth_uri);
41
42    if !url.contains('?') {
43        url.push('?');
44    } else {
45        match url.chars().last() {
46            Some('?') | None => {}
47            Some(_) => url.push('&'),
48        }
49    }
50
51    let mut params = vec![
52        format!("scope={}", scopes_string),
53        "&access_type=offline".to_string(),
54        format!("&redirect_uri={}", redirect_uri.unwrap_or(OOB_REDIRECT_URI)),
55        "&response_type=code".to_string(),
56        format!("&client_id={}", client_id),
57    ];
58    if force_account_selection {
59        params.push("&prompt=select_account+consent".to_string());
60    }
61    params.into_iter().fold(url, |mut u, param| {
62        u.push_str(&percent_encode(param.as_ref(), &QUERY_SET).to_string());
63        u
64    })
65}
66
67/// Method by which the user agent return token to this application.
68///
69/// cf. <https://developers.google.com/identity/protocols/OAuth2InstalledApp#choosingredirecturi>
70pub enum InstalledFlowReturnMethod {
71    /// Involves showing a URL to the user and asking to copy a code from their browser
72    /// (default)
73    Interactive,
74    /// Involves spinning up a local HTTP server and Google redirecting the browser to
75    /// the server with a URL containing the code (preferred, but not as reliable).
76    HTTPRedirect,
77    /// Identical to [Self::HTTPRedirect], but allows a port to be specified for the
78    /// server, instead of choosing a port randomly.
79    HTTPPortRedirect(u16),
80}
81
82/// InstalledFlowImpl provides tokens for services that follow the "Installed" OAuth flow. (See
83/// <https://www.oauth.com/oauth2-servers/authorization/>,
84/// <https://developers.google.com/identity/protocols/OAuth2InstalledApp>).
85pub struct InstalledFlow {
86    pub(crate) app_secret: ApplicationSecret,
87    pub(crate) method: InstalledFlowReturnMethod,
88    pub(crate) flow_delegate: Box<dyn InstalledFlowDelegate>,
89    pub(crate) force_account_selection: bool,
90}
91
92impl InstalledFlow {
93    /// Create a new InstalledFlow with the provided secret and method.
94    ///
95    /// In order to specify the redirect URL to use (in the case of `HTTPRedirect` or
96    /// `HTTPPortRedirect` as method), either implement the `InstalledFlowDelegate` trait, or
97    /// use the `DefaultInstalledFlowDelegateWithRedirectURI`, which presents the URL on stdout.
98    /// The redirect URL to use is configured with the OAuth provider, and possible options are
99    /// given in the `ApplicationSecret.redirect_uris` field.
100    ///
101    /// The `InstalledFlowDelegate` implementation should be assigned to the `flow_delegate` field
102    /// of the `InstalledFlow` struct.
103    pub(crate) fn new(
104        app_secret: ApplicationSecret,
105        method: InstalledFlowReturnMethod,
106    ) -> InstalledFlow {
107        InstalledFlow {
108            app_secret,
109            method,
110            flow_delegate: Box::new(DefaultInstalledFlowDelegate),
111            force_account_selection: false,
112        }
113    }
114
115    /// Handles the token request flow; it consists of the following steps:
116    /// . Obtain a authorization code with user cooperation or internal redirect.
117    /// . Obtain a token and refresh token using that code.
118    /// . Return that token
119    ///
120    /// It's recommended not to use the DefaultInstalledFlowDelegate, but a specialized one.
121    pub(crate) async fn token<T>(
122        &self,
123        hyper_client: &impl SendRequest,
124        scopes: &[T],
125    ) -> Result<TokenInfo, Error>
126    where
127        T: AsRef<str>,
128    {
129        match self.method {
130            InstalledFlowReturnMethod::HTTPRedirect => {
131                self.ask_auth_code_via_http(hyper_client, None, &self.app_secret, scopes)
132                    .await
133            }
134            InstalledFlowReturnMethod::HTTPPortRedirect(port) => {
135                self.ask_auth_code_via_http(hyper_client, Some(port), &self.app_secret, scopes)
136                    .await
137            }
138            InstalledFlowReturnMethod::Interactive => {
139                self.ask_auth_code_interactively(hyper_client, &self.app_secret, scopes)
140                    .await
141            }
142        }
143    }
144
145    async fn ask_auth_code_interactively<T>(
146        &self,
147        hyper_client: &impl SendRequest,
148        app_secret: &ApplicationSecret,
149        scopes: &[T],
150    ) -> Result<TokenInfo, Error>
151    where
152        T: AsRef<str>,
153    {
154        let url = build_authentication_request_url(
155            &app_secret.auth_uri,
156            &app_secret.client_id,
157            scopes,
158            self.flow_delegate.redirect_uri(),
159            self.force_account_selection,
160        );
161        log::debug!("Presenting auth url to user: {}", url);
162        let auth_code = self
163            .flow_delegate
164            .present_user_url(&url, true /* need code */)
165            .await
166            .map_err(Error::UserError)?;
167        log::debug!("Received auth code: {}", auth_code);
168        self.exchange_auth_code(&auth_code, hyper_client, app_secret, None)
169            .await
170    }
171
172    async fn ask_auth_code_via_http<T>(
173        &self,
174        hyper_client: &impl SendRequest,
175        port: Option<u16>,
176        app_secret: &ApplicationSecret,
177        scopes: &[T],
178    ) -> Result<TokenInfo, Error>
179    where
180        T: AsRef<str>,
181    {
182        use std::borrow::Cow;
183        let server = InstalledFlowServer::run(port)?;
184        let server_addr = server.local_addr();
185
186        // Present url to user.
187        // The redirect URI must be this very localhost URL, otherwise authorization is refused
188        // by certain providers.
189        let redirect_uri: Cow<str> = match self.flow_delegate.redirect_uri() {
190            Some(uri) => uri.into(),
191            None => format!("http://localhost:{}", server_addr.port()).into(),
192        };
193        let url = build_authentication_request_url(
194            &app_secret.auth_uri,
195            &app_secret.client_id,
196            scopes,
197            Some(redirect_uri.as_ref()),
198            self.force_account_selection,
199        );
200        log::debug!("Presenting auth url to user: {}", url);
201        let _ = self
202            .flow_delegate
203            .present_user_url(&url, false /* need code */)
204            .await;
205        let auth_code = server.wait_for_auth_code().await;
206        self.exchange_auth_code(&auth_code, hyper_client, app_secret, Some(server_addr))
207            .await
208    }
209
210    async fn exchange_auth_code(
211        &self,
212        authcode: &str,
213        hyper_client: &impl SendRequest,
214        app_secret: &ApplicationSecret,
215        server_addr: Option<SocketAddr>,
216    ) -> Result<TokenInfo, Error> {
217        let redirect_uri = self.flow_delegate.redirect_uri();
218        let request = Self::request_token(app_secret, authcode, redirect_uri, server_addr);
219        log::debug!("Sending request: {:?}", request);
220        let (head, body) = hyper_client.request(request).await?.into_parts();
221        let body = body.collect().await?.to_bytes();
222        log::debug!("Received response; head: {:?} body: {:?}", head, body);
223        TokenInfo::from_json(&body)
224    }
225
226    /// Sends the authorization code to the provider in order to obtain access and refresh tokens.
227    fn request_token(
228        app_secret: &ApplicationSecret,
229        authcode: &str,
230        custom_redirect_uri: Option<&str>,
231        server_addr: Option<SocketAddr>,
232    ) -> http::Request<String> {
233        use std::borrow::Cow;
234        let redirect_uri: Cow<str> = match (custom_redirect_uri, server_addr) {
235            (Some(uri), _) => uri.into(),
236            (None, Some(addr)) => format!("http://localhost:{}", addr.port()).into(),
237            (None, None) => OOB_REDIRECT_URI.into(),
238        };
239
240        let body = form_urlencoded::Serializer::new(String::new())
241            .extend_pairs(vec![
242                ("code", authcode),
243                ("client_id", app_secret.client_id.as_str()),
244                ("client_secret", app_secret.client_secret.as_str()),
245                ("redirect_uri", redirect_uri.as_ref()),
246                ("grant_type", "authorization_code"),
247            ])
248            .finish();
249
250        http::Request::post(&app_secret.token_uri)
251            .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
252            .body(body)
253            .unwrap() // TODO: error check
254    }
255}
256
257struct InstalledFlowServer {
258    addr: SocketAddr,
259    auth_code_rx: oneshot::Receiver<String>,
260    trigger_shutdown_tx: oneshot::Sender<()>,
261    shutdown_complete: tokio::task::JoinHandle<()>,
262}
263
264impl InstalledFlowServer {
265    fn run(port: Option<u16>) -> Result<Self, Error> {
266        let (auth_code_tx, auth_code_rx) = oneshot::channel::<String>();
267        let (trigger_shutdown_tx, mut trigger_shutdown_rx) = oneshot::channel::<()>();
268        let auth_code_tx = Arc::new(Mutex::new(Some(auth_code_tx)));
269
270        let service = hyper::service::service_fn(move |req| {
271            installed_flow_server::handle_req(req, auth_code_tx.clone())
272        });
273
274        let addr: std::net::SocketAddr = match port {
275            Some(port) => ([127, 0, 0, 1], port).into(),
276            None => ([127, 0, 0, 1], 0).into(),
277        };
278
279        let server =
280            hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
281                .http1_only();
282        let graceful = hyper_util::server::graceful::GracefulShutdown::new();
283
284        let std_listener = std::net::TcpListener::bind(addr)?;
285        std_listener.set_nonblocking(true)?;
286        let addr = std_listener.local_addr()?;
287        let tcp_server = tokio::net::TcpListener::from_std(std_listener)?;
288
289        log::debug!("HTTP server listening on {}", addr);
290
291        let shutdown_complete = tokio::spawn(async move {
292            loop {
293                let conn = tokio::select! {
294                    Ok((conn,_)) = tcp_server.accept() => conn,
295                    _ = &mut trigger_shutdown_rx => break,
296                    else => break,
297                };
298
299                let conn = server
300                    .serve_connection(hyper_util::rt::TokioIo::new(conn), service.clone())
301                    .into_owned();
302
303                let conn = graceful.watch(conn);
304
305                tokio::spawn(async move {
306                    if let Err(err) = conn.await {
307                        log::debug!("connection error: {err}");
308                    }
309                });
310            }
311
312            tokio::select! {
313                _ = graceful.shutdown() => {
314                     log::debug!("Gracefully shutdown!");
315                },
316                _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
317                     log::debug!("Waited 10 seconds for graceful shutdown, aborting...");
318                }
319            }
320        });
321
322        Ok(InstalledFlowServer {
323            addr,
324            auth_code_rx,
325            trigger_shutdown_tx,
326            shutdown_complete,
327        })
328    }
329
330    fn local_addr(&self) -> SocketAddr {
331        self.addr
332    }
333
334    async fn wait_for_auth_code(self) -> String {
335        log::debug!("Waiting for HTTP server to receive auth code");
336        // Wait for the auth code from the server.
337        let auth_code = self
338            .auth_code_rx
339            .await
340            .expect("server shutdown while waiting for auth_code");
341        log::debug!("HTTP server received auth code: {}", auth_code);
342        log::debug!("Shutting down HTTP server");
343        // auth code received. shutdown the server
344        let _ = self.trigger_shutdown_tx.send(());
345        let _ = self.shutdown_complete.await;
346        auth_code
347    }
348}
349
350mod installed_flow_server {
351    use http::{Request, Response, StatusCode, Uri};
352    use std::sync::Arc;
353    use tokio::sync::{oneshot, Mutex};
354    use url::form_urlencoded;
355
356    pub(super) async fn handle_req<B: hyper::body::Body>(
357        req: Request<B>,
358        auth_code_tx: Arc<Mutex<Option<oneshot::Sender<String>>>>,
359    ) -> Result<Response<String>, http::Error> {
360        match req.uri().path_and_query() {
361            Some(path_and_query) => {
362                // We use a fake URL because the redirect goes to a URL, meaning we
363                // can't use the url form decode (because there's slashes and hashes and stuff in
364                // it).
365                let url = Uri::builder()
366                    .scheme("http")
367                    .authority("example.com")
368                    .path_and_query(path_and_query.clone())
369                    .build();
370
371                match url {
372                    Err(_) => http::Response::builder()
373                        .status(StatusCode::BAD_REQUEST)
374                        .body(String::from("Unparseable URL")),
375                    Ok(url) => match auth_code_from_url(url) {
376                        Some(auth_code) => {
377                            if let Some(sender) = auth_code_tx.lock().await.take() {
378                                let _ = sender.send(auth_code);
379                            }
380                            http::Response::builder()
381                                .status(StatusCode::OK)
382                                .body(String::from(
383                                    "<html><head><title>Success</title></head><body>You may now \
384                                     close this window.</body></html>",
385                                ))
386                        }
387                        None => http::Response::builder()
388                            .status(StatusCode::BAD_REQUEST)
389                            .body(String::from("No `code` in URL")),
390                    },
391                }
392            }
393            None => http::Response::builder()
394                .status(StatusCode::BAD_REQUEST)
395                .body(String::from("Invalid Request!")),
396        }
397    }
398
399    fn auth_code_from_url(url: http::Uri) -> Option<String> {
400        // The provider redirects to the specified localhost URL, appending the authorization
401        // code, like this: http://localhost:8080/xyz/?code=4/731fJ3BheyCouCniPufAd280GHNV5Ju35yYcGs
402        form_urlencoded::parse(url.query().unwrap_or("").as_bytes()).find_map(|(param, val)| {
403            if param == "code" {
404                Some(val.into_owned())
405            } else {
406                None
407            }
408        })
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use crate::client::LegacyClient;
415
416    use super::*;
417    use http::Uri;
418
419    #[test]
420    fn test_request_url_builder() {
421        assert_eq!(
422            "https://accounts.google.\
423             com/o/oauth2/auth?scope=email%20profile&access_type=offline&redirect_uri=urn:ietf:wg:oauth:2.0:\
424             oob&response_type=code&client_id=812741506391-h38jh0j4fv0ce1krdkiq0hfvt6n5amr\
425             f.apps.googleusercontent.com",
426            build_authentication_request_url(
427                "https://accounts.google.com/o/oauth2/auth",
428                "812741506391-h38jh0j4fv0ce1krdkiq0hfvt6n5am\
429                 rf.apps.googleusercontent.com",
430                &["email", "profile"],
431                None,
432                false
433            )
434        );
435    }
436
437    #[test]
438    fn test_request_url_builder_appends_queries() {
439        assert_eq!(
440            "https://accounts.google.\
441             com/o/oauth2/auth?unknown=testing&scope=email%20profile&access_type=offline&redirect_uri=urn:ietf:wg:oauth:2.0:\
442             oob&response_type=code&client_id=812741506391-h38jh0j4fv0ce1krdkiq0hfvt6n5amr\
443             f.apps.googleusercontent.com",
444            build_authentication_request_url(
445                "https://accounts.google.com/o/oauth2/auth?unknown=testing",
446                "812741506391-h38jh0j4fv0ce1krdkiq0hfvt6n5am\
447                 rf.apps.googleusercontent.com",
448                &["email", "profile"],
449                None,
450                false
451            )
452        );
453    }
454
455    #[tokio::test]
456    async fn test_server_random_local_port() {
457        let addr1 = InstalledFlowServer::run(None).unwrap().local_addr();
458        let addr2 = InstalledFlowServer::run(None).unwrap().local_addr();
459        assert_ne!(addr1.port(), addr2.port());
460    }
461
462    #[tokio::test]
463    async fn test_http_handle_url() {
464        let (tx, rx) = oneshot::channel();
465        // URLs are usually a bit botched
466        let url: Uri = "http://example.com:1234/?code=ab/c%2Fd#".parse().unwrap();
467        let req = http::Request::get(url).body(String::new()).unwrap();
468        installed_flow_server::handle_req(req, Arc::new(Mutex::new(Some(tx))))
469            .await
470            .unwrap();
471        assert_eq!(rx.await.unwrap().as_str(), "ab/c/d");
472    }
473
474    #[tokio::test]
475    async fn test_server() {
476        let client: LegacyClient<hyper_util::client::legacy::connect::HttpConnector> =
477            hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
478                .build_http();
479        let server = InstalledFlowServer::run(None).unwrap();
480
481        let response = client
482            .get(format!("http://{}/", server.local_addr()).parse().unwrap())
483            .await;
484        match response {
485            Result::Ok(_response) => {
486                // TODO: Do we really want this to assert success?
487                //assert!(response.status().is_success());
488            }
489            Result::Err(err) => {
490                panic!("Failed to request from local server: {:?}", err);
491            }
492        }
493
494        let response = client
495            .get(
496                format!("http://{}/?code=ab/c%2Fd#", server.local_addr())
497                    .parse()
498                    .unwrap(),
499            )
500            .await;
501        match response {
502            Result::Ok(response) => {
503                assert!(response.status().is_success());
504            }
505            Result::Err(err) => {
506                panic!("Failed to request from local server: {:?}", err);
507            }
508        }
509
510        assert_eq!(server.wait_for_auth_code().await.as_str(), "ab/c/d");
511    }
512}