use std::ops::{Deref, DerefMut};
use std::{error, fmt, io};
use rocket::data::{Data, FromData, Limits, Outcome};
use rocket::error_;
use rocket::form::prelude as form;
use rocket::http::uri::fmt::{Formatter as UriFormatter, FromUriParam, Query, UriDisplay};
use rocket::http::{ContentType, Status};
use rocket::request::{local_cache, Request};
use rocket::response::{self, content, Responder};
use serde::{Deserialize, Serialize};
#[repr(transparent)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct UrlEncoded<T>(pub T);
#[derive(Debug)]
pub enum Error<'a> {
Io(io::Error),
Parse(&'a str, ::serde_urlencoded::de::Error),
}
impl<'a> fmt::Display for Error<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Io(err) => write!(f, "i/o error: {}", err),
Self::Parse(_, err) => write!(f, "parse error: {}", err),
}
}
}
impl<'a> error::Error for Error<'a> {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match self {
Self::Io(err) => Some(err),
Self::Parse(_, err) => Some(err),
}
}
}
impl<T> UrlEncoded<T> {
#[inline(always)]
pub fn into_inner(self) -> T {
self.0
}
}
impl<'r, T: Deserialize<'r>> UrlEncoded<T> {
fn from_str(s: &'r str) -> Result<Self, Error<'r>> {
::serde_urlencoded::from_str(s)
.map(UrlEncoded)
.map_err(|e| Error::Parse(s, e))
}
async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Result<Self, Error<'r>> {
let limit = req.limits().get("form").unwrap_or(Limits::FORM);
let string = match data.open(limit).into_string().await {
Ok(s) if s.is_complete() => s.into_inner(),
Ok(_) => {
let eof = io::ErrorKind::UnexpectedEof;
return Err(Error::Io(io::Error::new(eof, "data limit exceeded")));
}
Err(e) => return Err(Error::Io(e)),
};
Self::from_str(local_cache!(req, string))
}
}
#[rocket::async_trait]
impl<'r, T: Deserialize<'r>> FromData<'r> for UrlEncoded<T> {
type Error = Error<'r>;
async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> {
match Self::from_data(req, data).await {
Ok(value) => Outcome::Success(value),
Err(Error::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => {
Outcome::Failure((Status::PayloadTooLarge, Error::Io(e)))
}
Err(Error::Parse(s, e)) => {
error_!("{:?}", e);
Outcome::Failure((Status::UnprocessableEntity, Error::Parse(s, e)))
}
Err(e) => Outcome::Failure((Status::BadRequest, e)),
}
}
}
impl<'r, T: Serialize> Responder<'r, 'static> for UrlEncoded<T> {
fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> {
let string = ::serde_urlencoded::to_string(&self.0).map_err(|e| {
error_!("UrlEncoding failed to serialize: {:?}", e);
Status::InternalServerError
})?;
content::Custom(ContentType::Form, string).respond_to(req)
}
}
impl<T: Serialize> UriDisplay<Query> for UrlEncoded<T> {
fn fmt(&self, f: &mut UriFormatter<'_, Query>) -> fmt::Result {
let string = ::serde_urlencoded::to_string(&self.0).map_err(|_| fmt::Error)?;
f.write_value(&string)
}
}
macro_rules! impl_from_uri_param_from_inner_type {
($($lt:lifetime)?, $T:ty) => (
impl<$($lt,)? T: Serialize> FromUriParam<Query, $T> for UrlEncoded<T> {
type Target = UrlEncoded<$T>;
#[inline(always)]
fn from_uri_param(param: $T) -> Self::Target {
UrlEncoded(param)
}
}
)
}
impl_from_uri_param_from_inner_type!(, T);
impl_from_uri_param_from_inner_type!('a, &'a T);
impl_from_uri_param_from_inner_type!('a, &'a mut T);
rocket::http::impl_from_uri_param_identity!([Query] (T: Serialize) UrlEncoded<T>);
impl<T> From<T> for UrlEncoded<T> {
fn from(value: T) -> Self {
UrlEncoded(value)
}
}
impl<T> Deref for UrlEncoded<T> {
type Target = T;
#[inline(always)]
fn deref(&self) -> &T {
&self.0
}
}
impl<T> DerefMut for UrlEncoded<T> {
#[inline(always)]
fn deref_mut(&mut self) -> &mut T {
&mut self.0
}
}
impl From<Error<'_>> for form::Error<'_> {
fn from(e: Error<'_>) -> Self {
match e {
Error::Io(e) => e.into(),
Error::Parse(_, e) => form::Error::custom(e),
}
}
}
#[rocket::async_trait]
impl<'v, T: Deserialize<'v> + Send> form::FromFormField<'v> for UrlEncoded<T> {
fn from_value(field: form::ValueField<'v>) -> Result<Self, form::Errors<'v>> {
Ok(Self::from_str(field.value)?)
}
async fn from_data(f: form::DataField<'v, '_>) -> Result<Self, form::Errors<'v>> {
Ok(Self::from_data(f.request, f.data).await?)
}
}
#[inline(always)]
pub fn from_slice<'a, T>(slice: &'a [u8]) -> Result<T, ::serde_urlencoded::de::Error>
where
T: Deserialize<'a>,
{
::serde_urlencoded::from_bytes(slice)
}
#[inline(always)]
pub fn from_str<'a, T>(string: &'a str) -> Result<T, ::serde_urlencoded::de::Error>
where
T: Deserialize<'a>,
{
::serde_urlencoded::from_str(string)
}