extern crate proc_macro;
extern crate proc_macro2;
extern crate quote;
extern crate syn;
use proc_macro::TokenStream;
use proc_macro2::*;
use quote::quote;
use std::collections::HashMap;
use syn::DeriveInput;
#[proc_macro_derive(GettersByType)]
pub fn fields_getters_mutable_by_type(input: TokenStream) -> TokenStream {
fields_getters_by_type_impl(input, false)
}
#[proc_macro_derive(GettersMutByType)]
pub fn fields_getters_immutable_by_type(input: TokenStream) -> TokenStream {
fields_getters_by_type_impl(input, true)
}
fn fields_getters_by_type_impl(input: TokenStream, with_mutability: bool) -> TokenStream {
let ast: DeriveInput = syn::parse(input).unwrap();
let (vis, ty, generics) = (&ast.vis, &ast.ident, &ast.generics);
let fields_by_type = match ast.data {
syn::Data::Struct(e) => read_fields(e.fields, with_mutability),
_ => panic!("{} can only be derived for structs.", if with_mutability { "GettersMutByType" } else { "GettersByType" }),
};
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let methods = fields_by_type.into_iter().fold(Vec::<TokenTree>::new(), |mut acc, (type_name, fields_sharing_type)| {
let ctx = MethodContext {
method_return_type: fields_sharing_type.type_ident,
type_name: fix_type_name(&type_name),
vis,
};
acc.extend(make_method_tokens("get_fields", &ctx, false, fields_sharing_type.immutable_fields));
if with_mutability {
acc.extend(make_method_tokens("get_mut_fields", &ctx, true, fields_sharing_type.mutable_fields));
}
acc
});
let tokens = quote! {
impl #impl_generics #ty #ty_generics
#where_clause
{
#(#methods)
*
}
};
tokens.into()
}
struct MethodContext<'a> {
method_return_type: syn::Type,
vis: &'a syn::Visibility,
type_name: String,
}
fn make_method_tokens(method_prefix: &str, ctx: &MethodContext, mutability: bool, mut field_names: Vec<String>) -> proc_macro2::TokenStream {
let count = field_names.len();
let field_idents = field_names.iter_mut().map(|i| syn::Ident::new(&i, proc_macro2::Span::call_site()));
let method_name = syn::Ident::new(&format!("{}_{}", method_prefix, ctx.type_name), proc_macro2::Span::call_site());
let (vis, method_return_type) = (&ctx.vis, &ctx.method_return_type);
if mutability {
quote! {
#vis fn #method_name(&mut self) -> [&mut #method_return_type; #count] {
[#(&mut self.#field_idents),*]
}
}
} else {
quote! {
#vis fn #method_name(&self) -> [&#method_return_type; #count] {
[#(&self.#field_idents),*]
}
}
}
}
struct FieldsSharingType {
immutable_fields: Vec<String>,
mutable_fields: Vec<String>,
type_ident: syn::Type,
}
impl FieldsSharingType {
fn new(type_ident: syn::Type) -> FieldsSharingType {
FieldsSharingType {
immutable_fields: vec![],
mutable_fields: vec![],
type_ident,
}
}
}
fn read_fields(fields: syn::Fields, with_mutability: bool) -> HashMap<String, FieldsSharingType> {
let mut fields_by_type = HashMap::<String, FieldsSharingType>::new();
for field in fields.iter() {
if let Some(ref ident) = field.ident {
match get_data_from_field(&field) {
Ok(FieldInfo { is_mutable, type_ident, type_name }) => {
let fields_by_type = fields_by_type.entry(type_name).or_insert_with(|| FieldsSharingType::new(type_ident));
if is_mutable && with_mutability {
fields_by_type.mutable_fields.push(ident.to_string());
}
fields_by_type.immutable_fields.push(ident.to_string());
}
Err(err) => {
eprintln!("[WARNING::GetterByType] {} for field: {}", err, ident);
}
}
}
}
fields_by_type
}
struct FieldInfo {
is_mutable: bool,
type_ident: syn::Type,
type_name: String,
}
fn get_data_from_field(field: &syn::Field) -> Result<FieldInfo, &'static str> {
let (type_name, type_ident, is_field_mutable) = match field.ty {
syn::Type::Path(ref path) => (get_type_string(path), field.ty.clone(), true),
syn::Type::Reference(ref reference) => match *reference.elem {
syn::Type::Path(ref path) => (get_type_string(path), syn::Type::Path(path.clone()), reference.mutability.is_some()),
_ => return Err("Reference not covered"),
},
_ => return Err("Type not covered"),
};
Ok(FieldInfo {
is_mutable: is_field_mutable,
type_ident,
type_name: type_name?,
})
}
fn get_type_string(path: &syn::TypePath) -> Result<String, &'static str> {
let mut error = None;
let operation = path
.path
.segments
.iter()
.map(|segment| {
segment.ident.to_string()
+ match get_argument_string(&segment.arguments) {
Ok(ref string) => string,
Err(err) => {
error = Some(err);
""
}
}
})
.collect::<String>();
match error {
None => Ok(operation),
Some(err) => Err(err),
}
}
fn get_argument_string(arguments: &syn::PathArguments) -> Result<String, &'static str> {
let mut type_name = String::new();
match arguments {
syn::PathArguments::AngleBracketed(ref angle) => {
type_name += "<";
for arg in &angle.args {
type_name += &match arg {
syn::GenericArgument::Type(ref ty) => get_type_path_string(ty),
_ => Ok("".into()),
}?;
}
type_name += ">";
}
syn::PathArguments::None => {}
syn::PathArguments::Parenthesized(ref paren) => {
type_name += "(";
for arg in &paren.inputs {
get_type_path_string(arg)?;
}
type_name += ")";
match paren.output {
syn::ReturnType::Default => {}
syn::ReturnType::Type(_, ref arg) => {
type_name += "arrow_";
get_type_path_string(&**arg)?;
}
}
}
}
Ok(type_name)
}
fn get_type_path_string(type_argument: &syn::Type) -> Result<String, &'static str> {
match type_argument {
syn::Type::Path(ref argpath) => {
let mut error = None;
let operation = argpath
.path
.segments
.iter()
.map(|argsegment| {
format!("{}_", argsegment.ident)
+ match get_argument_string(&argsegment.arguments) {
Ok(ref string) => string,
Err(err) => {
error = Some(err);
""
}
}
})
.collect::<String>();
match error {
None => Ok(operation),
Some(err) => Err(err),
}
}
_ => Err("Type argument not covered"),
}
}
fn fix_type_name(name: &str) -> String {
name.to_string()
.to_lowercase()
.chars()
.map(|c| match c {
'<' | '>' | '(' | ')' | '-' => '_',
_ => c,
})
.collect()
}