axonml-jit
Overview
axonml-jit is the AxonML tracing JIT. A Tracer records tensor operations
into a typed Graph IR; an Optimizer runs a stack of passes (constant
folding, DCE, CSE, algebraic simplification, elementwise fusion, strength
reduction); a JitCompiler emits either an interpreter-backed
CompiledFunction or a native Cranelift-compiled one; and a graph-hash
FunctionCache (LRU) lets repeated compilations hit cache. Cranelift
0.111 is the code-gen backend.
Features
- Operation Tracing —
trace(|tracer| { ... })andTracerAPIs build aGraphfrom recorded operations using thread-local state - Typed IR —
Graph,Node,NodeId,Op(40+ variants),Shape(with broadcast checks and broadcast-shape computation),DataType - Optimizer —
Optimizer::default_passes()plus sixOptimizationPassvariants (ConstantFolding,DeadCodeElimination,CommonSubexpressionElimination,AlgebraicSimplification,ElementwiseFusion,StrengthReduction) - JIT Compiler —
JitCompilerwith interpreter execution and optional Cranelift native codegen (enable_native(true)) - Higher-Level Facade —
compile_fn,compile_fn_with_config,compile_graph,compile_graph_with_config,CompiledModel,LazyCompiled(deferred compilation) withCompileConfig(Mode::{Default, ReduceOverhead, MaxAutotune},Backend::{Default, Eager, AOT, ONNX},fullgraph,dynamic,disable, custom passes) - Function Caching —
FunctionCachewith LRU eviction andSelf::hash_graph-based keying;CacheStatswithutilization - Shape Inference — automatic shape propagation including broadcast semantics (
Shape::broadcast_shape) - Thread-Local Tracing — safe concurrent tracing via per-thread tracer state
Modules
| Module | Description |
|---|---|
ir |
Graph, Node, NodeId, Op, Shape, DataType, topological order, validation |
trace |
Tracer, TracedValue, trace entry point, thread-local state |
optimize |
Optimizer, OptimizationPass (6 variants), default_passes |
codegen |
JitCompiler, CompiledFunction (Interpreted + Cranelift Native kinds) |
compile |
compile_fn, compile_graph, CompiledModel, LazyCompiled, CompileConfig, CompileStats, Mode, Backend |
cache |
FunctionCache (LRU), CacheStats |
error |
JitError, JitResult |
Usage
Add this to your Cargo.toml:
[]
= "0.6.1"
Basic Tracing and Compilation
use ;
// Trace operations to build a computation graph
let graph = trace;
// Compile the graph (interpreter-backed by default)
let compiler = new;
let compiled = compiler.compile?;
// Execute with real data — inputs are name/slice tuples
let a_data = ;
let b_data = ;
let result = compiled.run?;
Cranelift Native Codegen
use JitCompiler;
let mut compiler = new;
compiler.enable_native; // opt in to Cranelift codegen
let compiled = compiler.compile?;
Traced Operations
use trace;
let graph = trace;
Custom Optimization
use ;
// Build a custom pass pipeline
let mut optimizer = new;
optimizer.add_pass;
optimizer.add_pass;
optimizer.add_pass;
optimizer.add_pass;
// Apply optimizations directly
let optimized_graph = optimizer.optimize;
// Or hand the optimizer to the compiler
let compiler = with_optimizer;
let compiled = compiler.compile?;
Higher-Level compile_fn / CompiledModel
use ;
use HashMap;
// Zero-config
let model = compile_fn?;
// With config
let cfg = new
.mode
.backend
.fullgraph;
let model = compile_fn_with_config?;
// CompiledModel runs on HashMap<String, Vec<f32>>
let mut inputs = new;
inputs.insert;
let outputs = model.run?;
// Inspect compilation
println!;
Lazy Compilation
use LazyCompiled;
let lazy = new;
// Compiled on first call, cached thereafter.
let outputs = lazy.run?;
Cache Management
use JitCompiler;
let compiler = new;
// Compile multiple graphs
let _ = compiler.compile?;
let _ = compiler.compile?;
// Check cache statistics
let stats = compiler.cache_stats;
println!;
println!;
// Clear cache if needed
compiler.clear_cache;
Supported Operations
All recorded via Op enum variants:
Binary Operations
Add,Sub,Mul,Div,Pow,Max,Min
Unary Operations
Neg,Abs,Sqrt,Exp,Log,Sin,Cos,Tanh
Activations
Relu,Sigmoid,Gelu,Silu
Scalar Operations
AddScalar,MulScalar
Reductions
Sum,SumAxis,Mean,MeanAxis,MaxAxis
Shape Operations
Reshape,Transpose,Squeeze,Unsqueeze,Broadcast
Matrix Operations
MatMul
Comparison / Conditional
Gt,Lt,Eq,Where
Special
Cast(changeDataType),Contiguous,Input,Output,Constant
Optimization Passes
| Pass | Description |
|---|---|
ConstantFolding |
Evaluate constant expressions at compile time |
DeadCodeElimination |
Remove nodes that don't feed an output |
CommonSubexpressionElimination |
Reuse identical subexpressions |
AlgebraicSimplification |
x * 1 = x, x + 0 = x, etc. |
ElementwiseFusion |
Fuse consecutive elementwise ops |
StrengthReduction |
Replace expensive ops with cheaper equivalents |
Default CompileConfig enables ConstantFolding, DeadCodeElimination, and
CommonSubexpressionElimination. Mode::MaxAutotune additionally appends
ElementwiseFusion and AlgebraicSimplification.
Tests
License
Licensed under either of:
- MIT License
- Apache License, Version 2.0
at your option.