use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use super::parse::{EntityDef, FilterType};
use crate::utils::marker;
pub fn generate(entity: &EntityDef) -> TokenStream {
if !entity.has_filters() && !entity.has_sort_fields() {
return TokenStream::new();
}
let vis = &entity.vis;
let query_name = entity.ident_with("", "Query");
let filter_fields = entity.filter_fields();
let field_defs: Vec<TokenStream> = filter_fields
.iter()
.flat_map(|f| {
let name = f.name();
let ty = f.ty();
let filter = f.filter();
match filter.filter_type {
FilterType::Eq | FilterType::Like => {
vec![quote! { pub #name: Option<#ty> }]
}
FilterType::Range => {
let from_name = format_ident!("{}_from", name);
let to_name = format_ident!("{}_to", name);
vec![
quote! { pub #from_name: Option<#ty> },
quote! { pub #to_name: Option<#ty> },
]
}
FilterType::None => vec![]
}
})
.collect();
let marker = marker::generated();
let filter_name = entity.ident_with("", "Filter");
let (sort_enum, sort_field) = generate_sort_enum(entity);
quote! {
#sort_enum
#marker
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "api", derive(utoipa::ToSchema))]
#vis struct #query_name {
#(#field_defs,)*
#sort_field
pub limit: Option<i64>,
pub offset: Option<i64>,
}
#vis type #filter_name = #query_name;
}
}
fn generate_sort_enum(entity: &EntityDef) -> (TokenStream, TokenStream) {
let sort_fields = entity.sort_fields();
if sort_fields.is_empty() {
return (TokenStream::new(), TokenStream::new());
}
let vis = &entity.vis;
let sort_name = entity.ident_with("", "SortField");
let mut variants = Vec::new();
let mut arms = Vec::new();
for field in &sort_fields {
let column = field.name_str();
for (suffix, direction) in [("Asc", "ASC"), ("Desc", "DESC")] {
let variant = format_ident!(
"{}{}",
convert_case::Casing::to_case(&column, convert_case::Case::Pascal),
suffix
);
let fragment = format!("{column} {direction}");
variants.push(quote! { #variant });
arms.push(quote! { Self::#variant => #fragment });
}
}
let doc = format!("Sortable columns for [`{}`] queries.", entity.name());
let marker = marker::generated();
let sort_enum = quote! {
#marker
#[doc = #doc]
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
#[cfg_attr(feature = "api", derive(utoipa::ToSchema))]
#vis enum #sort_name {
#(#variants,)*
}
impl #sort_name {
#vis fn order_by(&self) -> &'static str {
match self {
#(#arms,)*
}
}
}
};
let sort_field = quote! {
pub sort: Option<#sort_name>,
};
(sort_enum, sort_field)
}
#[cfg(test)]
mod sort_tests {
use quote::quote;
use syn::DeriveInput;
use super::*;
fn parse_entity(tokens: proc_macro2::TokenStream) -> EntityDef {
let input: DeriveInput = syn::parse2(tokens).expect("test entity must parse");
EntityDef::from_derive_input(&input).expect("test entity must be valid")
}
fn sortable_entity() -> EntityDef {
parse_entity(quote! {
#[entity(table = "posts")]
pub struct Post {
#[id]
pub id: uuid::Uuid,
#[field(create, update, response)]
#[sort]
pub title: String,
#[field(create, response)]
#[sort]
#[filter]
pub views: i64,
}
})
}
#[test]
fn sort_enum_generated_with_asc_desc_variants() {
let code = generate(&sortable_entity()).to_string();
assert!(code.contains("enum PostSortField"));
assert!(code.contains("TitleAsc"));
assert!(code.contains("TitleDesc"));
assert!(code.contains("ViewsDesc"));
assert!(code.contains("\"title ASC\""));
assert!(code.contains("\"views DESC\""));
}
#[test]
fn query_struct_gains_sort_field() {
let code = generate(&sortable_entity()).to_string();
assert!(code.contains("pub sort : Option < PostSortField >"));
}
#[test]
fn sort_only_entity_still_generates_query() {
let entity = parse_entity(quote! {
#[entity(table = "posts")]
pub struct Post {
#[id]
pub id: uuid::Uuid,
#[field(create, response)]
#[sort]
pub title: String,
}
});
let code = generate(&entity).to_string();
assert!(code.contains("struct PostQuery"));
assert!(code.contains("PostSortField"));
}
#[test]
fn no_sort_no_enum() {
let entity = parse_entity(quote! {
#[entity(table = "posts")]
pub struct Post {
#[id]
pub id: uuid::Uuid,
#[field(create, response)]
#[filter]
pub title: String,
}
});
let code = generate(&entity).to_string();
assert!(!code.contains("SortField"));
}
}