1use crate::authenticator_delegate::{DefaultInstalledFlowDelegate, InstalledFlowDelegate};
6use crate::client::SendRequest;
7use crate::error::Error;
8use crate::types::{ApplicationSecret, TokenInfo};
9
10use http_body_util::BodyExt;
11use std::convert::AsRef;
12use std::net::SocketAddr;
13use std::sync::Arc;
14
15use http::header;
16use percent_encoding::{percent_encode, AsciiSet, CONTROLS};
17use tokio::sync::{oneshot, Mutex};
18use url::form_urlencoded;
19
20const QUERY_SET: AsciiSet = CONTROLS.add(b' ').add(b'"').add(b'#').add(b'<').add(b'>');
21
22const OOB_REDIRECT_URI: &str = "urn:ietf:wg:oauth:2.0:oob";
23
24fn build_authentication_request_url<T>(
28 auth_uri: &str,
29 client_id: &str,
30 scopes: &[T],
31 redirect_uri: Option<&str>,
32 force_account_selection: bool,
33) -> String
34where
35 T: AsRef<str>,
36{
37 let mut url = String::new();
38 let scopes_string = crate::helper::join(scopes, " ");
39
40 url.push_str(auth_uri);
41
42 if !url.contains('?') {
43 url.push('?');
44 } else {
45 match url.chars().last() {
46 Some('?') | None => {}
47 Some(_) => url.push('&'),
48 }
49 }
50
51 let mut params = vec![
52 format!("scope={}", scopes_string),
53 "&access_type=offline".to_string(),
54 format!("&redirect_uri={}", redirect_uri.unwrap_or(OOB_REDIRECT_URI)),
55 "&response_type=code".to_string(),
56 format!("&client_id={}", client_id),
57 ];
58 if force_account_selection {
59 params.push("&prompt=select_account+consent".to_string());
60 }
61 params.into_iter().fold(url, |mut u, param| {
62 u.push_str(&percent_encode(param.as_ref(), &QUERY_SET).to_string());
63 u
64 })
65}
66
67pub enum InstalledFlowReturnMethod {
71 Interactive,
74 HTTPRedirect,
77 HTTPPortRedirect(u16),
80}
81
82pub struct InstalledFlow {
86 pub(crate) app_secret: ApplicationSecret,
87 pub(crate) method: InstalledFlowReturnMethod,
88 pub(crate) flow_delegate: Box<dyn InstalledFlowDelegate>,
89 pub(crate) force_account_selection: bool,
90}
91
92impl InstalledFlow {
93 pub(crate) fn new(
104 app_secret: ApplicationSecret,
105 method: InstalledFlowReturnMethod,
106 ) -> InstalledFlow {
107 InstalledFlow {
108 app_secret,
109 method,
110 flow_delegate: Box::new(DefaultInstalledFlowDelegate),
111 force_account_selection: false,
112 }
113 }
114
115 pub(crate) async fn token<T>(
122 &self,
123 hyper_client: &impl SendRequest,
124 scopes: &[T],
125 ) -> Result<TokenInfo, Error>
126 where
127 T: AsRef<str>,
128 {
129 match self.method {
130 InstalledFlowReturnMethod::HTTPRedirect => {
131 self.ask_auth_code_via_http(hyper_client, None, &self.app_secret, scopes)
132 .await
133 }
134 InstalledFlowReturnMethod::HTTPPortRedirect(port) => {
135 self.ask_auth_code_via_http(hyper_client, Some(port), &self.app_secret, scopes)
136 .await
137 }
138 InstalledFlowReturnMethod::Interactive => {
139 self.ask_auth_code_interactively(hyper_client, &self.app_secret, scopes)
140 .await
141 }
142 }
143 }
144
145 async fn ask_auth_code_interactively<T>(
146 &self,
147 hyper_client: &impl SendRequest,
148 app_secret: &ApplicationSecret,
149 scopes: &[T],
150 ) -> Result<TokenInfo, Error>
151 where
152 T: AsRef<str>,
153 {
154 let url = build_authentication_request_url(
155 &app_secret.auth_uri,
156 &app_secret.client_id,
157 scopes,
158 self.flow_delegate.redirect_uri(),
159 self.force_account_selection,
160 );
161 log::debug!("Presenting auth url to user: {}", url);
162 let auth_code = self
163 .flow_delegate
164 .present_user_url(&url, true )
165 .await
166 .map_err(Error::UserError)?;
167 log::debug!("Received auth code: {}", auth_code);
168 self.exchange_auth_code(&auth_code, hyper_client, app_secret, None)
169 .await
170 }
171
172 async fn ask_auth_code_via_http<T>(
173 &self,
174 hyper_client: &impl SendRequest,
175 port: Option<u16>,
176 app_secret: &ApplicationSecret,
177 scopes: &[T],
178 ) -> Result<TokenInfo, Error>
179 where
180 T: AsRef<str>,
181 {
182 use std::borrow::Cow;
183 let server = InstalledFlowServer::run(port)?;
184 let server_addr = server.local_addr();
185
186 let redirect_uri: Cow<str> = match self.flow_delegate.redirect_uri() {
190 Some(uri) => uri.into(),
191 None => format!("http://localhost:{}", server_addr.port()).into(),
192 };
193 let url = build_authentication_request_url(
194 &app_secret.auth_uri,
195 &app_secret.client_id,
196 scopes,
197 Some(redirect_uri.as_ref()),
198 self.force_account_selection,
199 );
200 log::debug!("Presenting auth url to user: {}", url);
201 let _ = self
202 .flow_delegate
203 .present_user_url(&url, false )
204 .await;
205 let auth_code = server.wait_for_auth_code().await;
206 self.exchange_auth_code(&auth_code, hyper_client, app_secret, Some(server_addr))
207 .await
208 }
209
210 async fn exchange_auth_code(
211 &self,
212 authcode: &str,
213 hyper_client: &impl SendRequest,
214 app_secret: &ApplicationSecret,
215 server_addr: Option<SocketAddr>,
216 ) -> Result<TokenInfo, Error> {
217 let redirect_uri = self.flow_delegate.redirect_uri();
218 let request = Self::request_token(app_secret, authcode, redirect_uri, server_addr);
219 log::debug!("Sending request: {:?}", request);
220 let (head, body) = hyper_client.request(request).await?.into_parts();
221 let body = body.collect().await?.to_bytes();
222 log::debug!("Received response; head: {:?} body: {:?}", head, body);
223 TokenInfo::from_json(&body)
224 }
225
226 fn request_token(
228 app_secret: &ApplicationSecret,
229 authcode: &str,
230 custom_redirect_uri: Option<&str>,
231 server_addr: Option<SocketAddr>,
232 ) -> http::Request<String> {
233 use std::borrow::Cow;
234 let redirect_uri: Cow<str> = match (custom_redirect_uri, server_addr) {
235 (Some(uri), _) => uri.into(),
236 (None, Some(addr)) => format!("http://localhost:{}", addr.port()).into(),
237 (None, None) => OOB_REDIRECT_URI.into(),
238 };
239
240 let body = form_urlencoded::Serializer::new(String::new())
241 .extend_pairs(vec![
242 ("code", authcode),
243 ("client_id", app_secret.client_id.as_str()),
244 ("client_secret", app_secret.client_secret.as_str()),
245 ("redirect_uri", redirect_uri.as_ref()),
246 ("grant_type", "authorization_code"),
247 ])
248 .finish();
249
250 http::Request::post(&app_secret.token_uri)
251 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
252 .body(body)
253 .unwrap() }
255}
256
257struct InstalledFlowServer {
258 addr: SocketAddr,
259 auth_code_rx: oneshot::Receiver<String>,
260 trigger_shutdown_tx: oneshot::Sender<()>,
261 shutdown_complete: tokio::task::JoinHandle<()>,
262}
263
264impl InstalledFlowServer {
265 fn run(port: Option<u16>) -> Result<Self, Error> {
266 let (auth_code_tx, auth_code_rx) = oneshot::channel::<String>();
267 let (trigger_shutdown_tx, mut trigger_shutdown_rx) = oneshot::channel::<()>();
268 let auth_code_tx = Arc::new(Mutex::new(Some(auth_code_tx)));
269
270 let service = hyper::service::service_fn(move |req| {
271 installed_flow_server::handle_req(req, auth_code_tx.clone())
272 });
273
274 let addr: std::net::SocketAddr = match port {
275 Some(port) => ([127, 0, 0, 1], port).into(),
276 None => ([127, 0, 0, 1], 0).into(),
277 };
278
279 let server =
280 hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
281 .http1_only();
282 let graceful = hyper_util::server::graceful::GracefulShutdown::new();
283
284 let std_listener = std::net::TcpListener::bind(addr)?;
285 std_listener.set_nonblocking(true)?;
286 let addr = std_listener.local_addr()?;
287 let tcp_server = tokio::net::TcpListener::from_std(std_listener)?;
288
289 log::debug!("HTTP server listening on {}", addr);
290
291 let shutdown_complete = tokio::spawn(async move {
292 loop {
293 let conn = tokio::select! {
294 Ok((conn,_)) = tcp_server.accept() => conn,
295 _ = &mut trigger_shutdown_rx => break,
296 else => break,
297 };
298
299 let conn = server
300 .serve_connection(hyper_util::rt::TokioIo::new(conn), service.clone())
301 .into_owned();
302
303 let conn = graceful.watch(conn);
304
305 tokio::spawn(async move {
306 if let Err(err) = conn.await {
307 log::debug!("connection error: {err}");
308 }
309 });
310 }
311
312 tokio::select! {
313 _ = graceful.shutdown() => {
314 log::debug!("Gracefully shutdown!");
315 },
316 _ = tokio::time::sleep(std::time::Duration::from_secs(10)) => {
317 log::debug!("Waited 10 seconds for graceful shutdown, aborting...");
318 }
319 }
320 });
321
322 Ok(InstalledFlowServer {
323 addr,
324 auth_code_rx,
325 trigger_shutdown_tx,
326 shutdown_complete,
327 })
328 }
329
330 fn local_addr(&self) -> SocketAddr {
331 self.addr
332 }
333
334 async fn wait_for_auth_code(self) -> String {
335 log::debug!("Waiting for HTTP server to receive auth code");
336 let auth_code = self
338 .auth_code_rx
339 .await
340 .expect("server shutdown while waiting for auth_code");
341 log::debug!("HTTP server received auth code: {}", auth_code);
342 log::debug!("Shutting down HTTP server");
343 let _ = self.trigger_shutdown_tx.send(());
345 let _ = self.shutdown_complete.await;
346 auth_code
347 }
348}
349
350mod installed_flow_server {
351 use http::{Request, Response, StatusCode, Uri};
352 use std::sync::Arc;
353 use tokio::sync::{oneshot, Mutex};
354 use url::form_urlencoded;
355
356 pub(super) async fn handle_req<B: hyper::body::Body>(
357 req: Request<B>,
358 auth_code_tx: Arc<Mutex<Option<oneshot::Sender<String>>>>,
359 ) -> Result<Response<String>, http::Error> {
360 match req.uri().path_and_query() {
361 Some(path_and_query) => {
362 let url = Uri::builder()
366 .scheme("http")
367 .authority("example.com")
368 .path_and_query(path_and_query.clone())
369 .build();
370
371 match url {
372 Err(_) => http::Response::builder()
373 .status(StatusCode::BAD_REQUEST)
374 .body(String::from("Unparseable URL")),
375 Ok(url) => match auth_code_from_url(url) {
376 Some(auth_code) => {
377 if let Some(sender) = auth_code_tx.lock().await.take() {
378 let _ = sender.send(auth_code);
379 }
380 http::Response::builder()
381 .status(StatusCode::OK)
382 .body(String::from(
383 "<html><head><title>Success</title></head><body>You may now \
384 close this window.</body></html>",
385 ))
386 }
387 None => http::Response::builder()
388 .status(StatusCode::BAD_REQUEST)
389 .body(String::from("No `code` in URL")),
390 },
391 }
392 }
393 None => http::Response::builder()
394 .status(StatusCode::BAD_REQUEST)
395 .body(String::from("Invalid Request!")),
396 }
397 }
398
399 fn auth_code_from_url(url: http::Uri) -> Option<String> {
400 form_urlencoded::parse(url.query().unwrap_or("").as_bytes()).find_map(|(param, val)| {
403 if param == "code" {
404 Some(val.into_owned())
405 } else {
406 None
407 }
408 })
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use crate::client::LegacyClient;
415
416 use super::*;
417 use http::Uri;
418
419 #[test]
420 fn test_request_url_builder() {
421 assert_eq!(
422 "https://accounts.google.\
423 com/o/oauth2/auth?scope=email%20profile&access_type=offline&redirect_uri=urn:ietf:wg:oauth:2.0:\
424 oob&response_type=code&client_id=812741506391-h38jh0j4fv0ce1krdkiq0hfvt6n5amr\
425 f.apps.googleusercontent.com",
426 build_authentication_request_url(
427 "https://accounts.google.com/o/oauth2/auth",
428 "812741506391-h38jh0j4fv0ce1krdkiq0hfvt6n5am\
429 rf.apps.googleusercontent.com",
430 &["email", "profile"],
431 None,
432 false
433 )
434 );
435 }
436
437 #[test]
438 fn test_request_url_builder_appends_queries() {
439 assert_eq!(
440 "https://accounts.google.\
441 com/o/oauth2/auth?unknown=testing&scope=email%20profile&access_type=offline&redirect_uri=urn:ietf:wg:oauth:2.0:\
442 oob&response_type=code&client_id=812741506391-h38jh0j4fv0ce1krdkiq0hfvt6n5amr\
443 f.apps.googleusercontent.com",
444 build_authentication_request_url(
445 "https://accounts.google.com/o/oauth2/auth?unknown=testing",
446 "812741506391-h38jh0j4fv0ce1krdkiq0hfvt6n5am\
447 rf.apps.googleusercontent.com",
448 &["email", "profile"],
449 None,
450 false
451 )
452 );
453 }
454
455 #[tokio::test]
456 async fn test_server_random_local_port() {
457 let addr1 = InstalledFlowServer::run(None).unwrap().local_addr();
458 let addr2 = InstalledFlowServer::run(None).unwrap().local_addr();
459 assert_ne!(addr1.port(), addr2.port());
460 }
461
462 #[tokio::test]
463 async fn test_http_handle_url() {
464 let (tx, rx) = oneshot::channel();
465 let url: Uri = "http://example.com:1234/?code=ab/c%2Fd#".parse().unwrap();
467 let req = http::Request::get(url).body(String::new()).unwrap();
468 installed_flow_server::handle_req(req, Arc::new(Mutex::new(Some(tx))))
469 .await
470 .unwrap();
471 assert_eq!(rx.await.unwrap().as_str(), "ab/c/d");
472 }
473
474 #[tokio::test]
475 async fn test_server() {
476 let client: LegacyClient<hyper_util::client::legacy::connect::HttpConnector> =
477 hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
478 .build_http();
479 let server = InstalledFlowServer::run(None).unwrap();
480
481 let response = client
482 .get(format!("http://{}/", server.local_addr()).parse().unwrap())
483 .await;
484 match response {
485 Result::Ok(_response) => {
486 }
489 Result::Err(err) => {
490 panic!("Failed to request from local server: {:?}", err);
491 }
492 }
493
494 let response = client
495 .get(
496 format!("http://{}/?code=ab/c%2Fd#", server.local_addr())
497 .parse()
498 .unwrap(),
499 )
500 .await;
501 match response {
502 Result::Ok(response) => {
503 assert!(response.status().is_success());
504 }
505 Result::Err(err) => {
506 panic!("Failed to request from local server: {:?}", err);
507 }
508 }
509
510 assert_eq!(server.wait_for_auth_code().await.as_str(), "ab/c/d");
511 }
512}