1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
/*
Appellation: acme-macros <library>
Contrib: FL03 <jo3mccain@icloud.com>
*/
//! # acme-macros
//!
//!
extern crate proc_macro;
pub(crate) use self::{primitives::*, utils::*};
pub(crate) mod ast;
pub(crate) mod handle;
pub(crate) mod ops;
pub(crate) mod utils;
pub(crate) mod autodiff;
pub(crate) mod operator;
pub(crate) mod partial;
use ast::partials::PartialAst;
use proc_macro::TokenStream;
use syn::parse_macro_input;
/// Compute the partial derivative of a given expression w.r.t a particular variable.
/// At the moment, the macro only supports expressions defined within the same scope.
///
/// # Examples
///
/// ### Basic arithmetic
///
/// ```
/// extern crate acme_macros as macros;
///
/// use macros::autodiff;
///
/// fn main() {
/// let x = 3f64;
/// let y = 4f64;
///
/// assert_eq!(y, autodiff!(x: x * y));
/// assert_eq!(x, autodiff!(y: x * y));
/// assert_eq!(1f64, autodiff!(x: x + y));
/// }
/// ```
///
/// ### Trigonometric functions
///
/// ```
/// extern crate acme_macros as macros;
///
/// use macros::autodiff;
///
/// fn main() {
/// let x = 2f64;
/// assert_eq!(autodiff!(x: x.cos()), -x.sin());
/// assert_eq!(autodiff!(x: x.sin()), x.cos());
/// assert_eq!(autodiff!(x: x.tan()), x.cos().powi(2).recip());
/// }
/// ```
#[proc_macro]
pub fn autodiff(input: TokenStream) -> TokenStream {
// Parse the input expression into a syntax tree
let expr = parse_macro_input!(input as PartialAst);
// Generate code to compute the gradient
let result = autodiff::impl_autodiff(&expr);
// Return the generated code as a token stream
TokenStream::from(result)
}
#[doc(hidden)]
#[proc_macro_attribute]
pub fn operator(_attr: TokenStream, item: TokenStream) -> TokenStream {
let ast = parse_macro_input!(item as syn::Item);
let result = operator::impl_operator(ast);
TokenStream::from(result)
}
#[doc(hidden)]
#[proc_macro_attribute]
pub fn partial(_attr: TokenStream, item: TokenStream) -> TokenStream {
let ast = parse_macro_input!(item as syn::ItemFn);
let result = partial::handle_item_fn(&ast);
TokenStream::from(result)
}
pub(crate) mod kw {
syn::custom_keyword!(eval);
syn::custom_keyword!(grad);
syn::custom_keyword!(cos);
syn::custom_keyword!(exp);
syn::custom_keyword!(ln);
syn::custom_keyword!(sin);
syn::custom_keyword!(tan);
}
pub(crate) mod primitives {
pub type BoxError = Box<dyn std::error::Error>;
}