use futures::channel::oneshot;
use http::StatusCode;
use hyper::server::conn::AddrIncoming;
use hyper::{Body, Response, Server};
use hyperx::header::{ContentLength, ContentType};
use serde::Serialize;
use std::collections::HashMap;
use std::error::Error as StdError;
use std::io;
use std::net::{TcpStream, ToSocketAddrs};
use std::sync::mpsc;
use std::time::Duration;
use tokio::runtime::Runtime;
use url::Url;
use uuid::Uuid;
use crate::util::RequestExt;
use crate::errors::*;
pub const VALID_PORTS: &[u16] = &[12731, 32492, 56909];
const MIN_TOKEN_VALIDITY: Duration = Duration::from_secs(2 * 24 * 60 * 60);
const MIN_TOKEN_VALIDITY_WARNING: &str = "two days";
fn query_pairs(url: &str) -> Result<HashMap<String, String>> {
let url = Url::parse("http://unused_base")
.expect("Failed to parse fake url prefix")
.join(url)
.context("Failed to parse url while extracting query params")?;
Ok(url
.query_pairs()
.map(|(k, v)| (k.into_owned(), v.into_owned()))
.collect())
}
fn html_response(body: &'static str) -> Response<Body> {
Response::builder()
.set_header(ContentType::html())
.set_header(ContentLength(body.len() as u64))
.body(body.into())
.unwrap()
}
fn json_response<T: Serialize>(data: &T) -> Result<Response<Body>> {
let body = serde_json::to_vec(data).context("Failed to serialize to JSON")?;
let len = body.len();
Ok(Response::builder()
.set_header(ContentType::json())
.set_header(ContentLength(len as u64))
.body(body.into())
.unwrap())
}
const REDIRECT_WITH_AUTH_JSON: &str = r##"<!doctype html>
<html lang="en">
<head><meta charset="utf-8"></head>
<body>
<script>
function writemsg(m) {
document.body.appendChild(document.createTextNode(m.toString()));
document.body.appendChild(document.createElement('br'));
}
function go() {
writemsg('Retrieving details of authenticator...');
fetch('/auth_detail.json').then(function (response) {
if (!response.ok) {
throw 'Error during retrieval - ' + response.status + ': ' + response.statusText;
}
writemsg('Using details to redirect to authentication page...');
return response.json()
}).then(function (auth_url) {
window.location.href = auth_url;
}).catch(writemsg);
}
go();
</script>
</body>
</html>
"##;
mod code_grant_pkce {
use super::{
html_response, json_response, query_pairs, MIN_TOKEN_VALIDITY, MIN_TOKEN_VALIDITY_WARNING,
REDIRECT_WITH_AUTH_JSON,
};
use futures::channel::oneshot;
use hyper::{Body, Method, Request, Response, StatusCode};
use rand::{rngs::OsRng, RngCore};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::mpsc;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use url::Url;
use crate::errors::*;
const CLIENT_ID_PARAM: &str = "client_id";
const CODE_CHALLENGE_PARAM: &str = "code_challenge";
const CODE_CHALLENGE_METHOD_PARAM: &str = "code_challenge_method";
const CODE_CHALLENGE_METHOD_VALUE: &str = "S256";
const REDIRECT_PARAM: &str = "redirect_uri";
const RESPONSE_TYPE_PARAM: &str = "response_type";
const RESPONSE_TYPE_PARAM_VALUE: &str = "code";
const STATE_PARAM: &str = "state";
const CODE_RESULT_PARAM: &str = "code";
const STATE_RESULT_PARAM: &str = "state";
#[derive(Serialize)]
struct TokenRequest<'a> {
client_id: &'a str,
code_verifier: &'a str,
code: &'a str,
grant_type: &'a str,
redirect_uri: &'a str,
}
const GRANT_TYPE_PARAM_VALUE: &str = "authorization_code";
#[derive(Deserialize)]
struct TokenResponse {
access_token: String,
token_type: String,
expires_in: u64, }
const TOKEN_TYPE_RESULT_PARAM_VALUE: &str = "bearer";
const NUM_CODE_VERIFIER_BYTES: usize = 256 / 8;
pub struct State {
pub auth_url: String,
pub auth_state_value: String,
pub code_tx: mpsc::SyncSender<String>,
pub shutdown_tx: Option<oneshot::Sender<()>>,
}
lazy_static! {
pub static ref STATE: Mutex<Option<State>> = Mutex::new(None);
}
pub fn generate_verifier_and_challenge() -> Result<(String, String)> {
let mut code_verifier_bytes = vec![0; NUM_CODE_VERIFIER_BYTES];
OsRng.fill_bytes(&mut code_verifier_bytes);
let code_verifier = base64::encode_config(&code_verifier_bytes, base64::URL_SAFE_NO_PAD);
let mut hasher = Sha256::new();
hasher.update(&code_verifier);
let code_challenge = base64::encode_config(&hasher.finalize(), base64::URL_SAFE_NO_PAD);
Ok((code_verifier, code_challenge))
}
pub fn finish_url(
client_id: &str,
url: &mut Url,
redirect_uri: &str,
state: &str,
code_challenge: &str,
) {
url.query_pairs_mut()
.append_pair(CLIENT_ID_PARAM, client_id)
.append_pair(CODE_CHALLENGE_PARAM, code_challenge)
.append_pair(CODE_CHALLENGE_METHOD_PARAM, CODE_CHALLENGE_METHOD_VALUE)
.append_pair(REDIRECT_PARAM, redirect_uri)
.append_pair(RESPONSE_TYPE_PARAM, RESPONSE_TYPE_PARAM_VALUE)
.append_pair(STATE_PARAM, state);
}
fn handle_code_response(params: HashMap<String, String>) -> Result<(String, String)> {
let code = params
.get(CODE_RESULT_PARAM)
.context("No code found in response")?;
let state = params
.get(STATE_RESULT_PARAM)
.context("No state found in response")?;
Ok((code.to_owned(), state.to_owned()))
}
fn handle_token_response(res: TokenResponse) -> Result<(String, Instant)> {
let token = res.access_token;
if res.token_type.to_lowercase() != TOKEN_TYPE_RESULT_PARAM_VALUE {
bail!(
"Token type in response is not {}",
TOKEN_TYPE_RESULT_PARAM_VALUE
)
}
let expires_at = Instant::now() + Duration::from_secs(res.expires_in);
Ok((token, expires_at))
}
const SUCCESS_AFTER_REDIRECT: &str = r##"<!doctype html>
<html lang="en">
<head><meta charset="utf-8"></head>
<body>In-browser step of authentication complete, you can now close this page!</body>
</html>
"##;
pub fn serve(req: Request<Body>) -> Result<Response<Body>> {
let mut state = STATE.lock().unwrap();
let state = state.as_mut().unwrap();
debug!("Handling {} {}", req.method(), req.uri());
let response = match (req.method(), req.uri().path()) {
(&Method::GET, "/") => html_response(REDIRECT_WITH_AUTH_JSON),
(&Method::GET, "/auth_detail.json") => json_response(&state.auth_url)?,
(&Method::GET, "/redirect") => {
let query_pairs = query_pairs(&req.uri().to_string())?;
let (code, auth_state) = handle_code_response(query_pairs)
.context("Failed to handle response from redirect")?;
if auth_state != state.auth_state_value {
return Err(anyhow!("Mismatched auth states after redirect"));
}
state.code_tx.send(code).unwrap();
state.shutdown_tx.take().unwrap().send(()).unwrap();
html_response(SUCCESS_AFTER_REDIRECT)
}
_ => {
warn!("Route not found");
Response::builder()
.status(StatusCode::NOT_FOUND)
.body("".into())?
}
};
Ok(response)
}
pub async fn code_to_token(
token_url: &str,
client_id: &str,
code_verifier: &str,
code: &str,
redirect_uri: &str,
) -> Result<String> {
let token_request = TokenRequest {
client_id,
code_verifier,
code,
grant_type: GRANT_TYPE_PARAM_VALUE,
redirect_uri,
};
let client = reqwest::Client::new();
let res = client.post(token_url).json(&token_request).send().await?;
if !res.status().is_success() {
bail!(
"Sending code to {} failed, HTTP error: {}",
token_url,
res.status()
)
}
let (token, expires_at) = handle_token_response(
res.json()
.await
.context("Failed to parse token response as JSON")?,
)?;
if expires_at - Instant::now() < MIN_TOKEN_VALIDITY {
warn!(
"Token retrieved expires in under {}",
MIN_TOKEN_VALIDITY_WARNING
);
eprintln!(
"cachepot: Token retrieved expires in under {}",
MIN_TOKEN_VALIDITY_WARNING
);
}
Ok(token)
}
}
mod implicit {
use super::{
html_response, json_response, query_pairs, MIN_TOKEN_VALIDITY, MIN_TOKEN_VALIDITY_WARNING,
REDIRECT_WITH_AUTH_JSON,
};
use futures::channel::oneshot;
use hyper::{Body, Method, Request, Response, StatusCode};
use std::collections::HashMap;
use std::sync::mpsc;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use url::Url;
use crate::errors::*;
const CLIENT_ID_PARAM: &str = "client_id";
const REDIRECT_PARAM: &str = "redirect_uri";
const RESPONSE_TYPE_PARAM: &str = "response_type";
const RESPONSE_TYPE_PARAM_VALUE: &str = "token";
const STATE_PARAM: &str = "state";
const TOKEN_RESULT_PARAM: &str = "access_token";
const TOKEN_TYPE_RESULT_PARAM: &str = "token_type";
const TOKEN_TYPE_RESULT_PARAM_VALUE: &str = "bearer"; const EXPIRES_IN_RESULT_PARAM: &str = "expires_in"; const STATE_RESULT_PARAM: &str = "state";
pub struct State {
pub auth_url: String,
pub auth_state_value: String,
pub token_tx: mpsc::SyncSender<String>,
pub shutdown_tx: Option<oneshot::Sender<()>>,
}
lazy_static! {
pub static ref STATE: Mutex<Option<State>> = Mutex::new(None);
}
pub fn finish_url(client_id: &str, url: &mut Url, redirect_uri: &str, state: &str) {
url.query_pairs_mut()
.append_pair(CLIENT_ID_PARAM, client_id)
.append_pair(REDIRECT_PARAM, redirect_uri)
.append_pair(RESPONSE_TYPE_PARAM, RESPONSE_TYPE_PARAM_VALUE)
.append_pair(STATE_PARAM, state);
}
fn handle_response(params: HashMap<String, String>) -> Result<(String, Instant, String)> {
let token = params
.get(TOKEN_RESULT_PARAM)
.context("No token found in response")?;
let bearer = params
.get(TOKEN_TYPE_RESULT_PARAM)
.context("No token type found in response")?;
if bearer.to_lowercase() != TOKEN_TYPE_RESULT_PARAM_VALUE {
bail!(
"Token type in response is not {}",
TOKEN_TYPE_RESULT_PARAM_VALUE
)
}
let expires_in = params
.get(EXPIRES_IN_RESULT_PARAM)
.context("No expiry found in response")?;
let expires_at = Instant::now()
+ Duration::from_secs(
expires_in
.parse()
.map_err(|_| anyhow!("Failed to parse expiry as integer"))?,
);
let state = params
.get(STATE_RESULT_PARAM)
.context("No state found in response")?;
Ok((token.to_owned(), expires_at, state.to_owned()))
}
const SAVE_AUTH_AFTER_REDIRECT: &str = r##"<!doctype html>
<html lang="en">
<head><meta charset="utf-8"></head>
<body>
<script>
function writemsg(m) {
document.body.appendChild(document.createTextNode(m.toString()));
document.body.appendChild(document.createElement('br'));
}
function go() {
writemsg('Saving authentication details...');
var qs = window.location.hash.slice(1);
if (qs.length === 0) {
writemsg("ERROR: No URL hash returned from authorizer");
return
}
fetch('/save_auth?' + qs, { method: 'POST' }).then(function (response) {
if (!response.ok) {
throw 'Error during saving authentication - ' + response.status + ': ' + response.statusText;
}
writemsg('Authentication complete, you can now close this page!');
}).catch(writemsg);
}
go();
</script>
</body>
</html>
"##;
pub fn serve(req: Request<Body>) -> Result<Response<Body>> {
let mut state = STATE.lock().unwrap();
let state = state.as_mut().unwrap();
debug!("Handling {} {}", req.method(), req.uri());
let response = match (req.method(), req.uri().path()) {
(&Method::GET, "/") => html_response(REDIRECT_WITH_AUTH_JSON),
(&Method::GET, "/auth_detail.json") => json_response(&state.auth_url)?,
(&Method::GET, "/redirect") => html_response(SAVE_AUTH_AFTER_REDIRECT),
(&Method::POST, "/save_auth") => {
let query_pairs = query_pairs(&req.uri().to_string())?;
let (token, expires_at, auth_state) =
handle_response(query_pairs).context("Failed to save auth after redirect")?;
if auth_state != state.auth_state_value {
return Err(anyhow!("Mismatched auth states after redirect"));
}
if expires_at - Instant::now() < MIN_TOKEN_VALIDITY {
warn!(
"Token retrieved expires in under {}",
MIN_TOKEN_VALIDITY_WARNING
);
eprintln!(
"cachepot: Token retrieved expires in under {}",
MIN_TOKEN_VALIDITY_WARNING
);
}
state.token_tx.send(token).unwrap();
state.shutdown_tx.take().unwrap().send(()).unwrap();
json_response(&"")?
}
_ => {
warn!("Route not found");
Response::builder()
.status(StatusCode::NOT_FOUND)
.body("".into())
.unwrap()
}
};
Ok(response)
}
}
macro_rules! make_service {
($serve_fn: expr) => {{
use core::convert::Infallible;
use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request};
make_service_fn(|_socket: &AddrStream| async {
Ok::<_, Infallible>(service_fn(|req: Request<Body>| async move {
let uri = req.uri().clone();
$serve_fn(req).or_else(|e| error_code_response(uri, e))
}))
})
}};
}
#[allow(clippy::unnecessary_wraps)]
fn error_code_response<E>(uri: hyper::Uri, e: E) -> hyper::Result<Response<Body>>
where
E: std::fmt::Debug,
{
let body = format!("{:?}", e);
eprintln!(
"cachepot: Error during a request to {} on the client auth web server\n{}",
uri, body
);
let len = body.len();
let builder = Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR);
let res = builder
.set_header(ContentType::text())
.set_header(ContentLength(len as u64))
.body(body.into())
.unwrap();
Ok::<Response<Body>, hyper::Error>(res)
}
fn try_bind() -> Result<hyper::server::Builder<AddrIncoming>> {
for &port in VALID_PORTS {
let mut addrs = ("localhost", port)
.to_socket_addrs()
.expect("Failed to interpret localhost address to listen on");
let addr = addrs
.next()
.expect("Expected at least one address in parsed socket address");
match TcpStream::connect(addr) {
Ok(_) => continue,
Err(ref e) if e.kind() == io::ErrorKind::ConnectionRefused => (),
Err(e) => {
return Err(e)
.with_context(|| format!("Failed to check {} is available for binding", addr))
}
}
match Server::try_bind(&addr) {
Ok(s) => return Ok(s),
Err(ref err)
if err
.source()
.and_then(|err| err.downcast_ref::<io::Error>())
.map(|err| err.kind() == io::ErrorKind::AddrInUse)
.unwrap_or(false) =>
{
continue
}
Err(e) => return Err(e).with_context(|| format!("Failed to bind to {}", addr)),
}
}
bail!("Could not bind to any valid port: ({:?})", VALID_PORTS)
}
pub fn get_token_oauth2_code_grant_pkce(
client_id: &str,
mut auth_url: Url,
token_url: &str,
) -> Result<String> {
let runtime = Runtime::new()?;
let builder = {
let _guard = runtime.enter();
try_bind()?
};
let server = builder.serve(make_service!(code_grant_pkce::serve));
let port = server.local_addr().port();
let redirect_uri = format!("http://localhost:{}/redirect", port);
let auth_state_value = Uuid::new_v4().to_simple_ref().to_string();
let (verifier, challenge) = code_grant_pkce::generate_verifier_and_challenge()?;
code_grant_pkce::finish_url(
client_id,
&mut auth_url,
&redirect_uri,
&auth_state_value,
&challenge,
);
info!("Listening on http://localhost:{} with 1 thread.", port);
println!(
"cachepot: Please visit http://localhost:{} in your browser",
port
);
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (code_tx, code_rx) = mpsc::sync_channel(1);
let state = code_grant_pkce::State {
auth_url: auth_url.to_string(),
auth_state_value,
code_tx,
shutdown_tx: Some(shutdown_tx),
};
*code_grant_pkce::STATE.lock().unwrap() = Some(state);
runtime.block_on(server.with_graceful_shutdown(async move {
if let Err(e) = shutdown_rx.await {
warn!(
"Something went wrong while waiting for auth server shutdown: {}",
e
)
}
}))?;
info!("Server finished, using code to request token");
let code = code_rx
.try_recv()
.expect("Hyper shutdown but code not available - internal error");
runtime.block_on(async move {
code_grant_pkce::code_to_token(token_url, client_id, &verifier, &code, &redirect_uri)
.await
.context("Failed to convert oauth2 code into a token")
})
}
pub fn get_token_oauth2_implicit(client_id: &str, mut auth_url: Url) -> Result<String> {
let runtime = Runtime::new()?;
let builder = {
let _guard = runtime.enter();
try_bind()?
};
let server = builder.serve(make_service!(implicit::serve));
let port = server.local_addr().port();
let redirect_uri = format!("http://localhost:{}/redirect", port);
let auth_state_value = Uuid::new_v4().to_simple_ref().to_string();
implicit::finish_url(client_id, &mut auth_url, &redirect_uri, &auth_state_value);
info!("Listening on http://localhost:{} with 1 thread.", port);
println!(
"cachepot: Please visit http://localhost:{} in your browser",
port
);
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (token_tx, token_rx) = mpsc::sync_channel(1);
let state = implicit::State {
auth_url: auth_url.to_string(),
auth_state_value,
token_tx,
shutdown_tx: Some(shutdown_tx),
};
*implicit::STATE.lock().unwrap() = Some(state);
runtime.block_on(server.with_graceful_shutdown(async move {
if let Err(e) = shutdown_rx.await {
warn!(
"Something went wrong while waiting for auth server shutdown: {}",
e
)
}
}))?;
info!("Server finished, returning token");
Ok(token_rx
.try_recv()
.expect("Hyper shutdown but token not available - internal error"))
}