axonml_jit/
lib.rs

1//! JIT Compilation for Axonml
2//!
3//! This crate provides Just-In-Time compilation for tensor operations,
4//! enabling significant performance improvements through:
5//!
6//! - Operation tracing and graph construction
7//! - Graph optimization (fusion, constant folding, dead code elimination)
8//! - Native code generation via Cranelift
9//! - Compiled function caching
10//!
11//! # Example
12//!
13//! ```ignore
14//! use axonml_jit::{JitCompiler, trace};
15//!
16//! // Trace operations to build a computation graph
17//! let graph = trace(|tracer| {
18//!     let a = tracer.input("a", &[2, 3]);
19//!     let b = tracer.input("b", &[2, 3]);
20//!     let c = a.add(&b);
21//!     let d = c.mul_scalar(2.0);
22//!     tracer.output("result", d)
23//! });
24//!
25//! // Compile the graph
26//! let compiler = JitCompiler::new();
27//! let compiled = compiler.compile(&graph)?;
28//!
29//! // Execute with real tensors
30//! let a = Tensor::randn(&[2, 3]);
31//! let b = Tensor::randn(&[2, 3]);
32//! let result = compiled.run(&[("a", &a), ("b", &b)])?;
33//! ```
34//!
35//! @version 0.1.0
36//! @author AutomataNexus Development Team
37
38#![warn(missing_docs)]
39#![allow(clippy::module_name_repetitions)]
40
41pub mod ir;
42pub mod trace;
43pub mod optimize;
44pub mod codegen;
45pub mod cache;
46pub mod error;
47
48pub use ir::{Graph, Node, NodeId, Op, DataType, Shape};
49pub use trace::{Tracer, TracedValue, trace};
50pub use optimize::{Optimizer, OptimizationPass};
51pub use codegen::{JitCompiler, CompiledFunction};
52pub use cache::FunctionCache;
53pub use error::{JitError, JitResult};
54
55#[cfg(test)]
56mod tests {
57    use super::*;
58
59    #[test]
60    fn test_simple_trace() {
61        let graph = trace(|tracer| {
62            let a = tracer.input("a", &[2, 3]);
63            let b = tracer.input("b", &[2, 3]);
64            let c = a.add(&b);
65            tracer.output("result", c)
66        });
67
68        assert_eq!(graph.inputs().len(), 2);
69        assert_eq!(graph.outputs().len(), 1);
70    }
71
72    #[test]
73    fn test_optimization() {
74        let graph = trace(|tracer| {
75            let a = tracer.input("a", &[2, 3]);
76            let b = tracer.constant(2.0, &[2, 3]);
77            let c = a.mul(&b);
78            tracer.output("result", c)
79        });
80
81        let mut optimizer = Optimizer::new();
82        optimizer.add_pass(OptimizationPass::ConstantFolding);
83        let optimized = optimizer.optimize(graph);
84
85        // Graph should still be valid
86        assert_eq!(optimized.inputs().len(), 1);
87    }
88}