use convert_case::{Case, Casing};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{Data, DeriveInput, Fields, Ident, LitStr, Type};
pub fn derive_model_impl(input: &DeriveInput) -> Result<TokenStream, syn::Error> {
let name = &input.ident;
let module_name = format_ident!("{}", name.to_string().to_case(Case::Snake));
let fields = match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => &fields.named,
_ => {
return Err(syn::Error::new_spanned(
input,
"Model derive only supports structs with named fields",
));
}
},
_ => {
return Err(syn::Error::new_spanned(
input,
"Model derive only supports structs",
));
}
};
let struct_attrs = parse_struct_attrs(input)?;
let table_name = struct_attrs
.table_name
.unwrap_or_else(|| name.to_string().to_case(Case::Snake));
let field_infos: Vec<FieldInfo> = fields.iter().map(parse_field).collect::<Result<_, _>>()?;
let pk_fields: Vec<_> = field_infos
.iter()
.filter(|f| f.is_id)
.map(|f| f.column_name.as_str())
.collect();
if pk_fields.is_empty() {
return Err(syn::Error::new_spanned(
input,
"Model must have at least one field marked with #[prax(id)]",
));
}
let field_modules: Vec<_> = field_infos
.iter()
.map(generate_field_module_from_derive)
.collect();
let where_variants: Vec<_> = field_infos
.iter()
.map(|f| {
let variant_name = format_ident!("{}", f.name.to_string().to_case(Case::Pascal));
let field_mod = &f.name;
quote! { #variant_name(#field_mod::WhereOp) }
})
.collect();
let select_variants: Vec<_> = field_infos
.iter()
.map(|f| {
let variant_name = format_ident!("{}", f.name.to_string().to_case(Case::Pascal));
quote! { #variant_name }
})
.collect();
let order_variants: Vec<_> = field_infos
.iter()
.filter(|f| !f.is_list)
.map(|f| {
let variant_name = format_ident!("{}", f.name.to_string().to_case(Case::Pascal));
quote! { #variant_name(super::_prax_prelude::SortOrder) }
})
.collect();
Ok(quote! {
pub mod #module_name {
use super::*;
pub const TABLE_NAME: &str = #table_name;
pub const PRIMARY_KEY: &[&str] = &[#(#pk_fields),*];
impl super::_prax_prelude::PraxModel for #name {
const TABLE_NAME: &'static str = TABLE_NAME;
const PRIMARY_KEY: &'static [&'static str] = PRIMARY_KEY;
}
#(#field_modules)*
#[derive(Debug, Clone)]
pub enum WhereParam {
#(#where_variants,)*
And(Vec<WhereParam>),
Or(Vec<WhereParam>),
Not(Box<WhereParam>),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SelectParam {
#(#select_variants,)*
}
#[derive(Debug, Clone, Copy)]
pub enum OrderByParam {
#(#order_variants,)*
}
#[derive(Debug, Default)]
pub struct Query {
pub select: Vec<SelectParam>,
pub where_: Vec<WhereParam>,
pub order_by: Vec<OrderByParam>,
pub skip: Option<usize>,
pub take: Option<usize>,
}
impl Query {
pub fn new() -> Self {
Self::default()
}
pub fn r#where(mut self, param: WhereParam) -> Self {
self.where_.push(param);
self
}
pub fn order_by(mut self, param: OrderByParam) -> Self {
self.order_by.push(param);
self
}
pub fn skip(mut self, n: usize) -> Self {
self.skip = Some(n);
self
}
pub fn take(mut self, n: usize) -> Self {
self.take = Some(n);
self
}
}
pub struct Actions;
impl Actions {
pub fn find_many() -> Query {
Query::new()
}
pub fn find_unique() -> Query {
Query::new().take(1)
}
pub fn find_first() -> Query {
Query::new().take(1)
}
}
}
})
}
#[derive(Debug, Default)]
struct StructAttrs {
table_name: Option<String>,
schema_name: Option<String>,
}
fn parse_struct_attrs(input: &DeriveInput) -> Result<StructAttrs, syn::Error> {
let mut attrs = StructAttrs::default();
for attr in &input.attrs {
if !attr.path().is_ident("prax") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("table") {
let value: LitStr = meta.value()?.parse()?;
attrs.table_name = Some(value.value());
} else if meta.path.is_ident("schema") {
let value: LitStr = meta.value()?.parse()?;
attrs.schema_name = Some(value.value());
}
Ok(())
})?;
}
Ok(attrs)
}
#[derive(Debug)]
#[allow(dead_code)]
struct FieldInfo {
name: Ident,
ty: Type,
column_name: String,
is_id: bool,
is_auto: bool,
is_unique: bool,
is_optional: bool,
is_list: bool,
}
fn parse_field(field: &syn::Field) -> Result<FieldInfo, syn::Error> {
let name = field
.ident
.clone()
.ok_or_else(|| syn::Error::new_spanned(field, "Fields must be named"))?;
let ty = field.ty.clone();
let mut column_name = name.to_string().to_case(Case::Snake);
let mut is_id = false;
let mut is_auto = false;
let mut is_unique = false;
let is_optional = is_option_type(&ty);
let is_list = is_vec_type(&ty);
for attr in &field.attrs {
if !attr.path().is_ident("prax") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("id") {
is_id = true;
} else if meta.path.is_ident("auto") {
is_auto = true;
} else if meta.path.is_ident("unique") {
is_unique = true;
} else if meta.path.is_ident("column") {
let value: LitStr = meta.value()?.parse()?;
column_name = value.value();
}
Ok(())
})?;
}
Ok(FieldInfo {
name,
ty,
column_name,
is_id,
is_auto,
is_unique,
is_optional,
is_list,
})
}
fn is_option_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.first() {
return segment.ident == "Option";
}
}
false
}
fn is_vec_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.first() {
return segment.ident == "Vec";
}
}
false
}
fn generate_field_module_from_derive(field: &FieldInfo) -> TokenStream {
let field_name = &field.name;
let column_name = &field.column_name;
let ty = &field.ty;
let where_ops = quote! {
#[derive(Debug, Clone)]
pub enum WhereOp {
Equals(#ty),
Not(#ty),
IsNull,
IsNotNull,
}
pub fn equals(value: #ty) -> super::WhereParam {
super::WhereParam::#field_name(WhereOp::Equals(value))
}
pub fn not(value: #ty) -> super::WhereParam {
super::WhereParam::#field_name(WhereOp::Not(value))
}
};
quote! {
pub mod #field_name {
use super::*;
pub const COLUMN: &str = #column_name;
#where_ops
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn test_parse_simple_model() {
let input: DeriveInput = parse_quote! {
#[prax(table = "users")]
struct User {
#[prax(id, auto)]
id: i32,
#[prax(unique)]
email: String,
name: Option<String>,
}
};
let result = derive_model_impl(&input);
assert!(result.is_ok(), "Failed: {:?}", result.err());
let code = result.unwrap().to_string();
assert!(code.contains("pub mod user"));
assert!(code.contains("TABLE_NAME"));
assert!(code.contains("users"));
}
#[test]
fn test_parse_model_without_id() {
let input: DeriveInput = parse_quote! {
struct NoId {
name: String,
}
};
let result = derive_model_impl(&input);
assert!(result.is_err());
}
#[test]
fn test_is_option_type() {
let ty: Type = parse_quote!(Option<String>);
assert!(is_option_type(&ty));
let ty: Type = parse_quote!(String);
assert!(!is_option_type(&ty));
}
#[test]
fn test_is_vec_type() {
let ty: Type = parse_quote!(Vec<i32>);
assert!(is_vec_type(&ty));
let ty: Type = parse_quote!(i32);
assert!(!is_vec_type(&ty));
}
}