acme_macros/
lib.rs

1/*
2    Appellation: acme-macros <library>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5//! # acme-macros
6//!
7//!
8extern crate proc_macro;
9
10pub(crate) use self::{error::Error, primitives::*, utils::*};
11
12pub(crate) mod ast;
13pub(crate) mod error;
14pub(crate) mod handle;
15pub(crate) mod ops;
16pub(crate) mod utils;
17
18pub(crate) mod autodiff;
19pub(crate) mod operator;
20
21use proc_macro::TokenStream;
22use syn::parse_macro_input;
23
24/// Compute the partial derivative of a given expression w.r.t a particular variable.
25/// At the moment, the macro only supports expressions defined within the same scope.
26///
27/// # Examples
28///
29/// ### Basic arithmetic
30///
31/// ```
32/// extern crate acme_macros as macros;
33///
34/// use macros::autodiff;
35///
36/// fn main() {
37///     let x = 3f64;
38///     let y = 4f64;
39///
40///     assert_eq!(y, autodiff!(x: x * y));
41///     assert_eq!(x, autodiff!(y: x * y));
42///     assert_eq!(1f64, autodiff!(x: x + y));
43/// }
44/// ```
45///
46/// ### Trigonometric functions
47///
48/// ```
49/// extern crate acme_macros as macros;
50///
51/// use macros::autodiff;
52///
53/// fn main() {
54///     let x = 2f64;
55///     assert_eq!(autodiff!(x: x.cos()), -x.sin());
56///     assert_eq!(autodiff!(x: x.sin()), x.cos());
57///     assert_eq!(autodiff!(x: x.tan()), x.cos().powi(2).recip());
58/// }
59/// ```
60#[proc_macro]
61pub fn autodiff(input: TokenStream) -> TokenStream {
62    // Parse the input expression into a syntax tree
63    let expr = parse_macro_input!(input as ast::AutodiffAst);
64
65    // Generate code to compute the gradient
66    let result = autodiff::impl_autodiff(&expr);
67
68    // Return the generated code as a token stream
69    TokenStream::from(result)
70}
71
72#[doc(hidden)]
73#[proc_macro_attribute]
74pub fn operator(args: TokenStream, item: TokenStream) -> TokenStream {
75    let mut attrs = ast::operator::OperatorAttr::new();
76    let op_parser = syn::meta::parser(|meta| attrs.parser(meta));
77    let _ = parse_macro_input!(args with op_parser);
78    let item = parse_macro_input!(item as syn::Item);
79    let ast = ast::OperatorAst::new(Some(attrs), item);
80    let result = operator::impl_operator(&ast);
81    TokenStream::from(result)
82}
83
84pub(crate) mod kw {
85    syn::custom_keyword!(eval);
86    syn::custom_keyword!(grad);
87
88    syn::custom_keyword!(cos);
89    syn::custom_keyword!(exp);
90    syn::custom_keyword!(ln);
91    syn::custom_keyword!(sin);
92    syn::custom_keyword!(tan);
93}
94
95pub(crate) mod primitives {
96    pub type Result<T = ()> = std::result::Result<T, crate::Error>;
97    pub type BoxError = Box<dyn std::error::Error>;
98}