#[cfg(all(not(feature = "std"), feature = "alloc"))]
use alloc::{string::String, vec::Vec};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(all(feature = "validator", any(feature = "std", feature = "alloc")))]
use validator::Validate;
#[derive(Debug, Clone, PartialEq, Eq, Default)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(rename_all = "lowercase")
)]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[cfg_attr(feature = "proptest", derive(proptest_derive::Arbitrary))]
pub enum SortDirection {
#[default]
Asc,
Desc,
}
#[cfg(all(feature = "serde", any(feature = "std", feature = "alloc")))]
fn default_sort_direction() -> SortDirection {
SortDirection::Asc
}
#[cfg(any(feature = "std", feature = "alloc"))]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema, utoipa::IntoParams))]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[cfg_attr(feature = "proptest", derive(proptest_derive::Arbitrary))]
pub struct SortParams {
pub sort_by: String,
#[cfg_attr(feature = "serde", serde(default = "default_sort_direction"))]
pub direction: SortDirection,
}
#[cfg(any(feature = "std", feature = "alloc"))]
impl SortParams {
#[must_use]
pub fn new(sort_by: impl Into<String>, direction: SortDirection) -> Self {
Self {
sort_by: sort_by.into(),
direction,
}
}
#[must_use]
pub fn asc(sort_by: impl Into<String>) -> Self {
Self::new(sort_by, SortDirection::Asc)
}
#[must_use]
pub fn desc(sort_by: impl Into<String>) -> Self {
Self::new(sort_by, SortDirection::Desc)
}
}
#[cfg(any(feature = "std", feature = "alloc"))]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[cfg_attr(feature = "proptest", derive(proptest_derive::Arbitrary))]
pub struct FilterEntry {
pub field: String,
pub operator: String,
pub value: String,
}
#[cfg(any(feature = "std", feature = "alloc"))]
impl FilterEntry {
#[must_use]
pub fn new(
field: impl Into<String>,
operator: impl Into<String>,
value: impl Into<String>,
) -> Self {
Self {
field: field.into(),
operator: operator.into(),
value: value.into(),
}
}
}
#[cfg(any(feature = "std", feature = "alloc"))]
#[derive(Debug, Clone, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema, utoipa::IntoParams))]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[cfg_attr(feature = "proptest", derive(proptest_derive::Arbitrary))]
pub struct FilterParams {
#[cfg_attr(feature = "serde", serde(default))]
pub filters: Vec<FilterEntry>,
}
#[cfg(any(feature = "std", feature = "alloc"))]
impl FilterParams {
#[must_use]
pub fn new(filters: impl IntoIterator<Item = FilterEntry>) -> Self {
Self {
filters: filters.into_iter().collect(),
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.filters.is_empty()
}
}
#[cfg(any(feature = "std", feature = "alloc"))]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema, utoipa::IntoParams))]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[cfg_attr(feature = "validator", derive(Validate))]
#[cfg_attr(feature = "proptest", derive(proptest_derive::Arbitrary))]
pub struct SearchParams {
#[cfg_attr(
feature = "validator",
validate(length(
min = 1,
max = 500,
message = "query must be between 1 and 500 characters"
))
)]
#[cfg_attr(feature = "proptest", proptest(strategy = "search_query_strategy()"))]
pub query: String,
#[cfg_attr(
feature = "serde",
serde(default, skip_serializing_if = "Vec::is_empty")
)]
pub fields: Vec<String>,
}
#[cfg(any(feature = "std", feature = "alloc"))]
impl SearchParams {
#[must_use]
pub fn new(query: impl Into<String>) -> Self {
Self {
query: query.into(),
fields: Vec::new(),
}
}
pub fn try_new(query: impl Into<String>) -> Result<Self, crate::error::ValidationError> {
let query = query.into();
if query.is_empty() || query.len() > 500 {
return Err(crate::error::ValidationError {
field: "/query".into(),
message: "must be between 1 and 500 characters".into(),
rule: Some("length".into()),
});
}
Ok(Self {
query,
fields: Vec::new(),
})
}
pub fn try_with_fields(
query: impl Into<String>,
fields: impl IntoIterator<Item = impl Into<String>>,
) -> Result<Self, crate::error::ValidationError> {
let mut s = Self::try_new(query)?;
s.fields = fields.into_iter().map(Into::into).collect();
Ok(s)
}
#[must_use]
pub fn with_fields(
query: impl Into<String>,
fields: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
Self {
query: query.into(),
fields: fields.into_iter().map(Into::into).collect(),
}
}
}
#[cfg(feature = "axum")]
#[allow(clippy::result_large_err)]
mod axum_extractors {
use super::SortParams;
use crate::error::ApiError;
use axum::extract::{FromRequestParts, Query};
use axum::http::request::Parts;
impl<S: Send + Sync> FromRequestParts<S> for SortParams {
type Rejection = ApiError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let Query(params) = Query::<Self>::from_request_parts(parts, state)
.await
.map_err(|e| ApiError::bad_request(e.to_string()))?;
Ok(params)
}
}
}
#[cfg(all(feature = "proptest", any(feature = "std", feature = "alloc")))]
fn search_query_strategy() -> impl proptest::strategy::Strategy<Value = String> {
proptest::string::string_regex("[a-zA-Z0-9 ]{1,500}").expect("valid regex")
}
#[cfg(all(feature = "arbitrary", any(feature = "std", feature = "alloc")))]
impl<'a> arbitrary::Arbitrary<'a> for SearchParams {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let len = u.int_in_range(1usize..=500)?;
let query: String = (0..len)
.map(|_| -> arbitrary::Result<char> {
let byte = u.int_in_range(32u8..=126)?;
Ok(char::from(byte))
})
.collect::<arbitrary::Result<_>>()?;
let fields = <Vec<String> as arbitrary::Arbitrary>::arbitrary(u)?;
Ok(Self { query, fields })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sort_direction_default_is_asc() {
assert_eq!(SortDirection::default(), SortDirection::Asc);
}
#[cfg(feature = "serde")]
#[test]
fn sort_direction_serde_lowercase() {
let asc = serde_json::to_value(SortDirection::Asc).unwrap();
assert_eq!(asc, serde_json::json!("asc"));
let desc = serde_json::to_value(SortDirection::Desc).unwrap();
assert_eq!(desc, serde_json::json!("desc"));
let back: SortDirection = serde_json::from_value(asc).unwrap();
assert_eq!(back, SortDirection::Asc);
}
#[test]
fn sort_params_asc_helper() {
let p = SortParams::asc("created_at");
assert_eq!(p.sort_by, "created_at");
assert_eq!(p.direction, SortDirection::Asc);
}
#[test]
fn sort_params_desc_helper() {
let p = SortParams::desc("name");
assert_eq!(p.sort_by, "name");
assert_eq!(p.direction, SortDirection::Desc);
}
#[cfg(feature = "serde")]
#[test]
fn sort_params_serde_round_trip() {
let p = SortParams::desc("created_at");
let json = serde_json::to_value(&p).unwrap();
assert_eq!(json["sort_by"], "created_at");
assert_eq!(json["direction"], "desc");
let back: SortParams = serde_json::from_value(json).unwrap();
assert_eq!(back, p);
}
#[cfg(feature = "serde")]
#[test]
fn sort_params_serde_default_direction() {
let json = serde_json::json!({"sort_by": "name"});
let p: SortParams = serde_json::from_value(json).unwrap();
assert_eq!(p.direction, SortDirection::Asc);
}
#[test]
fn filter_params_default_is_empty() {
let f = FilterParams::default();
assert!(f.is_empty());
}
#[test]
fn filter_params_new() {
let f = FilterParams::new([FilterEntry::new("status", "eq", "active")]);
assert!(!f.is_empty());
assert_eq!(f.filters.len(), 1);
assert_eq!(f.filters[0].field, "status");
assert_eq!(f.filters[0].operator, "eq");
assert_eq!(f.filters[0].value, "active");
}
#[cfg(feature = "serde")]
#[test]
fn filter_params_serde_round_trip() {
let f = FilterParams::new([FilterEntry::new("age", "gt", "18")]);
let json = serde_json::to_value(&f).unwrap();
let back: FilterParams = serde_json::from_value(json).unwrap();
assert_eq!(back, f);
}
#[cfg(feature = "serde")]
#[test]
fn filter_params_serde_empty_filters_default() {
let json = serde_json::json!({});
let f: FilterParams = serde_json::from_value(json).unwrap();
assert!(f.is_empty());
}
#[test]
fn search_params_new() {
let s = SearchParams::new("annual report");
assert_eq!(s.query, "annual report");
assert!(s.fields.is_empty());
}
#[test]
fn search_params_with_fields() {
let s = SearchParams::with_fields("report", ["title", "description"]);
assert_eq!(s.query, "report");
assert_eq!(s.fields, vec!["title", "description"]);
}
#[cfg(feature = "serde")]
#[test]
fn search_params_serde_round_trip() {
let s = SearchParams::with_fields("hello", ["name"]);
let json = serde_json::to_value(&s).unwrap();
assert_eq!(json["query"], "hello");
assert_eq!(json["fields"], serde_json::json!(["name"]));
let back: SearchParams = serde_json::from_value(json).unwrap();
assert_eq!(back, s);
}
#[cfg(feature = "serde")]
#[test]
fn search_params_serde_omits_empty_fields() {
let s = SearchParams::new("test");
let json = serde_json::to_value(&s).unwrap();
assert!(json.get("fields").is_none());
}
#[cfg(feature = "validator")]
#[test]
fn search_params_validate_empty_query_fails() {
use validator::Validate;
let s = SearchParams::new("");
assert!(s.validate().is_err());
}
#[cfg(feature = "validator")]
#[test]
fn search_params_validate_too_long_fails() {
use validator::Validate;
let s = SearchParams::new("a".repeat(501));
assert!(s.validate().is_err());
}
#[cfg(feature = "validator")]
#[test]
fn search_params_validate_boundary_max() {
use validator::Validate;
let s = SearchParams::new("a".repeat(500));
assert!(s.validate().is_ok());
}
#[test]
fn search_params_try_new_valid() {
let s = SearchParams::try_new("report").unwrap();
assert_eq!(s.query, "report");
assert!(s.fields.is_empty());
}
#[test]
fn search_params_try_new_boundary_min() {
assert!(SearchParams::try_new("a").is_ok());
}
#[test]
fn search_params_try_new_boundary_max() {
assert!(SearchParams::try_new("a".repeat(500)).is_ok());
}
#[test]
fn search_params_try_new_empty_fails() {
let err = SearchParams::try_new("").unwrap_err();
assert_eq!(err.field, "/query");
assert_eq!(err.rule.as_deref(), Some("length"));
}
#[test]
fn search_params_try_new_too_long_fails() {
assert!(SearchParams::try_new("a".repeat(501)).is_err());
}
#[test]
fn search_params_try_with_fields_valid() {
let s = SearchParams::try_with_fields("report", ["title", "body"]).unwrap();
assert_eq!(s.query, "report");
assert_eq!(s.fields, vec!["title", "body"]);
}
#[test]
fn search_params_try_with_fields_empty_query_fails() {
assert!(SearchParams::try_with_fields("", ["title"]).is_err());
}
#[cfg(feature = "axum")]
mod axum_extractor_tests {
use super::super::{SortDirection, SortParams};
use axum::extract::FromRequestParts;
use axum::http::Request;
async fn extract(q: &str) -> Result<SortParams, u16> {
let req = Request::builder().uri(format!("/?{q}")).body(()).unwrap();
let (mut parts, ()) = req.into_parts();
SortParams::from_request_parts(&mut parts, &())
.await
.map_err(|e| e.status)
}
#[tokio::test]
async fn sort_default_direction() {
let p = extract("sort_by=name").await.unwrap();
assert_eq!(p.sort_by, "name");
assert_eq!(p.direction, SortDirection::Asc);
}
#[tokio::test]
async fn sort_custom_direction() {
let p = extract("sort_by=created_at&direction=desc").await.unwrap();
assert_eq!(p.direction, SortDirection::Desc);
}
#[tokio::test]
async fn sort_missing_sort_by_rejected() {
assert_eq!(extract("").await.unwrap_err(), 400);
}
}
#[cfg(feature = "validator")]
#[test]
fn search_params_validate_ok() {
use validator::Validate;
let s = SearchParams::new("valid query");
assert!(s.validate().is_ok());
}
}