Skip to main content

axonml_jit/
compile.rs

1//! torch.compile Equivalent - High-Level Compilation API
2//!
3//! Provides a PyTorch 2.0 torch.compile-like API for automatic optimization
4//! of models and functions through tracing and compilation.
5//!
6//! # Example
7//! ```rust,ignore
8//! use axonml_jit::compile::{compile_fn, CompileConfig, Mode};
9//!
10//! // Compile with default settings
11//! let compiled = compile_fn(|t| {
12//!     let x = t.input("x", &[2, 3]);
13//!     let y = x.relu();
14//!     t.output("y", y)
15//! }).unwrap();
16//!
17//! // Or with custom configuration
18//! let compiled = compile_fn_with_config(f, CompileConfig::new()
19//!     .mode(Mode::MaxAutotune)
20//!     .fullgraph(true)).unwrap();
21//! ```
22//!
23//! @version 0.1.0
24
25use crate::codegen::{CompiledFunction, JitCompiler};
26use crate::ir::{Graph, Node, Op};
27use crate::optimize::{OptimizationPass, Optimizer};
28use crate::trace::{trace, TracedValue, Tracer};
29use crate::{JitError, JitResult};
30use std::collections::HashMap;
31use std::sync::Mutex;
32
33// =============================================================================
34// Compilation Mode
35// =============================================================================
36
37/// Compilation mode controlling optimization level.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum Mode {
40    /// Default mode: balanced optimization
41    Default,
42    /// Reduce overhead: minimize compilation time
43    ReduceOverhead,
44    /// Maximum autotune: try multiple implementations
45    MaxAutotune,
46}
47
48impl Default for Mode {
49    fn default() -> Self {
50        Self::Default
51    }
52}
53
54/// Backend for code generation.
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum Backend {
57    /// Default backend (Cranelift)
58    Default,
59    /// Eager mode (no compilation)
60    Eager,
61    /// AOT (Ahead of Time) compilation
62    AOT,
63    /// ONNX export
64    ONNX,
65}
66
67impl Default for Backend {
68    fn default() -> Self {
69        Self::Default
70    }
71}
72
73// =============================================================================
74// Compile Configuration
75// =============================================================================
76
77/// Configuration for model compilation.
78#[derive(Debug, Clone)]
79pub struct CompileConfig {
80    /// Compilation mode
81    pub mode: Mode,
82    /// Backend for code generation
83    pub backend: Backend,
84    /// Whether to require full graph capture
85    pub fullgraph: bool,
86    /// Whether to enable dynamic shapes
87    pub dynamic: bool,
88    /// Disable compilation (for debugging)
89    pub disable: bool,
90    /// Optimization passes to apply
91    pub passes: Vec<OptimizationPass>,
92}
93
94impl Default for CompileConfig {
95    fn default() -> Self {
96        Self {
97            mode: Mode::Default,
98            backend: Backend::Default,
99            fullgraph: false,
100            dynamic: false,
101            disable: false,
102            passes: vec![
103                OptimizationPass::ConstantFolding,
104                OptimizationPass::DeadCodeElimination,
105                OptimizationPass::CommonSubexpressionElimination,
106            ],
107        }
108    }
109}
110
111impl CompileConfig {
112    /// Creates a new compile configuration with defaults.
113    pub fn new() -> Self {
114        Self::default()
115    }
116
117    /// Builder: set compilation mode.
118    pub fn mode(mut self, mode: Mode) -> Self {
119        self.mode = mode;
120        if mode == Mode::MaxAutotune {
121            // Add more aggressive optimizations
122            self.passes.push(OptimizationPass::ElementwiseFusion);
123            self.passes.push(OptimizationPass::AlgebraicSimplification);
124        }
125        self
126    }
127
128    /// Builder: set backend.
129    pub fn backend(mut self, backend: Backend) -> Self {
130        self.backend = backend;
131        self
132    }
133
134    /// Builder: require full graph capture.
135    pub fn fullgraph(mut self, fullgraph: bool) -> Self {
136        self.fullgraph = fullgraph;
137        self
138    }
139
140    /// Builder: enable dynamic shapes.
141    pub fn dynamic(mut self, dynamic: bool) -> Self {
142        self.dynamic = dynamic;
143        self
144    }
145
146    /// Builder: disable compilation.
147    pub fn disable(mut self, disable: bool) -> Self {
148        self.disable = disable;
149        self
150    }
151
152    /// Builder: add optimization pass.
153    pub fn add_pass(mut self, pass: OptimizationPass) -> Self {
154        self.passes.push(pass);
155        self
156    }
157}
158
159// =============================================================================
160// Compiled Model
161// =============================================================================
162
163/// A compiled model or function.
164///
165/// Wraps traced computation with optimized execution.
166pub struct CompiledModel {
167    /// Original graph
168    graph: Graph,
169    /// Optimized graph
170    optimized_graph: Graph,
171    /// Compiled function (if available)
172    compiled_fn: Option<CompiledFunction>,
173    /// Configuration
174    config: CompileConfig,
175    /// Input names
176    input_names: Vec<String>,
177    /// Output names
178    output_names: Vec<String>,
179}
180
181impl CompiledModel {
182    /// Creates a new compiled model from a graph.
183    pub fn from_graph(graph: Graph, config: CompileConfig) -> JitResult<Self> {
184        // Apply optimizations
185        let mut optimizer = Optimizer::new();
186        for pass in &config.passes {
187            optimizer.add_pass(*pass);
188        }
189        let optimized_graph = optimizer.optimize(graph.clone());
190
191        // Compile if not disabled
192        let compiled_fn = if !config.disable && config.backend != Backend::Eager {
193            let compiler = JitCompiler::new();
194            compiler.compile(&optimized_graph).ok()
195        } else {
196            None
197        };
198
199        let input_names: Vec<String> = graph.inputs().keys().cloned().collect();
200        let output_names: Vec<String> = graph.outputs().keys().cloned().collect();
201
202        Ok(Self {
203            graph,
204            optimized_graph,
205            compiled_fn,
206            config,
207            input_names,
208            output_names,
209        })
210    }
211
212    /// Returns input names.
213    pub fn input_names(&self) -> &[String] {
214        &self.input_names
215    }
216
217    /// Returns output names.
218    pub fn output_names(&self) -> &[String] {
219        &self.output_names
220    }
221
222    /// Returns the original graph.
223    pub fn graph(&self) -> &Graph {
224        &self.graph
225    }
226
227    /// Returns the optimized graph.
228    pub fn optimized_graph(&self) -> &Graph {
229        &self.optimized_graph
230    }
231
232    /// Checks if compilation succeeded.
233    pub fn is_compiled(&self) -> bool {
234        self.compiled_fn.is_some()
235    }
236
237    /// Returns compilation statistics.
238    pub fn stats(&self) -> CompileStats {
239        CompileStats {
240            original_ops: self.graph.len(),
241            optimized_ops: self.optimized_graph.len(),
242            is_compiled: self.compiled_fn.is_some(),
243            passes_applied: self.config.passes.len(),
244        }
245    }
246
247    /// Runs the compiled model with named inputs.
248    pub fn run(&self, inputs: &HashMap<String, Vec<f32>>) -> JitResult<HashMap<String, Vec<f32>>> {
249        // Validate inputs
250        for name in &self.input_names {
251            if !inputs.contains_key(name) {
252                return Err(JitError::InputNotFound(name.clone()));
253            }
254        }
255
256        // Fall back to interpreted execution (compiled function API may differ)
257        self.interpret(inputs)
258    }
259
260    /// Interprets the graph (fallback when not compiled).
261    fn interpret(
262        &self,
263        inputs: &HashMap<String, Vec<f32>>,
264    ) -> JitResult<HashMap<String, Vec<f32>>> {
265        // Simple interpreter for the graph
266        let mut values: HashMap<String, Vec<f32>> = HashMap::new();
267
268        // Copy inputs
269        for (name, data) in inputs {
270            values.insert(name.clone(), data.clone());
271        }
272
273        for node in self.optimized_graph.nodes() {
274            let result = self.execute_node(node, &values)?;
275            // Use node id as key
276            let key = format!("node_{}", node.id.index());
277            values.insert(key, result);
278        }
279
280        // Collect outputs
281        let mut outputs = HashMap::new();
282        for name in &self.output_names {
283            // Find the output node and get its value
284            if let Some(node_id) = self.optimized_graph.output(name) {
285                let key = format!("node_{}", node_id.index());
286                if let Some(val) = values.get(&key) {
287                    outputs.insert(name.clone(), val.clone());
288                }
289            }
290        }
291
292        Ok(outputs)
293    }
294
295    /// Executes a single node.
296    fn execute_node(&self, node: &Node, values: &HashMap<String, Vec<f32>>) -> JitResult<Vec<f32>> {
297        match &node.op {
298            Op::Input { name } => values
299                .get(name)
300                .cloned()
301                .ok_or_else(|| JitError::InputNotFound(name.clone())),
302            Op::Output { input, .. } => {
303                let key = format!("node_{}", input.index());
304                values
305                    .get(&key)
306                    .cloned()
307                    .ok_or_else(|| JitError::InputNotFound(key))
308            }
309            Op::Constant { value } => Ok(vec![*value as f32]),
310            Op::Add { lhs, rhs } => {
311                let a = self.get_node_value(*lhs, values)?;
312                let b = self.get_node_value(*rhs, values)?;
313                Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
314            }
315            Op::Sub { lhs, rhs } => {
316                let a = self.get_node_value(*lhs, values)?;
317                let b = self.get_node_value(*rhs, values)?;
318                Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
319            }
320            Op::Mul { lhs, rhs } => {
321                let a = self.get_node_value(*lhs, values)?;
322                let b = self.get_node_value(*rhs, values)?;
323                Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).collect())
324            }
325            Op::Div { lhs, rhs } => {
326                let a = self.get_node_value(*lhs, values)?;
327                let b = self.get_node_value(*rhs, values)?;
328                Ok(a.iter().zip(b.iter()).map(|(x, y)| x / y).collect())
329            }
330            Op::Neg { input } => {
331                let a = self.get_node_value(*input, values)?;
332                Ok(a.iter().map(|x| -x).collect())
333            }
334            Op::Exp { input } => {
335                let a = self.get_node_value(*input, values)?;
336                Ok(a.iter().map(|x| x.exp()).collect())
337            }
338            Op::Log { input } => {
339                let a = self.get_node_value(*input, values)?;
340                Ok(a.iter().map(|x| x.ln()).collect())
341            }
342            Op::Sqrt { input } => {
343                let a = self.get_node_value(*input, values)?;
344                Ok(a.iter().map(|x| x.sqrt()).collect())
345            }
346            Op::Relu { input } => {
347                let a = self.get_node_value(*input, values)?;
348                Ok(a.iter().map(|x| x.max(0.0)).collect())
349            }
350            Op::Sigmoid { input } => {
351                let a = self.get_node_value(*input, values)?;
352                Ok(a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect())
353            }
354            Op::Tanh { input } => {
355                let a = self.get_node_value(*input, values)?;
356                Ok(a.iter().map(|x| x.tanh()).collect())
357            }
358            _ => {
359                // For unsupported ops, return zeros with same shape
360                let numel = node.shape.numel();
361                Ok(vec![0.0; numel])
362            }
363        }
364    }
365
366    /// Gets value for a node by ID.
367    fn get_node_value(
368        &self,
369        node_id: crate::ir::NodeId,
370        values: &HashMap<String, Vec<f32>>,
371    ) -> JitResult<Vec<f32>> {
372        // Check if it's an input node
373        let node = self.optimized_graph.node(node_id);
374        if let Op::Input { name } = &node.op {
375            return values
376                .get(name)
377                .cloned()
378                .ok_or_else(|| JitError::InputNotFound(name.clone()));
379        }
380
381        // Otherwise use node key
382        let key = format!("node_{}", node_id.index());
383        values
384            .get(&key)
385            .cloned()
386            .ok_or_else(|| JitError::InputNotFound(key))
387    }
388}
389
390/// Compilation statistics.
391#[derive(Debug, Clone)]
392pub struct CompileStats {
393    /// Number of operations in original graph
394    pub original_ops: usize,
395    /// Number of operations after optimization
396    pub optimized_ops: usize,
397    /// Whether compilation succeeded
398    pub is_compiled: bool,
399    /// Number of optimization passes applied
400    pub passes_applied: usize,
401}
402
403impl CompileStats {
404    /// Returns the optimization ratio.
405    pub fn optimization_ratio(&self) -> f32 {
406        if self.original_ops == 0 {
407            1.0
408        } else {
409            self.optimized_ops as f32 / self.original_ops as f32
410        }
411    }
412}
413
414// =============================================================================
415// Compile Functions
416// =============================================================================
417
418/// Compiles a traced function with default settings.
419///
420/// # Example
421/// ```rust,ignore
422/// let graph = trace(|t| {
423///     let x = t.input("x", &[2, 3]);
424///     let y = x.relu();
425///     t.output("y", y)
426/// });
427///
428/// let compiled = compile_graph(graph)?;
429/// ```
430pub fn compile_graph(graph: Graph) -> JitResult<CompiledModel> {
431    CompiledModel::from_graph(graph, CompileConfig::default())
432}
433
434/// Compiles with custom configuration.
435pub fn compile_graph_with_config(graph: Graph, config: CompileConfig) -> JitResult<CompiledModel> {
436    CompiledModel::from_graph(graph, config)
437}
438
439/// Traces and compiles a function in one step.
440///
441/// # Example
442/// ```rust,ignore
443/// let compiled = compile_fn(|t| {
444///     let x = t.input("x", &[2, 3]);
445///     let y = x.add(&t.constant(1.0, &[2, 3]));
446///     t.output("y", y)
447/// })?;
448/// ```
449pub fn compile_fn<F>(f: F) -> JitResult<CompiledModel>
450where
451    F: FnOnce(&Tracer) -> TracedValue,
452{
453    let graph = trace(f);
454    compile_graph(graph)
455}
456
457/// Traces and compiles with custom configuration.
458pub fn compile_fn_with_config<F>(f: F, config: CompileConfig) -> JitResult<CompiledModel>
459where
460    F: FnOnce(&Tracer) -> TracedValue,
461{
462    let graph = trace(f);
463    compile_graph_with_config(graph, config)
464}
465
466// =============================================================================
467// Dynamo-style Decorators
468// =============================================================================
469
470/// Wrapper for lazy compilation.
471///
472/// Traces and compiles on first call, caches for subsequent calls.
473pub struct LazyCompiled<F> {
474    func: F,
475    compiled: Mutex<Option<CompiledModel>>,
476    config: CompileConfig,
477}
478
479impl<F> LazyCompiled<F>
480where
481    F: Fn(&Tracer) -> TracedValue,
482{
483    /// Creates a new lazy compiled wrapper.
484    pub fn new(func: F) -> Self {
485        Self {
486            func,
487            compiled: Mutex::new(None),
488            config: CompileConfig::default(),
489        }
490    }
491
492    /// Creates with custom config.
493    pub fn with_config(func: F, config: CompileConfig) -> Self {
494        Self {
495            func,
496            compiled: Mutex::new(None),
497            config,
498        }
499    }
500
501    /// Runs the function, compiling on first call.
502    pub fn run(&self, inputs: &HashMap<String, Vec<f32>>) -> JitResult<HashMap<String, Vec<f32>>> {
503        let mut compiled = self.compiled.lock().unwrap();
504
505        if compiled.is_none() {
506            let graph = trace(&self.func);
507            *compiled = Some(CompiledModel::from_graph(graph, self.config.clone())?);
508        }
509
510        compiled.as_ref().unwrap().run(inputs)
511    }
512}
513
514// =============================================================================
515// Tests
516// =============================================================================
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521
522    #[test]
523    fn test_compile_config_default() {
524        let config = CompileConfig::default();
525        assert_eq!(config.mode, Mode::Default);
526        assert!(!config.fullgraph);
527        assert!(!config.disable);
528    }
529
530    #[test]
531    fn test_compile_config_builder() {
532        let config = CompileConfig::new()
533            .mode(Mode::MaxAutotune)
534            .fullgraph(true)
535            .dynamic(true);
536
537        assert_eq!(config.mode, Mode::MaxAutotune);
538        assert!(config.fullgraph);
539        assert!(config.dynamic);
540    }
541
542    #[test]
543    fn test_compile_simple_graph() {
544        let graph = trace(|t| {
545            let x = t.input("x", &[2]);
546            let y = x.relu();
547            t.output("y", y)
548        });
549
550        let compiled = compile_graph(graph).unwrap();
551        assert!(compiled.input_names().contains(&"x".to_string()));
552    }
553
554    #[test]
555    fn test_compile_stats() {
556        let graph = trace(|t| {
557            let x = t.input("x", &[2]);
558            let y = x.relu();
559            t.output("y", y)
560        });
561
562        let compiled = compile_graph(graph).unwrap();
563        let stats = compiled.stats();
564
565        assert!(stats.original_ops > 0);
566        assert!(stats.passes_applied > 0);
567    }
568
569    #[test]
570    fn test_mode_enum() {
571        assert_eq!(Mode::default(), Mode::Default);
572        assert_ne!(Mode::MaxAutotune, Mode::ReduceOverhead);
573    }
574
575    #[test]
576    fn test_backend_enum() {
577        assert_eq!(Backend::default(), Backend::Default);
578    }
579
580    #[test]
581    fn test_compiled_model_run() {
582        let graph = trace(|t| {
583            let x = t.input("x", &[2]);
584            let y = x.relu();
585            t.output("y", y)
586        });
587
588        let compiled = compile_graph_with_config(
589            graph,
590            CompileConfig::new().disable(true), // Use interpreter
591        )
592        .unwrap();
593
594        let mut inputs = HashMap::new();
595        inputs.insert("x".to_string(), vec![-1.0, 2.0]);
596
597        let outputs = compiled.run(&inputs).unwrap();
598        let y = outputs.get("y").unwrap();
599        assert_eq!(y, &vec![0.0, 2.0]); // ReLU
600    }
601}