use std::collections::{BTreeMap, BTreeSet};
use darling::{FromDeriveInput, FromField, Result, ast, util::Flag};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{Expr, Generics, Ident, Type, parse_quote};
use crate::common::{
BindingMode, FieldIdentity, NamedIndexResolution, With, impl_generics_with_lifetime,
process_fields,
};
#[derive(FromDeriveInput, Debug)]
#[darling(
attributes(squire),
supports(struct_named, struct_newtype, struct_tuple)
)]
pub struct ColumnsDerive {
ident: Ident,
generics: Generics,
data: ast::Data<(), FieldDerive>,
named: Flag,
sequential: Flag,
}
impl ColumnsDerive {
pub fn derive(self) -> Result<TokenStream> {
let (fields, style) = self.fields()?;
let binding_mode = BindingMode::from_flags_and_style(&self.named, &self.sequential, style)?;
let field_metas = process_fields(&fields, |i, field| field.build_meta(i, binding_mode))?;
let meta = Columns {
ident: self.ident,
generics: self.generics,
fields: field_metas,
binding_mode,
};
meta.generate_impl()
}
fn fields(&self) -> Result<(Vec<&FieldDerive>, ast::Style)> {
match &self.data {
ast::Data::Struct(contents) => match contents.style {
ast::Style::Struct | ast::Style::Tuple => {
let fields = contents
.fields
.iter()
.filter(|field| !field.skip.is_present())
.collect();
Ok((fields, contents.style))
}
ast::Style::Unit => Err(darling::Error::unsupported_shape("unit struct")),
},
ast::Data::Enum(_) => Err(darling::Error::unsupported_shape("enum")),
}
}
}
#[derive(FromField, Debug)]
#[darling(attributes(squire))]
struct FieldDerive {
ident: Option<Ident>,
ty: Type,
borrow: Flag,
index: Option<i32>,
rename: Option<Ident>,
skip: Flag,
result: Flag,
fetch_with: Option<With>,
json: Flag,
jsonb: Flag,
}
impl FieldDerive {
fn build_meta(&self, field_index: usize, binding_mode: BindingMode) -> Result<Column> {
let sequential = binding_mode == BindingMode::Sequential;
let identity = FieldIdentity::from_field(
&self.ident,
field_index,
self.rename.as_ref(),
self.index,
sequential,
);
let fetch_expr = self.build_fetch_expr(field_index)?;
let borrow_bound = self.borrow_bound();
Ok(Column {
ident: self.ident.clone(),
identity,
fetch_expr,
borrow_bound,
})
}
fn build_fetch_expr(&self, _field_index: usize) -> Result<Expr> {
if self.json.is_present() && self.jsonb.is_present() {
return Err(
darling::Error::custom("cannot use both json and jsonb attributes")
.with_span(&self.jsonb.span()),
);
}
let ty = &self.ty;
let column_var: Ident = parse_quote!(column);
let mut expr: Expr = if self.json.is_present() {
parse_quote!(<squire::Json<#ty> as squire::Fetch<'row>>::fetch_column(statement, #column_var)?.0)
} else if self.jsonb.is_present() {
parse_quote!(<squire::Jsonb<#ty> as squire::Fetch<'row>>::fetch_column(statement, #column_var)?.0)
} else {
parse_quote!(<#ty as squire::Fetch<'row>>::fetch_column(statement, #column_var)?)
};
if let Some(ref with) = self.fetch_with {
expr = with.wrap(&expr);
}
if self.result.is_present() {
expr = parse_quote!(#expr?);
}
if self.borrow.is_present() && !matches!(&self.ty, Type::Reference(_)) {
return Err(self.borrow_error());
}
Ok(expr)
}
fn borrow_bound(&self) -> Option<syn::Lifetime> {
if self.borrow.is_present() {
if let Type::Reference(syn::TypeReference {
lifetime: Some(ref lifetime),
..
}) = self.ty
{
Some(lifetime.clone())
} else {
None
}
} else {
None
}
}
fn borrow_error(&self) -> darling::Error {
darling::Error::custom("borrow can only be used with references")
.with_span(&self.borrow.span())
}
}
struct Columns {
ident: Ident,
generics: Generics,
fields: Vec<Column>,
binding_mode: BindingMode,
}
impl Columns {
fn generate_impl(self) -> Result<TokenStream> {
let ident = &self.ident;
let (indexes_impl_generics, ty_generics, indexes_where_clause) =
self.generics.split_for_impl();
let columns_impl_generics = impl_generics_with_lifetime(&self.generics, "'row");
let lifetime_bounds: BTreeSet<_> = self
.fields
.iter()
.filter_map(|f| f.borrow_bound.clone())
.collect();
let mut columns_where_clause = indexes_where_clause.cloned();
if !lifetime_bounds.is_empty() {
let lifetime_predicates: Vec<syn::WherePredicate> = lifetime_bounds
.iter()
.map(|lt| parse_quote!('row: #lt))
.collect();
if columns_where_clause.is_none() {
columns_where_clause = Some(parse_quote!(where));
}
if let Some(ref mut where_clause) = columns_where_clause {
where_clause.predicates.extend(lifetime_predicates);
}
}
let column_names: BTreeMap<&str, usize> = self
.fields
.iter()
.enumerate()
.filter_map(|(i, field)| field.identity.name().map(|name| (name, i)))
.collect();
if self.binding_mode.is_named() && column_names.len() < self.fields.len() {
return Err(darling::Error::custom("not all fields have names"));
}
let NamedIndexResolution { indexes, resolve } =
if self.binding_mode.is_named() && !column_names.is_empty() {
NamedIndexResolution::derive(
&column_names,
quote!(columns),
quote!(squire::ColumnIndex),
)
} else {
NamedIndexResolution::empty()
};
let fetch_statements = self.generate_fetch_statements(&column_names);
Ok(quote! {
impl #indexes_impl_generics squire::ColumnIndexes for #ident #ty_generics
#indexes_where_clause
{
#indexes
fn resolve<'connection>(statement: &squire::Statement<'connection>) -> Option<Self::Indexes> {
#resolve
}
}
impl #columns_impl_generics squire::Columns<'row> for #ident #ty_generics
#columns_where_clause
{
fn fetch<'connection>(statement: &'row squire::Statement<'connection>, indexes: Self::Indexes) -> squire::Result<Self>
where
'connection: 'row,
{
#fetch_statements
}
}
})
}
fn generate_fetch_statements(&self, column_names: &BTreeMap<&str, usize>) -> TokenStream {
let field_bindings: Vec<_> = self
.fields
.iter()
.enumerate()
.map(|(i, field)| {
let var_name = field
.ident
.as_ref()
.map(|id| quote!(#id))
.unwrap_or_else(|| {
let var_ident = Ident::new(&format!("field_{}", i), Span::call_site());
quote!(#var_ident)
});
let column_expr = match &field.identity {
FieldIdentity::Named(name) => {
let offset = column_names.get(name.as_str()).unwrap();
quote! { indexes[#offset] }
}
FieldIdentity::Sequential(index) => {
quote! { squire::ColumnIndex::try_from(#index)? }
}
};
let fetch_expr = &field.fetch_expr;
quote! {
let #var_name = {
let column = #column_expr;
#fetch_expr
};
}
})
.collect();
let field_names: Vec<_> = self
.fields
.iter()
.enumerate()
.map(|(i, field)| {
field
.ident
.as_ref()
.map(|id| quote!(#id))
.unwrap_or_else(|| {
let var_ident = Ident::new(&format!("field_{}", i), Span::call_site());
quote!(#var_ident)
})
})
.collect();
if self.fields.iter().any(|f| f.ident.is_some()) {
quote! {
#(#field_bindings)*
Ok(Self { #(#field_names),* })
}
} else {
quote! {
#(#field_bindings)*
Ok(Self(#(#field_names),*))
}
}
}
}
struct Column {
ident: Option<Ident>,
identity: FieldIdentity<i32>,
fetch_expr: Expr,
borrow_bound: Option<syn::Lifetime>,
}