Skip to main content

axonml_jit/
codegen.rs

1//! Code Generation — Interpreter and Cranelift Native Backends
2//!
3//! Provides `CompiledFunction`, the executable output of the JIT, holding an
4//! `Arc<Graph>` and a `CompiledKind` enum that is either `Interpreted`
5//! (tree-walking evaluator) or `Native` (a raw `*const u8` pointer plus size
6//! returned by Cranelift). `CompiledFunction::run` dispatches on kind; native
7//! execution transmutes the pointer to `extern "C" fn(*const f32, *mut f32)`,
8//! flattens inputs, and sizes the output buffer from the graph's output node
9//! shapes. The interpreter path (`run_interpreted` / `eval_node`) walks the
10//! graph in topological order, implementing Add/Sub/Mul/Div/Pow/Max/Min scalar
11//! and tensor ops, unary math (Neg/Abs/Sqrt/Exp/Log/Sin/Cos/Tanh), activations
12//! (Relu/Sigmoid/Gelu via the tanh approximation/Silu), reductions (Sum, Mean,
13//! and axis-aware `SumAxis`/`MeanAxis`/`MaxAxis`), shape ops (Reshape, Squeeze,
14//! Unsqueeze, Broadcast, Contiguous — pass-through — and a full strided
15//! Transpose), 2D matmul, comparisons (Gt/Lt/Eq), and Where selection. Helpers
16//! `normalize_axis`, `reduce_axis` (generic fn-pointer reducer with keepdim
17//! handling), and `matmul_impl` (validated 2D GEMM) support the interpreter.
18//! `JitCompiler` owns the `Optimizer` and `FunctionCache`, gates native codegen
19//! behind `enable_native`, and in `compile_native` builds a Cranelift `JITModule`
20//! with a `fn(*const f32, *mut f32)` signature, lowers each IR node through
21//! `codegen_node` (Input load, Output, Constant, Add/Sub/Mul/Div, Neg/Abs/Sqrt,
22//! AddScalar/MulScalar — unsupported ops bubble up as `UnsupportedOp`), stores
23//! the final output, finalizes and leaks the module so the code remains live.
24//! Tests cover elementwise add, activation chains, sigmoid accuracy, 2D matmul,
25//! and cache hit behaviour.
26//!
27//! # File
28//! `crates/axonml-jit/src/codegen.rs`
29//!
30//! # Author
31//! Andrew Jewell Sr. — AutomataNexus LLC
32//! ORCID: 0009-0005-2158-7060
33//!
34//! # Updated
35//! April 16, 2026 11:15 PM EST
36//!
37//! # Disclaimer
38//! Use at own risk. This software is provided "as is", without warranty of any
39//! kind, express or implied. The author and AutomataNexus shall not be held
40//! liable for any damages arising from the use of this software.
41
42// =============================================================================
43// Imports
44// =============================================================================
45
46use std::sync::Arc;
47
48use crate::cache::FunctionCache;
49use crate::error::{JitError, JitResult};
50use crate::ir::{Graph, Node, NodeId, Op};
51use crate::optimize::Optimizer;
52
53// =============================================================================
54// CompiledFunction
55// =============================================================================
56
57/// A compiled function ready for execution.
58#[derive(Clone)]
59pub struct CompiledFunction {
60    /// The original graph.
61    graph: Arc<Graph>,
62    /// Function kind.
63    kind: CompiledKind,
64}
65
66#[derive(Clone)]
67enum CompiledKind {
68    /// Interpreted execution (fallback).
69    Interpreted,
70    /// Native code via Cranelift JIT (enabled via `JitCompiler::enable_native(true)`).
71    Native {
72        /// Pointer to compiled code.
73        code_ptr: *const u8,
74        /// Code size.
75        code_size: usize,
76    },
77}
78
79// Safety: The native code pointer is never dereferenced without proper synchronization
80unsafe impl Send for CompiledKind {}
81unsafe impl Sync for CompiledKind {}
82
83impl CompiledFunction {
84    /// Creates a placeholder compiled function (for testing).
85    pub fn placeholder() -> Self {
86        Self {
87            graph: Arc::new(Graph::new()),
88            kind: CompiledKind::Interpreted,
89        }
90    }
91
92    /// Returns the graph.
93    pub fn graph(&self) -> &Graph {
94        &self.graph
95    }
96
97    // -------------------------------------------------------------------------
98    // Execution Dispatch
99    // -------------------------------------------------------------------------
100
101    /// Executes the compiled function with the given inputs.
102    pub fn run(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
103        match &self.kind {
104            CompiledKind::Interpreted => self.run_interpreted(inputs),
105            CompiledKind::Native {
106                code_ptr,
107                code_size,
108            } => {
109                // Native execution via function pointer call
110                // Safety: code_ptr points to valid compiled code from Cranelift
111                unsafe {
112                    let func: extern "C" fn(*const f32, *mut f32) = std::mem::transmute(code_ptr);
113                    let flat_inputs: Vec<f32> =
114                        inputs.iter().flat_map(|(_, d)| d.iter().copied()).collect();
115                    // Allocate exact output size from graph shapes
116                    let output_size: usize = self
117                        .graph
118                        .outputs()
119                        .values()
120                        .map(|id| self.graph.node(*id).shape.numel())
121                        .sum();
122                    let mut output = vec![0.0f32; output_size];
123                    func(flat_inputs.as_ptr(), output.as_mut_ptr());
124                    let _ = code_size; // Used for memory management
125                    Ok(output)
126                }
127            }
128        }
129    }
130
131    // -------------------------------------------------------------------------
132    // Interpreter
133    // -------------------------------------------------------------------------
134
135    /// Interpreted execution.
136    fn run_interpreted(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
137        let mut values: Vec<Option<Vec<f32>>> = vec![None; self.graph.len()];
138
139        // Set input values
140        for (name, data) in inputs {
141            if let Some(id) = self.graph.input(name) {
142                values[id.index()] = Some(data.to_vec());
143            } else {
144                return Err(JitError::InputNotFound(name.to_string()));
145            }
146        }
147
148        // Execute in topological order
149        for node in self.graph.nodes() {
150            let result = self.eval_node(node, &values)?;
151            values[node.id.index()] = Some(result);
152        }
153
154        // Get output value
155        if let Some((_, output_id)) = self.graph.outputs().iter().next() {
156            let output_node = self.graph.node(*output_id);
157            if let Op::Output { input, .. } = &output_node.op {
158                return Ok(values[input.index()].clone().unwrap_or_default());
159            }
160        }
161
162        Err(JitError::OutputNotFound("no output".to_string()))
163    }
164
165    fn eval_node(&self, node: &Node, values: &[Option<Vec<f32>>]) -> JitResult<Vec<f32>> {
166        let get = |id: NodeId| -> JitResult<&Vec<f32>> {
167            values[id.index()]
168                .as_ref()
169                .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not computed", id)))
170        };
171
172        match &node.op {
173            Op::Input { .. } => {
174                // Already set
175                Ok(values[node.id.index()].clone().unwrap_or_default())
176            }
177
178            Op::Output { input, .. } => Ok(get(*input)?.clone()),
179
180            Op::Constant { value } => {
181                let numel = node.shape.numel();
182                Ok(vec![*value as f32; numel])
183            }
184
185            // Binary ops
186            Op::Add { lhs, rhs } => {
187                let a = get(*lhs)?;
188                let b = get(*rhs)?;
189                Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
190            }
191
192            Op::Sub { lhs, rhs } => {
193                let a = get(*lhs)?;
194                let b = get(*rhs)?;
195                Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
196            }
197
198            Op::Mul { lhs, rhs } => {
199                let a = get(*lhs)?;
200                let b = get(*rhs)?;
201                Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).collect())
202            }
203
204            Op::Div { lhs, rhs } => {
205                let a = get(*lhs)?;
206                let b = get(*rhs)?;
207                Ok(a.iter().zip(b.iter()).map(|(x, y)| x / y).collect())
208            }
209
210            Op::Pow { base, exp } => {
211                let a = get(*base)?;
212                let b = get(*exp)?;
213                Ok(a.iter().zip(b.iter()).map(|(x, y)| x.powf(*y)).collect())
214            }
215
216            Op::Max { lhs, rhs } => {
217                let a = get(*lhs)?;
218                let b = get(*rhs)?;
219                Ok(a.iter().zip(b.iter()).map(|(x, y)| x.max(*y)).collect())
220            }
221
222            Op::Min { lhs, rhs } => {
223                let a = get(*lhs)?;
224                let b = get(*rhs)?;
225                Ok(a.iter().zip(b.iter()).map(|(x, y)| x.min(*y)).collect())
226            }
227
228            // Scalar ops
229            Op::AddScalar { input, scalar } => {
230                let a = get(*input)?;
231                Ok(a.iter().map(|x| x + *scalar as f32).collect())
232            }
233
234            Op::MulScalar { input, scalar } => {
235                let a = get(*input)?;
236                Ok(a.iter().map(|x| x * *scalar as f32).collect())
237            }
238
239            // Unary ops
240            Op::Neg { input } => {
241                let a = get(*input)?;
242                Ok(a.iter().map(|x| -x).collect())
243            }
244
245            Op::Abs { input } => {
246                let a = get(*input)?;
247                Ok(a.iter().map(|x| x.abs()).collect())
248            }
249
250            Op::Sqrt { input } => {
251                let a = get(*input)?;
252                Ok(a.iter().map(|x| x.sqrt()).collect())
253            }
254
255            Op::Exp { input } => {
256                let a = get(*input)?;
257                Ok(a.iter().map(|x| x.exp()).collect())
258            }
259
260            Op::Log { input } => {
261                let a = get(*input)?;
262                Ok(a.iter().map(|x| x.ln()).collect())
263            }
264
265            Op::Sin { input } => {
266                let a = get(*input)?;
267                Ok(a.iter().map(|x| x.sin()).collect())
268            }
269
270            Op::Cos { input } => {
271                let a = get(*input)?;
272                Ok(a.iter().map(|x| x.cos()).collect())
273            }
274
275            Op::Tanh { input } => {
276                let a = get(*input)?;
277                Ok(a.iter().map(|x| x.tanh()).collect())
278            }
279
280            // Activations
281            Op::Relu { input } => {
282                let a = get(*input)?;
283                Ok(a.iter().map(|x| x.max(0.0)).collect())
284            }
285
286            Op::Sigmoid { input } => {
287                let a = get(*input)?;
288                Ok(a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect())
289            }
290
291            Op::Gelu { input } => {
292                let a = get(*input)?;
293                const SQRT_2_OVER_PI: f32 = 0.797_884_6;
294                Ok(a.iter()
295                    .map(|x| 0.5 * x * (1.0 + (SQRT_2_OVER_PI * (x + 0.044715 * x.powi(3))).tanh()))
296                    .collect())
297            }
298
299            Op::Silu { input } => {
300                let a = get(*input)?;
301                Ok(a.iter().map(|x| x / (1.0 + (-x).exp())).collect())
302            }
303
304            // Reductions
305            Op::Sum { input } => {
306                let a = get(*input)?;
307                Ok(vec![a.iter().sum()])
308            }
309
310            Op::Mean { input } => {
311                let a = get(*input)?;
312                let sum: f32 = a.iter().sum();
313                Ok(vec![sum / a.len() as f32])
314            }
315
316            Op::SumAxis {
317                input,
318                axis,
319                keepdim,
320            } => {
321                // Simplified: just sum all for now
322                let a = get(*input)?;
323                let input_node = self.graph.node(*input);
324                let input_shape = input_node.shape.dims();
325
326                reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)
327            }
328
329            Op::MeanAxis {
330                input,
331                axis,
332                keepdim,
333            } => {
334                let a = get(*input)?;
335                let input_node = self.graph.node(*input);
336                let input_shape = input_node.shape.dims();
337                let axis_size = input_shape[normalize_axis(*axis, input_shape.len())];
338
339                let sum = reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)?;
340                Ok(sum.iter().map(|x| x / axis_size as f32).collect())
341            }
342
343            Op::MaxAxis {
344                input,
345                axis,
346                keepdim,
347            } => {
348                let a = get(*input)?;
349                let input_node = self.graph.node(*input);
350                let input_shape = input_node.shape.dims();
351
352                reduce_axis(a, input_shape, *axis, *keepdim, f32::max, f32::NEG_INFINITY)
353            }
354
355            // Shape ops
356            Op::Reshape { input, .. }
357            | Op::Squeeze { input, .. }
358            | Op::Unsqueeze { input, .. }
359            | Op::Broadcast { input, .. }
360            | Op::Contiguous { input } => Ok(get(*input)?.clone()),
361
362            // Transpose actually reorders data
363            Op::Transpose { input, dim0, dim1 } => {
364                let a = get(*input)?;
365                let input_shape = &self.graph.node(*input).shape;
366                let ndim = input_shape.dims().len();
367                if ndim < 2 || *dim0 >= ndim || *dim1 >= ndim || dim0 == dim1 {
368                    return Ok(a.clone());
369                }
370                // Build transposed strides and reindex
371                let dims = input_shape.dims();
372                let mut perm: Vec<usize> = (0..ndim).collect();
373                perm.swap(*dim0, *dim1);
374                let new_shape: Vec<usize> = perm.iter().map(|&d| dims[d]).collect();
375                let numel: usize = dims.iter().product();
376                let mut result = vec![0.0f32; numel];
377
378                // Compute strides for input
379                let mut in_strides = vec![1usize; ndim];
380                for d in (0..ndim - 1).rev() {
381                    in_strides[d] = in_strides[d + 1] * dims[d + 1];
382                }
383                // Compute strides for output
384                let mut out_strides = vec![1usize; ndim];
385                for d in (0..ndim - 1).rev() {
386                    out_strides[d] = out_strides[d + 1] * new_shape[d + 1];
387                }
388
389                #[allow(clippy::needless_range_loop)]
390                for flat in 0..numel {
391                    // Convert flat index to multi-dim in output space
392                    let mut remaining = flat;
393                    let mut out_idx = vec![0usize; ndim];
394                    for d in 0..ndim {
395                        out_idx[d] = remaining / out_strides[d];
396                        remaining %= out_strides[d];
397                    }
398                    // Map to input space via inverse permutation
399                    let mut in_flat = 0;
400                    for d in 0..ndim {
401                        in_flat += out_idx[d] * in_strides[perm[d]];
402                    }
403                    result[flat] = a[in_flat];
404                }
405                Ok(result)
406            }
407
408            // Matrix multiplication
409            Op::MatMul { lhs, rhs } => {
410                let a = get(*lhs)?;
411                let b = get(*rhs)?;
412                let lhs_node = self.graph.node(*lhs);
413                let rhs_node = self.graph.node(*rhs);
414
415                let lhs_shape = lhs_node.shape.dims();
416                let rhs_shape = rhs_node.shape.dims();
417
418                matmul_impl(a, b, lhs_shape, rhs_shape)
419            }
420
421            // Comparisons
422            Op::Gt { lhs, rhs } => {
423                let a = get(*lhs)?;
424                let b = get(*rhs)?;
425                Ok(a.iter()
426                    .zip(b.iter())
427                    .map(|(x, y)| if x > y { 1.0 } else { 0.0 })
428                    .collect())
429            }
430
431            Op::Lt { lhs, rhs } => {
432                let a = get(*lhs)?;
433                let b = get(*rhs)?;
434                Ok(a.iter()
435                    .zip(b.iter())
436                    .map(|(x, y)| if x < y { 1.0 } else { 0.0 })
437                    .collect())
438            }
439
440            Op::Eq { lhs, rhs } => {
441                let a = get(*lhs)?;
442                let b = get(*rhs)?;
443                Ok(a.iter()
444                    .zip(b.iter())
445                    .map(|(x, y)| {
446                        if (x - y).abs() < f32::EPSILON {
447                            1.0
448                        } else {
449                            0.0
450                        }
451                    })
452                    .collect())
453            }
454
455            Op::Where { condition, x, y } => {
456                let cond = get(*condition)?;
457                let a = get(*x)?;
458                let b = get(*y)?;
459                Ok(cond
460                    .iter()
461                    .zip(a.iter().zip(b.iter()))
462                    .map(|(c, (a, b))| if *c != 0.0 { *a } else { *b })
463                    .collect())
464            }
465
466            Op::Cast { input, .. } => {
467                // For f32, just pass through
468                Ok(get(*input)?.clone())
469            }
470        }
471    }
472}
473
474// =============================================================================
475// Reduction and MatMul Helpers
476// =============================================================================
477
478fn normalize_axis(axis: i32, ndim: usize) -> usize {
479    if axis < 0 {
480        (ndim as i32 + axis) as usize
481    } else {
482        axis as usize
483    }
484}
485
486fn reduce_axis(
487    data: &[f32],
488    shape: &[usize],
489    axis: i32,
490    keepdim: bool,
491    op: fn(f32, f32) -> f32,
492    init: f32,
493) -> JitResult<Vec<f32>> {
494    let axis = normalize_axis(axis, shape.len());
495
496    // Compute strides
497    let mut strides = vec![1usize; shape.len()];
498    for i in (0..shape.len() - 1).rev() {
499        strides[i] = strides[i + 1] * shape[i + 1];
500    }
501
502    // Compute output shape
503    let mut output_shape: Vec<usize> = shape.to_vec();
504    if keepdim {
505        output_shape[axis] = 1;
506    } else {
507        output_shape.remove(axis);
508    }
509
510    let output_numel: usize = output_shape.iter().product();
511    let mut result = vec![init; output_numel];
512
513    // Reduce
514    for (i, &val) in data.iter().enumerate() {
515        // Convert linear index to multi-index
516        let mut multi_idx = vec![0usize; shape.len()];
517        let mut idx = i;
518        for (d, &st) in strides.iter().enumerate() {
519            multi_idx[d] = idx / st;
520            idx %= st;
521        }
522
523        // Compute output index
524        let out_idx = if keepdim {
525            let mut out_idx = 0;
526            let mut temp_strides = vec![1usize; output_shape.len()];
527            for d in (0..output_shape.len() - 1).rev() {
528                temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
529            }
530            for d in 0..output_shape.len() {
531                let dim_idx = if d == axis { 0 } else { multi_idx[d] };
532                out_idx += dim_idx * temp_strides[d];
533            }
534            out_idx
535        } else {
536            let mut out_idx = 0;
537            let mut temp_strides = vec![1usize; output_shape.len()];
538            if !output_shape.is_empty() {
539                for d in (0..output_shape.len() - 1).rev() {
540                    temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
541                }
542            }
543            let mut out_d = 0;
544            for (d, &mi) in multi_idx.iter().enumerate().take(shape.len()) {
545                if d == axis {
546                    continue;
547                }
548                if out_d < temp_strides.len() {
549                    out_idx += mi * temp_strides[out_d];
550                }
551                out_d += 1;
552            }
553            out_idx
554        };
555
556        if out_idx < result.len() {
557            result[out_idx] = op(result[out_idx], val);
558        }
559    }
560
561    Ok(result)
562}
563
564fn matmul_impl(a: &[f32], b: &[f32], a_shape: &[usize], b_shape: &[usize]) -> JitResult<Vec<f32>> {
565    // Simple 2D matmul
566    if a_shape.len() != 2 || b_shape.len() != 2 {
567        return Err(JitError::UnsupportedOp(
568            "Only 2D matmul supported in interpreter".to_string(),
569        ));
570    }
571
572    let m = a_shape[0];
573    let k = a_shape[1];
574    let n = b_shape[1];
575
576    if k != b_shape[0] {
577        return Err(JitError::ShapeMismatch {
578            expected: vec![k],
579            found: vec![b_shape[0]],
580        });
581    }
582
583    let mut result = vec![0.0f32; m * n];
584
585    for i in 0..m {
586        for j in 0..n {
587            let mut sum = 0.0;
588            for p in 0..k {
589                sum += a[i * k + p] * b[p * n + j];
590            }
591            result[i * n + j] = sum;
592        }
593    }
594
595    Ok(result)
596}
597
598// =============================================================================
599// JitCompiler
600// =============================================================================
601
602/// JIT compiler.
603pub struct JitCompiler {
604    optimizer: Optimizer,
605    cache: FunctionCache,
606    use_native: bool,
607}
608
609impl JitCompiler {
610    /// Creates a new JIT compiler.
611    pub fn new() -> Self {
612        Self {
613            optimizer: Optimizer::default_passes(),
614            cache: FunctionCache::default_size(),
615            use_native: false, // Fallback to interpreter for now
616        }
617    }
618
619    /// Creates a compiler with custom optimizer.
620    pub fn with_optimizer(optimizer: Optimizer) -> Self {
621        Self {
622            optimizer,
623            cache: FunctionCache::default_size(),
624            use_native: false,
625        }
626    }
627
628    /// Enable native code generation via Cranelift.
629    ///
630    /// When enabled, the compiler will attempt to generate machine code
631    /// for supported operations. Unsupported ops fall back to the interpreter.
632    pub fn enable_native(&mut self, enable: bool) {
633        self.use_native = enable;
634    }
635
636    // -------------------------------------------------------------------------
637    // Compilation Pipeline
638    // -------------------------------------------------------------------------
639
640    /// Compiles a graph into an executable function.
641    pub fn compile(&self, graph: &Graph) -> JitResult<CompiledFunction> {
642        // Check cache
643        let cache_key = FunctionCache::hash_graph(graph);
644        if let Some(cached) = self.cache.get(cache_key) {
645            return Ok(cached);
646        }
647
648        // Validate graph
649        graph.validate().map_err(JitError::InvalidGraph)?;
650
651        // Optimize
652        let optimized = self.optimizer.optimize(graph.clone());
653
654        // Generate code
655        let func = if self.use_native {
656            self.compile_native(&optimized)?
657        } else {
658            self.compile_interpreted(&optimized)
659        };
660
661        // Cache result
662        self.cache.insert(cache_key, func.clone());
663
664        Ok(func)
665    }
666
667    fn compile_interpreted(&self, graph: &Graph) -> CompiledFunction {
668        CompiledFunction {
669            graph: Arc::new(graph.clone()),
670            kind: CompiledKind::Interpreted,
671        }
672    }
673
674    // -------------------------------------------------------------------------
675    // Cranelift Native Codegen
676    // -------------------------------------------------------------------------
677
678    fn compile_native(&self, graph: &Graph) -> JitResult<CompiledFunction> {
679        use cranelift::prelude::*;
680        use cranelift_jit::{JITBuilder, JITModule};
681        use cranelift_module::{Linkage, Module};
682
683        // Initialize Cranelift JIT module
684        let mut flag_builder = settings::builder();
685        flag_builder.set("use_colocated_libcalls", "false").unwrap();
686        flag_builder.set("is_pic", "false").unwrap();
687        let isa_builder = cranelift_native::builder()
688            .map_err(|e| JitError::CompilationFailed(format!("Failed to get native ISA: {}", e)))?;
689        let isa = isa_builder
690            .finish(settings::Flags::new(flag_builder))
691            .map_err(|e| JitError::CompilationFailed(format!("Failed to build ISA: {}", e)))?;
692
693        let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
694        let mut module = JITModule::new(builder);
695
696        // Create function signature: fn(inputs: *const f32, outputs: *mut f32)
697        let mut sig = module.make_signature();
698        sig.params.push(AbiParam::new(types::I64)); // input ptr
699        sig.params.push(AbiParam::new(types::I64)); // output ptr
700
701        let func_id = module
702            .declare_function("jit_kernel", Linkage::Export, &sig)
703            .map_err(|e| {
704                JitError::CompilationFailed(format!("Failed to declare function: {}", e))
705            })?;
706
707        let mut ctx = module.make_context();
708        ctx.func.signature = sig;
709
710        // Build function body
711        let mut builder_ctx = FunctionBuilderContext::new();
712        {
713            let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_ctx);
714            let entry_block = builder.create_block();
715            builder.append_block_params_for_function_params(entry_block);
716            builder.switch_to_block(entry_block);
717            builder.seal_block(entry_block);
718
719            let input_ptr = builder.block_params(entry_block)[0];
720            let output_ptr = builder.block_params(entry_block)[1];
721
722            // Generate code for each operation in the graph
723            let mut values: Vec<Option<Value>> = vec![None; graph.len()];
724
725            for node in graph.nodes() {
726                let result = self.codegen_node(&mut builder, node, &values, input_ptr)?;
727                values[node.id.index()] = Some(result);
728            }
729
730            // Store output
731            if let Some((_, output_id)) = graph.outputs().iter().next() {
732                let output_node = graph.node(*output_id);
733                if let Op::Output { input, .. } = &output_node.op {
734                    if let Some(val) = values[input.index()] {
735                        builder.ins().store(MemFlags::new(), val, output_ptr, 0);
736                    }
737                }
738            }
739
740            builder.ins().return_(&[]);
741            builder.finalize();
742        }
743
744        // Compile the function
745        module.define_function(func_id, &mut ctx).map_err(|e| {
746            JitError::CompilationFailed(format!("Failed to define function: {}", e))
747        })?;
748        module.clear_context(&mut ctx);
749        module
750            .finalize_definitions()
751            .map_err(|e| JitError::CompilationFailed(format!("Failed to finalize: {:?}", e)))?;
752
753        let code_ptr = module.get_finalized_function(func_id);
754        let code_size = 0; // JITModule manages memory
755
756        // Leak the module to keep the code alive
757        std::mem::forget(module);
758
759        Ok(CompiledFunction {
760            graph: Arc::new(graph.clone()),
761            kind: CompiledKind::Native {
762                code_ptr,
763                code_size,
764            },
765        })
766    }
767
768    fn codegen_node(
769        &self,
770        builder: &mut cranelift::prelude::FunctionBuilder,
771        node: &Node,
772        values: &[Option<cranelift::prelude::Value>],
773        input_ptr: cranelift::prelude::Value,
774    ) -> JitResult<cranelift::prelude::Value> {
775        use cranelift::prelude::*;
776
777        let get = |id: NodeId| -> JitResult<Value> {
778            values[id.index()]
779                .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not compiled", id)))
780        };
781
782        match &node.op {
783            Op::Input { name, .. } => {
784                // Load from input pointer at appropriate offset
785                let offset = self.get_input_offset(name);
786                Ok(builder
787                    .ins()
788                    .load(types::F32, MemFlags::new(), input_ptr, offset))
789            }
790
791            Op::Output { input, .. } => get(*input),
792
793            Op::Constant { value } => Ok(builder.ins().f32const(*value as f32)),
794
795            Op::Add { lhs, rhs } => {
796                let a = get(*lhs)?;
797                let b = get(*rhs)?;
798                Ok(builder.ins().fadd(a, b))
799            }
800
801            Op::Sub { lhs, rhs } => {
802                let a = get(*lhs)?;
803                let b = get(*rhs)?;
804                Ok(builder.ins().fsub(a, b))
805            }
806
807            Op::Mul { lhs, rhs } => {
808                let a = get(*lhs)?;
809                let b = get(*rhs)?;
810                Ok(builder.ins().fmul(a, b))
811            }
812
813            Op::Div { lhs, rhs } => {
814                let a = get(*lhs)?;
815                let b = get(*rhs)?;
816                Ok(builder.ins().fdiv(a, b))
817            }
818
819            Op::Neg { input } => {
820                let a = get(*input)?;
821                Ok(builder.ins().fneg(a))
822            }
823
824            Op::Abs { input } => {
825                let a = get(*input)?;
826                Ok(builder.ins().fabs(a))
827            }
828
829            Op::Sqrt { input } => {
830                let a = get(*input)?;
831                Ok(builder.ins().sqrt(a))
832            }
833
834            Op::AddScalar { input, scalar } => {
835                let a = get(*input)?;
836                let s = builder.ins().f32const(*scalar as f32);
837                Ok(builder.ins().fadd(a, s))
838            }
839
840            Op::MulScalar { input, scalar } => {
841                let a = get(*input)?;
842                let s = builder.ins().f32const(*scalar as f32);
843                Ok(builder.ins().fmul(a, s))
844            }
845
846            // For operations not easily supported by Cranelift scalars,
847            // fall back to interpreted execution for the whole graph
848            _ => Err(JitError::UnsupportedOp(format!(
849                "Operation {:?} not supported in native codegen, using interpreter",
850                node.op
851            ))),
852        }
853    }
854
855    fn get_input_offset(&self, _name: &str) -> i32 {
856        // Input offset calculation for native codegen.
857        // Currently only supports single-input graphs (offset=0).
858        // Multi-input support requires storing the graph in JitCompiler
859        // or passing input layout info to compile_native().
860        0
861    }
862
863    // -------------------------------------------------------------------------
864    // Cache Accessors
865    // -------------------------------------------------------------------------
866
867    /// Returns cache statistics.
868    pub fn cache_stats(&self) -> crate::cache::CacheStats {
869        self.cache.stats()
870    }
871
872    /// Clears the compilation cache.
873    pub fn clear_cache(&self) {
874        self.cache.clear();
875    }
876}
877
878impl Default for JitCompiler {
879    fn default() -> Self {
880        Self::new()
881    }
882}
883
884// =============================================================================
885// Tests
886// =============================================================================
887
888#[cfg(test)]
889mod tests {
890    use super::*;
891    use crate::trace::trace;
892
893    #[test]
894    fn test_compile_simple() {
895        let graph = trace(|tracer| {
896            let a = tracer.input("a", &[4]);
897            let b = tracer.input("b", &[4]);
898            let c = a.add(&b);
899            tracer.output("result", c)
900        });
901
902        let compiler = JitCompiler::new();
903        let func = compiler.compile(&graph).unwrap();
904
905        let a = [1.0, 2.0, 3.0, 4.0];
906        let b = [5.0, 6.0, 7.0, 8.0];
907        let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
908
909        assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
910    }
911
912    #[test]
913    fn test_compile_chain() {
914        let graph = trace(|tracer| {
915            let x = tracer.input("x", &[4]);
916            let y = x.relu().mul_scalar(2.0).add_scalar(1.0);
917            tracer.output("y", y)
918        });
919
920        let compiler = JitCompiler::new();
921        let func = compiler.compile(&graph).unwrap();
922
923        let x = [-1.0, 0.0, 1.0, 2.0];
924        let result = func.run(&[("x", &x)]).unwrap();
925
926        // relu([-1, 0, 1, 2]) = [0, 0, 1, 2]
927        // * 2 = [0, 0, 2, 4]
928        // + 1 = [1, 1, 3, 5]
929        assert_eq!(result, vec![1.0, 1.0, 3.0, 5.0]);
930    }
931
932    #[test]
933    fn test_compile_activations() {
934        let graph = trace(|tracer| {
935            let x = tracer.input("x", &[3]);
936            let y = x.sigmoid();
937            tracer.output("y", y)
938        });
939
940        let compiler = JitCompiler::new();
941        let func = compiler.compile(&graph).unwrap();
942
943        let x = [0.0, 1.0, -1.0];
944        let result = func.run(&[("x", &x)]).unwrap();
945
946        // sigmoid(0) = 0.5
947        assert!((result[0] - 0.5).abs() < 0.01);
948        // sigmoid(1) ≈ 0.731
949        assert!((result[1] - 0.731).abs() < 0.01);
950    }
951
952    #[test]
953    fn test_compile_matmul() {
954        let graph = trace(|tracer| {
955            let a = tracer.input("a", &[2, 3]);
956            let b = tracer.input("b", &[3, 2]);
957            let c = a.matmul(&b);
958            tracer.output("c", c)
959        });
960
961        let compiler = JitCompiler::new();
962        let func = compiler.compile(&graph).unwrap();
963
964        // Identity-like matrices
965        let a = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]; // 2x3
966        let b = [1.0, 0.0, 0.0, 1.0, 0.0, 0.0]; // 3x2
967        let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
968
969        assert_eq!(result.len(), 4); // 2x2
970    }
971
972    #[test]
973    fn test_caching() {
974        let graph = trace(|tracer| {
975            let x = tracer.input("x", &[4]);
976            tracer.output("y", x.relu())
977        });
978
979        let compiler = JitCompiler::new();
980        assert_eq!(compiler.cache_stats().entries, 0);
981
982        let _ = compiler.compile(&graph).unwrap();
983        assert_eq!(compiler.cache_stats().entries, 1);
984
985        // Second compile should use cache
986        let _ = compiler.compile(&graph).unwrap();
987        assert_eq!(compiler.cache_stats().entries, 1);
988    }
989}