use std::fmt;
use ring::rand::{SecureRandom, SystemRandom};
use rocket::fairing::{AdHoc, Fairing};
use rocket::handler;
use rocket::http::uri::Absolute;
use rocket::http::{Cookie, Cookies, Method, SameSite, Status};
use rocket::outcome::{IntoOutcome, Outcome};
use rocket::request::{FormItems, FromForm, Request};
use rocket::response::{Redirect, Responder};
use rocket::{Data, Route, State};
use serde_json::Value;
use crate::{Error, ErrorKind, OAuthConfig};
const STATE_COOKIE_NAME: &str = "rocket_oauth2_state";
fn generate_state(rng: &dyn SecureRandom) -> Result<String, Error> {
let mut buf = [0; 16];
rng.fill(&mut buf).map_err(|_| {
Error::new_from(
ErrorKind::Other,
String::from("Failed to generate random data"),
)
})?;
Ok(base64::encode_config(&buf, base64::URL_SAFE_NO_PAD))
}
#[derive(Clone, PartialEq, Debug)]
pub enum TokenRequest {
AuthorizationCode(String),
RefreshToken(String),
}
#[derive(Clone, PartialEq, Debug)]
pub struct TokenResponse {
data: Value,
}
impl std::convert::TryFrom<Value> for TokenResponse {
type Error = Error;
fn try_from(data: Value) -> Result<Self, Error> {
if !data.is_object() {
return Err(Error::new_from(
ErrorKind::ExchangeFailure,
String::from("TokenResponse data was not an object"),
));
}
match data.get("access_token") {
Some(val) if val.is_string() => (),
_ => {
return Err(Error::new_from(
ErrorKind::ExchangeFailure,
String::from("TokenResponse access_token was missing or not a string"),
))
}
}
match data.get("token_type") {
Some(val) if val.is_string() => (),
_ => {
return Err(Error::new_from(
ErrorKind::ExchangeFailure,
String::from("TokenResponse token_type was missing or not a string"),
))
}
}
Ok(Self { data })
}
}
impl TokenResponse {
pub fn as_value(&self) -> &Value {
&self.data
}
pub fn access_token(&self) -> &str {
self.data
.get("access_token")
.and_then(Value::as_str)
.expect("access_token required at construction")
}
pub fn token_type(&self) -> &str {
self.data
.get("token_type")
.and_then(Value::as_str)
.expect("token_type required at construction")
}
pub fn expires_in(&self) -> Option<i64> {
self.data.get("expires_in").and_then(Value::as_i64)
}
pub fn refresh_token(&self) -> Option<&str> {
self.data.get("refresh_token").and_then(Value::as_str)
}
pub fn scope(&self) -> Option<&str> {
self.data.get("scope").and_then(Value::as_str)
}
}
pub trait Adapter: Send + Sync + 'static {
fn authorization_uri(
&self,
config: &OAuthConfig,
state: &str,
scopes: &[&str],
) -> Result<Absolute<'static>, Error>;
fn exchange_code(
&self,
config: &OAuthConfig,
token: TokenRequest,
) -> Result<TokenResponse, Error>;
}
pub trait Callback: Send + Sync + 'static {
type Responder: Responder<'static>;
fn callback(&self, request: &Request<'_>, token: TokenResponse) -> Self::Responder;
}
impl<F, R> Callback for F
where
F: Fn(&Request<'_>, TokenResponse) -> R + Send + Sync + 'static,
R: Responder<'static>,
{
type Responder = R;
fn callback(&self, request: &Request<'_>, token: TokenResponse) -> Self::Responder {
(self)(request, token)
}
}
pub struct OAuth2<C> {
adapter: Box<dyn Adapter>,
callback: C,
config: OAuthConfig,
login_scopes: Vec<String>,
rng: SystemRandom,
}
impl<C: Callback> OAuth2<C> {
pub fn fairing<A: Adapter>(
adapter: A,
callback: C,
config_name: &str,
callback_uri: &str,
login: Option<(&str, Vec<String>)>,
) -> impl Fairing {
let config_name = config_name.to_string();
let callback_uri = callback_uri.to_string();
let mut login = login.map(|(lu, ls)| (lu.to_string(), ls));
AdHoc::on_attach("OAuth Init", move |rocket| {
let config = match OAuthConfig::from_config(rocket.config(), &config_name) {
Ok(c) => c,
Err(e) => {
log::error!("Invalid configuration: {:?}", e);
return Err(rocket);
}
};
let mut new_login = None;
if let Some((lu, ls)) = login.as_mut() {
let new_ls = std::mem::replace(ls, vec![]);
new_login = Some((lu.as_str(), new_ls));
};
Ok(rocket.attach(Self::custom(
adapter,
callback,
config,
&callback_uri,
new_login,
)))
})
}
pub fn custom<A: Adapter>(
adapter: A,
callback: C,
config: OAuthConfig,
callback_uri: &str,
login: Option<(&str, Vec<String>)>,
) -> impl Fairing {
let mut routes = Vec::new();
routes.push(Route::new(Method::Get, callback_uri, redirect_handler::<C>));
let mut login_scopes = vec![];
if let Some((uri, scopes)) = login {
routes.push(Route::new(Method::Get, uri, login_handler::<C>));
login_scopes = scopes;
}
let oauth2 = Self {
adapter: Box::new(adapter),
callback,
config,
login_scopes,
rng: SystemRandom::new(),
};
AdHoc::on_attach("OAuth Mount", |rocket| {
Ok(rocket.manage(oauth2).mount("/", routes))
})
}
pub fn get_redirect(
&self,
cookies: &mut Cookies<'_>,
scopes: &[&str],
) -> Result<Redirect, Error> {
let state = generate_state(&self.rng)?;
let uri = self
.adapter
.authorization_uri(&self.config, &state, scopes)?;
cookies.add_private(
Cookie::build(STATE_COOKIE_NAME, state)
.same_site(SameSite::Lax)
.finish(),
);
Ok(Redirect::to(uri))
}
pub fn refresh(&self, refresh_token: &str) -> Result<TokenResponse, Error> {
self.adapter.exchange_code(
&self.config,
TokenRequest::RefreshToken(refresh_token.to_string()),
)
}
fn handle<'r>(&self, request: &'r Request<'_>, _data: Data) -> handler::Outcome<'r> {
let query = request.uri().query().into_outcome(Status::BadRequest)?;
#[derive(FromForm)]
struct CallbackQuery {
code: String,
state: String,
scope: Option<String>,
}
let params = match CallbackQuery::from_form(&mut FormItems::from(query), false) {
Ok(p) => p,
Err(_) => return handler::Outcome::failure(Status::BadRequest),
};
{
let mut cookies = request.guard::<Cookies<'_>>().expect("request cookies");
match cookies.get_private(STATE_COOKIE_NAME) {
Some(ref cookie) if cookie.value() == params.state => {
cookies.remove(cookie.clone());
}
_ => return handler::Outcome::failure(Status::BadRequest),
}
}
let token = match self
.adapter
.exchange_code(&self.config, TokenRequest::AuthorizationCode(params.code))
{
Ok(mut token) => {
let data = token
.data
.as_object_mut()
.expect("data is guaranteed to be an Object");
if let (None, Some(scope)) = (data.get("scope"), params.scope) {
data.insert(String::from("scope"), Value::String(scope));
}
token
}
Err(e) => {
log::error!("Token exchange failed: {:?}", e);
return handler::Outcome::failure(Status::BadRequest);
}
};
let responder = self.callback.callback(request, token);
handler::Outcome::from(request, responder)
}
}
impl<C: fmt::Debug> fmt::Debug for OAuth2<C> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("OAuth2")
.field("adapter", &(..))
.field("callback", &self.callback)
.field("config", &self.config)
.field("login_scopes", &self.login_scopes)
.finish()
}
}
fn redirect_handler<'r, C: Callback>(request: &'r Request<'_>, data: Data) -> handler::Outcome<'r> {
let oauth = match request.guard::<State<'_, OAuth2<C>>>() {
Outcome::Success(oauth) => oauth,
Outcome::Failure(_) => return handler::Outcome::failure(Status::InternalServerError),
Outcome::Forward(()) => unreachable!(),
};
oauth.handle(request, data)
}
fn login_handler<'r, C: Callback>(request: &'r Request<'_>, _data: Data) -> handler::Outcome<'r> {
let oauth = match request.guard::<State<'_, OAuth2<C>>>() {
Outcome::Success(oauth) => oauth,
Outcome::Failure(_) => return handler::Outcome::failure(Status::InternalServerError),
Outcome::Forward(()) => unreachable!(),
};
let mut cookies = request.guard::<Cookies<'_>>().expect("request cookies");
let scopes: Vec<_> = oauth.login_scopes.iter().map(String::as_str).collect();
handler::Outcome::from(request, oauth.get_redirect(&mut cookies, &scopes))
}