rsdiff_macros/
lib.rs

1/*
2    Appellation: rsdiff-macros <library>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5//! # rsdiff-macros
6//!
7//!
8#![allow(clippy::module_inception, clippy::needless_doctest_main)]
9// #![cfg_attr(not(feature = "std"), no_std)]
10
11extern crate proc_macro;
12
13pub(crate) use self::{error::Error, primitives::*, utils::*};
14
15pub(crate) mod ast;
16pub(crate) mod error;
17pub(crate) mod handle;
18pub(crate) mod ops;
19pub(crate) mod utils;
20
21pub(crate) mod autodiff;
22pub(crate) mod operator;
23
24use proc_macro::TokenStream;
25use syn::parse_macro_input;
26
27/// Compute the partial derivative of a given expression w.r.t a particular variable.
28/// At the moment, the macro only supports expressions defined within the same scope.
29///
30/// # Examples
31///
32/// ### Basic arithmetic
33///
34/// ```
35/// extern crate rsdiff_macros as macros;
36///
37/// use macros::autodiff;
38///
39/// fn main() {
40///     let x = 3f64;
41///     let y = 4f64;
42///
43///     assert_eq!(y, autodiff!(x: x * y));
44///     assert_eq!(x, autodiff!(y: x * y));
45///     assert_eq!(1f64, autodiff!(x: x + y));
46/// }
47/// ```
48///
49/// ### Trigonometric functions
50///
51/// ```
52/// extern crate rsdiff_macros as macros;
53///
54/// use macros::autodiff;
55///
56/// fn main() {
57///     let x = 2f64;
58///     assert_eq!(autodiff!(x: x.cos()), -x.sin());
59///     assert_eq!(autodiff!(x: x.sin()), x.cos());
60///     assert_eq!(autodiff!(x: x.tan()), x.cos().powi(2).recip());
61/// }
62/// ```
63#[proc_macro]
64pub fn autodiff(input: TokenStream) -> TokenStream {
65    // Parse the input expression into a syntax tree
66    let expr = parse_macro_input!(input as ast::AutodiffAst);
67
68    // Generate code to compute the gradient
69    let result = autodiff::impl_autodiff(&expr);
70
71    // Return the generated code as a token stream
72    TokenStream::from(result)
73}
74
75#[doc(hidden)]
76#[proc_macro_attribute]
77pub fn operator(args: TokenStream, item: TokenStream) -> TokenStream {
78    let mut attrs = ast::operator::OperatorAttr::new();
79    let op_parser = syn::meta::parser(|meta| attrs.parser(meta));
80    let _ = parse_macro_input!(args with op_parser);
81    let item = parse_macro_input!(item as syn::Item);
82    let ast = ast::OperatorAst::new(Some(dbg!(attrs)), item);
83    let result = operator::impl_operator(&ast);
84    TokenStream::from(result)
85}
86
87pub(crate) mod kw {
88    syn::custom_keyword!(eval);
89    syn::custom_keyword!(grad);
90
91    syn::custom_keyword!(cos);
92    syn::custom_keyword!(exp);
93    syn::custom_keyword!(ln);
94    syn::custom_keyword!(sin);
95    syn::custom_keyword!(tan);
96}
97
98#[allow(unused)]
99pub(crate) mod primitives {
100    pub type Result<T = ()> = std::result::Result<T, crate::Error>;
101}