Skip to main content

axonml_jit/
compile.rs

1//! torch.compile Equivalent - High-Level Compilation API
2//!
3//! Provides a PyTorch 2.0 torch.compile-like API for automatic optimization
4//! of models and functions through tracing and compilation.
5//!
6//! # Example
7//! ```rust,ignore
8//! use axonml_jit::compile::{compile_fn, CompileConfig, Mode};
9//!
10//! // Compile with default settings
11//! let compiled = compile_fn(|t| {
12//!     let x = t.input("x", &[2, 3]);
13//!     let y = x.relu();
14//!     t.output("y", y)
15//! }).unwrap();
16//!
17//! // Or with custom configuration
18//! let compiled = compile_fn_with_config(f, CompileConfig::new()
19//!     .mode(Mode::MaxAutotune)
20//!     .fullgraph(true)).unwrap();
21//! ```
22//!
23//! @version 0.1.0
24
25use crate::codegen::{CompiledFunction, JitCompiler};
26use crate::ir::{Graph, Node, Op};
27use crate::optimize::{OptimizationPass, Optimizer};
28use crate::trace::{trace, TracedValue, Tracer};
29use crate::{JitError, JitResult};
30use std::collections::HashMap;
31use std::sync::Mutex;
32
33// =============================================================================
34// Compilation Mode
35// =============================================================================
36
37/// Compilation mode controlling optimization level.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum Mode {
40    /// Default mode: balanced optimization
41    Default,
42    /// Reduce overhead: minimize compilation time
43    ReduceOverhead,
44    /// Maximum autotune: try multiple implementations
45    MaxAutotune,
46}
47
48impl Default for Mode {
49    fn default() -> Self {
50        Self::Default
51    }
52}
53
54/// Backend for code generation.
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum Backend {
57    /// Default backend (Cranelift)
58    Default,
59    /// Eager mode (no compilation)
60    Eager,
61    /// AOT (Ahead of Time) compilation
62    AOT,
63    /// ONNX export
64    ONNX,
65}
66
67impl Default for Backend {
68    fn default() -> Self {
69        Self::Default
70    }
71}
72
73// =============================================================================
74// Compile Configuration
75// =============================================================================
76
77/// Configuration for model compilation.
78#[derive(Debug, Clone)]
79pub struct CompileConfig {
80    /// Compilation mode
81    pub mode: Mode,
82    /// Backend for code generation
83    pub backend: Backend,
84    /// Whether to require full graph capture
85    pub fullgraph: bool,
86    /// Whether to enable dynamic shapes
87    pub dynamic: bool,
88    /// Disable compilation (for debugging)
89    pub disable: bool,
90    /// Optimization passes to apply
91    pub passes: Vec<OptimizationPass>,
92}
93
94impl Default for CompileConfig {
95    fn default() -> Self {
96        Self {
97            mode: Mode::Default,
98            backend: Backend::Default,
99            fullgraph: false,
100            dynamic: false,
101            disable: false,
102            passes: vec![
103                OptimizationPass::ConstantFolding,
104                OptimizationPass::DeadCodeElimination,
105                OptimizationPass::CommonSubexpressionElimination,
106            ],
107        }
108    }
109}
110
111impl CompileConfig {
112    /// Creates a new compile configuration with defaults.
113    pub fn new() -> Self {
114        Self::default()
115    }
116
117    /// Builder: set compilation mode.
118    pub fn mode(mut self, mode: Mode) -> Self {
119        self.mode = mode;
120        if mode == Mode::MaxAutotune {
121            // Add more aggressive optimizations
122            self.passes.push(OptimizationPass::ElementwiseFusion);
123            self.passes.push(OptimizationPass::AlgebraicSimplification);
124        }
125        self
126    }
127
128    /// Builder: set backend.
129    pub fn backend(mut self, backend: Backend) -> Self {
130        self.backend = backend;
131        self
132    }
133
134    /// Builder: require full graph capture.
135    pub fn fullgraph(mut self, fullgraph: bool) -> Self {
136        self.fullgraph = fullgraph;
137        self
138    }
139
140    /// Builder: enable dynamic shapes.
141    pub fn dynamic(mut self, dynamic: bool) -> Self {
142        self.dynamic = dynamic;
143        self
144    }
145
146    /// Builder: disable compilation.
147    pub fn disable(mut self, disable: bool) -> Self {
148        self.disable = disable;
149        self
150    }
151
152    /// Builder: add optimization pass.
153    pub fn add_pass(mut self, pass: OptimizationPass) -> Self {
154        self.passes.push(pass);
155        self
156    }
157}
158
159// =============================================================================
160// Compiled Model
161// =============================================================================
162
163/// A compiled model or function.
164///
165/// Wraps traced computation with optimized execution.
166pub struct CompiledModel {
167    /// Original graph
168    graph: Graph,
169    /// Optimized graph
170    optimized_graph: Graph,
171    /// Compiled function (if available)
172    compiled_fn: Option<CompiledFunction>,
173    /// Configuration
174    config: CompileConfig,
175    /// Input names
176    input_names: Vec<String>,
177    /// Output names
178    output_names: Vec<String>,
179}
180
181impl CompiledModel {
182    /// Creates a new compiled model from a graph.
183    pub fn from_graph(graph: Graph, config: CompileConfig) -> JitResult<Self> {
184        // Apply optimizations
185        let mut optimizer = Optimizer::new();
186        for pass in &config.passes {
187            optimizer.add_pass(*pass);
188        }
189        let optimized_graph = optimizer.optimize(graph.clone());
190
191        // Compile if not disabled
192        let compiled_fn = if !config.disable && config.backend != Backend::Eager {
193            let compiler = JitCompiler::new();
194            compiler.compile(&optimized_graph).ok()
195        } else {
196            None
197        };
198
199        let input_names: Vec<String> = graph.inputs().keys().cloned().collect();
200        let output_names: Vec<String> = graph.outputs().keys().cloned().collect();
201
202        Ok(Self {
203            graph,
204            optimized_graph,
205            compiled_fn,
206            config,
207            input_names,
208            output_names,
209        })
210    }
211
212    /// Returns input names.
213    pub fn input_names(&self) -> &[String] {
214        &self.input_names
215    }
216
217    /// Returns output names.
218    pub fn output_names(&self) -> &[String] {
219        &self.output_names
220    }
221
222    /// Returns the original graph.
223    pub fn graph(&self) -> &Graph {
224        &self.graph
225    }
226
227    /// Returns the optimized graph.
228    pub fn optimized_graph(&self) -> &Graph {
229        &self.optimized_graph
230    }
231
232    /// Checks if compilation succeeded.
233    pub fn is_compiled(&self) -> bool {
234        self.compiled_fn.is_some()
235    }
236
237    /// Returns compilation statistics.
238    pub fn stats(&self) -> CompileStats {
239        CompileStats {
240            original_ops: self.graph.len(),
241            optimized_ops: self.optimized_graph.len(),
242            is_compiled: self.compiled_fn.is_some(),
243            passes_applied: self.config.passes.len(),
244        }
245    }
246
247    /// Runs the compiled model with named inputs.
248    pub fn run(&self, inputs: &HashMap<String, Vec<f32>>) -> JitResult<HashMap<String, Vec<f32>>> {
249        // Validate inputs
250        for name in &self.input_names {
251            if !inputs.contains_key(name) {
252                return Err(JitError::InputNotFound(name.clone()));
253            }
254        }
255
256        // Fall back to interpreted execution (compiled function API may differ)
257        self.interpret(inputs)
258    }
259
260    /// Interprets the graph (fallback when not compiled).
261    fn interpret(&self, inputs: &HashMap<String, Vec<f32>>) -> JitResult<HashMap<String, Vec<f32>>> {
262        // Simple interpreter for the graph
263        let mut values: HashMap<String, Vec<f32>> = HashMap::new();
264
265        // Copy inputs
266        for (name, data) in inputs {
267            values.insert(name.clone(), data.clone());
268        }
269
270        for node in self.optimized_graph.nodes() {
271            let result = self.execute_node(node, &values)?;
272            // Use node id as key
273            let key = format!("node_{}", node.id.index());
274            values.insert(key, result);
275        }
276
277        // Collect outputs
278        let mut outputs = HashMap::new();
279        for name in &self.output_names {
280            // Find the output node and get its value
281            if let Some(node_id) = self.optimized_graph.output(name) {
282                let key = format!("node_{}", node_id.index());
283                if let Some(val) = values.get(&key) {
284                    outputs.insert(name.clone(), val.clone());
285                }
286            }
287        }
288
289        Ok(outputs)
290    }
291
292    /// Executes a single node.
293    fn execute_node(
294        &self,
295        node: &Node,
296        values: &HashMap<String, Vec<f32>>,
297    ) -> JitResult<Vec<f32>> {
298        match &node.op {
299            Op::Input { name } => {
300                values
301                    .get(name)
302                    .cloned()
303                    .ok_or_else(|| JitError::InputNotFound(name.clone()))
304            }
305            Op::Output { input, .. } => {
306                let key = format!("node_{}", input.index());
307                values
308                    .get(&key)
309                    .cloned()
310                    .ok_or_else(|| JitError::InputNotFound(key))
311            }
312            Op::Constant { value } => Ok(vec![*value as f32]),
313            Op::Add { lhs, rhs } => {
314                let a = self.get_node_value(*lhs, values)?;
315                let b = self.get_node_value(*rhs, values)?;
316                Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
317            }
318            Op::Sub { lhs, rhs } => {
319                let a = self.get_node_value(*lhs, values)?;
320                let b = self.get_node_value(*rhs, values)?;
321                Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
322            }
323            Op::Mul { lhs, rhs } => {
324                let a = self.get_node_value(*lhs, values)?;
325                let b = self.get_node_value(*rhs, values)?;
326                Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).collect())
327            }
328            Op::Div { lhs, rhs } => {
329                let a = self.get_node_value(*lhs, values)?;
330                let b = self.get_node_value(*rhs, values)?;
331                Ok(a.iter().zip(b.iter()).map(|(x, y)| x / y).collect())
332            }
333            Op::Neg { input } => {
334                let a = self.get_node_value(*input, values)?;
335                Ok(a.iter().map(|x| -x).collect())
336            }
337            Op::Exp { input } => {
338                let a = self.get_node_value(*input, values)?;
339                Ok(a.iter().map(|x| x.exp()).collect())
340            }
341            Op::Log { input } => {
342                let a = self.get_node_value(*input, values)?;
343                Ok(a.iter().map(|x| x.ln()).collect())
344            }
345            Op::Sqrt { input } => {
346                let a = self.get_node_value(*input, values)?;
347                Ok(a.iter().map(|x| x.sqrt()).collect())
348            }
349            Op::Relu { input } => {
350                let a = self.get_node_value(*input, values)?;
351                Ok(a.iter().map(|x| x.max(0.0)).collect())
352            }
353            Op::Sigmoid { input } => {
354                let a = self.get_node_value(*input, values)?;
355                Ok(a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect())
356            }
357            Op::Tanh { input } => {
358                let a = self.get_node_value(*input, values)?;
359                Ok(a.iter().map(|x| x.tanh()).collect())
360            }
361            _ => {
362                // For unsupported ops, return zeros with same shape
363                let numel = node.shape.numel();
364                Ok(vec![0.0; numel])
365            }
366        }
367    }
368
369    /// Gets value for a node by ID.
370    fn get_node_value(
371        &self,
372        node_id: crate::ir::NodeId,
373        values: &HashMap<String, Vec<f32>>,
374    ) -> JitResult<Vec<f32>> {
375        // Check if it's an input node
376        let node = self.optimized_graph.node(node_id);
377        if let Op::Input { name } = &node.op {
378            return values
379                .get(name)
380                .cloned()
381                .ok_or_else(|| JitError::InputNotFound(name.clone()));
382        }
383
384        // Otherwise use node key
385        let key = format!("node_{}", node_id.index());
386        values
387            .get(&key)
388            .cloned()
389            .ok_or_else(|| JitError::InputNotFound(key))
390    }
391}
392
393/// Compilation statistics.
394#[derive(Debug, Clone)]
395pub struct CompileStats {
396    /// Number of operations in original graph
397    pub original_ops: usize,
398    /// Number of operations after optimization
399    pub optimized_ops: usize,
400    /// Whether compilation succeeded
401    pub is_compiled: bool,
402    /// Number of optimization passes applied
403    pub passes_applied: usize,
404}
405
406impl CompileStats {
407    /// Returns the optimization ratio.
408    pub fn optimization_ratio(&self) -> f32 {
409        if self.original_ops == 0 {
410            1.0
411        } else {
412            self.optimized_ops as f32 / self.original_ops as f32
413        }
414    }
415}
416
417// =============================================================================
418// Compile Functions
419// =============================================================================
420
421/// Compiles a traced function with default settings.
422///
423/// # Example
424/// ```rust,ignore
425/// let graph = trace(|t| {
426///     let x = t.input("x", &[2, 3]);
427///     let y = x.relu();
428///     t.output("y", y)
429/// });
430///
431/// let compiled = compile_graph(graph)?;
432/// ```
433pub fn compile_graph(graph: Graph) -> JitResult<CompiledModel> {
434    CompiledModel::from_graph(graph, CompileConfig::default())
435}
436
437/// Compiles with custom configuration.
438pub fn compile_graph_with_config(graph: Graph, config: CompileConfig) -> JitResult<CompiledModel> {
439    CompiledModel::from_graph(graph, config)
440}
441
442/// Traces and compiles a function in one step.
443///
444/// # Example
445/// ```rust,ignore
446/// let compiled = compile_fn(|t| {
447///     let x = t.input("x", &[2, 3]);
448///     let y = x.add(&t.constant(1.0, &[2, 3]));
449///     t.output("y", y)
450/// })?;
451/// ```
452pub fn compile_fn<F>(f: F) -> JitResult<CompiledModel>
453where
454    F: FnOnce(&Tracer) -> TracedValue,
455{
456    let graph = trace(f);
457    compile_graph(graph)
458}
459
460/// Traces and compiles with custom configuration.
461pub fn compile_fn_with_config<F>(f: F, config: CompileConfig) -> JitResult<CompiledModel>
462where
463    F: FnOnce(&Tracer) -> TracedValue,
464{
465    let graph = trace(f);
466    compile_graph_with_config(graph, config)
467}
468
469// =============================================================================
470// Dynamo-style Decorators
471// =============================================================================
472
473/// Wrapper for lazy compilation.
474///
475/// Traces and compiles on first call, caches for subsequent calls.
476pub struct LazyCompiled<F> {
477    func: F,
478    compiled: Mutex<Option<CompiledModel>>,
479    config: CompileConfig,
480}
481
482impl<F> LazyCompiled<F>
483where
484    F: Fn(&Tracer) -> TracedValue,
485{
486    /// Creates a new lazy compiled wrapper.
487    pub fn new(func: F) -> Self {
488        Self {
489            func,
490            compiled: Mutex::new(None),
491            config: CompileConfig::default(),
492        }
493    }
494
495    /// Creates with custom config.
496    pub fn with_config(func: F, config: CompileConfig) -> Self {
497        Self {
498            func,
499            compiled: Mutex::new(None),
500            config,
501        }
502    }
503
504    /// Runs the function, compiling on first call.
505    pub fn run(&self, inputs: &HashMap<String, Vec<f32>>) -> JitResult<HashMap<String, Vec<f32>>> {
506        let mut compiled = self.compiled.lock().unwrap();
507
508        if compiled.is_none() {
509            let graph = trace(&self.func);
510            *compiled = Some(CompiledModel::from_graph(graph, self.config.clone())?);
511        }
512
513        compiled.as_ref().unwrap().run(inputs)
514    }
515}
516
517// =============================================================================
518// Tests
519// =============================================================================
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524
525    #[test]
526    fn test_compile_config_default() {
527        let config = CompileConfig::default();
528        assert_eq!(config.mode, Mode::Default);
529        assert!(!config.fullgraph);
530        assert!(!config.disable);
531    }
532
533    #[test]
534    fn test_compile_config_builder() {
535        let config = CompileConfig::new()
536            .mode(Mode::MaxAutotune)
537            .fullgraph(true)
538            .dynamic(true);
539
540        assert_eq!(config.mode, Mode::MaxAutotune);
541        assert!(config.fullgraph);
542        assert!(config.dynamic);
543    }
544
545    #[test]
546    fn test_compile_simple_graph() {
547        let graph = trace(|t| {
548            let x = t.input("x", &[2]);
549            let y = x.relu();
550            t.output("y", y)
551        });
552
553        let compiled = compile_graph(graph).unwrap();
554        assert!(compiled.input_names().contains(&"x".to_string()));
555    }
556
557    #[test]
558    fn test_compile_stats() {
559        let graph = trace(|t| {
560            let x = t.input("x", &[2]);
561            let y = x.relu();
562            t.output("y", y)
563        });
564
565        let compiled = compile_graph(graph).unwrap();
566        let stats = compiled.stats();
567
568        assert!(stats.original_ops > 0);
569        assert!(stats.passes_applied > 0);
570    }
571
572    #[test]
573    fn test_mode_enum() {
574        assert_eq!(Mode::default(), Mode::Default);
575        assert_ne!(Mode::MaxAutotune, Mode::ReduceOverhead);
576    }
577
578    #[test]
579    fn test_backend_enum() {
580        assert_eq!(Backend::default(), Backend::Default);
581    }
582
583    #[test]
584    fn test_compiled_model_run() {
585        let graph = trace(|t| {
586            let x = t.input("x", &[2]);
587            let y = x.relu();
588            t.output("y", y)
589        });
590
591        let compiled = compile_graph_with_config(
592            graph,
593            CompileConfig::new().disable(true), // Use interpreter
594        )
595        .unwrap();
596
597        let mut inputs = HashMap::new();
598        inputs.insert("x".to_string(), vec![-1.0, 2.0]);
599
600        let outputs = compiled.run(&inputs).unwrap();
601        let y = outputs.get("y").unwrap();
602        assert_eq!(y, &vec![0.0, 2.0]); // ReLU
603    }
604}