axonml_jit/lib.rs
1//! JIT Compilation for Axonml
2//!
3//! This crate provides Just-In-Time compilation for tensor operations,
4//! enabling significant performance improvements through:
5//!
6//! - Operation tracing and graph construction
7//! - Graph optimization (fusion, constant folding, dead code elimination)
8//! - Native code generation via Cranelift
9//! - Compiled function caching
10//!
11//! # Example
12//!
13//! ```ignore
14//! use axonml_jit::{JitCompiler, trace};
15//!
16//! // Trace operations to build a computation graph
17//! let graph = trace(|tracer| {
18//! let a = tracer.input("a", &[2, 3]);
19//! let b = tracer.input("b", &[2, 3]);
20//! let c = a.add(&b);
21//! let d = c.mul_scalar(2.0);
22//! tracer.output("result", d)
23//! });
24//!
25//! // Compile the graph
26//! let compiler = JitCompiler::new();
27//! let compiled = compiler.compile(&graph)?;
28//!
29//! // Execute with real tensors
30//! let a = Tensor::randn(&[2, 3]);
31//! let b = Tensor::randn(&[2, 3]);
32//! let result = compiled.run(&[("a", &a), ("b", &b)])?;
33//! ```
34//!
35//! @version 0.1.0
36//! @author AutomataNexus Development Team
37
38#![warn(missing_docs)]
39#![allow(clippy::module_name_repetitions)]
40
41pub mod ir;
42pub mod trace;
43pub mod optimize;
44pub mod codegen;
45pub mod cache;
46pub mod error;
47
48pub use ir::{Graph, Node, NodeId, Op, DataType, Shape};
49pub use trace::{Tracer, TracedValue, trace};
50pub use optimize::{Optimizer, OptimizationPass};
51pub use codegen::{JitCompiler, CompiledFunction};
52pub use cache::FunctionCache;
53pub use error::{JitError, JitResult};
54
55#[cfg(test)]
56mod tests {
57 use super::*;
58
59 #[test]
60 fn test_simple_trace() {
61 let graph = trace(|tracer| {
62 let a = tracer.input("a", &[2, 3]);
63 let b = tracer.input("b", &[2, 3]);
64 let c = a.add(&b);
65 tracer.output("result", c)
66 });
67
68 assert_eq!(graph.inputs().len(), 2);
69 assert_eq!(graph.outputs().len(), 1);
70 }
71
72 #[test]
73 fn test_optimization() {
74 let graph = trace(|tracer| {
75 let a = tracer.input("a", &[2, 3]);
76 let b = tracer.constant(2.0, &[2, 3]);
77 let c = a.mul(&b);
78 tracer.output("result", c)
79 });
80
81 let mut optimizer = Optimizer::new();
82 optimizer.add_pass(OptimizationPass::ConstantFolding);
83 let optimized = optimizer.optimize(graph);
84
85 // Graph should still be valid
86 assert_eq!(optimized.inputs().len(), 1);
87 }
88}