1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
use proc_macro2::TokenStream; use quote::quote; use syn::{ parse_macro_input, punctuated::Punctuated, token::Comma, Data, DeriveInput, Expr, ExprLit, Field, Fields, GenericArgument, Ident, Lit, Path, PathArguments::AngleBracketed, Type, TypePath, }; #[proc_macro_derive(NeuralNetwork)] pub fn derive_neural_network(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(input as DeriveInput); let name = input.ident; let fields = match input.data { Data::Struct(data) => match data.fields { Fields::Named(fields) => fields.named, Fields::Unnamed(_) | Fields::Unit => unimplemented!(), }, Data::Enum(_) | Data::Union(_) => unimplemented!(), }; proc_macro::TokenStream::from(impl_neural_network(name, fields)) } fn get_field_type_args(field: &Field) -> &Punctuated<GenericArgument, Comma> { let type_args = &match &field.ty { Type::Path(TypePath { qself: _, path: Path { leading_colon: _, segments, }, }) => segments, _ => unimplemented!(), }[0] .arguments; match type_args { AngleBracketed(args) => &args.args, _ => unimplemented!(), } } fn as_usize(arg: &GenericArgument) -> usize { match arg { GenericArgument::Const(Expr::Lit(ExprLit { attrs: _, lit: Lit::Int(v), })) => v.base10_parse::<usize>().unwrap(), _ => unimplemented!(), } } fn impl_neural_network(name: Ident, fields: Punctuated<Field, Comma>) -> TokenStream { let forward_chain = fields.iter().fold(quote!(input), |acc, f| { let name = &f.ident; quote!(self.#name.forward(#acc)) }); let input_size = as_usize(&get_field_type_args(fields.first().unwrap())[1]); let output_size = as_usize(&get_field_type_args(fields.last().unwrap())[2]); quote! { impl NeuralNetwork<#input_size, #output_size> for #name { fn forward(&self, input: [f32; #input_size]) -> [f32; #output_size] { #forward_chain } } } }