use quote::quote;
use syn::__private::TokenStream;
use syn::{Data, DeriveInput, Fields, Meta, parse_macro_input};
#[proc_macro_derive(Csv, attributes(migrations))]
pub fn csv_macro(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let struct_name = input.ident;
let expanded = quote! {
impl #struct_name {
pub async fn save_csv(
records: Vec<Self>,
file_path: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
Self: serde::Serialize,
{
let mut writer = csv::Writer::from_path(file_path)?;
for record in records {
writer.serialize(record)?;
}
writer.flush()?;
Ok(())
}
pub async fn load_csv(
file_path: &str,
) -> Result<Vec<Self>, Box<dyn std::error::Error + Send + Sync>>
where
Self: for<'de> serde::Deserialize<'de>,
{
let mut reader = csv::Reader::from_path(file_path)?;
let mut records = Vec::new();
for result in reader.deserialize() {
let record: Self = result?;
records.push(record);
}
Ok(records)
}
pub async fn save_field<F, T>(
records: Vec<Self>,
file_path: &str,
field_extractor: F,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
where
F: Fn(&Self) -> T,
T: serde::Serialize,
{
let mut writer = csv::Writer::from_path(file_path)?;
for record in records {
let field_value = field_extractor(&record);
writer.serialize(field_value)?;
}
writer.flush()?;
Ok(())
}
}
};
TokenStream::from(expanded)
}
#[proc_macro_derive(Persistency, attributes(migrations))]
pub fn persistency_macro(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let struct_name = input.ident;
let mut schema = String::new();
let mut table = String::new();
for attr in input.attrs {
if attr.path().is_ident("migrations") {
match attr.meta {
Meta::List(list) => {
let nested = list
.parse_args_with(
syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
)
.unwrap_or_else(|_| syn::punctuated::Punctuated::new());
for meta in nested {
match meta {
Meta::NameValue(nv) if nv.path.is_ident("schema") => {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}) = nv.value
{
schema = lit_str.value();
}
}
Meta::NameValue(nv) if nv.path.is_ident("table") => {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}) = nv.value
{
table = lit_str.value();
}
}
_ => {}
}
}
}
_ => {}
}
}
}
if schema.is_empty() || table.is_empty() {
return TokenStream::from(quote! {
compile_error!("Both `schema` and `table` must be provided as attributes for #[migrations(schema = \"...\", table = \"...\")].");
});
}
let expanded = quote! {
impl #struct_name {
pub async fn set_persistent(
pool: &bb8_postgres::bb8::Pool<bb8_postgres::PostgresConnectionManager<bb8_postgres::tokio_postgres::NoTls>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let client = pool.get().await?;
let query = format!(
r#"ALTER TABLE {}."{}" SET LOGGED;"#, #schema, #table
);
client.execute(&query, &[]).await?;
Ok(())
}
pub async fn set_unlogged(
pool: &bb8_postgres::bb8::Pool<bb8_postgres::PostgresConnectionManager<bb8_postgres::tokio_postgres::NoTls>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let client = pool.get().await?;
let query = format!(
r#"ALTER TABLE {}."{}" SET UNLOGGED;"#, #schema, #table
);
client.execute(&query, &[]).await?;
Ok(())
}
}
};
TokenStream::from(expanded)
}
#[proc_macro_derive(Vacuum, attributes(migrations))]
pub fn vacuum_macro(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let struct_name = input.ident;
let mut schema = String::new();
let mut table = String::new();
for attr in input.attrs {
if attr.path().is_ident("migrations") {
match attr.meta {
Meta::List(list) => {
let nested = list
.parse_args_with(
syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
)
.unwrap_or_else(|_| syn::punctuated::Punctuated::new());
for meta in nested {
match meta {
Meta::NameValue(nv) if nv.path.is_ident("schema") => {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}) = nv.value
{
schema = lit_str.value();
}
}
Meta::NameValue(nv) if nv.path.is_ident("table") => {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}) = nv.value
{
table = lit_str.value();
}
}
_ => {}
}
}
}
_ => {}
}
}
}
if schema.is_empty() || table.is_empty() {
return TokenStream::from(quote! {
compile_error!("Both `schema` and `table` must be provided as attributes for #[migrations(schema = \"...\", table = \"...\")].");
});
}
let expanded = quote! {
impl #struct_name {
pub async fn vacuum(
pool: &bb8_postgres::bb8::Pool<bb8_postgres::PostgresConnectionManager<bb8_postgres::tokio_postgres::NoTls>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let client = pool.get().await?;
let query = format!(r#"VACUUM FULL {}."{}";"#, #schema, #table);
client.execute(&query, &[]).await?;
Ok(())
}
}
};
TokenStream::from(expanded)
}
#[proc_macro_derive(Reindex, attributes(migrations))]
pub fn reindex_macro(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let struct_name = input.ident;
let mut schema = String::new();
let mut table = String::new();
for attr in input.attrs {
if attr.path().is_ident("migrations") {
match attr.meta {
Meta::List(list) => {
let nested = list
.parse_args_with(
syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
)
.unwrap_or_else(|_| syn::punctuated::Punctuated::new());
for meta in nested {
match meta {
Meta::NameValue(nv) if nv.path.is_ident("schema") => {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}) = nv.value
{
schema = lit_str.value();
}
}
Meta::NameValue(nv) if nv.path.is_ident("table") => {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}) = nv.value
{
table = lit_str.value();
}
}
_ => {}
}
}
}
_ => {}
}
}
}
if schema.is_empty() || table.is_empty() {
return TokenStream::from(quote! {
compile_error!("Both `schema` and `table` must be provided as attributes for #[migrations(schema = \"...\", table = \"...\")].");
});
}
let expanded = quote! {
impl #struct_name {
pub async fn reindex(
pool: &bb8_postgres::bb8::Pool<bb8_postgres::PostgresConnectionManager<bb8_postgres::tokio_postgres::NoTls>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let client = pool.get().await?;
let query = format!(r#"REINDEX TABLE {}."{}";"#, #schema, #table);
client.execute(&query, &[]).await?;
Ok(())
}
}
};
TokenStream::from(expanded)
}
#[proc_macro_derive(Database, attributes(migrations))]
pub fn database_macro(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let struct_name = input.ident;
let mut schema = String::new();
let mut table = String::new();
for attr in input.attrs.iter() {
if attr.path().is_ident("migrations") {
match attr.meta {
Meta::List(ref list) => {
let nested = list
.parse_args_with(
syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
)
.unwrap_or_else(|_| syn::punctuated::Punctuated::new());
for meta in nested {
match meta {
Meta::NameValue(ref nv) if nv.path.is_ident("schema") => {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(ref lit_str),
..
}) = nv.value
{
schema = lit_str.value();
}
}
Meta::NameValue(ref nv) if nv.path.is_ident("table") => {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(ref lit_str),
..
}) = nv.value
{
table = lit_str.value();
}
}
_ => {}
}
}
}
_ => {}
}
}
}
if schema.is_empty() || table.is_empty() {
return TokenStream::from(quote! {
compile_error!("Both `schema` and `table` must be provided as attributes for #[migrations(schema = \"...\", table = \"...\")].");
});
}
let mut field_conversions = Vec::new();
let mut field_names = Vec::new();
if let Data::Struct(data_struct) = &input.data {
if let Fields::Named(fields_named) = &data_struct.fields {
for field in fields_named.named.iter() {
if let Some(field_name) = &field.ident {
let mut db_field_name = field_name.to_string();
for attr in &field.attrs {
if attr.path().is_ident("serde") {
if let Meta::List(list) = &attr.meta {
let nested = list.parse_args_with(syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated)
.unwrap_or_else(|_| syn::punctuated::Punctuated::new());
for meta in nested {
if let Meta::NameValue(nv) = meta {
if nv.path.is_ident("rename") {
if let syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Str(lit_str),
..
}) = nv.value
{
db_field_name = lit_str.value();
}
}
}
}
}
}
}
let field_name_ident = field_name;
let db_field_name_literal = db_field_name.clone();
field_conversions.push(quote! {
#field_name_ident: row.get(#db_field_name_literal)
});
field_names.push(db_field_name);
}
}
}
}
let field_names_str = field_names
.iter()
.map(|name| format!("\"{}\"", name))
.collect::<Vec<_>>()
.join(", ");
let expanded = quote! {
impl #struct_name {
pub fn convert_from_row(row: bb8_postgres::tokio_postgres::Row) -> Self {
Self {
#(#field_conversions),*
}
}
pub async fn get(
connection: &bb8_postgres::bb8::Pool<bb8_postgres::PostgresConnectionManager<bb8_postgres::tokio_postgres::NoTls>>,
) -> Result<Vec<Self>, Box<dyn std::error::Error + Send + Sync>> {
let client = connection.get().await?;
let query = format!(
r#"SELECT {} FROM {}."{}"#,
#field_names_str, #schema, #table
);
let rows = client.query(&query, &[]).await?;
let results: Vec<Self> = rows.into_iter().map(Self::convert_from_row).collect();
Ok(results)
}
pub async fn get_filter_i64<F>(
connection: &bb8_postgres::bb8::Pool<bb8_postgres::PostgresConnectionManager<bb8_postgres::tokio_postgres::NoTls>>,
field_extractor: F,
filter_values: std::collections::HashSet<i64>,
) -> Result<Vec<Self>, Box<dyn std::error::Error + Send + Sync>>
where
F: Fn(&Self) -> i64,
{
let client = connection.get().await?;
let query = format!(
r#"SELECT {} FROM {}."{}"#,
#field_names_str, #schema, #table
);
let rows = client.query(&query, &[]).await?;
let results: Vec<Self> = rows
.into_iter()
.map(Self::convert_from_row)
.filter(|record| filter_values.contains(&field_extractor(record)))
.collect();
Ok(results)
}
pub async fn get_filter_i64_i64<F1, F2>(
connection: &bb8_postgres::bb8::Pool<bb8_postgres::PostgresConnectionManager<bb8_postgres::tokio_postgres::NoTls>>,
field_extractor1: F1,
filter_values1: std::collections::HashSet<i64>,
field_extractor2: F2,
filter_values2: std::collections::HashSet<i64>,
) -> Result<Vec<Self>, Box<dyn std::error::Error + Send + Sync>>
where
F1: Fn(&Self) -> Option<i64>,
F2: Fn(&Self) -> Option<i64>,
{
let client = connection.get().await?;
let query = format!(
r#"SELECT {} FROM {}."{}"#,
#field_names_str, #schema, #table
);
let rows = client.query(&query, &[]).await?;
let results: Vec<Self> = rows
.into_iter()
.map(Self::convert_from_row)
.filter(|record| {
let value1 = field_extractor1(record).unwrap_or(-1);
let value2 = field_extractor2(record).unwrap_or(-1);
filter_values1.contains(&value1) && filter_values2.contains(&value2)
})
.collect();
Ok(results)
}
pub async fn get_filter_uuid<F>(
connection: &bb8_postgres::bb8::Pool<bb8_postgres::PostgresConnectionManager<bb8_postgres::tokio_postgres::NoTls>>,
field_extractor: F,
filter_values: std::collections::HashSet<uuid>,
) -> Result<Vec<Self>, Box<dyn std::error::Error + Send + Sync>>
where
F: Fn(&Self) -> uuid,
{
let client = connection.get().await?;
let query = format!(
r#"SELECT {} FROM {}."{}"#,
#field_names_str, #schema, #table
);
let rows = client.query(&query, &[]).await?;
let results: Vec<Self> = rows
.into_iter()
.map(Self::convert_from_row)
.filter(|record| filter_values.contains(&field_extractor(record)))
.collect();
Ok(results)
}
}
};
TokenStream::from(expanded)
}