impl_table 0.1.3

Generate table binding and utils for rust-postgres and rusqlite.
Documentation
use crate::parse_arguments::parse_arguments;
use crate::Argument;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::spanned::Spanned;

fn extract_attribute_name(attr: &syn::Attribute) -> Option<String> {
    attr.path
        .segments
        .first()
        .map(|segment| segment.value().ident.to_string())
}

fn is_pound_primary_key(attr: &&syn::Attribute) -> bool {
    extract_attribute_name(attr).map_or(false, |name| name == "primary_key")
}

fn is_pound_column(attr: &&syn::Attribute) -> bool {
    extract_attribute_name(attr).map_or(false, |name| name == "column")
}

// Has to accept && to work with find.
fn is_pound(attr: &&syn::Attribute) -> bool {
    is_pound_column(attr) || is_pound_primary_key(attr)
}

fn extract_named_fields(
    struct_def: &syn::DeriveInput,
) -> Result<&syn::punctuated::Punctuated<syn::Field, syn::token::Comma>, syn::Error> {
    let data = &struct_def.data;
    if let syn::Data::Struct(data_struct) = data {
        let fields = &data_struct.fields;
        if let syn::Fields::Named(named_fields) = fields {
            Ok(&named_fields.named)
        } else {
            panic!("Expecting named fields within a struct.");
        }
    } else {
        Err(syn::Error::new(
            Span::call_site(),
            "impl_table can only be applied to structs.",
        ))
    }
}

fn extract_column_name(tts: TokenStream) -> Result<Option<String>, syn::Error> {
    if tts.is_empty() {
        return Ok(None);
    }

    // tts includes the parentheses around them, which is not supported by parse_argument.
    let all_tts = quote! { column #tts };
    let args_vec = parse_arguments(all_tts, tts.span())?;

    assert!(
        args_vec.len() == 1,
        "Argumement list of 'column` is too long: column {}",
        tts
    );

    let outer_arg = args_vec.first().unwrap();

    // Function name is always "column"
    if let Argument::Function { name: _, args } = outer_arg {
        if args.len() > 1 {
            return Err(syn::Error::new(
                tts.span(),
                "`column` only accepts one name argument.",
            ));
        }

        args.first()
            .map(|arg| {
                if let Argument::Flag { key, value } = arg {
                    if key == "name" {
                        return Ok(value.to_string());
                    }
                }
                // Should really be a warning, if we could add one.
                return Err(syn::Error::new(
                    tts.span(),
                    "Only `name` option is supported in `column`.",
                ));
            })
            // Map Option<Result<A,E>> to Result<Option<A>, E>
            .transpose()
    } else {
        panic!("Unexpected `column` argument parsing result: {:?}", outer_arg);
    }
}

pub fn derive_table(item: proc_macro::TokenStream) -> Result<TokenStream, syn::Error> {
    let struct_def = syn::parse_macro_input::parse::<syn::DeriveInput>(item)?;
    let named_fields = extract_named_fields(&struct_def)?;

    let mut fields = vec![];
    let mut columns = vec![];
    let mut non_columns = vec![];
    let mut primary_key = None;
    for field in named_fields {
        let field_name = if let Some(name) = &field.ident {
            name.to_string()
        } else {
            panic!(
                "Expecting named fields to have names in struct {}.",
                struct_def.ident
            );
        };

        let optional_pound = field.attrs.iter().find(is_pound);
        if let Some(pound) = optional_pound {
            let column_name = extract_column_name(pound.tts.clone())?;

            fields.push(field_name.clone());
            columns.push(column_name.unwrap_or(field_name.clone()));

            if is_pound_primary_key(&pound) {
                if primary_key.is_none() {
                    primary_key = Some(field_name);
                } else {
                    return Err(syn::Error::new(
                        pound.span(),
                        "Expecting no more than one primary_key field.",
                    ));
                }
            }
        } else {
            non_columns.push(field_name.clone())
        }
    }
    if primary_key.is_none() {
        return Err(syn::Error::new(
            Span::call_site(),
            "Expecting at least one primary_key field.",
        ));
    }

    let struct_name = &struct_def.ident;
    let _columns_list = columns.join(", ");
    let field_values = fields.iter().enumerate().map(|(index, field)| {
        // TODO: remember span of the fields.
        let ident = syn::Ident::new(field, Span::call_site());
        quote! {
            #ident: row.get(#index)?
        }
    });
    let mut field_list: Vec<TokenStream> = field_values.collect();
    // TODO: add an option to add defaults and raise warning when there are unknown fields.
    if !non_columns.is_empty() {
        field_list.push(quote! { ..Default::default() });
    }

    let expr = quote! {
        impl #struct_name {
            fn table_name() -> &'static str {
                #struct_name::TABLE_NAME
            }

            fn all_columns() -> &'static [&'static str] {
                &[ #(#columns),* ]
            }

            fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
                Ok(Self {
                    #(#field_list),*
                })
            }
        }
    };
    Ok(expr.into())
}

#[cfg(test)]
mod tests {
    use super::*;

    macro_rules! build_attr {
        ($($attr:tt)*) => {
            {
                let mut outer_fields: syn::FieldsNamed = syn::parse_quote! {
                    {
                        $($attr)*
                        i: i64,
                    }
                };
                outer_fields.named.pop().unwrap().into_value().attrs.pop().unwrap()
            }
        }
    }

    #[test]
    fn test_pound_primary_key() {
        let attrs = build_attr! { #[primary_key] };
        assert_eq!(true, is_pound_primary_key(&&attrs));

        let attrs_column = build_attr! { #[column] };
        assert_eq!(false, is_pound_primary_key(&&attrs_column));
    }

    #[test]
    fn test_pound_column() {
        let attrs = build_attr! { #[column] };
        assert_eq!(true, is_pound_column(&&attrs));

        let attrs_option = build_attr! { #[column(name = "test")] };
        assert_eq!(true, is_pound_column(&&attrs_option));

        let attrs_primary_key = build_attr! { #[primary_key] };
        assert_eq!(false, is_pound_column(&&attrs_primary_key));
    }

    #[test]
    fn test_extract_column_name() {
        assert_eq!(None, extract_column_name(quote! { () }).unwrap());
        assert_eq!(
            Some("abcdxyz".to_string()),
            extract_column_name(quote! { (name = "abcdxyz") }).unwrap()
        );
        assert_eq!(
            "compile_error ! { \"`column` only accepts one name argument.\" }",
            extract_column_name(quote! { (name = "abcdxyz", other) })
                .err()
                .unwrap()
                .to_compile_error()
                .to_string()
        );
        assert_eq!(
            "compile_error ! { \"Only `name` option is supported in `column`.\" }",
            extract_column_name(quote! { (other = "x") })
                .err()
                .unwrap()
                .to_compile_error()
                .to_string()
        );
        assert_eq!(
            "compile_error ! { \"Only `name` option is supported in `column`.\" }",
            extract_column_name(quote! { (other) })
                .err()
                .unwrap()
                .to_compile_error()
                .to_string()
        );
    }

    #[test]
    #[should_panic]
    #[allow(unused)]
    fn test_extract_column_name_non_func() {
        extract_column_name(quote! { = "name" });
    }

    #[test]
    #[should_panic]
    #[allow(unused)]
    fn test_extract_column_name_longer_list() {
        extract_column_name(quote! { , name });
    }
}