use convert_case::{Case, Casing};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::spanned::Spanned;
use syn::{Data, DeriveInput, Fields, Ident, LitStr, Type};
pub fn derive_model_impl(input: &DeriveInput) -> Result<TokenStream, syn::Error> {
let name = &input.ident;
let module_name = format_ident!("{}", name.to_string().to_case(Case::Snake));
let fields = match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => &fields.named,
_ => {
return Err(syn::Error::new_spanned(
input,
"Model derive only supports structs with named fields",
));
}
},
_ => {
return Err(syn::Error::new_spanned(
input,
"Model derive only supports structs",
));
}
};
let struct_attrs = parse_struct_attrs(input)?;
let table_name = struct_attrs
.table_name
.unwrap_or_else(|| name.to_string().to_case(Case::Snake));
let field_infos: Vec<FieldInfo> = fields.iter().map(parse_field).collect::<Result<_, _>>()?;
let pk_fields: Vec<_> = field_infos
.iter()
.filter(|f| f.is_id)
.map(|f| f.column_name.as_str())
.collect();
if pk_fields.is_empty() {
return Err(syn::Error::new_spanned(
input,
"Model must have at least one field marked with #[prax(id)]",
));
}
let field_modules: Vec<_> = field_infos
.iter()
.filter(|f| !f.is_list)
.map(generate_field_module_from_derive)
.collect();
let where_variants: Vec<_> = field_infos
.iter()
.filter(|f| !f.is_list)
.map(|f| {
let variant_name = format_ident!("{}", f.name.to_string().to_case(Case::Pascal));
let field_mod = &f.name;
quote! { #variant_name(#field_mod::WhereOp) }
})
.collect();
let from_filter_arms: Vec<_> = field_infos
.iter()
.filter(|f| !f.is_list)
.map(|f| {
let variant_name = format_ident!("{}", f.name.to_string().to_case(Case::Pascal));
quote! { WhereParam::#variant_name(op) => op.to_filter(), }
})
.collect();
let select_variants: Vec<_> = field_infos
.iter()
.filter(|f| !f.is_list)
.map(|f| {
let variant_name = format_ident!("{}", f.name.to_string().to_case(Case::Pascal));
quote! { #variant_name }
})
.collect();
let order_variants: Vec<_> = field_infos
.iter()
.filter(|f| !f.is_list)
.map(|f| {
let variant_name = format_ident!("{}", f.name.to_string().to_case(Case::Pascal));
quote! { #variant_name(::prax_orm::_prax_prelude::SortOrder) }
})
.collect();
let all_columns: Vec<String> = field_infos
.iter()
.filter(|f| !f.is_list)
.map(|f| f.column_name.clone())
.collect();
let pk_columns_owned: Vec<String> = field_infos
.iter()
.filter(|f| f.is_id)
.map(|f| f.column_name.clone())
.collect();
let from_row_fields: Vec<(Ident, Type, String)> = field_infos
.iter()
.filter(|f| !f.is_list)
.map(|f| (f.name.clone(), f.ty.clone(), f.column_name.clone()))
.collect();
let from_row_relation_fields: Vec<Ident> = field_infos
.iter()
.filter(|f| f.is_list)
.map(|f| f.name.clone())
.collect();
let model_with_pk_fields: Vec<(Ident, Type, String, bool)> = field_infos
.iter()
.filter(|f| !f.is_list)
.map(|f| (f.name.clone(), f.ty.clone(), f.column_name.clone(), f.is_id))
.collect();
let model_trait_impl = super::derive_model_trait::emit(
name,
&name.to_string(),
&table_name,
&pk_columns_owned,
&all_columns,
);
let from_row_impl =
super::derive_from_row::emit(name, &from_row_fields, &from_row_relation_fields);
let model_with_pk_impl = super::derive_model_with_pk::emit(name, &model_with_pk_fields);
let client_impl = super::derive_client::emit(quote! { super::#name });
let relation_mods: Vec<_> = field_infos
.iter()
.filter_map(|f| {
f.relation.as_ref().map(|rel| {
let kind = if f.is_list {
super::relation_accessors::RelationKindTokens::HasMany
} else if f.is_optional {
super::relation_accessors::RelationKindTokens::HasOne
} else {
super::relation_accessors::RelationKindTokens::BelongsTo
};
super::relation_accessors::emit(super::relation_accessors::RelationSpec {
field_name: &f.name,
owner: name,
target: &rel.target,
kind,
local_key: &rel.local_key,
foreign_key: &rel.foreign_key,
})
})
})
.collect();
let loader_relations: Vec<super::derive_relation_loader::LoaderRelation<'_>> = field_infos
.iter()
.filter_map(|f| {
f.relation.as_ref().map(|rel| {
let kind = if f.is_list {
super::derive_relation_loader::LoaderKind::HasMany
} else {
super::derive_relation_loader::LoaderKind::HasOne
};
super::derive_relation_loader::LoaderRelation {
field_name: &f.name,
target: &rel.target,
kind,
}
})
})
.collect();
let model_relation_loader_impl = super::derive_relation_loader::emit(name, &loader_relations);
Ok(quote! {
pub mod #module_name {
use super::*;
pub const TABLE_NAME: &str = #table_name;
pub const PRIMARY_KEY: &[&str] = &[#(#pk_fields),*];
impl ::prax_orm::_prax_prelude::PraxModel for #name {
const TABLE_NAME: &'static str = TABLE_NAME;
const PRIMARY_KEY: &'static [&'static str] = PRIMARY_KEY;
}
#(#field_modules)*
#(#relation_mods)*
#[derive(Debug, Clone)]
pub enum WhereParam {
#(#where_variants,)*
And(Vec<WhereParam>),
Or(Vec<WhereParam>),
Not(Box<WhereParam>),
}
impl From<WhereParam> for ::prax_query::filter::Filter {
fn from(p: WhereParam) -> Self {
match p {
#(#from_filter_arms)*
WhereParam::And(ps) => ::prax_query::filter::Filter::And(
ps.into_iter().map(Into::into).collect::<Vec<_>>().into_boxed_slice()
),
WhereParam::Or(ps) => ::prax_query::filter::Filter::Or(
ps.into_iter().map(Into::into).collect::<Vec<_>>().into_boxed_slice()
),
WhereParam::Not(p) => ::prax_query::filter::Filter::Not(Box::new((*p).into())),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SelectParam {
#(#select_variants,)*
}
#[derive(Debug, Clone, Copy)]
pub enum OrderByParam {
#(#order_variants,)*
}
#client_impl
}
#model_trait_impl
#from_row_impl
#model_with_pk_impl
#model_relation_loader_impl
})
}
#[derive(Debug, Default)]
struct StructAttrs {
table_name: Option<String>,
schema_name: Option<String>,
}
fn parse_struct_attrs(input: &DeriveInput) -> Result<StructAttrs, syn::Error> {
let mut attrs = StructAttrs::default();
for attr in &input.attrs {
if !attr.path().is_ident("prax") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("table") {
let value: LitStr = meta.value()?.parse()?;
attrs.table_name = Some(value.value());
} else if meta.path.is_ident("schema") {
let value: LitStr = meta.value()?.parse()?;
attrs.schema_name = Some(value.value());
}
Ok(())
})?;
}
Ok(attrs)
}
#[derive(Debug)]
#[allow(dead_code)]
struct FieldInfo {
name: Ident,
ty: Type,
column_name: String,
is_id: bool,
is_auto: bool,
is_unique: bool,
is_optional: bool,
is_list: bool,
relation: Option<RelationAttr>,
}
#[derive(Debug)]
struct RelationAttr {
target: syn::Ident,
foreign_key: String,
local_key: String,
}
fn parse_field(field: &syn::Field) -> Result<FieldInfo, syn::Error> {
let name = field
.ident
.clone()
.ok_or_else(|| syn::Error::new_spanned(field, "Fields must be named"))?;
let ty = field.ty.clone();
let mut column_name = name.to_string().to_case(Case::Snake);
let mut is_id = false;
let mut is_auto = false;
let mut is_unique = false;
let mut relation: Option<RelationAttr> = None;
let is_optional = is_option_type(&ty);
let is_list = is_vec_type(&ty);
for attr in &field.attrs {
if !attr.path().is_ident("prax") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("id") {
is_id = true;
} else if meta.path.is_ident("auto") {
is_auto = true;
} else if meta.path.is_ident("unique") {
is_unique = true;
} else if meta.path.is_ident("column") {
let value: LitStr = meta.value()?.parse()?;
column_name = value.value();
} else if meta.path.is_ident("relation") {
let mut target: Option<syn::Ident> = None;
let mut fk: Option<String> = None;
let mut lk: Option<String> = None;
meta.parse_nested_meta(|inner| {
if inner.path.is_ident("target") {
let s: LitStr = inner.value()?.parse()?;
target = Some(format_ident!("{}", s.value()));
} else if inner.path.is_ident("foreign_key") {
let s: LitStr = inner.value()?.parse()?;
fk = Some(s.value());
} else if inner.path.is_ident("local_key") {
let s: LitStr = inner.value()?.parse()?;
lk = Some(s.value());
}
Ok(())
})?;
let target = target.ok_or_else(|| {
syn::Error::new(meta.path.span(), "relation requires target = \"ModelName\"")
})?;
let foreign_key = fk.ok_or_else(|| {
syn::Error::new(meta.path.span(), "relation requires foreign_key = \"...\"")
})?;
relation = Some(RelationAttr {
target,
foreign_key,
local_key: lk.unwrap_or_else(|| "id".to_string()),
});
}
Ok(())
})?;
}
Ok(FieldInfo {
name,
ty,
column_name,
is_id,
is_auto,
is_unique,
is_optional,
is_list,
relation,
})
}
fn is_option_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty
&& let Some(segment) = type_path.path.segments.first()
{
return segment.ident == "Option";
}
false
}
fn is_vec_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty
&& let Some(segment) = type_path.path.segments.first()
{
return segment.ident == "Vec";
}
false
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum TypeCategory {
Numeric,
String,
Boolean,
Other,
}
fn classify_field_type(ty: &Type) -> TypeCategory {
let type_name = if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.first()
&& segment.ident == "Option"
&& let syn::PathArguments::AngleBracketed(args) = &segment.arguments
&& let Some(syn::GenericArgument::Type(inner)) = args.args.first()
{
return classify_field_type(inner);
}
type_path.path.segments.last().map(|s| s.ident.to_string())
} else {
None
};
match type_name.as_deref() {
Some(
"i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128"
| "usize" | "f32" | "f64" | "Decimal" | "NaiveDate" | "NaiveDateTime" | "NaiveTime"
| "DateTime",
) => TypeCategory::Numeric,
Some("String" | "str") => TypeCategory::String,
Some("bool") => TypeCategory::Boolean,
_ => TypeCategory::Other,
}
}
fn generate_field_module_from_derive(field: &FieldInfo) -> TokenStream {
let field_name = &field.name;
let column_name = &field.column_name;
let ty = &field.ty;
let variant_name = format_ident!("{}", field.name.to_string().to_case(Case::Pascal));
let is_optional = field.is_optional;
let category = classify_field_type(ty);
let mut variants = vec![quote! { Equals(#ty) }, quote! { Not(#ty) }];
if is_optional {
variants.push(quote! { IsNull });
variants.push(quote! { IsNotNull });
}
match category {
TypeCategory::Numeric => {
variants.push(quote! { In(Vec<#ty>) });
variants.push(quote! { NotIn(Vec<#ty>) });
variants.push(quote! { Gt(#ty) });
variants.push(quote! { Gte(#ty) });
variants.push(quote! { Lt(#ty) });
variants.push(quote! { Lte(#ty) });
}
TypeCategory::String => {
variants.push(quote! { In(Vec<#ty>) });
variants.push(quote! { NotIn(Vec<#ty>) });
variants.push(quote! { Contains(String) });
variants.push(quote! { StartsWith(String) });
variants.push(quote! { EndsWith(String) });
}
TypeCategory::Boolean => {}
TypeCategory::Other => {
variants.push(quote! { In(Vec<#ty>) });
variants.push(quote! { NotIn(Vec<#ty>) });
}
}
let mut arms = vec![
quote! { Self::Equals(v) => Filter::Equals(col, v.into()) },
quote! { Self::Not(v) => Filter::NotEquals(col, v.into()) },
];
if is_optional {
arms.push(quote! { Self::IsNull => Filter::IsNull(col) });
arms.push(quote! { Self::IsNotNull => Filter::IsNotNull(col) });
}
match category {
TypeCategory::Numeric => {
arms.push(
quote! { Self::In(vs) => Filter::In(col, vs.into_iter().map(Into::into).collect()) },
);
arms.push(
quote! { Self::NotIn(vs) => Filter::NotIn(col, vs.into_iter().map(Into::into).collect()) },
);
arms.push(quote! { Self::Gt(v) => Filter::Gt(col, v.into()) });
arms.push(quote! { Self::Gte(v) => Filter::Gte(col, v.into()) });
arms.push(quote! { Self::Lt(v) => Filter::Lt(col, v.into()) });
arms.push(quote! { Self::Lte(v) => Filter::Lte(col, v.into()) });
}
TypeCategory::String => {
arms.push(
quote! { Self::In(vs) => Filter::In(col, vs.into_iter().map(Into::into).collect()) },
);
arms.push(
quote! { Self::NotIn(vs) => Filter::NotIn(col, vs.into_iter().map(Into::into).collect()) },
);
arms.push(
quote! { Self::Contains(v) => Filter::Contains(col, FilterValue::String(v)) },
);
arms.push(
quote! { Self::StartsWith(v) => Filter::StartsWith(col, FilterValue::String(v)) },
);
arms.push(
quote! { Self::EndsWith(v) => Filter::EndsWith(col, FilterValue::String(v)) },
);
}
TypeCategory::Boolean => {}
TypeCategory::Other => {
arms.push(
quote! { Self::In(vs) => Filter::In(col, vs.into_iter().map(Into::into).collect()) },
);
arms.push(
quote! { Self::NotIn(vs) => Filter::NotIn(col, vs.into_iter().map(Into::into).collect()) },
);
}
}
let mut ctors = vec![quote! {
pub fn equals(value: #ty) -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::Equals(value))
}
pub fn not(value: #ty) -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::Not(value))
}
}];
if is_optional {
ctors.push(quote! {
pub fn is_null() -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::IsNull)
}
pub fn is_not_null() -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::IsNotNull)
}
});
}
match category {
TypeCategory::Numeric => {
ctors.push(quote! {
pub fn in_(values: Vec<#ty>) -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::In(values))
}
pub fn not_in(values: Vec<#ty>) -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::NotIn(values))
}
pub fn gt(value: #ty) -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::Gt(value))
}
pub fn gte(value: #ty) -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::Gte(value))
}
pub fn lt(value: #ty) -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::Lt(value))
}
pub fn lte(value: #ty) -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::Lte(value))
}
});
}
TypeCategory::String => {
ctors.push(quote! {
pub fn in_(values: Vec<#ty>) -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::In(values))
}
pub fn not_in(values: Vec<#ty>) -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::NotIn(values))
}
pub fn contains(value: impl Into<String>) -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::Contains(value.into()))
}
pub fn starts_with(value: impl Into<String>) -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::StartsWith(value.into()))
}
pub fn ends_with(value: impl Into<String>) -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::EndsWith(value.into()))
}
});
}
TypeCategory::Boolean => {}
TypeCategory::Other => {
ctors.push(quote! {
pub fn in_(values: Vec<#ty>) -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::In(values))
}
pub fn not_in(values: Vec<#ty>) -> super::WhereParam {
super::WhereParam::#variant_name(WhereOp::NotIn(values))
}
});
}
}
quote! {
pub mod #field_name {
use super::*;
pub const COLUMN: &str = #column_name;
#[derive(Debug, Clone)]
pub enum WhereOp {
#(#variants,)*
}
impl WhereOp {
pub fn to_filter(self) -> ::prax_query::filter::Filter {
use ::prax_query::filter::{Filter, FilterValue};
use ::std::borrow::Cow;
let col: Cow<'static, str> = Cow::Borrowed(COLUMN);
match self {
#(#arms,)*
}
}
}
#(#ctors)*
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn test_parse_simple_model() {
let input: DeriveInput = parse_quote! {
#[prax(table = "users")]
struct User {
#[prax(id, auto)]
id: i32,
#[prax(unique)]
email: String,
name: Option<String>,
}
};
let result = derive_model_impl(&input);
assert!(result.is_ok(), "Failed: {:?}", result.err());
let code = result.unwrap().to_string();
assert!(code.contains("pub mod user"));
assert!(code.contains("TABLE_NAME"));
assert!(code.contains("users"));
}
#[test]
fn test_parse_model_without_id() {
let input: DeriveInput = parse_quote! {
struct NoId {
name: String,
}
};
let result = derive_model_impl(&input);
assert!(result.is_err());
}
#[test]
fn test_is_option_type() {
let ty: Type = parse_quote!(Option<String>);
assert!(is_option_type(&ty));
let ty: Type = parse_quote!(String);
assert!(!is_option_type(&ty));
}
#[test]
fn test_is_vec_type() {
let ty: Type = parse_quote!(Vec<i32>);
assert!(is_vec_type(&ty));
let ty: Type = parse_quote!(i32);
assert!(!is_vec_type(&ty));
}
}