#![warn(missing_docs)]
mod failure;
use std::io::Cursor;
use std::marker::PhantomData;
use rocket::{Data, Request, Response};
use rocket::http::{ContentType, Status};
use rocket::http::hyper::header;
use rocket::request::FromRequest;
use rocket::response::{self, Responder};
use rocket::outcome::Outcome;
use oxide_auth::endpoint::{NormalizedParameter, WebRequest, WebResponse};
use oxide_auth::frontends::dev::*;
pub use oxide_auth::frontends::simple::endpoint::Generic;
pub use oxide_auth::frontends::simple::request::NoError;
pub use self::failure::OAuthFailure;
pub struct OAuthRequest<'r> {
auth: Option<String>,
query: Result<NormalizedParameter, WebError>,
body: Result<Option<NormalizedParameter>, WebError>,
lifetime: PhantomData<&'r ()>,
}
#[derive(Debug)]
pub struct OAuthResponse<'r>(Response<'r>);
#[derive(Clone, Copy, Debug)]
pub enum WebError {
Encoding,
BodyNeeded,
NotAForm,
}
impl<'r> OAuthRequest<'r> {
pub fn new<'a>(request: &'a Request<'r>) -> Self {
let query = request.uri().query().unwrap_or("");
let query = match serde_urlencoded::from_str(query) {
Ok(query) => Ok(query),
Err(_) => Err(WebError::Encoding),
};
let body = match request.content_type() {
Some(ct) if *ct == ContentType::Form => Ok(None),
_ => Err(WebError::NotAForm),
};
let mut all_auth = request.headers().get("Authorization");
let optional = all_auth.next();
let auth = if let Some(_) = all_auth.next() {
None
} else {
optional.map(str::to_owned)
};
OAuthRequest {
auth,
query,
body,
lifetime: PhantomData,
}
}
pub fn add_body(&mut self, data: Data) {
if let Ok(None) = self.body {
match serde_urlencoded::from_reader(data.open()) {
Ok(query) => self.body = Ok(Some(query)),
Err(_) => self.body = Err(WebError::Encoding),
}
}
}
}
impl<'r> OAuthResponse<'r> {
pub fn new() -> Self {
Default::default()
}
pub fn from_response(response: Response<'r>) -> Self {
OAuthResponse(response)
}
}
impl<'r> WebRequest for OAuthRequest<'r> {
type Error = WebError;
type Response = OAuthResponse<'r>;
fn query(&mut self) -> Result<Cow<dyn QueryParameter + 'static>, Self::Error> {
match self.query.as_ref() {
Ok(query) => Ok(Cow::Borrowed(query as &dyn QueryParameter)),
Err(err) => Err(*err),
}
}
fn urlbody(&mut self) -> Result<Cow<dyn QueryParameter + 'static>, Self::Error> {
match self.body.as_ref() {
Ok(None) => Err(WebError::BodyNeeded),
Ok(Some(body)) => Ok(Cow::Borrowed(body as &dyn QueryParameter)),
Err(err) => Err(*err),
}
}
fn authheader(&mut self) -> Result<Option<Cow<str>>, Self::Error> {
Ok(self.auth.as_ref().map(String::as_str).map(Cow::Borrowed))
}
}
impl<'r> WebResponse for OAuthResponse<'r> {
type Error = WebError;
fn ok(&mut self) -> Result<(), Self::Error> {
self.0.set_status(Status::Ok);
Ok(())
}
fn redirect(&mut self, url: Url) -> Result<(), Self::Error> {
self.0.set_status(Status::Found);
self.0.set_header(header::Location(url.into()));
Ok(())
}
fn client_error(&mut self) -> Result<(), Self::Error> {
self.0.set_status(Status::BadRequest);
Ok(())
}
fn unauthorized(&mut self, kind: &str) -> Result<(), Self::Error> {
self.0.set_status(Status::Unauthorized);
self.0.set_raw_header("WWW-Authenticate", kind.to_owned());
Ok(())
}
fn body_text(&mut self, text: &str) -> Result<(), Self::Error> {
self.0.set_sized_body(Cursor::new(text.to_owned()));
self.0.set_header(ContentType::Plain);
Ok(())
}
fn body_json(&mut self, data: &str) -> Result<(), Self::Error> {
self.0.set_sized_body(Cursor::new(data.to_owned()));
self.0.set_header(ContentType::JSON);
Ok(())
}
}
impl<'a, 'r> FromRequest<'a, 'r> for OAuthRequest<'r> {
type Error = NoError;
fn from_request(request: &'a Request<'r>) -> Outcome<Self, (Status, Self::Error), ()> {
Outcome::Success(Self::new(request))
}
}
impl<'r> Responder<'r> for OAuthResponse<'r> {
fn respond_to(self, _: &Request) -> response::Result<'r> {
Ok(self.0)
}
}
impl<'r> Responder<'r> for WebError {
fn respond_to(self, _: &Request) -> response::Result<'r> {
match self {
WebError::Encoding => Err(Status::BadRequest),
WebError::NotAForm => Err(Status::BadRequest),
WebError::BodyNeeded => Err(Status::InternalServerError),
}
}
}
impl<'r> Default for OAuthResponse<'r> {
fn default() -> Self {
OAuthResponse(Default::default())
}
}
impl<'r> From<Response<'r>> for OAuthResponse<'r> {
fn from(r: Response<'r>) -> Self {
OAuthResponse::from_response(r)
}
}
impl<'r> Into<Response<'r>> for OAuthResponse<'r> {
fn into(self) -> Response<'r> {
self.0
}
}