use std::error::Error as StdError;
use std::fmt;
use rocket::request::{self, FromRequest, Request};
use rocket::Outcome;
use rocket::http::Status;
use Digest;
use error::Error as DigestError;
#[derive(Clone, Debug)]
pub enum Error {
Digest(DigestError),
Header(&'static str),
}
impl Error {
pub fn into_rocket_failure(self) -> (Status, Self) {
match self {
Error::Digest(digest) => (Status::BadRequest, Error::Digest(digest)),
Error::Header(which) => (Status::BadRequest, Error::Header(which)),
}
}
}
impl From<DigestError> for Error {
fn from(e: DigestError) -> Self {
Error::Digest(e)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Error::Digest(ref e) => write!(f, "Digest: {}", e),
_ => write!(f, "{}", self.description()),
}
}
}
impl StdError for Error {
fn description(&self) -> &str {
match *self {
Error::Digest(ref e) => e.description(),
Error::Header(_) => "Expected exactly one header",
}
}
fn cause(&self) -> Option<&StdError> {
match *self {
Error::Digest(ref e) => Some(e),
_ => None,
}
}
}
pub struct DigestHeader(pub Digest);
impl DigestHeader {
pub fn new(digest: Digest) -> Self {
DigestHeader(digest)
}
pub fn into_inner(self) -> Digest {
self.0
}
}
impl<'a, 'r> FromRequest<'a, 'r> for DigestHeader {
type Error = Error;
fn from_request(request: &'a Request<'r>) -> request::Outcome<DigestHeader, Self::Error> {
let res = request
.headers()
.get_one("Digest")
.ok_or(Error::Header("Digest"))
.and_then(|raw_header| raw_header.parse::<Digest>().map_err(Error::from))
.map(DigestHeader);
match res {
Ok(success) => Outcome::Success(success),
Err(error) => Outcome::Failure(error.into_rocket_failure()),
}
}
}
pub struct ContentLengthHeader(pub usize);
impl<'a, 'r> FromRequest<'a, 'r> for ContentLengthHeader {
type Error = Error;
fn from_request(
request: &'a Request<'r>,
) -> request::Outcome<ContentLengthHeader, Self::Error> {
let res = request
.headers()
.get_one("Content-Length")
.ok_or(Error::Header("Content-Length"))
.and_then(|raw_header| {
raw_header
.parse::<usize>()
.map_err(|_| Error::Header("Content-Length"))
})
.map(ContentLengthHeader);
match res {
Ok(success) => Outcome::Success(success),
Err(error) => Outcome::Failure(error.into_rocket_failure()),
}
}
}