mod field_value;
pub mod fields;
use std::borrow::Cow;
use std::fmt::Display;
use async_trait::async_trait;
use bytes::Bytes;
use chrono::NaiveDateTime;
use chrono_tz::Tz;
use cot_core::error::impl_into_cot_error;
use cot_core::headers::{MULTIPART_FORM_CONTENT_TYPE, URLENCODED_FORM_CONTENT_TYPE};
pub use cot_macros::Form;
use derive_more::with_trait::Debug;
pub use field_value::{FormFieldValue, FormFieldValueError};
use http_body_util::BodyExt;
use thiserror::Error;
use crate::request::{Request, RequestExt};
const ERROR_PREFIX: &str = "failed to process a form:";
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum FormError {
#[error("{ERROR_PREFIX} request error: {error}")]
#[non_exhaustive]
RequestError {
#[from]
error: Box<crate::Error>,
},
#[error("{ERROR_PREFIX} multipart error: {error}")]
#[non_exhaustive]
MultipartError {
#[from]
error: FormFieldValueError,
},
}
impl_into_cot_error!(FormError, BAD_REQUEST);
#[must_use]
#[derive(Debug, Clone)]
pub enum FormResult<T: Form> {
Ok(T),
ValidationError(T::Context),
}
impl<T: Form> FormResult<T> {
#[track_caller]
pub fn unwrap(self) -> T {
match self {
Self::Ok(form) => form,
Self::ValidationError(context) => panic!("Form validation failed: {context:?}"),
}
}
}
#[derive(Debug, Error, PartialEq, Eq)]
#[non_exhaustive]
#[error("{message}")]
pub enum FormFieldValidationError {
#[error("This field is required.")]
Required,
#[error("This exceeds the maximum length of {max_length}.")]
MaximumLengthExceeded {
max_length: u32,
},
#[error("This is below the minimum length of {min_length}.")]
MinimumLengthNotMet {
min_length: u32,
},
#[error("This is below the minimum value of {min_value}.")]
MinimumValueNotMet {
min_value: String,
},
#[error("This exceeds the maximum value of {max_value}.")]
MaximumValueExceeded {
max_value: String,
},
#[error("The datetime value `{datetime}` is ambiguous.")]
AmbiguousDateTime {
datetime: NaiveDateTime,
},
#[error("Local datetime {datetime} does not exist for the specified timezone {timezone}.")]
NonExistentLocalDateTime {
datetime: NaiveDateTime,
timezone: Tz,
},
#[error("This field must be checked.")]
BooleanRequiredToBeTrue,
#[error("Value is not valid for this field.")]
InvalidValue(String),
#[error("Error getting field value: {0}")]
FormFieldValueError(#[from] FormFieldValueError),
#[error("{0}")]
Custom(Cow<'static, str>),
}
impl FormFieldValidationError {
#[must_use]
pub fn invalid_value<T: Into<String>>(value: T) -> Self {
Self::InvalidValue(value.into())
}
#[must_use]
pub fn maximum_length_exceeded(max_length: u32) -> Self {
Self::MaximumLengthExceeded { max_length }
}
#[must_use]
pub fn minimum_length_not_met(min_length: u32) -> Self {
FormFieldValidationError::MinimumLengthNotMet { min_length }
}
#[must_use]
pub fn minimum_value_not_met<T: Display>(min_value: T) -> Self {
FormFieldValidationError::MinimumValueNotMet {
min_value: min_value.to_string(),
}
}
#[must_use]
pub fn maximum_value_exceeded<T: Display>(max_value: T) -> Self {
FormFieldValidationError::MaximumValueExceeded {
max_value: max_value.to_string(),
}
}
#[must_use]
pub fn ambiguous_datetime(datetime: NaiveDateTime) -> Self {
FormFieldValidationError::AmbiguousDateTime { datetime }
}
#[must_use]
pub fn non_existent_local_datetime(datetime: NaiveDateTime, timezone: Tz) -> Self {
FormFieldValidationError::NonExistentLocalDateTime { datetime, timezone }
}
#[must_use]
pub const fn from_string(message: String) -> Self {
Self::Custom(Cow::Owned(message))
}
#[must_use]
pub const fn from_static(message: &'static str) -> Self {
Self::Custom(Cow::Borrowed(message))
}
}
#[derive(Debug)]
pub enum FormErrorTarget<'a> {
Field(&'a str),
Form,
}
#[async_trait]
#[diagnostic::on_unimplemented(
message = "`{Self}` does not implement the `Form` trait",
label = "`{Self}` is not a form",
note = "add #[derive(cot::form::Form)] to the struct to automatically derive the trait"
)]
pub trait Form: Sized {
type Context: FormContext + Send;
async fn from_request(request: &mut Request) -> Result<FormResult<Self>, FormError>;
async fn to_context(&self) -> Self::Context;
async fn build_context(request: &mut Request) -> Result<Self::Context, FormError> {
let mut context = Self::Context::new();
let mut form_data = form_data(request).await?;
while let Some((field_id, value)) = form_data.next_value().await? {
if let Err(err) = context.set_value(&field_id, value).await {
context.add_error(FormErrorTarget::Field(&field_id), err);
}
}
Ok(context)
}
}
async fn form_data(request: &mut Request) -> Result<FormData<'_>, FormError> {
let form_data = if content_type_str(request).starts_with(MULTIPART_FORM_CONTENT_TYPE) {
let multipart = multipart_form_data(request)?;
FormData::Multipart { inner: multipart }
} else {
let form_data_bytes = urlencoded_form_data(request).await?;
FormData::new_urlencoded(form_data_bytes)
};
Ok(form_data)
}
fn multipart_form_data(request: &mut Request) -> Result<multer::Multipart<'_>, FormError> {
let content_type = content_type_str(request);
let boundary =
multer::parse_boundary(content_type).map_err(FormFieldValueError::from_multer)?;
let body = std::mem::take(request.body_mut());
let multipart = multer::Multipart::new(body.into_data_stream(), boundary);
Ok(multipart)
}
async fn urlencoded_form_data(request: &mut Request) -> Result<Bytes, FormError> {
let result = if request.method() == http::Method::GET || request.method() == http::Method::HEAD
{
if let Some(query) = request.uri().query() {
Bytes::copy_from_slice(query.as_bytes())
} else {
Bytes::new()
}
} else if content_type_str(request) == URLENCODED_FORM_CONTENT_TYPE {
let body = std::mem::take(request.body_mut());
body.into_bytes()
.await
.map_err(|e| FormError::RequestError { error: Box::new(e) })?
} else {
return Err(FormError::RequestError {
error: Box::new(crate::Error::from(ExpectedForm)),
});
};
Ok(result)
}
#[derive(Debug, Error)]
#[error(
"request does not contain a form (expected a POST request with \
the `application/x-www-form-urlencoded` or `multipart/form-data` content type, \
or a GET or HEAD request)"
)]
struct ExpectedForm;
impl_into_cot_error!(ExpectedForm, BAD_REQUEST);
fn content_type_str(request: &mut Request) -> String {
request
.content_type()
.map_or("".into(), |value| String::from_utf8_lossy(value.as_bytes()))
.into_owned()
}
#[derive(Debug)]
enum FormData<'a> {
Form {
#[debug("..")]
inner: form_urlencoded::Parse<'a>,
_data: Bytes,
},
Multipart {
inner: multer::Multipart<'a>,
},
}
impl<'a> FormData<'a> {
fn new_urlencoded(data: Bytes) -> Self {
#[expect(unsafe_code)]
let slice = unsafe {
std::slice::from_raw_parts(data.as_ptr(), data.len())
};
FormData::Form {
inner: form_urlencoded::parse(slice),
_data: data,
}
}
async fn next_value(
&mut self,
) -> Result<Option<(String, FormFieldValue<'a>)>, FormFieldValueError> {
match self {
FormData::Form { inner, .. } => Ok(inner
.next()
.map(|(key, value)| (key.into_owned(), FormFieldValue::new_text(value)))),
FormData::Multipart { inner } => {
let next_field = inner.next_field().await;
match next_field {
Ok(Some(field)) => {
let name = field
.name()
.ok_or_else(FormFieldValueError::no_name)?
.to_owned();
let value = FormFieldValue::new_multipart(field);
Ok(Some((name, value)))
}
Ok(None) => Ok(None),
Err(err) => Err(FormFieldValueError::from_multer(err)),
}
}
}
}
}
#[async_trait]
pub trait FormContext: Debug {
fn new() -> Self
where
Self: Sized;
fn fields(&self) -> Box<dyn DoubleEndedIterator<Item = &dyn DynFormField> + '_>;
async fn set_value(
&mut self,
field_id: &str,
value: FormFieldValue<'_>,
) -> Result<(), FormFieldValidationError>;
fn add_error(&mut self, target: FormErrorTarget<'_>, error: FormFieldValidationError) {
self.errors_for_mut(target).push(error);
}
fn errors_for(&self, target: FormErrorTarget<'_>) -> &[FormFieldValidationError];
fn errors_for_mut(&mut self, target: FormErrorTarget<'_>)
-> &mut Vec<FormFieldValidationError>;
fn has_errors(&self) -> bool;
}
#[derive(Debug)]
pub struct FormFieldOptions {
pub id: String,
pub name: String,
pub required: bool,
}
pub trait FormField: Display {
type CustomOptions: Default;
fn with_options(options: FormFieldOptions, custom_options: Self::CustomOptions) -> Self
where
Self: Sized;
fn options(&self) -> &FormFieldOptions;
fn id(&self) -> &str {
&self.options().id
}
fn name(&self) -> &str {
&self.options().name
}
fn value(&self) -> Option<&str>;
fn set_value(
&mut self,
field: FormFieldValue<'_>,
) -> impl Future<Output = Result<(), FormFieldValueError>> + Send;
}
#[async_trait]
pub trait DynFormField: Display {
fn dyn_options(&self) -> &FormFieldOptions;
fn dyn_id(&self) -> &str;
fn dyn_value(&self) -> Option<&str>;
async fn dyn_set_value(&mut self, field: FormFieldValue<'_>)
-> Result<(), FormFieldValueError>;
}
#[async_trait]
impl<T: FormField + Send> DynFormField for T {
fn dyn_options(&self) -> &FormFieldOptions {
FormField::options(self)
}
fn dyn_id(&self) -> &str {
FormField::id(self)
}
fn dyn_value(&self) -> Option<&str> {
FormField::value(self)
}
async fn dyn_set_value(
&mut self,
field: FormFieldValue<'_>,
) -> Result<(), FormFieldValueError> {
FormField::set_value(self, field).await
}
}
pub trait AsFormField {
type Type: FormField;
fn new_field(
options: FormFieldOptions,
custom_options: <Self::Type as FormField>::CustomOptions,
) -> Self::Type {
Self::Type::with_options(options, custom_options)
}
fn clean_value(field: &Self::Type) -> Result<Self, FormFieldValidationError>
where
Self: Sized;
fn to_field_value(&self) -> String;
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use cot_core::headers::{MULTIPART_FORM_CONTENT_TYPE, URLENCODED_FORM_CONTENT_TYPE};
use super::*;
use crate::Body;
#[cot::test]
async fn urlencoded_form_data_extract_get_empty() {
let mut request = http::Request::builder()
.method(http::Method::GET)
.uri("https://example.com")
.body(Body::empty())
.unwrap();
let bytes = urlencoded_form_data(&mut request).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b""));
}
#[cot::test]
async fn urlencoded_form_data_extract_get() {
let mut request = http::Request::builder()
.method(http::Method::GET)
.uri("https://example.com/?hello=world")
.body(Body::empty())
.unwrap();
let bytes = urlencoded_form_data(&mut request).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"hello=world"));
}
#[cot::test]
async fn urlencoded_form_data_extract_head() {
let mut request = http::Request::builder()
.method(http::Method::HEAD)
.uri("https://example.com/?hello=world")
.body(Body::empty())
.unwrap();
let bytes = urlencoded_form_data(&mut request).await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"hello=world"));
}
#[cot::test]
async fn urlencoded_form_data_extract_urlencoded() {
let mut request = http::Request::builder()
.method(http::Method::POST)
.header(http::header::CONTENT_TYPE, URLENCODED_FORM_CONTENT_TYPE)
.body(Body::fixed("hello=world"))
.unwrap();
let result = urlencoded_form_data(&mut request).await.unwrap();
assert_eq!(result, Bytes::from_static(b"hello=world"));
}
#[cot::test]
async fn form_data_extract_multipart() {
let boundary = "boundary";
let body = format!(
"--{boundary}\r\n\
Content-Disposition: form-data; name=\"hello\"\r\n\
\r\n\
world\r\n\
--{boundary}\r\n\
Content-Disposition: form-data; name=\"test\"\r\n\
\r\n\
123\r\n\
--{boundary}--\r\n"
);
let mut request = http::Request::builder()
.method(http::Method::POST)
.header(
http::header::CONTENT_TYPE,
format!("{MULTIPART_FORM_CONTENT_TYPE}; boundary={boundary}"),
)
.body(Body::fixed(body))
.unwrap();
let mut form_data = form_data(&mut request).await.unwrap();
let mut values = Vec::new();
while let Some((field_id, value)) = form_data.next_value().await.unwrap() {
values.push((field_id, value.into_text().await.unwrap()));
}
assert_eq!(values.len(), 2);
assert_eq!(values[0].0, "hello");
assert_eq!(values[0].1, "world");
assert_eq!(values[1].0, "test");
assert_eq!(values[1].1, "123");
}
#[cot::test]
async fn form_data_extract_multipart_with_file() {
let boundary = "boundary";
let body = format!(
"--{boundary}\r\n\
Content-Disposition: form-data; name=\"hello\"\r\n\
\r\n\
world\r\n\
--{boundary}\r\n\
Content-Disposition: form-data; name=\"file\"; filename=\"test.txt\"\r\n\
Content-Type: text/plain\r\n\
\r\n\
file content\r\n\
--{boundary}--\r\n"
);
let mut request = http::Request::builder()
.method(http::Method::POST)
.header(
http::header::CONTENT_TYPE,
format!("{MULTIPART_FORM_CONTENT_TYPE}; boundary={boundary}"),
)
.body(Body::fixed(body))
.unwrap();
let mut form_data = form_data(&mut request).await.unwrap();
let mut values = Vec::new();
while let Some((field_id, value)) = form_data.next_value().await.unwrap() {
assert!(value.is_multipart());
values.push((
field_id,
value.filename().map(ToOwned::to_owned),
value.content_type().map(ToOwned::to_owned),
value.into_text().await.unwrap(),
));
}
assert_eq!(values.len(), 2);
assert_eq!(values[0].0, "hello");
assert_eq!(values[0].1, None);
assert_eq!(values[0].2, None);
assert_eq!(values[0].3, "world");
assert_eq!(values[1].0, "file");
assert_eq!(values[1].1, Some("test.txt".to_owned()));
assert_eq!(values[1].2, Some("text/plain".to_owned()));
assert_eq!(values[1].3, "file content");
}
#[cot::test]
async fn form_data_extract_invalid_content_type() {
let mut request = http::Request::builder()
.method(http::Method::POST)
.header(http::header::CONTENT_TYPE, "application/json")
.body(Body::fixed("{}"))
.unwrap();
let result = form_data(&mut request).await;
assert!(result.is_err());
if let Err(FormError::RequestError { error }) = result {
assert!(
error
.to_string()
.contains("request does not contain a form"),
"{}",
error
);
} else {
panic!("Expected RequestError");
}
}
}