use crate::parse_arguments::parse_arguments;
use crate::Argument;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::spanned::Spanned;
fn extract_attribute_name(attr: &syn::Attribute) -> Option<String> {
attr.path
.segments
.first()
.map(|segment| segment.value().ident.to_string())
}
fn is_pound_primary_key(attr: &&syn::Attribute) -> bool {
extract_attribute_name(attr).map_or(false, |name| name == "primary_key")
}
fn is_pound_column(attr: &&syn::Attribute) -> bool {
extract_attribute_name(attr).map_or(false, |name| name == "column")
}
fn is_pound(attr: &&syn::Attribute) -> bool {
is_pound_column(attr) || is_pound_primary_key(attr)
}
fn extract_named_fields(
struct_def: &syn::DeriveInput,
) -> Result<&syn::punctuated::Punctuated<syn::Field, syn::token::Comma>, syn::Error> {
let data = &struct_def.data;
if let syn::Data::Struct(data_struct) = data {
let fields = &data_struct.fields;
if let syn::Fields::Named(named_fields) = fields {
Ok(&named_fields.named)
} else {
panic!("Expecting named fields within a struct.");
}
} else {
Err(syn::Error::new(
Span::call_site(),
"impl_table can only be applied to structs.",
))
}
}
fn extract_column_name(tts: TokenStream) -> Result<Option<String>, syn::Error> {
if tts.is_empty() {
return Ok(None);
}
let all_tts = quote! { column #tts };
let args_vec = parse_arguments(all_tts, tts.span())?;
assert!(
args_vec.len() == 1,
"Argumement list of 'column` is too long: column {}",
tts
);
let outer_arg = args_vec.first().unwrap();
if let Argument::Function { name: _, args } = outer_arg {
if args.len() > 1 {
return Err(syn::Error::new(
tts.span(),
"`column` only accepts one name argument.",
));
}
args.first()
.map(|arg| {
if let Argument::Flag { key, value } = arg {
if key == "name" {
return Ok(value.to_string());
}
}
return Err(syn::Error::new(
tts.span(),
"Only `name` option is supported in `column`.",
));
})
.transpose()
} else {
panic!("Unexpected `column` argument parsing result: {:?}", outer_arg);
}
}
pub fn derive_table(item: proc_macro::TokenStream) -> Result<TokenStream, syn::Error> {
let struct_def = syn::parse_macro_input::parse::<syn::DeriveInput>(item)?;
let named_fields = extract_named_fields(&struct_def)?;
let mut fields = vec![];
let mut columns = vec![];
let mut non_columns = vec![];
let mut primary_key = None;
for field in named_fields {
let field_name = if let Some(name) = &field.ident {
name.to_string()
} else {
panic!(
"Expecting named fields to have names in struct {}.",
struct_def.ident
);
};
let optional_pound = field.attrs.iter().find(is_pound);
if let Some(pound) = optional_pound {
let column_name = extract_column_name(pound.tts.clone())?;
fields.push(field_name.clone());
columns.push(column_name.unwrap_or(field_name.clone()));
if is_pound_primary_key(£) {
if primary_key.is_none() {
primary_key = Some(field_name);
} else {
return Err(syn::Error::new(
pound.span(),
"Expecting no more than one primary_key field.",
));
}
}
} else {
non_columns.push(field_name.clone())
}
}
if primary_key.is_none() {
return Err(syn::Error::new(
Span::call_site(),
"Expecting at least one primary_key field.",
));
}
let struct_name = &struct_def.ident;
let _columns_list = columns.join(", ");
let field_values = fields.iter().enumerate().map(|(index, field)| {
let ident = syn::Ident::new(field, Span::call_site());
quote! {
#ident: row.get(#index)?
}
});
let mut field_list: Vec<TokenStream> = field_values.collect();
if !non_columns.is_empty() {
field_list.push(quote! { ..Default::default() });
}
let expr = quote! {
impl #struct_name {
fn table_name() -> &'static str {
#struct_name::TABLE_NAME
}
fn all_columns() -> &'static [&'static str] {
&[ #(#columns),* ]
}
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok(Self {
#(#field_list),*
})
}
}
};
Ok(expr.into())
}
#[cfg(test)]
mod tests {
use super::*;
macro_rules! build_attr {
($($attr:tt)*) => {
{
let mut outer_fields: syn::FieldsNamed = syn::parse_quote! {
{
$($attr)*
i: i64,
}
};
outer_fields.named.pop().unwrap().into_value().attrs.pop().unwrap()
}
}
}
#[test]
fn test_pound_primary_key() {
let attrs = build_attr! { #[primary_key] };
assert_eq!(true, is_pound_primary_key(&&attrs));
let attrs_column = build_attr! { #[column] };
assert_eq!(false, is_pound_primary_key(&&attrs_column));
}
#[test]
fn test_pound_column() {
let attrs = build_attr! { #[column] };
assert_eq!(true, is_pound_column(&&attrs));
let attrs_option = build_attr! { #[column(name = "test")] };
assert_eq!(true, is_pound_column(&&attrs_option));
let attrs_primary_key = build_attr! { #[primary_key] };
assert_eq!(false, is_pound_column(&&attrs_primary_key));
}
#[test]
fn test_extract_column_name() {
assert_eq!(None, extract_column_name(quote! { () }).unwrap());
assert_eq!(
Some("abcdxyz".to_string()),
extract_column_name(quote! { (name = "abcdxyz") }).unwrap()
);
assert_eq!(
"compile_error ! { \"`column` only accepts one name argument.\" }",
extract_column_name(quote! { (name = "abcdxyz", other) })
.err()
.unwrap()
.to_compile_error()
.to_string()
);
assert_eq!(
"compile_error ! { \"Only `name` option is supported in `column`.\" }",
extract_column_name(quote! { (other = "x") })
.err()
.unwrap()
.to_compile_error()
.to_string()
);
assert_eq!(
"compile_error ! { \"Only `name` option is supported in `column`.\" }",
extract_column_name(quote! { (other) })
.err()
.unwrap()
.to_compile_error()
.to_string()
);
}
#[test]
#[should_panic]
#[allow(unused)]
fn test_extract_column_name_non_func() {
extract_column_name(quote! { = "name" });
}
#[test]
#[should_panic]
#[allow(unused)]
fn test_extract_column_name_longer_list() {
extract_column_name(quote! { , name });
}
}