Skip to main content

axonml_jit/
lib.rs

1//! JIT Compilation for Axonml
2//!
3//! This crate provides Just-In-Time compilation for tensor operations,
4//! enabling significant performance improvements through:
5//!
6//! - Operation tracing and graph construction
7//! - Graph optimization (fusion, constant folding, dead code elimination)
8//! - Native code generation via Cranelift
9//! - Compiled function caching
10//!
11//! # Example
12//!
13//! ```ignore
14//! use axonml_jit::{JitCompiler, trace};
15//!
16//! // Trace operations to build a computation graph
17//! let graph = trace(|tracer| {
18//!     let a = tracer.input("a", &[2, 3]);
19//!     let b = tracer.input("b", &[2, 3]);
20//!     let c = a.add(&b);
21//!     let d = c.mul_scalar(2.0);
22//!     tracer.output("result", d)
23//! });
24//!
25//! // Compile the graph
26//! let compiler = JitCompiler::new();
27//! let compiled = compiler.compile(&graph)?;
28//!
29//! // Execute with real tensors
30//! let a = Tensor::randn(&[2, 3]);
31//! let b = Tensor::randn(&[2, 3]);
32//! let result = compiled.run(&[("a", &a), ("b", &b)])?;
33//! ```
34//!
35//! @version 0.1.0
36//! @author AutomataNexus Development Team
37
38#![warn(missing_docs)]
39#![allow(clippy::module_name_repetitions)]
40
41pub mod cache;
42pub mod codegen;
43pub mod compile;
44pub mod error;
45pub mod ir;
46pub mod optimize;
47pub mod trace;
48
49pub use cache::FunctionCache;
50pub use codegen::{CompiledFunction, JitCompiler};
51pub use compile::{
52    compile_fn, compile_fn_with_config, compile_graph, compile_graph_with_config, Backend,
53    CompileConfig, CompileStats, CompiledModel, LazyCompiled, Mode,
54};
55pub use error::{JitError, JitResult};
56pub use ir::{DataType, Graph, Node, NodeId, Op, Shape};
57pub use optimize::{OptimizationPass, Optimizer};
58pub use trace::{trace, TracedValue, Tracer};
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63
64    #[test]
65    fn test_simple_trace() {
66        let graph = trace(|tracer| {
67            let a = tracer.input("a", &[2, 3]);
68            let b = tracer.input("b", &[2, 3]);
69            let c = a.add(&b);
70            tracer.output("result", c)
71        });
72
73        assert_eq!(graph.inputs().len(), 2);
74        assert_eq!(graph.outputs().len(), 1);
75    }
76
77    #[test]
78    fn test_optimization() {
79        let graph = trace(|tracer| {
80            let a = tracer.input("a", &[2, 3]);
81            let b = tracer.constant(2.0, &[2, 3]);
82            let c = a.mul(&b);
83            tracer.output("result", c)
84        });
85
86        let mut optimizer = Optimizer::new();
87        optimizer.add_pass(OptimizationPass::ConstantFolding);
88        let optimized = optimizer.optimize(graph);
89
90        // Graph should still be valid
91        assert_eq!(optimized.inputs().len(), 1);
92    }
93}