petite_ad/lib.rs
1//! # petite AD
2//!
3//! A pure Rust automatic differentiation library supporting both single-variable
4//! and multi-variable functions with reverse-mode differentiation (backpropagation).
5//!
6//! ## Features
7//!
8//! - **Single-variable autodiff** - Chain operations like `sin`, `cos`, `exp`
9//! - **Multi-variable autodiff** - Build computational graphs for multiple inputs
10//! - **Zero-copy backward pass** - Efficient gradient computation through closure chains
11//! - **Convenient macros** - Use `mono_ops![]` and `multi_ops![]` for concise notation
12//!
13//! ## Examples
14//!
15//! ### Single-variable function
16//! ```
17//! use petite_ad::{MonoAD, mono_ops};
18//!
19//! let ops = mono_ops![sin, cos, exp];
20//! let (value, grad_fn) = MonoAD::compute_grad(&ops, 2.0);
21//! println!("f(2.0) = {}", value);
22//! println!("f'(2.0) = {}", grad_fn(1.0));
23//! ```
24//!
25//! ### Multi-variable function
26//! ```
27//! use petite_ad::{MultiAD, multi_ops};
28//!
29//! let exprs = multi_ops![
30//! (inp, 0), // x₁
31//! (inp, 1), // x₂
32//! (add, 0, 1), // x₁ + x₂
33//! (sin, 0), // sin(x₁)
34//! (mul, 2, 3), // sin(x₁) * (x₁ + x₂)
35//! ];
36//!
37//! let (value, grad_fn) = MultiAD::compute_grad(&exprs, &[0.6, 1.4]).unwrap();
38//! let gradients = grad_fn(1.0);
39//! println!("f(0.6, 1.4) = {}", value);
40//! println!("∇f = {:?}", gradients);
41//! ```
42
43mod error;
44mod macros;
45
46#[cfg(test)]
47mod test_utils;
48
49mod mono;
50mod multi;
51
52// Core types
53pub use mono::MonoAD;
54pub use multi::builder::GraphBuilder;
55pub use multi::MultiAD;
56
57// Error handling
58pub use error::{AutodiffError, Result};
59
60/// Type definitions for autodiff results and gradient functions.
61///
62/// This module provides type aliases for working with gradient computation results.
63pub mod types {
64 pub use crate::mono::types::{
65 BackwardResultArc as MonoResultArc, BackwardResultBox as MonoResultBox,
66 DynMathFn as MonoGradientFn,
67 };
68 pub use crate::multi::types::{
69 BackwardResultArc as MultiResultArc, BackwardResultBox as MultiResultBox,
70 DynGradFn as MultiGradientFn,
71 };
72}
73
74/// Traits for implementing custom differentiable functions.
75///
76/// These traits allow you to define your own mathematical functions
77/// with analytical gradients for testing and comparison purposes.
78pub mod traits {
79 pub use crate::mono::MonoFn;
80 pub use crate::multi::MultiFn;
81}