odesign_derive/
lib.rs

1use quote::quote;
2use syn::{parse_macro_input, DeriveInput};
3
4#[proc_macro_derive(Feature, attributes(dimension))]
5pub fn derive_feature(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
6    let input = parse_macro_input!(input as DeriveInput);
7    match derive_feature_impl(input) {
8        Ok(token_stream) => token_stream,
9        Err(e) => e.to_compile_error().into(),
10    }
11}
12
13fn derive_feature_impl(input: DeriveInput) -> syn::Result<proc_macro::TokenStream> {
14    let name = input.ident;
15
16    let mut dimension: Option<usize> = None;
17
18    for attr in input.attrs {
19        if attr.path().is_ident("dimension") {
20            if let syn::Meta::NameValue(meta) = attr.meta {
21                if let syn::Expr::Lit(val) = meta.value {
22                    if let syn::Lit::Int(v) = val.lit {
23                        dimension = Some(v.base10_parse::<usize>().unwrap());
24                    }
25                }
26            }
27        }
28    }
29
30    let dim = dimension.ok_or_else(|| {
31        syn::Error::new_spanned(
32            name.clone(),
33            "Missing #[dimension = <d>] attribute where d of type usize is equal to input dimension of feature function",
34        )
35    })?;
36
37    let expanded = quote! {
38        impl Feature<#dim> for #name {
39            fn val(&self, x: &nalgebra::SVector<f64, #dim>) -> f64 {
40                self.f(&x)
41            }
42
43            fn val_grad(&self, x: &nalgebra::SVector<f64, #dim>) -> (f64, nalgebra::SVector<f64,#dim>) {
44                num_dual::gradient(|v| self.f(&v), *x)
45            }
46
47            fn val_grad_hes(&self, x: &nalgebra::SVector<f64, #dim>) -> (f64, nalgebra::SVector<f64, #dim>, nalgebra::SMatrix<f64, #dim, #dim>) {
48                num_dual::hessian(|v| self.f(&v), *x)
49            }
50        }
51    };
52    Ok(proc_macro::TokenStream::from(expanded))
53}