bevy_autodiff
Automatic differentiation using Bevy ECS as the computational graph backend.
Variables are ECS entities, operations are components, and derivatives are computed by symbolic graph differentiation with chain-rule constant folding. An exploration of what ECS can do for automatic differentiation.
Features
- ECS as computation graph -- entities are variables, components define operations and connectivity
- Symbolic graph differentiation --
differentiate(output, wrt)creates new entities representing the derivative graph via the chain rule - Successive differentiation -- higher-order and mixed partials by repeated differentiation:
d²f/dxdy = differentiate(differentiate(f, x), y) - Constant folding -- zero/one terms are eliminated during differentiation to prevent graph bloat
- CompiledGraph -- flattens the ECS graph into a
Vec<NodeOp>for fast repeated evaluation without ECS overhead - Reverse-mode gradient -- single backward pass over CompiledGraph computes the full gradient regardless of input count
- Forward-mode symbolic partials -- pre-compiled derivative subgraphs for higher-order derivatives
- 21 elementary operations -- 16 unary + 5 binary, all with differentiation rules and reverse-mode adjoints
Installation
[]
= "0.3"
Quick Start
use AutoDiff;
let mut ad = new;
// Create input variable
let x = ad.var;
// Build computation graph: f(x) = x² + 3x + 1
let x_squared = ad.square;
let three = ad.constant;
let three_x = ad.mul;
let one = ad.constant;
let sum = ad.add;
let f = ad.add;
// Evaluate
assert_eq!; // f(2) = 4 + 6 + 1
// Symbolic differentiation
let dfdx = ad.differentiate;
assert_eq!; // f'(2) = 2·2 + 3
// Higher-order via successive differentiation
assert_eq!; // f''(x) = 2
assert_eq!; // f'''(x) = 0
Gradients
Reverse-mode (recommended for many inputs)
compile_primal compiles only the function value. gradient() computes all partial derivatives in a single backward pass -- O(1) in the number of inputs.
use AutoDiff;
let mut ad = new;
let x = ad.var;
let y = ad.var;
// f(x, y) = x² + x·y + y²
let x2 = ad.square;
let xy = ad.mul;
let y2 = ad.square;
let sum = ad.add;
let f = ad.add;
let mut cg = ad.compile_primal;
cg.eval;
let val = cg.value; // 7.0
let grad = cg.gradient; // [4.0, 5.0]
// Re-evaluate at new point without recompiling
cg.eval;
let grad = cg.gradient; // [5.0, 1.0]
Forward-mode (supports higher-order)
compile_order pre-compiles symbolic derivative subgraphs. Useful when you need second-order or mixed partial derivatives.
use AutoDiff;
let mut ad = new;
let x = ad.var;
let y = ad.var;
let xy = ad.mul;
let f = ad.add;
// Compile with all partials up to order 2
let mut cg = ad.compile_order;
cg.eval;
let dfdx = cg.partial; // df/dx = 2x + y = 4
let dfdy = cg.partial; // df/dy = x = 1
let d2fdx = cg.partial; // d²f/dx² = 2
let d2mix = cg.partial; // d²f/dxdy = 1
Supported Operations
| Category | Operations |
|---|---|
| Arithmetic | add, sub, mul, div, neg, square |
| Powers | sqrt, pow, powi, powf |
| Trigonometric | sin, cos, tan, asin, acos, atan |
| Hyperbolic | sinh, cosh, tanh, asinh, acosh, atanh |
| Exponential | exp, ln |
Expression Macros
The expr! macro provides natural mathematical syntax:
use ;
let mut ad = new;
let x = ad.var;
let y = ad.var;
let f = expr!;
assert_eq!; // 4 + 6
With the proc-macros feature, the #[autodiff] attribute transforms regular functions:
[]
= { = "0.3", = ["proc-macros"] }
use ;
let mut ad = new;
let x = ad.var;
let y = ad.var;
let f = rosenbrock;
How It Works
Symbolic Graph Differentiation
differentiate(output, wrt) walks the computation graph in topological order and applies the chain rule at every node, creating new ECS entities for the derivative subgraph:
- Topological sort from
outputback to inputs - Base cases:
d(wrt)/d(wrt) = 1,d(other_input)/d(wrt) = 0,d(constant)/d(wrt) = 0 - Chain rule at each operation node creates derivative entities
- Constant folding via
smart_add,smart_mul, etc. collapses zero/one terms
For higher-order: differentiate(differentiate(f, x), y) gives d²f/dxdy.
CompiledGraph
compile() flattens the ECS graph (and any pre-built derivative subgraphs) into a Vec<NodeOp> -- a topologically sorted array of simple operations. A single forward pass evaluates all values.
For first-order gradients, compile_primal() + gradient() uses reverse-mode: one forward pass caches values, then one backward pass propagates adjoints to compute all partial derivatives simultaneously.
ECS Architecture
The Bevy ECS world stores the computation graph:
- Entities represent variables (inputs, constants, intermediate results)
- Components store values (
Value), operations (UnaryOp,BinaryOp), connectivity (UnaryInput,BinaryInputs), and dependency bitmasks (Dependencies)
Examples
See examples/README.md for descriptions. Run with:
Testing
RUSTFLAGS="-Zautodiff=Enable"
The test suite (297 tests) validates correctness through:
| Test type | What it validates | Count |
|---|---|---|
| Unit tests | Graph construction, all 21 operations, derivative properties, constant folding, CompiledGraph eval, reverse-mode adjoint formulas, reverse-mode backward pass | 261 |
| Oracle (autodiff crate) | First derivatives against independent forward-mode AD | 22 |
| Doc-tests | Code examples in documentation | 14 |
| Cross-validation | Reverse-mode gradient matches forward-mode symbolic partials | 8 (within unit) |
Documentation
- Architecture -- ECS graph representation, compilation pipeline, differentiation approaches
- Numerical Precision -- precision tiers, tolerance justification, known considerations
- API Reference -- rustdoc on docs.rs
Development
This project was co-developed with Claude, an AI assistant by Anthropic.
License
MIT