Skip to main content

axonml_jit/
codegen.rs

1//! Code Generation
2//!
3//! Generates native code from computation graphs using Cranelift.
4
5use std::sync::Arc;
6
7use crate::ir::{Graph, Node, NodeId, Op};
8use crate::error::{JitError, JitResult};
9use crate::cache::FunctionCache;
10use crate::optimize::Optimizer;
11
12/// A compiled function ready for execution.
13#[derive(Clone)]
14pub struct CompiledFunction {
15    /// The original graph.
16    graph: Arc<Graph>,
17    /// Function kind.
18    kind: CompiledKind,
19}
20
21#[derive(Clone)]
22enum CompiledKind {
23    /// Interpreted execution (fallback).
24    Interpreted,
25    /// Native code (future: Cranelift JIT).
26    #[allow(dead_code)]
27    Native {
28        /// Pointer to compiled code.
29        code_ptr: *const u8,
30        /// Code size.
31        code_size: usize,
32    },
33}
34
35// Safety: The native code pointer is never dereferenced without proper synchronization
36unsafe impl Send for CompiledKind {}
37unsafe impl Sync for CompiledKind {}
38
39impl CompiledFunction {
40    /// Creates a placeholder compiled function (for testing).
41    pub fn placeholder() -> Self {
42        Self {
43            graph: Arc::new(Graph::new()),
44            kind: CompiledKind::Interpreted,
45        }
46    }
47
48    /// Returns the graph.
49    pub fn graph(&self) -> &Graph {
50        &self.graph
51    }
52
53    /// Executes the compiled function with the given inputs.
54    pub fn run(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
55        match &self.kind {
56            CompiledKind::Interpreted => self.run_interpreted(inputs),
57            CompiledKind::Native { code_ptr, code_size } => {
58                // Native execution via function pointer call
59                // Safety: code_ptr points to valid compiled code from Cranelift
60                unsafe {
61                    let func: extern "C" fn(*const f32, *mut f32) = std::mem::transmute(code_ptr);
62                    let flat_inputs: Vec<f32> = inputs.iter().flat_map(|(_, d)| d.iter().copied()).collect();
63                    let mut output = vec![0.0f32; self.graph.outputs().len() * 1024]; // Max output size
64                    func(flat_inputs.as_ptr(), output.as_mut_ptr());
65                    let _ = code_size; // Used for memory management
66                    Ok(output)
67                }
68            }
69        }
70    }
71
72    /// Interpreted execution.
73    fn run_interpreted(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
74        let mut values: Vec<Option<Vec<f32>>> = vec![None; self.graph.len()];
75
76        // Set input values
77        for (name, data) in inputs {
78            if let Some(id) = self.graph.input(name) {
79                values[id.index()] = Some(data.to_vec());
80            } else {
81                return Err(JitError::InputNotFound(name.to_string()));
82            }
83        }
84
85        // Execute in topological order
86        for node in self.graph.nodes() {
87            let result = self.eval_node(node, &values)?;
88            values[node.id.index()] = Some(result);
89        }
90
91        // Get output value
92        if let Some((_, output_id)) = self.graph.outputs().iter().next() {
93            let output_node = self.graph.node(*output_id);
94            if let Op::Output { input, .. } = &output_node.op {
95                return Ok(values[input.index()].clone().unwrap_or_default());
96            }
97        }
98
99        Err(JitError::OutputNotFound("no output".to_string()))
100    }
101
102    fn eval_node(&self, node: &Node, values: &[Option<Vec<f32>>]) -> JitResult<Vec<f32>> {
103        let get = |id: NodeId| -> JitResult<&Vec<f32>> {
104            values[id.index()]
105                .as_ref()
106                .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not computed", id)))
107        };
108
109        match &node.op {
110            Op::Input { .. } => {
111                // Already set
112                Ok(values[node.id.index()].clone().unwrap_or_default())
113            }
114
115            Op::Output { input, .. } => {
116                Ok(get(*input)?.clone())
117            }
118
119            Op::Constant { value } => {
120                let numel = node.shape.numel();
121                Ok(vec![*value as f32; numel])
122            }
123
124            // Binary ops
125            Op::Add { lhs, rhs } => {
126                let a = get(*lhs)?;
127                let b = get(*rhs)?;
128                Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
129            }
130
131            Op::Sub { lhs, rhs } => {
132                let a = get(*lhs)?;
133                let b = get(*rhs)?;
134                Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
135            }
136
137            Op::Mul { lhs, rhs } => {
138                let a = get(*lhs)?;
139                let b = get(*rhs)?;
140                Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).collect())
141            }
142
143            Op::Div { lhs, rhs } => {
144                let a = get(*lhs)?;
145                let b = get(*rhs)?;
146                Ok(a.iter().zip(b.iter()).map(|(x, y)| x / y).collect())
147            }
148
149            Op::Pow { base, exp } => {
150                let a = get(*base)?;
151                let b = get(*exp)?;
152                Ok(a.iter().zip(b.iter()).map(|(x, y)| x.powf(*y)).collect())
153            }
154
155            Op::Max { lhs, rhs } => {
156                let a = get(*lhs)?;
157                let b = get(*rhs)?;
158                Ok(a.iter().zip(b.iter()).map(|(x, y)| x.max(*y)).collect())
159            }
160
161            Op::Min { lhs, rhs } => {
162                let a = get(*lhs)?;
163                let b = get(*rhs)?;
164                Ok(a.iter().zip(b.iter()).map(|(x, y)| x.min(*y)).collect())
165            }
166
167            // Scalar ops
168            Op::AddScalar { input, scalar } => {
169                let a = get(*input)?;
170                Ok(a.iter().map(|x| x + *scalar as f32).collect())
171            }
172
173            Op::MulScalar { input, scalar } => {
174                let a = get(*input)?;
175                Ok(a.iter().map(|x| x * *scalar as f32).collect())
176            }
177
178            // Unary ops
179            Op::Neg { input } => {
180                let a = get(*input)?;
181                Ok(a.iter().map(|x| -x).collect())
182            }
183
184            Op::Abs { input } => {
185                let a = get(*input)?;
186                Ok(a.iter().map(|x| x.abs()).collect())
187            }
188
189            Op::Sqrt { input } => {
190                let a = get(*input)?;
191                Ok(a.iter().map(|x| x.sqrt()).collect())
192            }
193
194            Op::Exp { input } => {
195                let a = get(*input)?;
196                Ok(a.iter().map(|x| x.exp()).collect())
197            }
198
199            Op::Log { input } => {
200                let a = get(*input)?;
201                Ok(a.iter().map(|x| x.ln()).collect())
202            }
203
204            Op::Sin { input } => {
205                let a = get(*input)?;
206                Ok(a.iter().map(|x| x.sin()).collect())
207            }
208
209            Op::Cos { input } => {
210                let a = get(*input)?;
211                Ok(a.iter().map(|x| x.cos()).collect())
212            }
213
214            Op::Tanh { input } => {
215                let a = get(*input)?;
216                Ok(a.iter().map(|x| x.tanh()).collect())
217            }
218
219            // Activations
220            Op::Relu { input } => {
221                let a = get(*input)?;
222                Ok(a.iter().map(|x| x.max(0.0)).collect())
223            }
224
225            Op::Sigmoid { input } => {
226                let a = get(*input)?;
227                Ok(a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect())
228            }
229
230            Op::Gelu { input } => {
231                let a = get(*input)?;
232                const SQRT_2_OVER_PI: f32 = 0.7978845608;
233                Ok(a.iter().map(|x| {
234                    0.5 * x * (1.0 + (SQRT_2_OVER_PI * (x + 0.044715 * x.powi(3))).tanh())
235                }).collect())
236            }
237
238            Op::Silu { input } => {
239                let a = get(*input)?;
240                Ok(a.iter().map(|x| x / (1.0 + (-x).exp())).collect())
241            }
242
243            // Reductions
244            Op::Sum { input } => {
245                let a = get(*input)?;
246                Ok(vec![a.iter().sum()])
247            }
248
249            Op::Mean { input } => {
250                let a = get(*input)?;
251                let sum: f32 = a.iter().sum();
252                Ok(vec![sum / a.len() as f32])
253            }
254
255            Op::SumAxis { input, axis, keepdim } => {
256                // Simplified: just sum all for now
257                let a = get(*input)?;
258                let input_node = self.graph.node(*input);
259                let input_shape = input_node.shape.dims();
260
261                reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)
262            }
263
264            Op::MeanAxis { input, axis, keepdim } => {
265                let a = get(*input)?;
266                let input_node = self.graph.node(*input);
267                let input_shape = input_node.shape.dims();
268                let axis_size = input_shape[normalize_axis(*axis, input_shape.len())];
269
270                let sum = reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)?;
271                Ok(sum.iter().map(|x| x / axis_size as f32).collect())
272            }
273
274            Op::MaxAxis { input, axis, keepdim } => {
275                let a = get(*input)?;
276                let input_node = self.graph.node(*input);
277                let input_shape = input_node.shape.dims();
278
279                reduce_axis(a, input_shape, *axis, *keepdim, f32::max, f32::NEG_INFINITY)
280            }
281
282            // Shape ops - for interpreter, just pass through
283            Op::Reshape { input, .. } |
284            Op::Transpose { input, .. } |
285            Op::Squeeze { input, .. } |
286            Op::Unsqueeze { input, .. } |
287            Op::Broadcast { input, .. } |
288            Op::Contiguous { input } => {
289                Ok(get(*input)?.clone())
290            }
291
292            // Matrix multiplication
293            Op::MatMul { lhs, rhs } => {
294                let a = get(*lhs)?;
295                let b = get(*rhs)?;
296                let lhs_node = self.graph.node(*lhs);
297                let rhs_node = self.graph.node(*rhs);
298
299                let lhs_shape = lhs_node.shape.dims();
300                let rhs_shape = rhs_node.shape.dims();
301
302                matmul_impl(a, b, lhs_shape, rhs_shape)
303            }
304
305            // Comparisons
306            Op::Gt { lhs, rhs } => {
307                let a = get(*lhs)?;
308                let b = get(*rhs)?;
309                Ok(a.iter().zip(b.iter()).map(|(x, y)| if x > y { 1.0 } else { 0.0 }).collect())
310            }
311
312            Op::Lt { lhs, rhs } => {
313                let a = get(*lhs)?;
314                let b = get(*rhs)?;
315                Ok(a.iter().zip(b.iter()).map(|(x, y)| if x < y { 1.0 } else { 0.0 }).collect())
316            }
317
318            Op::Eq { lhs, rhs } => {
319                let a = get(*lhs)?;
320                let b = get(*rhs)?;
321                Ok(a.iter().zip(b.iter()).map(|(x, y)| if (x - y).abs() < f32::EPSILON { 1.0 } else { 0.0 }).collect())
322            }
323
324            Op::Where { condition, x, y } => {
325                let cond = get(*condition)?;
326                let a = get(*x)?;
327                let b = get(*y)?;
328                Ok(cond.iter().zip(a.iter().zip(b.iter())).map(|(c, (a, b))| {
329                    if *c != 0.0 { *a } else { *b }
330                }).collect())
331            }
332
333            Op::Cast { input, .. } => {
334                // For f32, just pass through
335                Ok(get(*input)?.clone())
336            }
337        }
338    }
339}
340
341fn normalize_axis(axis: i32, ndim: usize) -> usize {
342    if axis < 0 {
343        (ndim as i32 + axis) as usize
344    } else {
345        axis as usize
346    }
347}
348
349fn reduce_axis(
350    data: &[f32],
351    shape: &[usize],
352    axis: i32,
353    keepdim: bool,
354    op: fn(f32, f32) -> f32,
355    init: f32,
356) -> JitResult<Vec<f32>> {
357    let axis = normalize_axis(axis, shape.len());
358
359    // Compute strides
360    let mut strides = vec![1usize; shape.len()];
361    for i in (0..shape.len() - 1).rev() {
362        strides[i] = strides[i + 1] * shape[i + 1];
363    }
364
365    // Compute output shape
366    let mut output_shape: Vec<usize> = shape.to_vec();
367    if keepdim {
368        output_shape[axis] = 1;
369    } else {
370        output_shape.remove(axis);
371    }
372
373    let output_numel: usize = output_shape.iter().product();
374    let mut result = vec![init; output_numel];
375
376    // Reduce
377    for i in 0..data.len() {
378        // Convert linear index to multi-index
379        let mut multi_idx = vec![0usize; shape.len()];
380        let mut idx = i;
381        for d in 0..shape.len() {
382            multi_idx[d] = idx / strides[d];
383            idx %= strides[d];
384        }
385
386        // Compute output index
387        let out_idx = if keepdim {
388            let mut out_idx = 0;
389            let mut temp_strides = vec![1usize; output_shape.len()];
390            for d in (0..output_shape.len() - 1).rev() {
391                temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
392            }
393            for d in 0..output_shape.len() {
394                let dim_idx = if d == axis { 0 } else { multi_idx[d] };
395                out_idx += dim_idx * temp_strides[d];
396            }
397            out_idx
398        } else {
399            let mut out_idx = 0;
400            let mut temp_strides = vec![1usize; output_shape.len()];
401            if !output_shape.is_empty() {
402                for d in (0..output_shape.len() - 1).rev() {
403                    temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
404                }
405            }
406            let mut out_d = 0;
407            for d in 0..shape.len() {
408                if d == axis {
409                    continue;
410                }
411                if out_d < temp_strides.len() {
412                    out_idx += multi_idx[d] * temp_strides[out_d];
413                }
414                out_d += 1;
415            }
416            out_idx
417        };
418
419        if out_idx < result.len() {
420            result[out_idx] = op(result[out_idx], data[i]);
421        }
422    }
423
424    Ok(result)
425}
426
427fn matmul_impl(a: &[f32], b: &[f32], a_shape: &[usize], b_shape: &[usize]) -> JitResult<Vec<f32>> {
428    // Simple 2D matmul
429    if a_shape.len() != 2 || b_shape.len() != 2 {
430        return Err(JitError::UnsupportedOp("Only 2D matmul supported in interpreter".to_string()));
431    }
432
433    let m = a_shape[0];
434    let k = a_shape[1];
435    let n = b_shape[1];
436
437    if k != b_shape[0] {
438        return Err(JitError::ShapeMismatch {
439            expected: vec![k],
440            found: vec![b_shape[0]],
441        });
442    }
443
444    let mut result = vec![0.0f32; m * n];
445
446    for i in 0..m {
447        for j in 0..n {
448            let mut sum = 0.0;
449            for p in 0..k {
450                sum += a[i * k + p] * b[p * n + j];
451            }
452            result[i * n + j] = sum;
453        }
454    }
455
456    Ok(result)
457}
458
459/// JIT compiler.
460pub struct JitCompiler {
461    optimizer: Optimizer,
462    cache: FunctionCache,
463    use_native: bool,
464}
465
466impl JitCompiler {
467    /// Creates a new JIT compiler.
468    pub fn new() -> Self {
469        Self {
470            optimizer: Optimizer::default_passes(),
471            cache: FunctionCache::default_size(),
472            use_native: false, // Fallback to interpreter for now
473        }
474    }
475
476    /// Creates a compiler with custom optimizer.
477    pub fn with_optimizer(optimizer: Optimizer) -> Self {
478        Self {
479            optimizer,
480            cache: FunctionCache::default_size(),
481            use_native: false,
482        }
483    }
484
485    /// Compiles a graph into an executable function.
486    pub fn compile(&self, graph: &Graph) -> JitResult<CompiledFunction> {
487        // Check cache
488        let cache_key = FunctionCache::hash_graph(graph);
489        if let Some(cached) = self.cache.get(cache_key) {
490            return Ok(cached);
491        }
492
493        // Validate graph
494        graph.validate().map_err(JitError::InvalidGraph)?;
495
496        // Optimize
497        let optimized = self.optimizer.optimize(graph.clone());
498
499        // Generate code
500        let func = if self.use_native {
501            self.compile_native(&optimized)?
502        } else {
503            self.compile_interpreted(&optimized)
504        };
505
506        // Cache result
507        self.cache.insert(cache_key, func.clone());
508
509        Ok(func)
510    }
511
512    fn compile_interpreted(&self, graph: &Graph) -> CompiledFunction {
513        CompiledFunction {
514            graph: Arc::new(graph.clone()),
515            kind: CompiledKind::Interpreted,
516        }
517    }
518
519    fn compile_native(&self, graph: &Graph) -> JitResult<CompiledFunction> {
520        use cranelift::prelude::*;
521        use cranelift_jit::{JITBuilder, JITModule};
522        use cranelift_module::{Linkage, Module};
523
524        // Initialize Cranelift JIT module
525        let mut flag_builder = settings::builder();
526        flag_builder.set("use_colocated_libcalls", "false").unwrap();
527        flag_builder.set("is_pic", "false").unwrap();
528        let isa_builder = cranelift_native::builder()
529            .map_err(|e| JitError::CompilationFailed(format!("Failed to get native ISA: {}", e)))?;
530        let isa = isa_builder
531            .finish(settings::Flags::new(flag_builder))
532            .map_err(|e| JitError::CompilationFailed(format!("Failed to build ISA: {}", e)))?;
533
534        let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
535        let mut module = JITModule::new(builder);
536
537        // Create function signature: fn(inputs: *const f32, outputs: *mut f32)
538        let mut sig = module.make_signature();
539        sig.params.push(AbiParam::new(types::I64)); // input ptr
540        sig.params.push(AbiParam::new(types::I64)); // output ptr
541
542        let func_id = module
543            .declare_function("jit_kernel", Linkage::Export, &sig)
544            .map_err(|e| JitError::CompilationFailed(format!("Failed to declare function: {}", e)))?;
545
546        let mut ctx = module.make_context();
547        ctx.func.signature = sig;
548
549        // Build function body
550        let mut builder_ctx = FunctionBuilderContext::new();
551        {
552            let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_ctx);
553            let entry_block = builder.create_block();
554            builder.append_block_params_for_function_params(entry_block);
555            builder.switch_to_block(entry_block);
556            builder.seal_block(entry_block);
557
558            let input_ptr = builder.block_params(entry_block)[0];
559            let output_ptr = builder.block_params(entry_block)[1];
560
561            // Generate code for each operation in the graph
562            let mut values: Vec<Option<Value>> = vec![None; graph.len()];
563
564            for node in graph.nodes() {
565                let result = self.codegen_node(&mut builder, node, &values, input_ptr)?;
566                values[node.id.index()] = Some(result);
567            }
568
569            // Store output
570            if let Some((_, output_id)) = graph.outputs().iter().next() {
571                let output_node = graph.node(*output_id);
572                if let Op::Output { input, .. } = &output_node.op {
573                    if let Some(val) = values[input.index()] {
574                        builder.ins().store(MemFlags::new(), val, output_ptr, 0);
575                    }
576                }
577            }
578
579            builder.ins().return_(&[]);
580            builder.finalize();
581        }
582
583        // Compile the function
584        module
585            .define_function(func_id, &mut ctx)
586            .map_err(|e| JitError::CompilationFailed(format!("Failed to define function: {}", e)))?;
587        module.clear_context(&mut ctx);
588        module
589            .finalize_definitions()
590            .map_err(|e| JitError::CompilationFailed(format!("Failed to finalize: {:?}", e)))?;
591
592        let code_ptr = module.get_finalized_function(func_id);
593        let code_size = 0; // JITModule manages memory
594
595        // Leak the module to keep the code alive
596        std::mem::forget(module);
597
598        Ok(CompiledFunction {
599            graph: Arc::new(graph.clone()),
600            kind: CompiledKind::Native {
601                code_ptr: code_ptr as *const u8,
602                code_size,
603            },
604        })
605    }
606
607    fn codegen_node(
608        &self,
609        builder: &mut cranelift::prelude::FunctionBuilder,
610        node: &Node,
611        values: &[Option<cranelift::prelude::Value>],
612        input_ptr: cranelift::prelude::Value,
613    ) -> JitResult<cranelift::prelude::Value> {
614        use cranelift::prelude::*;
615
616        let get = |id: NodeId| -> JitResult<Value> {
617            values[id.index()]
618                .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not compiled", id)))
619        };
620
621        match &node.op {
622            Op::Input { name, .. } => {
623                // Load from input pointer at appropriate offset
624                let offset = self.get_input_offset(name);
625                Ok(builder.ins().load(types::F32, MemFlags::new(), input_ptr, offset))
626            }
627
628            Op::Output { input, .. } => get(*input),
629
630            Op::Constant { value } => Ok(builder.ins().f32const(*value as f32)),
631
632            Op::Add { lhs, rhs } => {
633                let a = get(*lhs)?;
634                let b = get(*rhs)?;
635                Ok(builder.ins().fadd(a, b))
636            }
637
638            Op::Sub { lhs, rhs } => {
639                let a = get(*lhs)?;
640                let b = get(*rhs)?;
641                Ok(builder.ins().fsub(a, b))
642            }
643
644            Op::Mul { lhs, rhs } => {
645                let a = get(*lhs)?;
646                let b = get(*rhs)?;
647                Ok(builder.ins().fmul(a, b))
648            }
649
650            Op::Div { lhs, rhs } => {
651                let a = get(*lhs)?;
652                let b = get(*rhs)?;
653                Ok(builder.ins().fdiv(a, b))
654            }
655
656            Op::Neg { input } => {
657                let a = get(*input)?;
658                Ok(builder.ins().fneg(a))
659            }
660
661            Op::Abs { input } => {
662                let a = get(*input)?;
663                Ok(builder.ins().fabs(a))
664            }
665
666            Op::Sqrt { input } => {
667                let a = get(*input)?;
668                Ok(builder.ins().sqrt(a))
669            }
670
671            Op::AddScalar { input, scalar } => {
672                let a = get(*input)?;
673                let s = builder.ins().f32const(*scalar as f32);
674                Ok(builder.ins().fadd(a, s))
675            }
676
677            Op::MulScalar { input, scalar } => {
678                let a = get(*input)?;
679                let s = builder.ins().f32const(*scalar as f32);
680                Ok(builder.ins().fmul(a, s))
681            }
682
683            // For operations not easily supported by Cranelift scalars,
684            // fall back to interpreted execution for the whole graph
685            _ => Err(JitError::UnsupportedOp(format!(
686                "Operation {:?} not supported in native codegen, using interpreter",
687                node.op
688            ))),
689        }
690    }
691
692    fn get_input_offset(&self, _name: &str) -> i32 {
693        // Simple offset calculation - in practice would use a mapping
694        0
695    }
696
697    /// Returns cache statistics.
698    pub fn cache_stats(&self) -> crate::cache::CacheStats {
699        self.cache.stats()
700    }
701
702    /// Clears the compilation cache.
703    pub fn clear_cache(&self) {
704        self.cache.clear();
705    }
706}
707
708impl Default for JitCompiler {
709    fn default() -> Self {
710        Self::new()
711    }
712}
713
714#[cfg(test)]
715mod tests {
716    use super::*;
717    use crate::trace::trace;
718
719    #[test]
720    fn test_compile_simple() {
721        let graph = trace(|tracer| {
722            let a = tracer.input("a", &[4]);
723            let b = tracer.input("b", &[4]);
724            let c = a.add(&b);
725            tracer.output("result", c)
726        });
727
728        let compiler = JitCompiler::new();
729        let func = compiler.compile(&graph).unwrap();
730
731        let a = [1.0, 2.0, 3.0, 4.0];
732        let b = [5.0, 6.0, 7.0, 8.0];
733        let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
734
735        assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
736    }
737
738    #[test]
739    fn test_compile_chain() {
740        let graph = trace(|tracer| {
741            let x = tracer.input("x", &[4]);
742            let y = x.relu().mul_scalar(2.0).add_scalar(1.0);
743            tracer.output("y", y)
744        });
745
746        let compiler = JitCompiler::new();
747        let func = compiler.compile(&graph).unwrap();
748
749        let x = [-1.0, 0.0, 1.0, 2.0];
750        let result = func.run(&[("x", &x)]).unwrap();
751
752        // relu([-1, 0, 1, 2]) = [0, 0, 1, 2]
753        // * 2 = [0, 0, 2, 4]
754        // + 1 = [1, 1, 3, 5]
755        assert_eq!(result, vec![1.0, 1.0, 3.0, 5.0]);
756    }
757
758    #[test]
759    fn test_compile_activations() {
760        let graph = trace(|tracer| {
761            let x = tracer.input("x", &[3]);
762            let y = x.sigmoid();
763            tracer.output("y", y)
764        });
765
766        let compiler = JitCompiler::new();
767        let func = compiler.compile(&graph).unwrap();
768
769        let x = [0.0, 1.0, -1.0];
770        let result = func.run(&[("x", &x)]).unwrap();
771
772        // sigmoid(0) = 0.5
773        assert!((result[0] - 0.5).abs() < 0.01);
774        // sigmoid(1) ≈ 0.731
775        assert!((result[1] - 0.731).abs() < 0.01);
776    }
777
778    #[test]
779    fn test_compile_matmul() {
780        let graph = trace(|tracer| {
781            let a = tracer.input("a", &[2, 3]);
782            let b = tracer.input("b", &[3, 2]);
783            let c = a.matmul(&b);
784            tracer.output("c", c)
785        });
786
787        let compiler = JitCompiler::new();
788        let func = compiler.compile(&graph).unwrap();
789
790        // Identity-like matrices
791        let a = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]; // 2x3
792        let b = [1.0, 0.0, 0.0, 1.0, 0.0, 0.0]; // 3x2
793        let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
794
795        assert_eq!(result.len(), 4); // 2x2
796    }
797
798    #[test]
799    fn test_caching() {
800        let graph = trace(|tracer| {
801            let x = tracer.input("x", &[4]);
802            tracer.output("y", x.relu())
803        });
804
805        let compiler = JitCompiler::new();
806        assert_eq!(compiler.cache_stats().entries, 0);
807
808        let _ = compiler.compile(&graph).unwrap();
809        assert_eq!(compiler.cache_stats().entries, 1);
810
811        // Second compile should use cache
812        let _ = compiler.compile(&graph).unwrap();
813        assert_eq!(compiler.cache_stats().entries, 1);
814    }
815}