petite_ad/
lib.rs

1//! # petite AD
2//!
3//! A pure Rust automatic differentiation library supporting both single-variable
4//! and multi-variable functions with reverse-mode differentiation (backpropagation).
5//!
6//! ## Features
7//!
8//! - **Single-variable autodiff** - Chain operations like `sin`, `cos`, `exp`
9//! - **Multi-variable autodiff** - Build computational graphs for multiple inputs
10//! - **Zero-copy backward pass** - Efficient gradient computation through closure chains
11//! - **Convenient macros** - Use `mono_ops![]` and `multi_ops![]` for concise notation
12//!
13//! ## Examples
14//!
15//! ### Single-variable function
16//! ```
17//! use petite_ad::{MonoAD, mono_ops};
18//!
19//! let ops = mono_ops![sin, cos, exp];
20//! let (value, grad_fn) = MonoAD::compute_grad(&ops, 2.0);
21//! println!("f(2.0) = {}", value);
22//! println!("f'(2.0) = {}", grad_fn(1.0));
23//! ```
24//!
25//! ### Multi-variable function
26//! ```
27//! use petite_ad::{MultiAD, multi_ops};
28//!
29//! let exprs = multi_ops![
30//!     (inp, 0),    // x₁
31//!     (inp, 1),    // x₂
32//!     (add, 0, 1), // x₁ + x₂
33//!     (sin, 0),    // sin(x₁)
34//!     (mul, 2, 3), // sin(x₁) * (x₁ + x₂)
35//! ];
36//!
37//! let (value, grad_fn) = MultiAD::compute_grad(&exprs, &[0.6, 1.4]).unwrap();
38//! let gradients = grad_fn(1.0);
39//! println!("f(0.6, 1.4) = {}", value);
40//! println!("∇f = {:?}", gradients);
41//! ```
42
43mod error;
44mod macros;
45
46#[cfg(test)]
47mod test_utils;
48
49mod mono;
50mod multi;
51
52// Core types
53pub use mono::MonoAD;
54pub use multi::builder::GraphBuilder;
55pub use multi::MultiAD;
56
57// Error handling
58pub use error::{AutodiffError, Result};
59
60/// Type definitions for autodiff results and gradient functions.
61///
62/// This module provides type aliases for working with gradient computation results.
63pub mod types {
64    pub use crate::mono::types::{
65        BackwardResultArc as MonoResultArc, BackwardResultBox as MonoResultBox,
66        DynMathFn as MonoGradientFn,
67    };
68    pub use crate::multi::types::{
69        BackwardResultArc as MultiResultArc, BackwardResultBox as MultiResultBox,
70        DynGradFn as MultiGradientFn,
71    };
72}
73
74/// Traits for implementing custom differentiable functions.
75///
76/// These traits allow you to define your own mathematical functions
77/// with analytical gradients for testing and comparison purposes.
78pub mod traits {
79    pub use crate::mono::MonoFn;
80    pub use crate::multi::MultiFn;
81}