use heck::ToUpperCamelCase;
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, quote_spanned};
use syn::{
Expr, Meta, Type, ext::IdentExt, punctuated::Punctuated, spanned::Spanned, token::Comma,
};
use super::from_query_result::{
DeriveFromQueryResult, FromQueryResultItem, ItemType as FqrItemType,
};
use super::into_active_model::{DeriveIntoActiveModel, IntoActiveModelField};
use super::util::GetMeta;
#[derive(Debug)]
enum Error {
InputNotStruct,
EntityNotSpecified,
NotSupportGeneric(Span),
OverlappingAttributes(Span),
Syn(syn::Error),
}
#[derive(Debug, PartialEq, Eq)]
enum ColumnAs {
Col {
col: Option<syn::Ident>,
field: syn::Ident,
},
Expr {
expr: syn::Expr,
field: syn::Ident,
},
Nested {
typ: Type,
field: syn::Ident,
alias: Option<String>,
},
Skip(syn::Ident),
}
struct DerivePartialModel {
entity: Option<syn::Type>,
active_model: Option<syn::Type>,
model_alias: Option<String>,
ident: syn::Ident,
fields: Vec<ColumnAs>,
from_query_result: bool,
into_active_model: bool,
}
impl DerivePartialModel {
fn new(input: syn::DeriveInput) -> Result<Self, Error> {
if !input.generics.params.is_empty() {
return Err(Error::NotSupportGeneric(input.generics.params.span()));
}
let fields = match input.data {
syn::Data::Struct(syn::DataStruct {
fields: syn::Fields::Named(syn::FieldsNamed { named, .. }),
..
}) => named,
_ => return Err(Error::InputNotStruct),
};
let mut entity = None;
let mut entity_string = String::new();
let mut active_model = None;
let mut model_alias = None;
let mut from_query_result = true;
let mut into_active_model = false;
for attr in input.attrs.iter() {
if !attr.path().is_ident("sea_orm") {
continue;
}
if let Ok(list) = attr.parse_args_with(Punctuated::<Meta, Comma>::parse_terminated) {
for meta in list {
if let Some(s) = meta.get_as_kv("entity") {
entity = Some(syn::parse_str::<syn::Type>(&s).map_err(Error::Syn)?);
entity_string = s;
} else if let Some(s) = meta.get_as_kv("alias") {
model_alias = Some(s);
} else if let Some(s) = meta.get_as_kv("from_query_result") {
if s == "false" {
from_query_result = false;
}
} else if meta.exists("into_active_model") {
into_active_model = true;
}
}
}
}
if into_active_model {
active_model = Some(
syn::parse_str::<syn::Type>(&format!(
"<{entity_string} as EntityTrait>::ActiveModel"
))
.map_err(Error::Syn)?,
);
}
let mut column_as_list = Vec::with_capacity(fields.len());
for field in fields {
let field_span = field.span();
let mut from_col = None;
let mut from_expr = None;
let mut nested = false;
let mut nested_alias = None;
let mut skip = false;
for attr in field.attrs.iter() {
if !attr.path().is_ident("sea_orm") {
continue;
}
if let Ok(list) = attr.parse_args_with(Punctuated::<Meta, Comma>::parse_terminated)
{
for meta in list.iter() {
if meta.exists("skip") {
skip = true;
} else if meta.exists("nested") {
nested = true;
} else if let Some(s) = meta.get_as_kv("from_col") {
from_col = Some(format_ident!("{}", s.to_upper_camel_case()));
} else if let Some(s) = meta.get_as_kv("from_expr") {
from_expr = Some(syn::parse_str::<Expr>(&s).map_err(Error::Syn)?);
} else if let Some(s) = meta.get_as_kv("alias") {
nested_alias = Some(s);
}
}
}
}
let field_name = field.ident.unwrap();
let col_as = match (from_col, from_expr, nested) {
(Some(col), None, false) => {
if entity.is_none() {
return Err(Error::EntityNotSpecified);
}
ColumnAs::Col {
col: Some(col),
field: field_name,
}
}
(None, Some(expr), false) => ColumnAs::Expr {
expr,
field: field_name,
},
(None, None, true) => ColumnAs::Nested {
typ: field.ty,
field: field_name,
alias: nested_alias,
},
(None, None, false) => {
if entity.is_none() {
return Err(Error::EntityNotSpecified);
}
if skip {
ColumnAs::Skip(field_name)
} else {
ColumnAs::Col {
col: None,
field: field_name,
}
}
}
(_, _, _) => return Err(Error::OverlappingAttributes(field_span)),
};
column_as_list.push(col_as);
}
Ok(Self {
entity,
active_model,
model_alias,
ident: input.ident,
fields: column_as_list,
from_query_result,
into_active_model,
})
}
fn expand(&self) -> syn::Result<TokenStream> {
let impl_partial_model = self.impl_partial_model();
let impl_from_query_result = if self.from_query_result {
DeriveFromQueryResult {
ident: self.ident.clone(),
generics: Default::default(),
fields: self
.fields
.iter()
.map(|col_as| FromQueryResultItem {
typ: match col_as {
ColumnAs::Nested { .. } => FqrItemType::Nested,
ColumnAs::Skip(_) => FqrItemType::Skip,
_ => FqrItemType::Flat,
},
ident: match col_as {
ColumnAs::Col { field, .. } => field,
ColumnAs::Expr { field, .. } => field,
ColumnAs::Nested { field, .. } => field,
ColumnAs::Skip(field) => field,
}
.to_owned(),
alias: None,
})
.collect(),
}
.impl_from_query_result(true)
} else {
quote!()
};
let impl_into_active_model = if self.into_active_model {
DeriveIntoActiveModel {
ident: self.ident.clone(),
active_model: self.active_model.clone(),
fields: self
.fields
.iter()
.filter_map(|col_as| {
match col_as {
ColumnAs::Col { field, .. } => Some(field),
ColumnAs::Expr { field, .. } => Some(field),
ColumnAs::Nested { .. } => None,
ColumnAs::Skip(_) => None,
}
.map(|f| IntoActiveModelField::Normal(f.clone()))
})
.collect(),
set_fields: Vec::new(),
exhaustive: false,
}
.impl_into_active_model()
} else {
quote!()
};
Ok(quote! {
#impl_partial_model
#impl_from_query_result
#impl_into_active_model
})
}
fn impl_partial_model(&self) -> TokenStream {
let select_ident = format_ident!("select");
let DerivePartialModel {
entity,
model_alias,
ident,
fields,
..
} = self;
let select_col_code_gen = fields.iter().map(|col_as| match col_as {
ColumnAs::Col { col, field } => {
let field = field.unraw().to_string();
let entity = entity.as_ref().unwrap();
let variant_name = if let Some(col) = col {
col
} else {
&format_ident!("{}", field.to_upper_camel_case())
};
let column = quote! {
<#entity as sea_orm::EntityTrait>::Column::#variant_name
};
let non_nested = match model_alias {
Some(model_alias) => quote! {
let col_expr = sea_orm::sea_query::Expr::col((#model_alias, #column));
let casted = sea_orm::ColumnTrait::select_as(&#column, col_expr);
sea_orm::QuerySelect::column_as(#select_ident, casted, col_alias)
},
None => quote! {
sea_orm::QuerySelect::column_as(#select_ident, #column, col_alias)
},
};
quote! {
let #select_ident = {
let col_alias = pre.map_or(#field.to_string(), |pre| format!("{pre}{}", #field));
if let Some(nested_alias) = nested_alias {
let alias = sea_orm::sea_query::SeaRc::new(nested_alias);
let col_expr = sea_orm::sea_query::Expr::col(
(alias, #column)
);
let casted = sea_orm::ColumnTrait::select_as(&#column, col_expr);
sea_orm::QuerySelect::column_as(#select_ident, casted, col_alias)
} else {
#non_nested
}
};
}
}
ColumnAs::Expr { expr, field } => {
let field = field.unraw().to_string();
quote!(let #select_ident =
if let Some(prefix) = pre {
let ident = format!("{prefix}{}", #field);
sea_orm::QuerySelect::column_as(#select_ident, #expr, ident)
} else {
sea_orm::QuerySelect::column_as(#select_ident, #expr, #field)
};
)
}
ColumnAs::Nested { typ, field, alias } => {
let field = field.unraw().to_string();
let alias_ref: Option<&str> = alias.as_deref();
let alias_arg = match alias_ref {
Some(s) => quote! { Some(#s) },
None => quote! { None },
};
quote!(let #select_ident =
<#typ as sea_orm::PartialModelTrait>::select_cols_nested(#select_ident,
Some(&if let Some(prefix) = pre {
format!("{prefix}{}_", #field)
} else {
format!("{}_", #field)
}
),
#alias_arg
);
)
}
ColumnAs::Skip(_) => quote!(),
});
quote! {
#[automatically_derived]
impl sea_orm::PartialModelTrait for #ident {
fn select_cols_nested<S: sea_orm::QuerySelect>(#select_ident: S, pre: Option<&str>, nested_alias: Option<&'static str>) -> S {
#(#select_col_code_gen)*
#select_ident
}
}
}
}
}
pub fn expand_derive_partial_model(input: syn::DeriveInput) -> syn::Result<TokenStream> {
let ident_span = input.ident.span();
match DerivePartialModel::new(input) {
Ok(partial_model) => partial_model.expand(),
Err(Error::NotSupportGeneric(span)) => Ok(quote_spanned! {
span => compile_error!("you can only derive `DerivePartialModel` on concrete struct");
}),
Err(Error::OverlappingAttributes(span)) => Ok(quote_spanned! {
span => compile_error!("you can only use one of `from_col`, `from_expr`, `nested`");
}),
Err(Error::EntityNotSpecified) => Ok(quote_spanned! {
ident_span => compile_error!("you need specific which entity you are using")
}),
Err(Error::InputNotStruct) => Ok(quote_spanned! {
ident_span => compile_error!("you can only derive `DerivePartialModel` on named struct");
}),
Err(Error::Syn(err)) => Err(err),
}
}
#[cfg(test)]
mod test {
use quote::format_ident;
use syn::{DeriveInput, Type, parse_str};
use crate::derives::partial_model::ColumnAs;
use super::DerivePartialModel;
type StdResult<T> = Result<T, Box<dyn std::error::Error>>;
const CODE_SNIPPET_1: &str = r#"
#[sea_orm(entity = "Entity")]
struct PartialModel {
default_field: i32,
#[sea_orm(from_col = "bar")]
alias_field: i32,
#[sea_orm(from_expr = "Expr::val(1).add(1)")]
expr_field : i32
}
"#;
#[test]
fn test_load_macro_input_1() -> StdResult<()> {
let input = parse_str::<DeriveInput>(CODE_SNIPPET_1)?;
let middle = DerivePartialModel::new(input).unwrap();
assert_eq!(middle.entity, Some(parse_str::<Type>("Entity").unwrap()));
assert_eq!(middle.ident, format_ident!("PartialModel"));
assert_eq!(middle.fields.len(), 3);
assert_eq!(
middle.fields[0],
ColumnAs::Col {
col: None,
field: format_ident!("default_field")
}
);
assert_eq!(
middle.fields[1],
ColumnAs::Col {
col: Some(format_ident!("Bar")),
field: format_ident!("alias_field"),
},
);
assert_eq!(
middle.fields[2],
ColumnAs::Expr {
expr: syn::parse_str("Expr::val(1).add(1)").unwrap(),
field: format_ident!("expr_field"),
}
);
assert_eq!(middle.from_query_result, true);
Ok(())
}
const CODE_SNIPPET_2: &str = r#"
#[sea_orm(entity = "MyEntity", from_query_result = "false")]
struct PartialModel {
default_field: i32,
}
"#;
#[test]
fn test_load_macro_input_2() -> StdResult<()> {
let input = parse_str::<DeriveInput>(CODE_SNIPPET_2)?;
let middle = DerivePartialModel::new(input).unwrap();
assert_eq!(middle.entity, Some(parse_str::<Type>("MyEntity").unwrap()));
assert_eq!(middle.ident, format_ident!("PartialModel"));
assert_eq!(middle.fields.len(), 1);
assert_eq!(
middle.fields[0],
ColumnAs::Col {
col: None,
field: format_ident!("default_field")
}
);
assert_eq!(middle.from_query_result, false);
Ok(())
}
}