Skip to main content

axonml_jit/
codegen.rs

1//! Code Generation
2//!
3//! # File
4//! `crates/axonml-jit/src/codegen.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::sync::Arc;
18
19use crate::cache::FunctionCache;
20use crate::error::{JitError, JitResult};
21use crate::ir::{Graph, Node, NodeId, Op};
22use crate::optimize::Optimizer;
23
24/// A compiled function ready for execution.
25#[derive(Clone)]
26pub struct CompiledFunction {
27    /// The original graph.
28    graph: Arc<Graph>,
29    /// Function kind.
30    kind: CompiledKind,
31}
32
33#[derive(Clone)]
34enum CompiledKind {
35    /// Interpreted execution (fallback).
36    Interpreted,
37    /// Native code (future: Cranelift JIT).
38    #[allow(dead_code)]
39    Native {
40        /// Pointer to compiled code.
41        code_ptr: *const u8,
42        /// Code size.
43        code_size: usize,
44    },
45}
46
47// Safety: The native code pointer is never dereferenced without proper synchronization
48unsafe impl Send for CompiledKind {}
49unsafe impl Sync for CompiledKind {}
50
51impl CompiledFunction {
52    /// Creates a placeholder compiled function (for testing).
53    pub fn placeholder() -> Self {
54        Self {
55            graph: Arc::new(Graph::new()),
56            kind: CompiledKind::Interpreted,
57        }
58    }
59
60    /// Returns the graph.
61    pub fn graph(&self) -> &Graph {
62        &self.graph
63    }
64
65    /// Executes the compiled function with the given inputs.
66    pub fn run(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
67        match &self.kind {
68            CompiledKind::Interpreted => self.run_interpreted(inputs),
69            CompiledKind::Native {
70                code_ptr,
71                code_size,
72            } => {
73                // Native execution via function pointer call
74                // Safety: code_ptr points to valid compiled code from Cranelift
75                unsafe {
76                    let func: extern "C" fn(*const f32, *mut f32) = std::mem::transmute(code_ptr);
77                    let flat_inputs: Vec<f32> =
78                        inputs.iter().flat_map(|(_, d)| d.iter().copied()).collect();
79                    let mut output = vec![0.0f32; self.graph.outputs().len() * 1024]; // Max output size
80                    func(flat_inputs.as_ptr(), output.as_mut_ptr());
81                    let _ = code_size; // Used for memory management
82                    Ok(output)
83                }
84            }
85        }
86    }
87
88    /// Interpreted execution.
89    fn run_interpreted(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
90        let mut values: Vec<Option<Vec<f32>>> = vec![None; self.graph.len()];
91
92        // Set input values
93        for (name, data) in inputs {
94            if let Some(id) = self.graph.input(name) {
95                values[id.index()] = Some(data.to_vec());
96            } else {
97                return Err(JitError::InputNotFound(name.to_string()));
98            }
99        }
100
101        // Execute in topological order
102        for node in self.graph.nodes() {
103            let result = self.eval_node(node, &values)?;
104            values[node.id.index()] = Some(result);
105        }
106
107        // Get output value
108        if let Some((_, output_id)) = self.graph.outputs().iter().next() {
109            let output_node = self.graph.node(*output_id);
110            if let Op::Output { input, .. } = &output_node.op {
111                return Ok(values[input.index()].clone().unwrap_or_default());
112            }
113        }
114
115        Err(JitError::OutputNotFound("no output".to_string()))
116    }
117
118    fn eval_node(&self, node: &Node, values: &[Option<Vec<f32>>]) -> JitResult<Vec<f32>> {
119        let get = |id: NodeId| -> JitResult<&Vec<f32>> {
120            values[id.index()]
121                .as_ref()
122                .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not computed", id)))
123        };
124
125        match &node.op {
126            Op::Input { .. } => {
127                // Already set
128                Ok(values[node.id.index()].clone().unwrap_or_default())
129            }
130
131            Op::Output { input, .. } => Ok(get(*input)?.clone()),
132
133            Op::Constant { value } => {
134                let numel = node.shape.numel();
135                Ok(vec![*value as f32; numel])
136            }
137
138            // Binary ops
139            Op::Add { lhs, rhs } => {
140                let a = get(*lhs)?;
141                let b = get(*rhs)?;
142                Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
143            }
144
145            Op::Sub { lhs, rhs } => {
146                let a = get(*lhs)?;
147                let b = get(*rhs)?;
148                Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
149            }
150
151            Op::Mul { lhs, rhs } => {
152                let a = get(*lhs)?;
153                let b = get(*rhs)?;
154                Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).collect())
155            }
156
157            Op::Div { lhs, rhs } => {
158                let a = get(*lhs)?;
159                let b = get(*rhs)?;
160                Ok(a.iter().zip(b.iter()).map(|(x, y)| x / y).collect())
161            }
162
163            Op::Pow { base, exp } => {
164                let a = get(*base)?;
165                let b = get(*exp)?;
166                Ok(a.iter().zip(b.iter()).map(|(x, y)| x.powf(*y)).collect())
167            }
168
169            Op::Max { lhs, rhs } => {
170                let a = get(*lhs)?;
171                let b = get(*rhs)?;
172                Ok(a.iter().zip(b.iter()).map(|(x, y)| x.max(*y)).collect())
173            }
174
175            Op::Min { lhs, rhs } => {
176                let a = get(*lhs)?;
177                let b = get(*rhs)?;
178                Ok(a.iter().zip(b.iter()).map(|(x, y)| x.min(*y)).collect())
179            }
180
181            // Scalar ops
182            Op::AddScalar { input, scalar } => {
183                let a = get(*input)?;
184                Ok(a.iter().map(|x| x + *scalar as f32).collect())
185            }
186
187            Op::MulScalar { input, scalar } => {
188                let a = get(*input)?;
189                Ok(a.iter().map(|x| x * *scalar as f32).collect())
190            }
191
192            // Unary ops
193            Op::Neg { input } => {
194                let a = get(*input)?;
195                Ok(a.iter().map(|x| -x).collect())
196            }
197
198            Op::Abs { input } => {
199                let a = get(*input)?;
200                Ok(a.iter().map(|x| x.abs()).collect())
201            }
202
203            Op::Sqrt { input } => {
204                let a = get(*input)?;
205                Ok(a.iter().map(|x| x.sqrt()).collect())
206            }
207
208            Op::Exp { input } => {
209                let a = get(*input)?;
210                Ok(a.iter().map(|x| x.exp()).collect())
211            }
212
213            Op::Log { input } => {
214                let a = get(*input)?;
215                Ok(a.iter().map(|x| x.ln()).collect())
216            }
217
218            Op::Sin { input } => {
219                let a = get(*input)?;
220                Ok(a.iter().map(|x| x.sin()).collect())
221            }
222
223            Op::Cos { input } => {
224                let a = get(*input)?;
225                Ok(a.iter().map(|x| x.cos()).collect())
226            }
227
228            Op::Tanh { input } => {
229                let a = get(*input)?;
230                Ok(a.iter().map(|x| x.tanh()).collect())
231            }
232
233            // Activations
234            Op::Relu { input } => {
235                let a = get(*input)?;
236                Ok(a.iter().map(|x| x.max(0.0)).collect())
237            }
238
239            Op::Sigmoid { input } => {
240                let a = get(*input)?;
241                Ok(a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect())
242            }
243
244            Op::Gelu { input } => {
245                let a = get(*input)?;
246                const SQRT_2_OVER_PI: f32 = 0.797_884_6;
247                Ok(a.iter()
248                    .map(|x| 0.5 * x * (1.0 + (SQRT_2_OVER_PI * (x + 0.044715 * x.powi(3))).tanh()))
249                    .collect())
250            }
251
252            Op::Silu { input } => {
253                let a = get(*input)?;
254                Ok(a.iter().map(|x| x / (1.0 + (-x).exp())).collect())
255            }
256
257            // Reductions
258            Op::Sum { input } => {
259                let a = get(*input)?;
260                Ok(vec![a.iter().sum()])
261            }
262
263            Op::Mean { input } => {
264                let a = get(*input)?;
265                let sum: f32 = a.iter().sum();
266                Ok(vec![sum / a.len() as f32])
267            }
268
269            Op::SumAxis {
270                input,
271                axis,
272                keepdim,
273            } => {
274                // Simplified: just sum all for now
275                let a = get(*input)?;
276                let input_node = self.graph.node(*input);
277                let input_shape = input_node.shape.dims();
278
279                reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)
280            }
281
282            Op::MeanAxis {
283                input,
284                axis,
285                keepdim,
286            } => {
287                let a = get(*input)?;
288                let input_node = self.graph.node(*input);
289                let input_shape = input_node.shape.dims();
290                let axis_size = input_shape[normalize_axis(*axis, input_shape.len())];
291
292                let sum = reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)?;
293                Ok(sum.iter().map(|x| x / axis_size as f32).collect())
294            }
295
296            Op::MaxAxis {
297                input,
298                axis,
299                keepdim,
300            } => {
301                let a = get(*input)?;
302                let input_node = self.graph.node(*input);
303                let input_shape = input_node.shape.dims();
304
305                reduce_axis(a, input_shape, *axis, *keepdim, f32::max, f32::NEG_INFINITY)
306            }
307
308            // Shape ops - for interpreter, just pass through
309            Op::Reshape { input, .. }
310            | Op::Transpose { input, .. }
311            | Op::Squeeze { input, .. }
312            | Op::Unsqueeze { input, .. }
313            | Op::Broadcast { input, .. }
314            | Op::Contiguous { input } => Ok(get(*input)?.clone()),
315
316            // Matrix multiplication
317            Op::MatMul { lhs, rhs } => {
318                let a = get(*lhs)?;
319                let b = get(*rhs)?;
320                let lhs_node = self.graph.node(*lhs);
321                let rhs_node = self.graph.node(*rhs);
322
323                let lhs_shape = lhs_node.shape.dims();
324                let rhs_shape = rhs_node.shape.dims();
325
326                matmul_impl(a, b, lhs_shape, rhs_shape)
327            }
328
329            // Comparisons
330            Op::Gt { lhs, rhs } => {
331                let a = get(*lhs)?;
332                let b = get(*rhs)?;
333                Ok(a.iter()
334                    .zip(b.iter())
335                    .map(|(x, y)| if x > y { 1.0 } else { 0.0 })
336                    .collect())
337            }
338
339            Op::Lt { lhs, rhs } => {
340                let a = get(*lhs)?;
341                let b = get(*rhs)?;
342                Ok(a.iter()
343                    .zip(b.iter())
344                    .map(|(x, y)| if x < y { 1.0 } else { 0.0 })
345                    .collect())
346            }
347
348            Op::Eq { lhs, rhs } => {
349                let a = get(*lhs)?;
350                let b = get(*rhs)?;
351                Ok(a.iter()
352                    .zip(b.iter())
353                    .map(|(x, y)| {
354                        if (x - y).abs() < f32::EPSILON {
355                            1.0
356                        } else {
357                            0.0
358                        }
359                    })
360                    .collect())
361            }
362
363            Op::Where { condition, x, y } => {
364                let cond = get(*condition)?;
365                let a = get(*x)?;
366                let b = get(*y)?;
367                Ok(cond
368                    .iter()
369                    .zip(a.iter().zip(b.iter()))
370                    .map(|(c, (a, b))| if *c != 0.0 { *a } else { *b })
371                    .collect())
372            }
373
374            Op::Cast { input, .. } => {
375                // For f32, just pass through
376                Ok(get(*input)?.clone())
377            }
378        }
379    }
380}
381
382fn normalize_axis(axis: i32, ndim: usize) -> usize {
383    if axis < 0 {
384        (ndim as i32 + axis) as usize
385    } else {
386        axis as usize
387    }
388}
389
390fn reduce_axis(
391    data: &[f32],
392    shape: &[usize],
393    axis: i32,
394    keepdim: bool,
395    op: fn(f32, f32) -> f32,
396    init: f32,
397) -> JitResult<Vec<f32>> {
398    let axis = normalize_axis(axis, shape.len());
399
400    // Compute strides
401    let mut strides = vec![1usize; shape.len()];
402    for i in (0..shape.len() - 1).rev() {
403        strides[i] = strides[i + 1] * shape[i + 1];
404    }
405
406    // Compute output shape
407    let mut output_shape: Vec<usize> = shape.to_vec();
408    if keepdim {
409        output_shape[axis] = 1;
410    } else {
411        output_shape.remove(axis);
412    }
413
414    let output_numel: usize = output_shape.iter().product();
415    let mut result = vec![init; output_numel];
416
417    // Reduce
418    for (i, &val) in data.iter().enumerate() {
419        // Convert linear index to multi-index
420        let mut multi_idx = vec![0usize; shape.len()];
421        let mut idx = i;
422        for (d, &st) in strides.iter().enumerate() {
423            multi_idx[d] = idx / st;
424            idx %= st;
425        }
426
427        // Compute output index
428        let out_idx = if keepdim {
429            let mut out_idx = 0;
430            let mut temp_strides = vec![1usize; output_shape.len()];
431            for d in (0..output_shape.len() - 1).rev() {
432                temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
433            }
434            for d in 0..output_shape.len() {
435                let dim_idx = if d == axis { 0 } else { multi_idx[d] };
436                out_idx += dim_idx * temp_strides[d];
437            }
438            out_idx
439        } else {
440            let mut out_idx = 0;
441            let mut temp_strides = vec![1usize; output_shape.len()];
442            if !output_shape.is_empty() {
443                for d in (0..output_shape.len() - 1).rev() {
444                    temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
445                }
446            }
447            let mut out_d = 0;
448            for (d, &mi) in multi_idx.iter().enumerate().take(shape.len()) {
449                if d == axis {
450                    continue;
451                }
452                if out_d < temp_strides.len() {
453                    out_idx += mi * temp_strides[out_d];
454                }
455                out_d += 1;
456            }
457            out_idx
458        };
459
460        if out_idx < result.len() {
461            result[out_idx] = op(result[out_idx], val);
462        }
463    }
464
465    Ok(result)
466}
467
468fn matmul_impl(a: &[f32], b: &[f32], a_shape: &[usize], b_shape: &[usize]) -> JitResult<Vec<f32>> {
469    // Simple 2D matmul
470    if a_shape.len() != 2 || b_shape.len() != 2 {
471        return Err(JitError::UnsupportedOp(
472            "Only 2D matmul supported in interpreter".to_string(),
473        ));
474    }
475
476    let m = a_shape[0];
477    let k = a_shape[1];
478    let n = b_shape[1];
479
480    if k != b_shape[0] {
481        return Err(JitError::ShapeMismatch {
482            expected: vec![k],
483            found: vec![b_shape[0]],
484        });
485    }
486
487    let mut result = vec![0.0f32; m * n];
488
489    for i in 0..m {
490        for j in 0..n {
491            let mut sum = 0.0;
492            for p in 0..k {
493                sum += a[i * k + p] * b[p * n + j];
494            }
495            result[i * n + j] = sum;
496        }
497    }
498
499    Ok(result)
500}
501
502/// JIT compiler.
503pub struct JitCompiler {
504    optimizer: Optimizer,
505    cache: FunctionCache,
506    use_native: bool,
507}
508
509impl JitCompiler {
510    /// Creates a new JIT compiler.
511    pub fn new() -> Self {
512        Self {
513            optimizer: Optimizer::default_passes(),
514            cache: FunctionCache::default_size(),
515            use_native: false, // Fallback to interpreter for now
516        }
517    }
518
519    /// Creates a compiler with custom optimizer.
520    pub fn with_optimizer(optimizer: Optimizer) -> Self {
521        Self {
522            optimizer,
523            cache: FunctionCache::default_size(),
524            use_native: false,
525        }
526    }
527
528    /// Compiles a graph into an executable function.
529    pub fn compile(&self, graph: &Graph) -> JitResult<CompiledFunction> {
530        // Check cache
531        let cache_key = FunctionCache::hash_graph(graph);
532        if let Some(cached) = self.cache.get(cache_key) {
533            return Ok(cached);
534        }
535
536        // Validate graph
537        graph.validate().map_err(JitError::InvalidGraph)?;
538
539        // Optimize
540        let optimized = self.optimizer.optimize(graph.clone());
541
542        // Generate code
543        let func = if self.use_native {
544            self.compile_native(&optimized)?
545        } else {
546            self.compile_interpreted(&optimized)
547        };
548
549        // Cache result
550        self.cache.insert(cache_key, func.clone());
551
552        Ok(func)
553    }
554
555    fn compile_interpreted(&self, graph: &Graph) -> CompiledFunction {
556        CompiledFunction {
557            graph: Arc::new(graph.clone()),
558            kind: CompiledKind::Interpreted,
559        }
560    }
561
562    fn compile_native(&self, graph: &Graph) -> JitResult<CompiledFunction> {
563        use cranelift::prelude::*;
564        use cranelift_jit::{JITBuilder, JITModule};
565        use cranelift_module::{Linkage, Module};
566
567        // Initialize Cranelift JIT module
568        let mut flag_builder = settings::builder();
569        flag_builder.set("use_colocated_libcalls", "false").unwrap();
570        flag_builder.set("is_pic", "false").unwrap();
571        let isa_builder = cranelift_native::builder()
572            .map_err(|e| JitError::CompilationFailed(format!("Failed to get native ISA: {}", e)))?;
573        let isa = isa_builder
574            .finish(settings::Flags::new(flag_builder))
575            .map_err(|e| JitError::CompilationFailed(format!("Failed to build ISA: {}", e)))?;
576
577        let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
578        let mut module = JITModule::new(builder);
579
580        // Create function signature: fn(inputs: *const f32, outputs: *mut f32)
581        let mut sig = module.make_signature();
582        sig.params.push(AbiParam::new(types::I64)); // input ptr
583        sig.params.push(AbiParam::new(types::I64)); // output ptr
584
585        let func_id = module
586            .declare_function("jit_kernel", Linkage::Export, &sig)
587            .map_err(|e| {
588                JitError::CompilationFailed(format!("Failed to declare function: {}", e))
589            })?;
590
591        let mut ctx = module.make_context();
592        ctx.func.signature = sig;
593
594        // Build function body
595        let mut builder_ctx = FunctionBuilderContext::new();
596        {
597            let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_ctx);
598            let entry_block = builder.create_block();
599            builder.append_block_params_for_function_params(entry_block);
600            builder.switch_to_block(entry_block);
601            builder.seal_block(entry_block);
602
603            let input_ptr = builder.block_params(entry_block)[0];
604            let output_ptr = builder.block_params(entry_block)[1];
605
606            // Generate code for each operation in the graph
607            let mut values: Vec<Option<Value>> = vec![None; graph.len()];
608
609            for node in graph.nodes() {
610                let result = self.codegen_node(&mut builder, node, &values, input_ptr)?;
611                values[node.id.index()] = Some(result);
612            }
613
614            // Store output
615            if let Some((_, output_id)) = graph.outputs().iter().next() {
616                let output_node = graph.node(*output_id);
617                if let Op::Output { input, .. } = &output_node.op {
618                    if let Some(val) = values[input.index()] {
619                        builder.ins().store(MemFlags::new(), val, output_ptr, 0);
620                    }
621                }
622            }
623
624            builder.ins().return_(&[]);
625            builder.finalize();
626        }
627
628        // Compile the function
629        module.define_function(func_id, &mut ctx).map_err(|e| {
630            JitError::CompilationFailed(format!("Failed to define function: {}", e))
631        })?;
632        module.clear_context(&mut ctx);
633        module
634            .finalize_definitions()
635            .map_err(|e| JitError::CompilationFailed(format!("Failed to finalize: {:?}", e)))?;
636
637        let code_ptr = module.get_finalized_function(func_id);
638        let code_size = 0; // JITModule manages memory
639
640        // Leak the module to keep the code alive
641        std::mem::forget(module);
642
643        Ok(CompiledFunction {
644            graph: Arc::new(graph.clone()),
645            kind: CompiledKind::Native {
646                code_ptr,
647                code_size,
648            },
649        })
650    }
651
652    fn codegen_node(
653        &self,
654        builder: &mut cranelift::prelude::FunctionBuilder,
655        node: &Node,
656        values: &[Option<cranelift::prelude::Value>],
657        input_ptr: cranelift::prelude::Value,
658    ) -> JitResult<cranelift::prelude::Value> {
659        use cranelift::prelude::*;
660
661        let get = |id: NodeId| -> JitResult<Value> {
662            values[id.index()]
663                .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not compiled", id)))
664        };
665
666        match &node.op {
667            Op::Input { name, .. } => {
668                // Load from input pointer at appropriate offset
669                let offset = self.get_input_offset(name);
670                Ok(builder
671                    .ins()
672                    .load(types::F32, MemFlags::new(), input_ptr, offset))
673            }
674
675            Op::Output { input, .. } => get(*input),
676
677            Op::Constant { value } => Ok(builder.ins().f32const(*value as f32)),
678
679            Op::Add { lhs, rhs } => {
680                let a = get(*lhs)?;
681                let b = get(*rhs)?;
682                Ok(builder.ins().fadd(a, b))
683            }
684
685            Op::Sub { lhs, rhs } => {
686                let a = get(*lhs)?;
687                let b = get(*rhs)?;
688                Ok(builder.ins().fsub(a, b))
689            }
690
691            Op::Mul { lhs, rhs } => {
692                let a = get(*lhs)?;
693                let b = get(*rhs)?;
694                Ok(builder.ins().fmul(a, b))
695            }
696
697            Op::Div { lhs, rhs } => {
698                let a = get(*lhs)?;
699                let b = get(*rhs)?;
700                Ok(builder.ins().fdiv(a, b))
701            }
702
703            Op::Neg { input } => {
704                let a = get(*input)?;
705                Ok(builder.ins().fneg(a))
706            }
707
708            Op::Abs { input } => {
709                let a = get(*input)?;
710                Ok(builder.ins().fabs(a))
711            }
712
713            Op::Sqrt { input } => {
714                let a = get(*input)?;
715                Ok(builder.ins().sqrt(a))
716            }
717
718            Op::AddScalar { input, scalar } => {
719                let a = get(*input)?;
720                let s = builder.ins().f32const(*scalar as f32);
721                Ok(builder.ins().fadd(a, s))
722            }
723
724            Op::MulScalar { input, scalar } => {
725                let a = get(*input)?;
726                let s = builder.ins().f32const(*scalar as f32);
727                Ok(builder.ins().fmul(a, s))
728            }
729
730            // For operations not easily supported by Cranelift scalars,
731            // fall back to interpreted execution for the whole graph
732            _ => Err(JitError::UnsupportedOp(format!(
733                "Operation {:?} not supported in native codegen, using interpreter",
734                node.op
735            ))),
736        }
737    }
738
739    fn get_input_offset(&self, _name: &str) -> i32 {
740        // Simple offset calculation - in practice would use a mapping
741        0
742    }
743
744    /// Returns cache statistics.
745    pub fn cache_stats(&self) -> crate::cache::CacheStats {
746        self.cache.stats()
747    }
748
749    /// Clears the compilation cache.
750    pub fn clear_cache(&self) {
751        self.cache.clear();
752    }
753}
754
755impl Default for JitCompiler {
756    fn default() -> Self {
757        Self::new()
758    }
759}
760
761#[cfg(test)]
762mod tests {
763    use super::*;
764    use crate::trace::trace;
765
766    #[test]
767    fn test_compile_simple() {
768        let graph = trace(|tracer| {
769            let a = tracer.input("a", &[4]);
770            let b = tracer.input("b", &[4]);
771            let c = a.add(&b);
772            tracer.output("result", c)
773        });
774
775        let compiler = JitCompiler::new();
776        let func = compiler.compile(&graph).unwrap();
777
778        let a = [1.0, 2.0, 3.0, 4.0];
779        let b = [5.0, 6.0, 7.0, 8.0];
780        let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
781
782        assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
783    }
784
785    #[test]
786    fn test_compile_chain() {
787        let graph = trace(|tracer| {
788            let x = tracer.input("x", &[4]);
789            let y = x.relu().mul_scalar(2.0).add_scalar(1.0);
790            tracer.output("y", y)
791        });
792
793        let compiler = JitCompiler::new();
794        let func = compiler.compile(&graph).unwrap();
795
796        let x = [-1.0, 0.0, 1.0, 2.0];
797        let result = func.run(&[("x", &x)]).unwrap();
798
799        // relu([-1, 0, 1, 2]) = [0, 0, 1, 2]
800        // * 2 = [0, 0, 2, 4]
801        // + 1 = [1, 1, 3, 5]
802        assert_eq!(result, vec![1.0, 1.0, 3.0, 5.0]);
803    }
804
805    #[test]
806    fn test_compile_activations() {
807        let graph = trace(|tracer| {
808            let x = tracer.input("x", &[3]);
809            let y = x.sigmoid();
810            tracer.output("y", y)
811        });
812
813        let compiler = JitCompiler::new();
814        let func = compiler.compile(&graph).unwrap();
815
816        let x = [0.0, 1.0, -1.0];
817        let result = func.run(&[("x", &x)]).unwrap();
818
819        // sigmoid(0) = 0.5
820        assert!((result[0] - 0.5).abs() < 0.01);
821        // sigmoid(1) ≈ 0.731
822        assert!((result[1] - 0.731).abs() < 0.01);
823    }
824
825    #[test]
826    fn test_compile_matmul() {
827        let graph = trace(|tracer| {
828            let a = tracer.input("a", &[2, 3]);
829            let b = tracer.input("b", &[3, 2]);
830            let c = a.matmul(&b);
831            tracer.output("c", c)
832        });
833
834        let compiler = JitCompiler::new();
835        let func = compiler.compile(&graph).unwrap();
836
837        // Identity-like matrices
838        let a = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]; // 2x3
839        let b = [1.0, 0.0, 0.0, 1.0, 0.0, 0.0]; // 3x2
840        let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
841
842        assert_eq!(result.len(), 4); // 2x2
843    }
844
845    #[test]
846    fn test_caching() {
847        let graph = trace(|tracer| {
848            let x = tracer.input("x", &[4]);
849            tracer.output("y", x.relu())
850        });
851
852        let compiler = JitCompiler::new();
853        assert_eq!(compiler.cache_stats().entries, 0);
854
855        let _ = compiler.compile(&graph).unwrap();
856        assert_eq!(compiler.cache_stats().entries, 1);
857
858        // Second compile should use cache
859        let _ = compiler.compile(&graph).unwrap();
860        assert_eq!(compiler.cache_stats().entries, 1);
861    }
862}