use inflector::Inflector;
use proc_macro::TokenStream;
use proc_macro2::{Ident, TokenStream as TokenStream2};
use quote::{format_ident, quote};
use syn::{Attribute, Field, Fields, ItemStruct, Visibility};
use crate::{DEFAULT_CREATE_TIME_NAME, DEFAULT_ID_NAME, DEFAULT_UPDATE_TIME_NAME};
pub fn impl_sql_helper(ast: &ItemStruct) -> TokenStream {
let mut filed_vec = Vec::new();
ast.fields.iter().for_each(|field| {
if is_vis_public_crate(&field.vis)
&& is_base_type(field)
&& !field_attr_exists(field, DEFAULT_ID_NAME)
{
if let Some(ident) = &field.ident {
if ident == DEFAULT_ID_NAME {
return;
}
}
filed_vec.push(field);
}
});
let table_filed_name_vec = filed_vec
.iter()
.map(|field| get_table_field_name(field))
.collect::<Vec<_>>();
let struct_name = &ast.ident;
let self_ident = format_ident!("self");
let struct_var_name = format_ident!("{}", struct_name.to_string().to_snake_case());
let id = get_ident(&ast.fields, DEFAULT_ID_NAME);
let create_time = get_ident(&ast.fields, DEFAULT_CREATE_TIME_NAME);
let update_time = get_ident(&ast.fields, DEFAULT_UPDATE_TIME_NAME);
let pool = quote!(&*db::POOL);
let query = quote!(sqlx::query);
let query_as = quote!(sqlx::query_as::<_, Self>);
let select_base_sql = format!(
"SELECT {}, {} FROM {}",
id,
table_filed_name_vec.join(", ").trim_end(),
struct_var_name
);
let count_base_sql = format!("SELECT count(1) FROM {}", struct_var_name);
let find_sql = format!("{} WHERE {} = ?", select_base_sql, id);
let find_fn = quote!(
pub async fn find(#id: i32) -> Result<Self, sqlx::Error> {
#query_as(#find_sql)
.bind(#id)
.fetch_one(#pool)
.await
}
);
let list_fn = quote!(
pub async fn list() -> Result<Vec<Self>, sqlx::Error> {
#query_as(#select_base_sql)
.fetch_all(#pool)
.await
}
);
let delete_sql = format!("DELETE FROM {} WHERE {} = ?", struct_var_name, id);
let delete_fn = quote!(
pub async fn delete(&self) -> Result<bool, sqlx::Error> {
#query(#delete_sql)
.bind(self.#id)
.execute(#pool)
.await
.map(|f| f.rows_affected() > 0)
}
);
let insert_sql = format!(
"INSERT INTO {} ({}) VALUES({})",
struct_var_name,
table_filed_name_vec.join(", ").trim_end(),
"?, "
.repeat(filed_vec.len())
.trim_end()
.trim_end_matches(',')
);
let insert_bind_quote_vec = fileds_to_bind_quote(&self_ident, &filed_vec);
let insert_auto_time_quote = get_auto_time_quote(&self_ident, Some(&create_time), &update_time);
let insert_fn = quote!(
pub async fn insert(&self) -> Result<Self, sqlx::Error> {
let sql = #insert_sql;
let last_id = #query(sql)
#(#insert_bind_quote_vec)*
.execute(#pool).await?.last_insert_id();
Self::find(last_id as i32).await
}
pub async fn insert_auto_time(&mut self) -> Result<Self, sqlx::Error> {
#insert_auto_time_quote
self.insert().await
}
);
let update_sql = format!(
"UPDATE {} SET {} WHERE {} = ?",
struct_var_name,
table_filed_name_vec
.iter()
.map(|filed_str| format!("{} = ?", filed_str))
.collect::<Vec<_>>()
.join(", "),
id
);
let update_bind_quote_vec = fileds_to_bind_quote(&self_ident, &filed_vec);
let update_auto_time_quote = get_auto_time_quote(&self_ident, None, &update_time);
let update_fn = quote!(
pub async fn update(&self) -> Result<bool, sqlx::Error> {
let sql = #update_sql;
#query(sql)
#(#update_bind_quote_vec)*
.bind(self.#id)
.execute(#pool).await.map(|f|f.rows_affected() > 0)
}
pub async fn update_auto_time(&mut self) -> Result<bool, sqlx::Error> {
#update_auto_time_quote
self.update().await
}
);
let save_or_update_fn = quote!(
pub async fn save_or_update(&self) -> Result<bool, sqlx::Error> {
match self.#id > 0 {
true => self.update().await,
false => self.insert().await.map(|_| true),
}
}
pub async fn save_or_update_auto_time(&mut self) -> Result<bool, sqlx::Error> {
match self.#id > 0 {
true => self.update_auto_time().await,
false => self.insert_auto_time().await.map(|_| true),
}
}
);
let mut new_auto_filed_vec = vec![];
for field in &filed_vec {
if !field_attr_exists(field, DEFAULT_CREATE_TIME_NAME)
&& !field_attr_exists(field, DEFAULT_UPDATE_TIME_NAME)
{
if let Some(ident) = &field.ident {
if ident == DEFAULT_CREATE_TIME_NAME || ident == DEFAULT_UPDATE_TIME_NAME {
continue;
}
}
new_auto_filed_vec.push(field);
}
}
let new_filed_vec = new_auto_filed_vec
.iter()
.map(|f| {
let ident = f.ident.as_ref().unwrap();
let ty = &f.ty;
quote!(#ident: #ty)
})
.collect::<Vec<_>>();
let new_self_filed_vec = new_auto_filed_vec
.iter()
.map(|f| {
let ident = f.ident.as_ref().unwrap();
quote!(#ident)
})
.collect::<Vec<_>>();
let new_fn = quote!(
pub fn new(#(#new_filed_vec),*,#create_time: chrono::NaiveDateTime, #update_time: chrono::NaiveDateTime) -> Self {
Self{
#id: 0,
#(#new_self_filed_vec),*,
#create_time,
#update_time
}
}
pub fn new_common(#(#new_filed_vec),*) -> Self {
Self::new(
#(#new_self_filed_vec),*,
chrono::Local::now().naive_local(),
chrono::Local::now().naive_local()
)
}
);
let base_page_select_sql = format!("{} WHERE {{}} LIMIT {{}}, {{}}", select_base_sql);
let base_page_fn = quote!(
pub async fn base_page(
page_index: i32,
page_size: i32,
where_sql: &str,
args: sqlx::mysql::MySqlArguments,
) -> Result<(Vec<Self>, i32, i32, i32), sqlx::Error> {
let mut index = page_index - 1;
if index < 0 {
index = 0;
}
let rows = page_size;
let (count,) = Self::base_count(where_sql, args.clone()).await?;
let arr = match count > 0 {
true => {
let sql = format!(
#base_page_select_sql,
where_sql,
index * rows,
rows
);
sqlx::query_as_with::<_, Self, sqlx::mysql::MySqlArguments>(&sql, args)
.fetch_all(#pool)
.await?
}
false => Vec::new(),
};
let total_page = (count as f32 / page_size as f32).ceil();
Ok((arr, count, index + 1, total_page as i32))
}
);
let base_count_sql = format!("{} WHERE {{}}", count_base_sql);
let base_count_fn = quote!(
pub async fn base_count(
where_sql: &str,
args: sqlx::mysql::MySqlArguments,
) -> Result<(i32,), sqlx::Error> {
let count_sql = format!(#base_count_sql, where_sql);
sqlx::query_as_with::<_, (i32,), sqlx::mysql::MySqlArguments>(
&count_sql,
args,
)
.fetch_one(#pool)
.await
}
);
let gen = quote!(
impl #struct_name {
#find_fn
#list_fn
#delete_fn
#insert_fn
#update_fn
#save_or_update_fn
#new_fn
#base_page_fn
#base_count_fn
}
);
gen.into()
}
fn fileds_to_bind_quote(struct_ident: &Ident, fields: &Vec<&Field>) -> Vec<TokenStream2> {
fields
.iter()
.map(|field| filed_to_bind_quote(&struct_ident, field))
.collect()
}
fn filed_to_bind_quote(struct_ident: &Ident, filed: &Field) -> TokenStream2 {
let syn::Type::Path(type_path) = &filed.ty else {
return quote!();
};
let Some(filed_var_name) = &filed.ident else {
return quote!();
};
let filed_qoute = if extract_type_from_option(&filed.ty).is_some() {
quote!(#struct_ident.#filed_var_name.as_ref())
} else {
match type_path.path.get_ident() {
Some(ident) => {
if ident == "String" || ident == "Decimal" {
quote!(&#struct_ident.#filed_var_name)
} else {
quote!(#struct_ident.#filed_var_name)
}
}
_ => quote!(),
}
};
quote!(
.bind(#filed_qoute)
)
}
fn is_vis_public_crate(vis: &Visibility) -> bool {
match vis {
Visibility::Public(_) | Visibility::Crate(_) => true,
_ => false,
}
}
fn is_base_type(field: &Field) -> bool {
match (&field.ty, &field.ident) {
(syn::Type::Path(_), Some(_)) => true,
_ => false,
}
}
fn field_attr_exists(field: &Field, attr_name: &str) -> bool {
get_field_attr(field, attr_name).is_some()
}
fn get_field_attr<'a>(field: &'a Field, attr_name: &str) -> Option<(&'a Field, &'a Attribute)> {
for attribute in field.attrs.iter() {
if attribute.path.is_ident(attr_name) {
return Some((field, attribute));
}
}
None
}
fn get_ident(fields: &Fields, ident_name: &str) -> Ident {
if let Some(field) = fields
.iter()
.find(|field| field_attr_exists(field, ident_name))
{
if let Some(ident) = &field.ident {
return ident.clone();
}
}
format_ident!("{}", ident_name)
}
fn get_table_field_name(field: &Field) -> String {
if let Some((_, attr)) = get_field_attr(field, "field_name") {
let lit = attr.parse_args::<syn::LitStr>();
if let Ok(s) = lit {
return s.value();
}
}
field.ident.as_ref().unwrap().to_string()
}
fn get_auto_time_quote(
struct_ident: &Ident,
create_time: Option<&Ident>,
update_time: &Ident,
) -> TokenStream2 {
let create_time_quote = if let Some(ct) = create_time {
get_update_time_quote(struct_ident, ct)
} else {
quote!()
};
let update_time_quote = get_update_time_quote(struct_ident, update_time);
quote!(
#create_time_quote
#update_time_quote
)
}
fn get_update_time_quote(struct_ident: &Ident, time_ident: &Ident) -> TokenStream2 {
quote!(
if #struct_ident.#time_ident == Default::default() {
#struct_ident.#time_ident = chrono::Local::now().naive_local();
}
)
}
fn extract_type_from_option(ty: &syn::Type) -> Option<&syn::Type> {
use syn::{GenericArgument, Path, PathArguments, PathSegment};
fn extract_type_path(ty: &syn::Type) -> Option<&Path> {
match *ty {
syn::Type::Path(ref typepath) if typepath.qself.is_none() => Some(&typepath.path),
_ => None,
}
}
fn extract_option_segment(path: &Path) -> Option<&PathSegment> {
let idents_of_path = path
.segments
.iter()
.into_iter()
.fold(String::new(), |mut acc, v| {
acc.push_str(&v.ident.to_string());
acc.push('|');
acc
});
vec!["Option|", "std|option|Option|", "core|option|Option|"]
.into_iter()
.find(|s| &idents_of_path == *s)
.and_then(|_| path.segments.last())
}
extract_type_path(ty)
.and_then(|path| extract_option_segment(path))
.and_then(|path_seg| {
let type_params = &path_seg.arguments;
match *type_params {
PathArguments::AngleBracketed(ref params) => params.args.first(),
_ => None,
}
})
.and_then(|generic_arg| match *generic_arg {
GenericArgument::Type(ref ty) => Some(ty),
_ => None,
})
}