extern crate proc_macro;
extern crate sqlx_plus_builder;
extern crate sqlx_plus_core;
use proc_macro::TokenStream;
use quote::quote;
use sqlx_plus_core::generate_create_table_sql;
use syn::{parse_macro_input, DeriveInput, Lit, Meta, MetaNameValue};
#[proc_macro_derive(SqlxPlus, attributes(table_name))]
pub fn derive_sqlx_crud(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let table_name = input
.attrs
.iter()
.find_map(|attr| {
if attr.path.is_ident("table_name") {
if let Meta::NameValue(MetaNameValue {
lit: Lit::Str(lit_str),
..
}) = attr.parse_meta().ok()?
{
Some(lit_str.value())
} else {
None
}
} else {
None
}
})
.expect("Expected a table_name attribute");
let fields: Vec<_> = if let syn::Data::Struct(data) = &input.data {
data.fields.iter().collect()
} else {
panic!("SqlxPlus can only be derived for structs");
};
let columns: Vec<_> = fields
.iter()
.map(|f| f.ident.as_ref().unwrap().to_string())
.collect();
let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
let columns_str = columns.join(", ");
let fields: Vec<syn::Field> = fields.into_iter().cloned().collect();
let create_table_sql = generate_create_table_sql(&fields, &table_name);
let expanded = quote! {
use sqlx_plus_core::{DbPool, SqlxBase, DbTransaction};
use sqlx_plus_builder::SqlBuilder;
impl #name {
pub async fn create_table_if_not_exists(pool: &DbPool) -> Result<(), sqlx::Error> {
pool.create_table_if_not_exists(#table_name, #create_table_sql).await
}
pub async fn save(pool: &DbPool, transaction: Option<&mut DbTransaction<'_>>, entity: &Self) -> Result<u64, sqlx::Error> {
let values_str = format!("{}",vec![#(format!("'{}'", entity.#field_names)),*].join(", "));
let sql = format!("INSERT INTO {} ({}) VALUES ({})", #table_name, #columns_str, values_str);
pool.insert(&sql, transaction).await
}
pub async fn saver(pool: &DbPool, transaction: Option<&mut DbTransaction<'_>>, entity: &Self) -> Result<Self, sqlx::Error> {
let values_str = format!("{}",vec![#(format!("'{}'", entity.#field_names)),*].join(", "));
let sql = format!("INSERT INTO {} ({}) VALUES ({}) RETURNING *", #table_name, #columns_str, values_str);
pool.insert_and_return_key(&sql, transaction).await
}
pub async fn get(pool: &DbPool, id: i64) -> Result<Option<Self>, sqlx::Error> {
let sql = format!("SELECT * FROM {} WHERE id = {}", #table_name, id);
pool.select_one(&sql).await
}
pub async fn update(pool: &DbPool, transaction: Option<&mut DbTransaction<'_>>, entity: &Self) -> Result<u64, sqlx::Error> {
let sql = format!("UPDATE {} SET ... WHERE id = {}", #table_name, entity.id);
pool.update(&sql, transaction).await
}
pub async fn delete(pool: &DbPool, transaction: Option<&mut DbTransaction<'_>>, id: i64) -> Result<u64, sqlx::Error> {
let sql = format!("DELETE FROM {} WHERE id = {}", #table_name, id);
pool.delete(&sql, transaction).await
}
pub async fn list(pool: &DbPool) -> Result<Vec<Self>, sqlx::Error> {
let sql = format!("SELECT * FROM {}", #table_name);
pool.select_all(&sql).await
}
pub async fn page(pool: &DbPool, limit: i64, offset: i64) -> Result<Vec<Self>, sqlx::Error> {
let sql = format!("SELECT * FROM {} LIMIT {} OFFSET {}", #table_name, limit, offset);
pool.select_page(&sql).await
}
pub async fn count(pool: &DbPool) -> Result<i64, sqlx::Error> {
let sql = format!("SELECT COUNT(*) FROM {}", #table_name);
pool.count(&sql).await
}
}
};
TokenStream::from(expanded)
}