use std::collections::HashMap;
use crate::core::{FieldSchema, FieldType, SqlValue};
#[cfg(feature = "csrf")]
pub mod csrf;
pub trait FormStruct: Sized {
fn parse(form: &std::collections::HashMap<String, String>) -> Result<Self, FormError>;
}
#[derive(Debug, thiserror::Error)]
pub enum FormError {
#[error("required field `{field}` was missing from the form")]
Missing { field: String },
#[error("field `{field}` has invalid {ty} value `{value}`: {detail}")]
Parse {
field: String,
ty: &'static str,
value: String,
detail: String,
},
#[error("PK field `{field}` of type {ty} is not supported in URL paths")]
UnsupportedPk { field: String, ty: &'static str },
}
pub fn parse_pk_string(field: &FieldSchema, raw: &str) -> Result<SqlValue, FormError> {
let make_parse_err = |ty: &'static str, e: &dyn std::fmt::Display| FormError::Parse {
field: field.name.to_owned(),
ty,
value: raw.to_owned(),
detail: e.to_string(),
};
match field.ty {
FieldType::I32 => raw
.parse::<i32>()
.map(SqlValue::I32)
.map_err(|e| make_parse_err("i32", &e)),
FieldType::I64 => raw
.parse::<i64>()
.map(SqlValue::I64)
.map_err(|e| make_parse_err("i64", &e)),
FieldType::String => Ok(SqlValue::String(raw.to_owned())),
FieldType::Uuid => uuid::Uuid::parse_str(raw)
.map(SqlValue::Uuid)
.map_err(|e| make_parse_err("Uuid", &e)),
FieldType::Bool
| FieldType::F32
| FieldType::F64
| FieldType::DateTime
| FieldType::Date
| FieldType::Json => Err(FormError::UnsupportedPk {
field: field.name.to_owned(),
ty: field.ty.as_str(),
}),
}
}
pub fn parse_form_value(
field: &FieldSchema,
raw: Option<&str>,
) -> Result<SqlValue, FormError> {
let Some(raw) = raw else {
return Ok(match field.ty {
FieldType::Bool => SqlValue::Bool(false),
_ if field.nullable => SqlValue::Null,
_ => {
return Err(FormError::Missing {
field: field.name.to_owned(),
});
}
});
};
if field.nullable && raw.is_empty() {
return Ok(SqlValue::Null);
}
let make_parse_err = |ty: &'static str, e: &dyn std::fmt::Display| FormError::Parse {
field: field.name.to_owned(),
ty,
value: raw.to_owned(),
detail: e.to_string(),
};
match field.ty {
FieldType::Bool => {
let v = !matches!(
raw.to_ascii_lowercase().as_str(),
"" | "false" | "0" | "off" | "no"
);
Ok(SqlValue::Bool(v))
}
FieldType::I32 => raw
.parse::<i32>()
.map(SqlValue::I32)
.map_err(|e| make_parse_err("i32", &e)),
FieldType::I64 => raw
.parse::<i64>()
.map(SqlValue::I64)
.map_err(|e| make_parse_err("i64", &e)),
FieldType::F32 => raw
.parse::<f32>()
.map(SqlValue::F32)
.map_err(|e| make_parse_err("f32", &e)),
FieldType::F64 => raw
.parse::<f64>()
.map(SqlValue::F64)
.map_err(|e| make_parse_err("f64", &e)),
FieldType::String => Ok(SqlValue::String(raw.to_owned())),
FieldType::Uuid => uuid::Uuid::parse_str(raw)
.map(SqlValue::Uuid)
.map_err(|e| make_parse_err("Uuid", &e)),
FieldType::Date => chrono::NaiveDate::parse_from_str(raw, "%Y-%m-%d")
.map(SqlValue::Date)
.map_err(|e| make_parse_err("Date", &e)),
FieldType::DateTime => {
if let Ok(d) = chrono::DateTime::parse_from_rfc3339(raw) {
return Ok(SqlValue::DateTime(d.with_timezone(&chrono::Utc)));
}
let ndt = chrono::NaiveDateTime::parse_from_str(raw, "%Y-%m-%dT%H:%M:%S")
.or_else(|_| chrono::NaiveDateTime::parse_from_str(raw, "%Y-%m-%dT%H:%M"))
.map_err(|e| make_parse_err("DateTime", &e))?;
Ok(SqlValue::DateTime(ndt.and_utc()))
}
FieldType::Json => Err(FormError::UnsupportedPk {
field: field.name.to_owned(),
ty: "Json",
}),
}
}
pub fn collect_values(
model: &'static crate::core::ModelSchema,
form: &HashMap<String, String>,
skip: &[&str],
) -> Result<Vec<(&'static str, SqlValue)>, FormError> {
let mut out = Vec::new();
for field in model.scalar_fields() {
if skip.contains(&field.name) {
continue;
}
let raw = form.get(field.name).map(String::as_str);
let value = parse_form_value(field, raw)?;
out.push((field.column, value));
}
Ok(out)
}