#![warn(future_incompatible, nonstandard_style, missing_docs)]
#![allow(clippy::needless_doctest_main)]
mod config;
mod error;
#[cfg(feature = "hyper_rustls_adapter")]
mod hyper_rustls_adapter;
#[cfg(feature = "hyper_rustls_adapter")]
pub use hyper_rustls_adapter::HyperRustlsAdapter;
pub use self::config::*;
pub use self::error::*;
use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;
use log::{error, info, warn};
use rocket::fairing::{AdHoc, Fairing};
use rocket::form::{Form, FromForm};
use rocket::http::uri::Absolute;
use rocket::http::{Cookie, CookieJar, SameSite, Status};
use rocket::request::{self, FromRequest, Outcome, Request};
use rocket::response::Redirect;
use rocket::{Build, Ignite, Rocket, Sentinel};
use serde_json::Value;
const STATE_COOKIE_NAME: &str = "rocket_oauth2_state";
fn generate_state(rng: &mut impl rand::RngCore) -> Result<String, Error> {
let mut buf = [0; 16]; rng.try_fill_bytes(&mut buf).map_err(|_| {
Error::new_from(
ErrorKind::Other,
String::from("Failed to generate random data"),
)
})?;
Ok(base64::encode_config(&buf, base64::URL_SAFE_NO_PAD))
}
#[derive(Clone, PartialEq, Debug)]
pub enum TokenRequest {
AuthorizationCode(String),
RefreshToken(String),
}
#[derive(Clone, PartialEq, Debug)]
pub struct TokenResponse<K> {
data: Value,
_k: PhantomData<fn() -> K>,
}
impl<K> TokenResponse<K> {
pub fn cast<L>(self) -> TokenResponse<L> {
TokenResponse {
data: self.data,
_k: PhantomData,
}
}
pub fn as_value(&self) -> &Value {
&self.data
}
pub fn access_token(&self) -> &str {
self.data
.get("access_token")
.and_then(Value::as_str)
.expect("access_token required at construction")
}
pub fn token_type(&self) -> &str {
self.data
.get("token_type")
.and_then(Value::as_str)
.expect("token_type required at construction")
}
pub fn expires_in(&self) -> Option<i64> {
self.data.get("expires_in").and_then(Value::as_i64)
}
pub fn refresh_token(&self) -> Option<&str> {
self.data.get("refresh_token").and_then(Value::as_str)
}
pub fn scope(&self) -> Option<&str> {
self.data.get("scope").and_then(Value::as_str)
}
}
impl std::convert::TryFrom<Value> for TokenResponse<()> {
type Error = Error;
fn try_from(data: Value) -> Result<Self, Error> {
if !data.is_object() {
return Err(Error::new_from(
ErrorKind::ExchangeFailure,
String::from("TokenResponse data was not an object"),
));
}
match data.get("access_token") {
Some(val) if val.is_string() => (),
_ => {
return Err(Error::new_from(
ErrorKind::ExchangeFailure,
String::from("TokenResponse access_token was missing or not a string"),
))
}
}
match data.get("token_type") {
Some(val) if val.is_string() => (),
_ => {
return Err(Error::new_from(
ErrorKind::ExchangeFailure,
String::from("TokenResponse token_type was missing or not a string"),
))
}
}
Ok(Self {
data,
_k: PhantomData,
})
}
}
#[rocket::async_trait]
impl<'r, K: 'static> FromRequest<'r> for TokenResponse<K> {
type Error = Error;
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
let oauth2 = request
.rocket()
.state::<Arc<Shared<K>>>()
.expect("OAuth2 fairing was not attached for this key type!");
let query = match request.uri().query() {
Some(q) => q,
None => {
return Outcome::Failure((
Status::BadRequest,
Error::new_from(
ErrorKind::ExchangeFailure,
"Missing query string in request",
),
))
}
};
#[derive(FromForm)]
struct CallbackQuery {
code: String,
state: String,
scope: Option<String>,
}
let params = match Form::<CallbackQuery>::parse_encoded(&query) {
Ok(p) => p,
Err(e) => {
warn!("Failed to parse OAuth2 query string: {:?}", e);
return Outcome::Failure((
Status::BadRequest,
Error::new_from(ErrorKind::ExchangeFailure, format!("{:?}", e)),
));
}
};
{
let cookies = request
.guard::<&CookieJar<'_>>()
.await
.expect("request cookies");
match cookies.get_private(STATE_COOKIE_NAME) {
Some(ref cookie) if cookie.value() == params.state => {
cookies.remove(cookie.clone());
}
other => {
if other.is_some() {
warn!("The OAuth2 state returned from the server did not match the stored state.");
} else {
error!("The OAuth2 state cookie was missing. It may have been blocked by the client?");
}
return Outcome::Failure((
Status::BadRequest,
Error::new_from(
ErrorKind::ExchangeFailure,
"The OAuth2 state returned from the server did match the stored state.",
),
));
}
}
}
match oauth2
.adapter
.exchange_code(&oauth2.config, TokenRequest::AuthorizationCode(params.code))
.await
{
Ok(mut token) => {
let data = token
.data
.as_object_mut()
.expect("data is guaranteed to be an Object");
if let (None, Some(scope)) = (data.get("scope"), params.scope) {
data.insert(String::from("scope"), Value::String(scope));
}
Outcome::Success(token.cast())
}
Err(e) => {
warn!("OAuth2 token exchange failed: {}", e);
Outcome::Failure((Status::BadRequest, e))
}
}
}
}
fn sentinel_abort<K: 'static>(rocket: &Rocket<Ignite>, wrapper: &str) -> bool {
if rocket.state::<Arc<Shared<K>>>().is_some() {
return false;
}
let type_name = std::any::type_name::<K>();
error!("{}<{}> was used in a mounted route without attaching a matching fairing", wrapper, type_name);
info!("attach either OAuth2::<{0}>::fairing() or OAuth2::<{0}>::custom()", type_name);
true
}
impl<K: 'static> Sentinel for TokenResponse<K> {
fn abort(rocket: &Rocket<Ignite>) -> bool {
sentinel_abort::<K>(rocket, "TokenResponse")
}
}
#[async_trait::async_trait]
pub trait Adapter: Send + Sync + 'static {
fn authorization_uri(
&self,
config: &OAuthConfig,
state: &str,
scopes: &[&str],
extra_params: &[(&str, &str)],
) -> Result<Absolute<'static>, Error>;
async fn exchange_code(
&self,
config: &OAuthConfig,
token: TokenRequest,
) -> Result<TokenResponse<()>, Error>;
}
struct Shared<K> {
adapter: Box<dyn Adapter>,
config: OAuthConfig,
_k: PhantomData<fn() -> TokenResponse<K>>,
}
pub struct OAuth2<K>(Arc<Shared<K>>);
impl<K: 'static> OAuth2<K> {
#[cfg(feature = "hyper_rustls_adapter")]
pub fn fairing(config_name: impl AsRef<str> + Send + 'static) -> impl Fairing {
AdHoc::try_on_ignite("rocket_oauth2::fairing", |rocket| async move {
let config = match OAuthConfig::from_figment(rocket.figment(), config_name.as_ref()) {
Ok(c) => c,
Err(e) => {
log::error!("Invalid configuration: {:?}", e);
return Err(rocket);
}
};
Ok(Self::_init(rocket, hyper_rustls_adapter::HyperRustlsAdapter::default(), config))
})
}
fn _init<A: Adapter>(rocket: Rocket<Build>, adapter: A, config: OAuthConfig) -> Rocket<Build> {
rocket.manage(Arc::new(Shared::<K> {
adapter: Box::new(adapter),
config,
_k: PhantomData,
}))
}
pub fn custom<A: Adapter>(adapter: A, config: OAuthConfig) -> impl Fairing {
AdHoc::on_ignite("rocket_oauth2::custom", |rocket| async {
Self::_init(rocket, adapter, config)
})
}
pub fn get_redirect(
&self,
cookies: &CookieJar<'_>,
scopes: &[&str],
) -> Result<Redirect, Error> {
self.get_redirect_extras(cookies, scopes, &[])
}
pub fn get_redirect_extras(
&self,
cookies: &CookieJar<'_>,
scopes: &[&str],
extras: &[(&str, &str)],
) -> Result<Redirect, Error> {
let state = generate_state(&mut rand::thread_rng())?;
let uri = self
.0
.adapter
.authorization_uri(&self.0.config, &state, scopes, extras)?;
cookies.add_private(
Cookie::build(STATE_COOKIE_NAME, state)
.same_site(SameSite::Lax)
.finish(),
);
Ok(Redirect::to(uri))
}
pub async fn refresh(&self, refresh_token: &str) -> Result<TokenResponse<K>, Error> {
self.0
.adapter
.exchange_code(
&self.0.config,
TokenRequest::RefreshToken(refresh_token.to_string()),
)
.await
.map(TokenResponse::cast)
}
}
#[rocket::async_trait]
impl<'r, K: 'static> FromRequest<'r> for OAuth2<K> {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
Outcome::Success(OAuth2(
request
.rocket()
.state::<Arc<Shared<K>>>()
.expect("OAuth2 fairing was not attached for this key type!")
.clone(),
))
}
}
impl<K: 'static> Sentinel for OAuth2<K> {
fn abort(rocket: &Rocket<Ignite>) -> bool {
sentinel_abort::<K>(rocket, "OAuth2")
}
}
impl<C: fmt::Debug> fmt::Debug for OAuth2<C> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("OAuth2")
.field("adapter", &(..))
.field("config", &self.0.config)
.finish()
}
}