1use quote::quote;
2use syn::{DeriveInput, parse_macro_input};
3
4#[proc_macro_derive(Feature, attributes(dimension))]
5pub fn derive_feature(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
6 let input = parse_macro_input!(input as DeriveInput);
7 match derive_feature_impl(input) {
8 Ok(token_stream) => token_stream,
9 Err(e) => e.to_compile_error().into(),
10 }
11}
12
13fn derive_feature_impl(input: DeriveInput) -> syn::Result<proc_macro::TokenStream> {
14 let name = input.ident;
15
16 let mut dimension: Option<usize> = None;
17
18 for attr in input.attrs {
19 if attr.path().is_ident("dimension") {
20 if let syn::Meta::NameValue(meta) = attr.meta {
21 if let syn::Expr::Lit(val) = meta.value {
22 if let syn::Lit::Int(v) = val.lit {
23 dimension = Some(v.base10_parse::<usize>().unwrap());
24 }
25 }
26 }
27 }
28 }
29
30 let dim = dimension.ok_or_else(|| {
31 syn::Error::new_spanned(
32 name.clone(),
33 "Missing #[dimension = <d>] attribute where d of type usize is equal to input dimension of feature function",
34 )
35 })?;
36
37 let expanded = quote! {
38 impl Feature<#dim> for #name {
39 fn val(&self, x: &nalgebra::SVector<f64, #dim>) -> f64 {
40 self.f(&x)
41 }
42
43 fn val_grad(&self, x: &nalgebra::SVector<f64, #dim>) -> (f64, nalgebra::SVector<f64,#dim>) {
44 num_dual::gradient(|v| self.f(&v), *x)
45 }
46
47 fn val_grad_hes(&self, x: &nalgebra::SVector<f64, #dim>) -> (f64, nalgebra::SVector<f64, #dim>, nalgebra::SMatrix<f64, #dim, #dim>) {
48 num_dual::hessian(|v| self.f(&v), *x)
49 }
50 }
51 };
52 Ok(proc_macro::TokenStream::from(expanded))
53}