use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use quote::quote;
use std::collections::HashMap;
use syn::{parse_macro_input, Attribute, Data, DeriveInput, Field, Fields, Meta};
enum ActionType {
Summary,
Queryable,
Action(String),
ActionById(String),
}
fn parse_action_attr(attr: &Attribute) -> Vec<ActionType> {
let mut types = Vec::new();
if let Meta::List(meta_list) = attr.meta.clone() {
if meta_list.path.is_ident("api_model") {
let mut selected_action = ActionType::Summary;
let mut selected_at = 0;
let mut i = 0;
for nested in meta_list.tokens.clone() {
if let proc_macro2::TokenTree::Ident(iden) = nested {
let id = iden.to_string();
match id.as_str() {
"summary" => {
types.push(ActionType::Summary);
}
"queryable" => {
types.push(ActionType::Queryable);
}
"action" => {
selected_at = i;
selected_action = ActionType::Action("".to_string());
}
"action_by_id" => {
selected_at = i;
selected_action = ActionType::ActionById("".to_string());
}
_ => {
if selected_at == (i - 2) {
match &selected_action {
ActionType::Action(_) => {
types.push(ActionType::Action(id.to_string()));
}
ActionType::ActionById(_) => {
types.push(ActionType::ActionById(id.to_string()));
}
_ => {}
}
} else {
panic!("Unexpected attribute key: {}", id);
}
}
}
}
i += 1;
}
}
}
types
}
pub fn api_model_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let input_cloned = item.clone();
let input = parse_macro_input!(item as DeriveInput);
let struct_name = &input.ident;
let fields = match &input.data {
Data::Struct(data_struct) => &data_struct.fields,
_ => panic!("api_mode can only be applied to structs"),
};
let attr_args = attr.to_string();
let mut iter_type = "CommonQueryResponse".to_string();
let mut base_endpoint = String::new();
for arg in attr_args.split(',') {
let parts: Vec<&str> = arg.split('=').collect();
if parts.len() == 2 {
let key = parts[0].trim();
let value = parts[1].trim().trim_matches('"');
match key {
"base" => base_endpoint = value.to_string(),
"iter_type" => iter_type = value.to_string(),
_ => panic!("Unexpected attribute key: {}", key),
}
}
}
let mut summary_fields = Vec::new();
let mut queryable_fields = Vec::new();
let mut action_names = HashMap::<String, Vec<Field>>::new();
let mut action_by_id_names = HashMap::<String, Vec<Field>>::new();
if let Fields::Named(named_fields) = fields {
for field in &named_fields.named {
for attr in &field.attrs {
for t in parse_action_attr(attr) {
match t {
ActionType::Summary => {
summary_fields.push(field.clone());
}
ActionType::Queryable => {
queryable_fields.push(field.clone());
}
ActionType::Action(action_name) => {
action_names
.entry(action_name)
.or_insert_with(Vec::new)
.push(field.clone());
}
ActionType::ActionById(action_name) => {
action_by_id_names
.entry(action_name)
.or_insert_with(Vec::new)
.push(field.clone());
}
}
}
}
}
}
let summary_struct = generate_summary_struct(&struct_name, &summary_fields);
let query_struct = generate_query_struct(&struct_name, &queryable_fields);
let action_struct = generate_action_struct(&struct_name, &action_names, "Action");
let action_by_id_struct =
generate_action_struct(&struct_name, &action_by_id_names, "ActionById");
let client_impl = generate_client_impl(struct_name, &base_endpoint, &iter_type);
let input = parse_macro_input!(input_cloned as syn::ItemStruct);
let stripped_input = strip_struct_attributes(&input);
let output = quote! {
#stripped_input
#action_struct
#action_by_id_struct
#summary_struct
#query_struct
#client_impl
};
output.into()
}
fn generate_action_struct(
struct_name: &syn::Ident,
actions: &HashMap<String, Vec<Field>>,
action_type: &str,
) -> proc_macro2::TokenStream {
let action_name = syn::Ident::new(
&format!("{}{}Request", struct_name, action_type),
struct_name.span(),
);
let mut action_fields = vec![];
let mut action_requests = vec![];
for (k, v) in actions.iter() {
let act = syn::Ident::new(&k.to_case(Case::Pascal), struct_name.span());
let request_struct_name = syn::Ident::new(
&format!("{}{}Request", struct_name, act),
struct_name.span(),
);
let fields = v.iter().map(|field| {
let field_name = &field.ident;
let field_type = &field.ty;
quote! { pub #field_name: #field_type, }
});
action_fields.push(quote! {
#act(#request_struct_name),
});
action_requests.push(quote! {
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default, Eq, PartialEq)]
#[cfg_attr(feature = "server", derive(schemars::JsonSchema, aide::OperationIo))]
pub struct #request_struct_name {
#(#fields)*
}
});
}
quote! {
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Eq, PartialEq)]
#[cfg_attr(feature = "server", derive(schemars::JsonSchema, aide::OperationIo))]
pub enum #action_name {
#(#action_fields)*
}
#(#action_requests)*
}
}
fn generate_summary_struct(
struct_name: &syn::Ident,
summary_fields: &[syn::Field],
) -> proc_macro2::TokenStream {
let summary_name = syn::Ident::new(&format!("{}Summary", struct_name), struct_name.span());
let fields = summary_fields.iter().map(|field| {
let field_name = &field.ident;
let field_type = &field.ty;
quote! { pub #field_name: #field_type, }
});
quote! {
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default, Eq, PartialEq)]
#[cfg_attr(feature = "server", derive(schemars::JsonSchema, aide::OperationIo))]
pub struct #summary_name {
#(#fields)*
}
}
}
fn generate_query_struct(
struct_name: &syn::Ident,
queryable_fields: &[syn::Field],
) -> proc_macro2::TokenStream {
let query_fields = queryable_fields.iter().map(|field| {
let field_name = &field.ident;
let field_type = &field.ty;
quote! { pub #field_name: Option<#field_type>, }
});
let query_name = syn::Ident::new(&format!("{}Query", struct_name), struct_name.span());
quote! {
#[derive(Debug, Clone, Serialize, Deserialize, Default, Eq, PartialEq, by_macros::QueryDisplay)]
#[cfg_attr(feature = "server", derive(schemars::JsonSchema, aide::OperationIo))]
pub struct #query_name {
pub size: usize,
pub bookmark: Option<String>,
#(#query_fields)*
}
}
}
fn generate_client_impl(
struct_name: &syn::Ident,
base_endpoint: &str,
iter_type: &str,
) -> proc_macro2::TokenStream {
let client_name = syn::Ident::new(&format!("{}Client", struct_name), struct_name.span());
let query_name = syn::Ident::new(&format!("{}Query", struct_name), struct_name.span());
let summary_name = syn::Ident::new(&format!("{}Summary", struct_name), struct_name.span());
let base_endpoint_lit = syn::LitStr::new(base_endpoint, struct_name.span());
let iter_type_with_summary = format!("{}<{}>", iter_type, summary_name);
let iter_type_tokens: proc_macro2::TokenStream = iter_type_with_summary.parse().unwrap();
quote! {
impl #struct_name {
pub fn get_client(endpoint: &str) -> #client_name {
#client_name { endpoint: endpoint.to_string() }
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, Eq, PartialEq)]
pub struct #client_name {
pub endpoint: String,
}
impl #client_name {
pub async fn query(
&self,
params: #query_name,
) -> crate::Result<#iter_type_tokens> {
let endpoint = format!("{}{}", self.endpoint, #base_endpoint_lit);
let query = format!("{}?{}", endpoint, params);
rest_api::get(&query).await
}
pub async fn get(&self, id: &str) -> crate::Result<#struct_name> {
let endpoint = format!("{}{}/{}", self.endpoint, #base_endpoint_lit, id);
rest_api::get(&endpoint).await
}
}
}
}
fn strip_struct_attributes(input: &syn::ItemStruct) -> syn::ItemStruct {
let mut stripped_struct = input.clone();
stripped_struct.fields = strip_api_model_attributes(&input.fields);
stripped_struct
}
fn strip_api_model_attributes(fields: &syn::Fields) -> syn::Fields {
match fields {
syn::Fields::Named(named_fields) => {
let stripped_fields = named_fields
.named
.iter()
.map(|field| {
let mut field = field.clone();
field.attrs = field
.attrs
.iter()
.filter(|attr| !attr.path().is_ident("api_model"))
.cloned()
.collect();
field
})
.collect();
syn::Fields::Named(syn::FieldsNamed {
brace_token: named_fields.brace_token,
named: stripped_fields,
})
}
syn::Fields::Unnamed(_) | syn::Fields::Unit => fields.clone(),
}
}