use std::collections::BTreeSet;
use ::{
anyhow::{self, Context},
proc_macro2::TokenStream,
quote::{ToTokens, quote},
syn::{self, parse::Parse, punctuated::Punctuated},
};
use easy_macros::{
TokensBuilder, always_context, context, get_attributes, has_attributes, parse_macro_input,
};
use easy_sql_compilation_data::CompilationData;
use quote::quote_spanned;
use crate::{
CUSTOM_SELECT_ALIAS_PREFIX,
derive_components::{
OUTPUT_FIELD_KEYS, OUTPUT_STRUCT_KEYS, supported_drivers, validate_sql_attribute_keys,
},
macros_components::{CollectedData, ProvidedDrivers, expr::Expr, joined_field::JoinedField},
sql_crate,
};
#[always_context]
pub fn sql_output_base(
item_name: &syn::Ident,
fields: &Punctuated<syn::Field, syn::Token![,]>,
joined_fields: Vec<JoinedField>,
table: &TokenStream,
drivers: &[syn::Path],
) -> anyhow::Result<TokenStream> {
let sql_crate = sql_crate();
let macro_support = quote! { #sql_crate::macro_support };
struct FieldWithSelect {
field: syn::Field,
expr: Expr,
}
let mut regular_fields = Punctuated::<syn::Field, syn::Token![,]>::new();
let mut fields_with_select = Vec::<FieldWithSelect>::new();
for field in fields.clone() {
let mut select_attr = None;
for attr_tokens in get_attributes!(field, #[sql(select = __unknown__)]) {
if select_attr.is_some() {
anyhow::bail!(
"Only one #[sql(select = ...)] attribute is allowed per field: {}",
field.ident.as_ref().unwrap()
);
}
select_attr = Some(attr_tokens);
}
if let Some(attr_tokens) = select_attr {
let parsed_attr: SelectAttribute = syn::parse2(attr_tokens.clone())?;
fields_with_select.push(FieldWithSelect {
field: field.clone(),
expr: parsed_attr.expr,
});
} else {
regular_fields.push(field);
}
}
let mut indices = BTreeSet::new();
for fws in fields_with_select.iter() {
fws.expr.collect_indices_impl(&mut indices);
}
let has_custom_select = !fields_with_select.is_empty();
let has_custom_select_args = !indices.is_empty();
let joined_field_aliases = (0..joined_fields.len())
.map(|i| format!("___easy_sql_joined_field_{}", i))
.collect::<Vec<_>>();
let joined_checks_field_names = joined_fields
.iter()
.map(|joined_field| joined_field.field.ident.as_ref().unwrap())
.collect::<Vec<_>>();
let context_strs2 = joined_fields
.iter()
.map(|joined_field| {
format!(
"Getting joined field `{}` with type {} for struct `{}` from table `{}`",
joined_field.field.ident.as_ref().unwrap(),
joined_field.field.ty.to_token_stream(),
item_name,
joined_field.table.to_token_stream()
)
})
.collect::<Vec<_>>();
let mut fields_quotes: Vec<Box<dyn Fn(&syn::Path) -> TokenStream>> = Vec::new();
for field in regular_fields.iter() {
let field_name = field.ident.as_ref().unwrap();
let field_name_str = field_name.to_string();
let context_str = format!(
"Getting field `{}` with type {} for struct `{}`",
field.ident.as_ref().unwrap(),
field.ty.to_token_stream(),
item_name
);
let macro_support = ¯o_support;
if has_attributes!(field, #[sql(bytes)]) {
let context_str2 = format!(
"Getting field `{}` with type {} for struct `{}` (Converting from binary)",
field.ident.as_ref().unwrap(),
field.ty.to_token_stream(),
item_name
);
fields_quotes.push(Box::new(move |driver|quote! {
#field_name: #macro_support::from_binary_vec( <#macro_support::DriverRow<#driver> as #macro_support::SqlxRow>::try_get(&data, #field_name_str).with_context(
#macro_support::context!(#context_str),
)?).with_context(
#macro_support::context!(#context_str2),
)?,
}));
} else {
fields_quotes.push(Box::new(move |driver|quote! {
#field_name: <#macro_support::DriverRow<#driver> as #macro_support::SqlxRow>::try_get(&data, #field_name_str).with_context(
#macro_support::context!(#context_str),
)?,
}));
}
}
for field_with_sel in &fields_with_select {
let field = &field_with_sel.field;
let field_name = field.ident.clone().unwrap();
let field_name_str = field_name.to_string();
let aliased_name = format!("{}{}", CUSTOM_SELECT_ALIAS_PREFIX, field_name_str);
let context_str = format!(
"Getting field `{}` with type {} for struct `{}` (with custom select)",
field_name,
field.ty.to_token_stream(),
item_name
);
let macro_support = ¯o_support;
if has_attributes!(field, #[sql(bytes)]) {
let context_str2 = format!(
"Getting field `{}` with type {} for struct `{}` (Converting from binary)",
field_name,
field.ty.to_token_stream(),
item_name
);
fields_quotes.push(Box::new(move |driver|quote! {
#field_name: #macro_support::from_binary_vec( <#macro_support::DriverRow<#driver> as #macro_support::SqlxRow>::try_get(&data, #aliased_name).with_context(
#macro_support::context!(#context_str),
)?).with_context(
#macro_support::context!(#context_str2),
)?,
}));
} else {
fields_quotes.push(Box::new(move |driver|quote! {
#field_name: <#macro_support::DriverRow<#driver> as #macro_support::SqlxRow>::try_get(&data, #aliased_name).with_context(
#macro_support::context!(#context_str),
)?,
}));
}
}
let select_str = regular_fields
.iter()
.map(|field| {
let field_name = field.ident.as_ref().unwrap();
format!("{{delimeter}}{}{{delimeter}}", field_name)
})
.collect::<Vec<_>>()
.join(", ");
let select_str_call = if !select_str.is_empty() {
quote! {
current_query.push_str(&format!(
#select_str
));
}
} else {
quote! {}
};
let select_joined = joined_fields.iter().enumerate().map(|(i, joined_field)| {
let ref_table = &joined_field.table;
let table_field = &joined_field.table_field;
let comma = if i == 0 && select_str.is_empty() {
""
} else {
", "
};
let alias = format!("___easy_sql_joined_field_{}", i);
let format_str = format!(
"{comma}{{delimeter}}{{}}{{delimeter}}.{{delimeter}}{}{{delimeter}} AS {}",
table_field, alias
);
quote! {
current_query.push_str(&format!(
#format_str,
<#ref_table as #sql_crate::Table<D>>::table_name(),
));
}
});
let select_body = if has_custom_select && !has_custom_select_args {
quote! {
current_query.push_str(&Self::__easy_sql_select::<D>(delimeter));
}
} else {
quote! {
#select_str_call
#(#select_joined)*
}
};
let trait_impl = if !has_custom_select_args {
quote! {
impl #sql_crate::markers::NormalSelect for #item_name {}
}
} else {
quote! {
impl #sql_crate::markers::WithArgsSelect for #item_name {}
}
};
let custom_select_impl = if has_custom_select {
let max_idx = indices.iter().max().copied().unwrap_or_default();
let mut types_driver_support_needed = Vec::new();
if has_custom_select_args {
for i in 0..=max_idx {
if !indices.contains(&i) {
anyhow::bail!(
"Missing argument in #[sql(select = ...)] expressions: arg{} is required but not used. \
Used arguments must be sequential starting from arg0.",
i
);
}
}
}
let (arg_params, arg_params_in_call) = if indices.is_empty() {
(vec![], vec![])
} else {
(
(0..=max_idx)
.map(|i| {
let arg_name = quote::format_ident!("arg{}", i);
quote! { #arg_name: &str }
})
.collect::<Vec<_>>(),
(0..=max_idx)
.map(|i| {
let arg_name = quote::format_ident!("arg{}", i);
quote! { #arg_name }
})
.collect::<Vec<_>>(),
)
};
let select_generation_code = {
let mut field_generation = Vec::new();
for field in regular_fields.iter() {
let field_name = field.ident.as_ref().unwrap();
let field_str = field_name.to_string();
field_generation.push(quote! {
parts.push(format!("{delimeter}{}{delimeter}", #field_str));
});
}
for (i, _joined_field) in joined_fields.iter().enumerate() {
let alias: &String = &joined_field_aliases[i];
let ref_table_ts = &_joined_field.table;
let field_name = &_joined_field.table_field;
let parts_format_str = format!(
"{{delimeter}}{{}}{{delimeter}}.{{delimeter}}{}{{delimeter}} AS {{delimeter}}{}{{delimeter}}",
field_name, alias
);
field_generation.push(quote_spanned! {field_name.span()=>
let _ = || {
let ___t___ = #macro_support::never_any::<#ref_table_ts>();
let _ = ___t___.#field_name;
};
parts.push(format!(#parts_format_str,
<#ref_table_ts as #sql_crate::Table<D>>::table_name(),
));
});
}
for field_with_sel in fields_with_select {
let field_name = field_with_sel.field.ident.as_ref().unwrap();
let field_str = field_name.to_string();
let expr = field_with_sel.expr;
let alias = format!("{}{}", CUSTOM_SELECT_ALIAS_PREFIX, field_str);
let mut checks = Vec::new();
let mut format_params = Vec::new();
let mut format_str = String::new();
let mut binds = Vec::new();
let mut before_param_n = quote! {};
let mut before_format = Vec::new();
let mut current_param_n = 0usize;
let drivers_for_checks = drivers
.iter()
.map(|e| e.to_token_stream())
.collect::<Vec<_>>();
let output_type_ts = quote! { #item_name };
let driver_for_select = ProvidedDrivers::SingleWithChecks {
driver: quote! { D },
checks: drivers_for_checks,
};
let mut data = CollectedData::new(
&mut format_str,
&mut format_params,
&mut binds,
&mut checks,
&sql_crate,
&driver_for_select,
&mut current_param_n,
&mut before_param_n,
&mut before_format,
Some(&output_type_ts),
Some(table),
&mut types_driver_support_needed,
);
let sql_template = expr.into_query_string(
&mut data, false, true, );
let parts_format_str = format!("{{}} AS {{delimeter}}{}{{delimeter}}", alias);
field_generation.push(quote! {
{
let _ = || {
let ___t___ = #macro_support::never_any::<#table>();
#(#checks)*
};
let formatted_expr = format!(#sql_template, #(#format_params),*);
parts.push(format!(#parts_format_str, formatted_expr));
}
});
}
quote! {
let mut parts = Vec::new();
#(#field_generation)*
parts.join(", ")
}
};
let where_clause_types = types_driver_support_needed
.into_iter()
.map(|ty| {
quote! {
#ty: #macro_support::Type<#macro_support::InternalDriver<D>>,
}
})
.collect::<Vec<_>>();
quote! {
impl #item_name {
pub fn __easy_sql_select<D: #sql_crate::Driver>(delimeter: &str, #(#arg_params),*) -> String
where
Self: #sql_crate::Output<#table, D>,
#(#where_clause_types)*
{
#select_generation_code
}
pub fn __easy_sql_select_driver_from_conn<D: #sql_crate::Driver>(
_conn: &impl #sql_crate::EasyExecutor<D>,
delimeter: &str,
#(#arg_params),*
) -> String
where
Self: #sql_crate::Output<#table, D>,
#(#where_clause_types)*
{
Self::__easy_sql_select::<D>(delimeter, #(#arg_params_in_call),*)
}
}
}
} else {
quote! {}
};
let driver_checks = drivers.iter().map(|driver| {
let fields_quotes = fields_quotes.iter().map(|f| f(driver));
quote! {
let _ = |data: #macro_support::DriverRow<#driver>| {
#macro_support::Result::<Self>::Ok(Self {
#(
#fields_quotes
)*
#(
#joined_checks_field_names: <#macro_support::DriverRow<#driver> as #macro_support::SqlxRow>::try_get(&data, #joined_field_aliases).with_context(
#macro_support::context!(#context_strs2),
)?,
)*
})
};
}
}).collect::<Vec<_>>();
let where_clauses_types = joined_fields
.iter()
.map(|e| &e.field)
.chain(fields.iter())
.map(|field| {
let bytes = has_attributes!(field, #[sql(bytes)]);
if bytes {
let bound_ty = quote! { Vec<u8> };
quote! {
for<'__easy_sql_x> #bound_ty: #macro_support::Decode<'__easy_sql_x, #macro_support::InternalDriver<D>>,
#bound_ty: #macro_support::Type<#macro_support::InternalDriver<D>>,
}
} else {
let field_ty = &field.ty;
quote! {
for<'__easy_sql_x> #field_ty: #macro_support::Decode<'__easy_sql_x, #macro_support::InternalDriver<D>>,
#field_ty: #macro_support::Type<#macro_support::InternalDriver<D>>,
}
}
});
let fields_quotes = fields_quotes.into_iter().map(|f| f(&syn::parse_quote! {D}));
Ok(quote! {
impl<D: #sql_crate::Driver> #sql_crate::Output<#table, D> for #item_name
where #macro_support::DriverRow<D>: #macro_support::ToConvert<D>,
for<'__easy_sql_x> &'__easy_sql_x str: #macro_support::ColumnIndex<#macro_support::DriverRow<D>>,
#(#where_clauses_types)*
{
type DataToConvert = #macro_support::DriverRow<D>;
type UsedForChecks = Self;
fn select(current_query: &mut String) {
use #macro_support::Context;
let delimeter = <D as #sql_crate::Driver>::identifier_delimiter();
#select_body
}
fn convert(data: #macro_support::DriverRow<D>) -> #macro_support::Result<Self> {
use #macro_support::{Context,context};
#(#driver_checks)*
Ok(Self {
#(
#fields_quotes
)*
#(
#joined_checks_field_names: <#macro_support::DriverRow<D> as #macro_support::SqlxRow>::try_get(&data, #joined_field_aliases).with_context(
#macro_support::context!(#context_strs2),
)?,
)*
})
}
}
impl #macro_support::OutputData<#table> for #item_name {
type SelectProvider = Self;
}
#trait_impl
#custom_select_impl
})
}
struct FieldAttribute {
table: syn::Path,
table_field: syn::Ident,
}
#[always_context]
impl Parse for FieldAttribute {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let table = input.parse::<syn::Path>()?;
input.parse::<syn::Token![.]>()?;
let table_field = input.parse::<syn::Ident>()?;
Ok(FieldAttribute { table, table_field })
}
}
struct SelectAttribute {
expr: Expr,
}
#[always_context]
impl Parse for SelectAttribute {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let expr = input.parse::<Expr>()?;
Ok(SelectAttribute { expr })
}
}
#[always_context]
pub fn output(item: proc_macro::TokenStream) -> anyhow::Result<proc_macro::TokenStream> {
let item = parse_macro_input!(item as syn::ItemStruct);
if let Some(error_tokens) =
validate_sql_attribute_keys(&item, "Output", OUTPUT_STRUCT_KEYS, OUTPUT_FIELD_KEYS)
{
return Ok(error_tokens.into());
}
let item_name = &item.ident;
let fields = match &item.fields {
syn::Fields::Named(fields_named) => fields_named.named.clone(),
syn::Fields::Unnamed(_) => {
anyhow::bail!("Unnamed struct fields is not supported")
}
syn::Fields::Unit => anyhow::bail!("Unit struct is not supported"),
};
let mut joined_fields = Vec::new();
let mut fields2: Punctuated<syn::Field, syn::token::Comma> = Punctuated::new();
for field in fields.into_iter() {
let mut attr = None;
for a in get_attributes!(field, #[sql(field = __unknown__)]) {
if attr.is_some() {
anyhow::bail!("Only one #[sql(field = ...)] attribute is allowed per field!");
}
attr = Some(a);
}
if let Some(attr) = attr {
let attr: FieldAttribute = syn::parse2(attr.clone())?;
joined_fields.push(JoinedField {
field,
table: attr.table,
table_field: attr.table_field,
});
} else {
fields2.push(field);
}
}
let fields = fields2;
let mut table = None;
for attr in get_attributes!(item, #[sql(table = __unknown__)]) {
if table.is_some() {
anyhow::bail!("Only one table attribute is allowed");
}
table = Some(attr);
}
#[no_context]
let table = table.with_context(context!("Table attribute is required"))?;
let compilation_data = CompilationData::load_in_macro()?;
let supported_drivers = supported_drivers(&item, &compilation_data, true)?;
let mut result = TokensBuilder::default();
result.add(sql_output_base(
item_name,
&fields,
joined_fields.clone(),
&table,
&supported_drivers,
)?);
Ok(result.finalize().into())
}