petite-ad
A pure Rust automatic differentiation library supporting both single-variable and multi-variable functions with reverse-mode differentiation (backpropagation).
Features
- Single-variable autodiff (
MonoAD) - Chain operations likesin,cos,expwith automatic gradient computation - Multi-variable autodiff (
MultiAD) - Build computational graphs for functions with multiple inputs - Box-wrapped by default - Results use
Box<dyn Fn>for flexibility; convert toArcwhen needed for thread-safety - Zero-copy backward pass - Gradients computed efficiently through closure chains
- Convenient macros - Use
mono_ops![]for concise operation lists - Builder API - Fluent interface for constructing computation graphs
- Comprehensive tests - 39 unit tests + 10 doctests covering all operations and edge cases
Installation
Add to your Cargo.toml:
[]
= "0.1.0"
Quick Start
Single-Variable Functions
use ;
let exprs = mono_ops!;
let = compute;
let gradient = backprop;
println!; // exp(cos(sin(2.0)))
println!; // derivative
Multi-Variable Functions
Using the GraphBuilder API (Recommended)
use ;
// Build: f(x, y) = sin(x) * (x + y)
let graph = new // 2 inputs
.add // x + y at index 2
.sin // sin(x) at index 3
.mul // sin(x) * (x + y) at index 4
.build;
let inputs = &;
let = compute_grad.unwrap;
let gradients = backprop_fn;
println!;
println!; // [∂f/∂x, ∂f/∂y]
Using Manual Graph Construction
use MultiAD;
// Build computational graph: f(x, y) = sin(x) * (x + y)
let exprs = &;
let inputs = &;
let = compute_grad.unwrap;
let gradients = backprop_fn;
println!;
println!; // [∂f/∂x, ∂f/∂y]
Available Operations
MonoAD (Single-Variable)
| Operation | Description | Derivative |
|---|---|---|
Sin |
Sine | x.cos() |
Cos |
Cosine | -x.sin() |
Exp |
Exponential | exp(x) |
MultiAD (Multi-Variable)
| Operation | Arity | Description |
|---|---|---|
Inp |
1 | Input placeholder |
Add |
2 | Addition: a + b |
Sub |
2 | Subtraction: a - b |
Mul |
2 | Multiplication: a * b |
Div |
2 | Division: a / b |
Pow |
2 | Power: a^b |
Sin |
1 | Sine: sin(x) |
Cos |
1 | Cosine: cos(x) |
Tan |
1 | Tangent: tan(x) |
Exp |
1 | Exponential: exp(x) |
Ln |
1 | Natural log: ln(x) |
Sqrt |
1 | Square root: sqrt(x) |
Abs |
1 | Absolute value: abs(x) |
License
MIT
Contributing
Contributions are welcome! Areas for improvement:
- Higher-order derivatives (Hessian computation)
- Vector/matrix operations
- Optimization algorithms (SGD, Adam, etc.)
- Constant/literal values in computation graphs (e.g.,
x^2without needing a separate input) - Additional mathematical operations