pub mod fields;
use std::borrow::Cow;
use std::fmt::{Debug, Display};
use async_trait::async_trait;
use bytes::Bytes;
pub use cot_macros::Form;
use thiserror::Error;
use crate::headers::FORM_CONTENT_TYPE;
use crate::request;
use crate::request::{Request, RequestExt};
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum FormError {
#[error("Request error: {error}")]
RequestError {
#[from]
error: Box<crate::Error>,
},
}
#[must_use]
#[derive(Debug, Clone)]
pub enum FormResult<T: Form> {
Ok(T),
ValidationError(T::Context),
}
impl<T: Form> FormResult<T> {
#[track_caller]
pub fn unwrap(self) -> T {
match self {
Self::Ok(form) => form,
Self::ValidationError(context) => panic!("Form validation failed: {context:?}"),
}
}
}
#[derive(Debug, Error, PartialEq, Eq)]
#[non_exhaustive]
#[error("{message}")]
pub enum FormFieldValidationError {
#[error("This field is required.")]
Required,
#[error("This exceeds the maximum length of {max_length}.")]
MaximumLengthExceeded {
max_length: u32,
},
#[error("This is below the minimum length of {min_length}.")]
MinimumLengthNotMet {
min_length: u32,
},
#[error("This is below the minimum value of {min_value}.")]
MinimumValueNotMet {
min_value: String,
},
#[error("This exceeds the maximum value of {max_value}.")]
MaximumValueExceeded {
max_value: String,
},
#[error("This field must be checked.")]
BooleanRequiredToBeTrue,
#[error("Value is not valid for this field.")]
InvalidValue(String),
#[error("{0}")]
Custom(Cow<'static, str>),
}
impl FormFieldValidationError {
#[must_use]
pub fn invalid_value<T: Into<String>>(value: T) -> Self {
Self::InvalidValue(value.into())
}
#[must_use]
pub fn maximum_length_exceeded(max_length: u32) -> Self {
Self::MaximumLengthExceeded { max_length }
}
#[must_use]
pub fn minimum_length_not_met(min_length: u32) -> Self {
FormFieldValidationError::MinimumLengthNotMet { min_length }
}
#[must_use]
pub fn minimum_value_not_met<T: Display>(min_value: T) -> Self {
FormFieldValidationError::MinimumValueNotMet {
min_value: min_value.to_string(),
}
}
#[must_use]
pub fn maximum_value_exceeded<T: Display>(max_value: T) -> Self {
FormFieldValidationError::MaximumValueExceeded {
max_value: max_value.to_string(),
}
}
#[must_use]
pub const fn from_string(message: String) -> Self {
Self::Custom(Cow::Owned(message))
}
#[must_use]
pub const fn from_static(message: &'static str) -> Self {
Self::Custom(Cow::Borrowed(message))
}
}
impl From<email_address::Error> for FormFieldValidationError {
fn from(error: email_address::Error) -> Self {
FormFieldValidationError::from_string(error.to_string())
}
}
#[derive(Debug)]
pub enum FormErrorTarget<'a> {
Field(&'a str),
Form,
}
#[async_trait]
#[diagnostic::on_unimplemented(
message = "`{Self}` does not implement the `Form` trait",
label = "`{Self}` is not a form",
note = "add #[derive(cot::form::Form)] to the struct to automatically derive the trait"
)]
pub trait Form: Sized {
type Context: FormContext;
async fn from_request(request: &mut Request) -> Result<FormResult<Self>, FormError>;
fn to_context(&self) -> Self::Context;
async fn build_context(request: &mut Request) -> Result<Self::Context, FormError> {
let form_data = form_data(request)
.await
.map_err(|error| FormError::RequestError {
error: Box::new(error),
})?;
let mut context = Self::Context::new();
for (field_id, value) in request::query_pairs(&form_data) {
let field_id = field_id.as_ref();
if let Err(err) = context.set_value(field_id, value) {
context.add_error(FormErrorTarget::Field(field_id), err);
}
}
Ok(context)
}
}
pub async fn form_data(request: &mut Request) -> crate::Result<Bytes> {
if request.method() == http::Method::GET || request.method() == http::Method::HEAD {
if let Some(query) = request.uri().query() {
return Ok(Bytes::copy_from_slice(query.as_bytes()));
}
Ok(Bytes::new())
} else {
request.expect_content_type(FORM_CONTENT_TYPE)?;
let body = std::mem::take(request.body_mut());
let bytes = body.into_bytes().await?;
Ok(bytes)
}
}
pub trait FormContext: Debug {
fn new() -> Self
where
Self: Sized;
fn fields(&self) -> Box<dyn DoubleEndedIterator<Item = &dyn DynFormField> + '_>;
fn set_value(
&mut self,
field_id: &str,
value: Cow<'_, str>,
) -> Result<(), FormFieldValidationError>;
fn add_error(&mut self, target: FormErrorTarget<'_>, error: FormFieldValidationError) {
self.errors_for_mut(target).push(error);
}
fn errors_for(&self, target: FormErrorTarget<'_>) -> &[FormFieldValidationError];
fn errors_for_mut(&mut self, target: FormErrorTarget<'_>)
-> &mut Vec<FormFieldValidationError>;
fn has_errors(&self) -> bool;
}
#[derive(Debug)]
pub struct FormFieldOptions {
pub id: String,
pub name: String,
pub required: bool,
}
pub trait FormField: Display {
type CustomOptions: Default;
fn with_options(options: FormFieldOptions, custom_options: Self::CustomOptions) -> Self
where
Self: Sized;
fn options(&self) -> &FormFieldOptions;
fn id(&self) -> &str {
&self.options().id
}
fn name(&self) -> &str {
&self.options().name
}
fn value(&self) -> Option<&str>;
fn set_value(&mut self, value: Cow<'_, str>);
}
pub trait DynFormField: Display {
fn dyn_options(&self) -> &FormFieldOptions;
fn dyn_id(&self) -> &str;
fn dyn_value(&self) -> Option<&str>;
fn dyn_set_value(&mut self, value: Cow<'_, str>);
}
impl<T: FormField> DynFormField for T {
fn dyn_options(&self) -> &FormFieldOptions {
FormField::options(self)
}
fn dyn_id(&self) -> &str {
FormField::id(self)
}
fn dyn_value(&self) -> Option<&str> {
FormField::value(self)
}
fn dyn_set_value(&mut self, value: Cow<'_, str>) {
FormField::set_value(self, value);
}
}
pub trait AsFormField {
type Type: FormField;
fn new_field(
options: FormFieldOptions,
custom_options: <Self::Type as FormField>::CustomOptions,
) -> Self::Type {
Self::Type::with_options(options, custom_options)
}
fn clean_value(field: &Self::Type) -> Result<Self, FormFieldValidationError>
where
Self: Sized;
fn to_field_value(&self) -> String;
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use crate::Body;
use crate::form::form_data;
use crate::headers::FORM_CONTENT_TYPE;
#[cot::test]
async fn form_data_extract_get_empty() {
let mut request = http::Request::builder()
.method(http::Method::GET)
.uri("https://example.com")
.body(Body::empty())
.unwrap();
let bytes = form_data(&mut request).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b""));
}
#[cot::test]
async fn form_data_extract_get() {
let mut request = http::Request::builder()
.method(http::Method::GET)
.uri("https://example.com/?hello=world")
.body(Body::empty())
.unwrap();
let bytes = form_data(&mut request).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"hello=world"));
}
#[cot::test]
async fn form_data_extract() {
let mut request = http::Request::builder()
.method(http::Method::POST)
.header(http::header::CONTENT_TYPE, FORM_CONTENT_TYPE)
.body(Body::fixed("hello=world"))
.unwrap();
let bytes = form_data(&mut request).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"hello=world"));
}
}