axonml-jit
Overview
axonml-jit provides Just-In-Time compilation for tensor operations, enabling significant performance improvements through operation tracing, graph optimization, and compiled function caching. It builds computation graphs from traced operations and optimizes them before execution.
Features
- Operation Tracing: Record tensor operations to build computation graphs automatically
- Graph Optimization: Constant folding, dead code elimination, algebraic simplification, and CSE
- Function Caching: LRU cache for compiled functions with configurable size
- Comprehensive IR: Rich intermediate representation supporting 40+ tensor operations
- Shape Inference: Automatic shape propagation including broadcast semantics
- Native Compilation: Prepared for Cranelift code generation (interpreter fallback available)
- Thread-Local Tracing: Safe concurrent tracing with thread-local state
Modules
| Module | Description |
|---|---|
ir |
Graph-based intermediate representation with Node, Op, Shape, and DataType definitions |
trace |
Operation tracing functionality with TracedValue and Tracer for graph construction |
optimize |
Optimization passes including constant folding, DCE, CSE, and algebraic simplification |
codegen |
JIT compiler and compiled function execution with interpreter fallback |
cache |
Function cache with LRU eviction and graph hashing |
error |
Error types and Result alias for JIT operations |
Usage
Add this to your Cargo.toml:
[]
= "0.1.0"
Basic Tracing and Compilation
use ;
// Trace operations to build a computation graph
let graph = trace;
// Compile the graph
let compiler = new;
let compiled = compiler.compile?;
// Execute with real data
let a_data = ;
let b_data = ;
let result = compiled.run?;
Traced Operations
use trace;
let graph = trace;
Custom Optimization
use ;
// Create optimizer with custom passes
let mut optimizer = new;
optimizer.add_pass;
optimizer.add_pass;
optimizer.add_pass;
optimizer.add_pass;
// Apply optimizations
let optimized_graph = optimizer.optimize;
// Compile optimized graph
let compiler = with_optimizer;
let compiled = compiler.compile?;
Cache Management
use JitCompiler;
let compiler = new;
// Compile multiple graphs
let _ = compiler.compile?;
let _ = compiler.compile?;
// Check cache statistics
let stats = compiler.cache_stats;
println!;
println!;
// Clear cache if needed
compiler.clear_cache;
Supported Operations
Binary Operations
add,sub,mul,div,pow,max,min
Unary Operations
neg,abs,sqrt,exp,log,sin,cos,tanh
Activations
relu,sigmoid,gelu,silu
Scalar Operations
add_scalar,mul_scalar
Reductions
sum,mean,sum_axis,mean_axis
Shape Operations
reshape,transpose,squeeze,unsqueeze
Matrix Operations
matmul
Comparison Operations
gt,lt,eq,where
Optimization Passes
| Pass | Description |
|---|---|
ConstantFolding |
Evaluate constant expressions at compile time |
DeadCodeElimination |
Remove nodes that do not contribute to outputs |
AlgebraicSimplification |
Simplify expressions (x * 1 = x, x + 0 = x, etc.) |
CommonSubexpressionElimination |
Reuse identical subexpressions |
ElementwiseFusion |
Fuse consecutive elementwise operations |
StrengthReduction |
Replace expensive ops with cheaper equivalents |
Tests
Run the test suite:
License
Licensed under either of:
- MIT License
- Apache License, Version 2.0
at your option.