Skip to main content

odesign_derive/
lib.rs

1use quote::quote;
2use syn::{DeriveInput, parse_macro_input};
3
4#[proc_macro_derive(Feature, attributes(dimension))]
5pub fn derive_feature(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
6    let input = parse_macro_input!(input as DeriveInput);
7    match derive_feature_impl(input) {
8        Ok(token_stream) => token_stream,
9        Err(e) => e.to_compile_error().into(),
10    }
11}
12
13fn derive_feature_impl(input: DeriveInput) -> syn::Result<proc_macro::TokenStream> {
14    let name = input.ident;
15
16    let mut dimension: Option<usize> = None;
17
18    for attr in input.attrs {
19        if attr.path().is_ident("dimension")
20            && let syn::Meta::NameValue(meta) = attr.meta
21            && let syn::Expr::Lit(val) = meta.value
22            && let syn::Lit::Int(v) = val.lit
23        {
24            dimension = Some(v.base10_parse::<usize>().unwrap());
25        }
26    }
27
28    let dim = dimension.ok_or_else(|| {
29        syn::Error::new_spanned(
30            name.clone(),
31            "Missing #[dimension = <d>] attribute where d of type usize is equal to input \
32             dimension of feature function",
33        )
34    })?;
35
36    let expanded = quote! {
37        impl Feature<#dim> for #name {
38            fn val(&self, x: &nalgebra::SVector<f64, #dim>) -> f64 {
39                self.f(&x)
40            }
41
42            fn val_grad(&self, x: &nalgebra::SVector<f64, #dim>) -> (f64, nalgebra::SVector<f64,#dim>) {
43                num_dual::gradient(|v| self.f(&v), *x)
44            }
45
46            fn val_grad_hes(&self, x: &nalgebra::SVector<f64, #dim>) -> (f64, nalgebra::SVector<f64, #dim>, nalgebra::SMatrix<f64, #dim, #dim>) {
47                num_dual::hessian(|v| self.f(&v), *x)
48            }
49        }
50    };
51    Ok(proc_macro::TokenStream::from(expanded))
52}