1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
/*
    Appellation: acme-macros <library>
    Contrib: FL03 <jo3mccain@icloud.com>
*/
//! # acme-macros
//!
//!
extern crate proc_macro;

pub(crate) use self::{primitives::*, utils::*};

pub(crate) mod ast;
pub(crate) mod handle;
pub(crate) mod ops;
pub(crate) mod utils;

pub(crate) mod autodiff;
pub(crate) mod operator;
pub(crate) mod partial;

use ast::partials::PartialAst;
use proc_macro::TokenStream;
use syn::parse_macro_input;

/// Compute the partial derivative of a given expression w.r.t a particular variable.
/// At the moment, the macro only supports expressions defined within the same scope.
///
/// # Examples
///
/// ### Basic arithmetic
///
/// ```
/// extern crate acme_macros as macros;
///
/// use macros::autodiff;
///
/// fn main() {
///     let x = 3f64;
///     let y = 4f64;
///
///     assert_eq!(y, autodiff!(x: x * y));
///     assert_eq!(x, autodiff!(y: x * y));
///     assert_eq!(1f64, autodiff!(x: x + y));
/// }
/// ```
///
/// ### Trigonometric functions
///
/// ```
/// extern crate acme_macros as macros;
///
/// use macros::autodiff;
///
/// fn main() {
///     let x = 2f64;
///     assert_eq!(autodiff!(x: x.cos()), -x.sin());
///     assert_eq!(autodiff!(x: x.sin()), x.cos());
///     assert_eq!(autodiff!(x: x.tan()), x.cos().powi(2).recip());
/// }
/// ```
#[proc_macro]
pub fn autodiff(input: TokenStream) -> TokenStream {
    // Parse the input expression into a syntax tree
    let expr = parse_macro_input!(input as PartialAst);

    // Generate code to compute the gradient
    let result = autodiff::impl_autodiff(&expr);

    // Return the generated code as a token stream
    TokenStream::from(result)
}

#[doc(hidden)]
#[proc_macro_attribute]
pub fn operator(_attr: TokenStream, item: TokenStream) -> TokenStream {
    let ast = parse_macro_input!(item as syn::Item);
    let result = operator::impl_operator(ast);
    TokenStream::from(result)
}

#[doc(hidden)]
#[proc_macro_attribute]
pub fn partial(_attr: TokenStream, item: TokenStream) -> TokenStream {
    let ast = parse_macro_input!(item as syn::ItemFn);
    let result = partial::handle_item_fn(&ast);
    TokenStream::from(result)
}

pub(crate) mod kw {
    syn::custom_keyword!(eval);
    syn::custom_keyword!(grad);

    syn::custom_keyword!(cos);
    syn::custom_keyword!(exp);
    syn::custom_keyword!(ln);
    syn::custom_keyword!(sin);
    syn::custom_keyword!(tan);
}

pub(crate) mod primitives {
    pub type BoxError = Box<dyn std::error::Error>;
}