axonml_jit/
codegen.rs

1//! Code Generation
2//!
3//! Generates native code from computation graphs using Cranelift.
4
5use std::sync::Arc;
6
7use crate::ir::{Graph, Node, NodeId, Op};
8use crate::error::{JitError, JitResult};
9use crate::cache::FunctionCache;
10use crate::optimize::Optimizer;
11
12/// A compiled function ready for execution.
13#[derive(Clone)]
14pub struct CompiledFunction {
15    /// The original graph.
16    graph: Arc<Graph>,
17    /// Function kind.
18    kind: CompiledKind,
19}
20
21#[derive(Clone)]
22enum CompiledKind {
23    /// Interpreted execution (fallback).
24    Interpreted,
25    /// Native code (future: Cranelift JIT).
26    #[allow(dead_code)]
27    Native {
28        /// Pointer to compiled code.
29        code_ptr: *const u8,
30        /// Code size.
31        code_size: usize,
32    },
33}
34
35// Safety: The native code pointer is never dereferenced without proper synchronization
36unsafe impl Send for CompiledKind {}
37unsafe impl Sync for CompiledKind {}
38
39impl CompiledFunction {
40    /// Creates a placeholder compiled function (for testing).
41    pub fn placeholder() -> Self {
42        Self {
43            graph: Arc::new(Graph::new()),
44            kind: CompiledKind::Interpreted,
45        }
46    }
47
48    /// Returns the graph.
49    pub fn graph(&self) -> &Graph {
50        &self.graph
51    }
52
53    /// Executes the compiled function with the given inputs.
54    pub fn run(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
55        match &self.kind {
56            CompiledKind::Interpreted => self.run_interpreted(inputs),
57            CompiledKind::Native { code_ptr, code_size } => {
58                // Native execution via function pointer call
59                // Safety: code_ptr points to valid compiled code from Cranelift
60                unsafe {
61                    let func: extern "C" fn(*const f32, *mut f32) = std::mem::transmute(code_ptr);
62                    let flat_inputs: Vec<f32> = inputs.iter().flat_map(|(_, d)| d.iter().copied()).collect();
63                    let mut output = vec![0.0f32; self.graph.outputs().len() * 1024]; // Max output size
64                    func(flat_inputs.as_ptr(), output.as_mut_ptr());
65                    let _ = code_size; // Used for memory management
66                    Ok(output)
67                }
68            }
69        }
70    }
71
72    /// Interpreted execution.
73    fn run_interpreted(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
74        let mut values: Vec<Option<Vec<f32>>> = vec![None; self.graph.len()];
75
76        // Set input values
77        for (name, data) in inputs {
78            if let Some(id) = self.graph.input(name) {
79                values[id.index()] = Some(data.to_vec());
80            } else {
81                return Err(JitError::InputNotFound(name.to_string()));
82            }
83        }
84
85        // Execute in topological order
86        for node in self.graph.nodes() {
87            let result = self.eval_node(node, &values)?;
88            values[node.id.index()] = Some(result);
89        }
90
91        // Get output value
92        if let Some((_, output_id)) = self.graph.outputs().iter().next() {
93            let output_node = self.graph.node(*output_id);
94            if let Op::Output { input, .. } = &output_node.op {
95                return Ok(values[input.index()].clone().unwrap_or_default());
96            }
97        }
98
99        Err(JitError::OutputNotFound("no output".to_string()))
100    }
101
102    fn eval_node(&self, node: &Node, values: &[Option<Vec<f32>>]) -> JitResult<Vec<f32>> {
103        let get = |id: NodeId| -> JitResult<&Vec<f32>> {
104            values[id.index()]
105                .as_ref()
106                .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not computed", id)))
107        };
108
109        match &node.op {
110            Op::Input { .. } => {
111                // Already set
112                Ok(values[node.id.index()].clone().unwrap_or_default())
113            }
114
115            Op::Output { input, .. } => {
116                Ok(get(*input)?.clone())
117            }
118
119            Op::Constant { value } => {
120                let numel = node.shape.numel();
121                Ok(vec![*value as f32; numel])
122            }
123
124            // Binary ops
125            Op::Add { lhs, rhs } => {
126                let a = get(*lhs)?;
127                let b = get(*rhs)?;
128                Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
129            }
130
131            Op::Sub { lhs, rhs } => {
132                let a = get(*lhs)?;
133                let b = get(*rhs)?;
134                Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
135            }
136
137            Op::Mul { lhs, rhs } => {
138                let a = get(*lhs)?;
139                let b = get(*rhs)?;
140                Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).collect())
141            }
142
143            Op::Div { lhs, rhs } => {
144                let a = get(*lhs)?;
145                let b = get(*rhs)?;
146                Ok(a.iter().zip(b.iter()).map(|(x, y)| x / y).collect())
147            }
148
149            Op::Pow { base, exp } => {
150                let a = get(*base)?;
151                let b = get(*exp)?;
152                Ok(a.iter().zip(b.iter()).map(|(x, y)| x.powf(*y)).collect())
153            }
154
155            Op::Max { lhs, rhs } => {
156                let a = get(*lhs)?;
157                let b = get(*rhs)?;
158                Ok(a.iter().zip(b.iter()).map(|(x, y)| x.max(*y)).collect())
159            }
160
161            Op::Min { lhs, rhs } => {
162                let a = get(*lhs)?;
163                let b = get(*rhs)?;
164                Ok(a.iter().zip(b.iter()).map(|(x, y)| x.min(*y)).collect())
165            }
166
167            // Scalar ops
168            Op::AddScalar { input, scalar } => {
169                let a = get(*input)?;
170                Ok(a.iter().map(|x| x + *scalar as f32).collect())
171            }
172
173            Op::MulScalar { input, scalar } => {
174                let a = get(*input)?;
175                Ok(a.iter().map(|x| x * *scalar as f32).collect())
176            }
177
178            // Unary ops
179            Op::Neg { input } => {
180                let a = get(*input)?;
181                Ok(a.iter().map(|x| -x).collect())
182            }
183
184            Op::Abs { input } => {
185                let a = get(*input)?;
186                Ok(a.iter().map(|x| x.abs()).collect())
187            }
188
189            Op::Sqrt { input } => {
190                let a = get(*input)?;
191                Ok(a.iter().map(|x| x.sqrt()).collect())
192            }
193
194            Op::Exp { input } => {
195                let a = get(*input)?;
196                Ok(a.iter().map(|x| x.exp()).collect())
197            }
198
199            Op::Log { input } => {
200                let a = get(*input)?;
201                Ok(a.iter().map(|x| x.ln()).collect())
202            }
203
204            Op::Sin { input } => {
205                let a = get(*input)?;
206                Ok(a.iter().map(|x| x.sin()).collect())
207            }
208
209            Op::Cos { input } => {
210                let a = get(*input)?;
211                Ok(a.iter().map(|x| x.cos()).collect())
212            }
213
214            Op::Tanh { input } => {
215                let a = get(*input)?;
216                Ok(a.iter().map(|x| x.tanh()).collect())
217            }
218
219            // Activations
220            Op::Relu { input } => {
221                let a = get(*input)?;
222                Ok(a.iter().map(|x| x.max(0.0)).collect())
223            }
224
225            Op::Sigmoid { input } => {
226                let a = get(*input)?;
227                Ok(a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect())
228            }
229
230            Op::Gelu { input } => {
231                let a = get(*input)?;
232                const SQRT_2_OVER_PI: f32 = 0.7978845608;
233                Ok(a.iter().map(|x| {
234                    0.5 * x * (1.0 + (SQRT_2_OVER_PI * (x + 0.044715 * x.powi(3))).tanh())
235                }).collect())
236            }
237
238            Op::Silu { input } => {
239                let a = get(*input)?;
240                Ok(a.iter().map(|x| x / (1.0 + (-x).exp())).collect())
241            }
242
243            // Reductions
244            Op::Sum { input } => {
245                let a = get(*input)?;
246                Ok(vec![a.iter().sum()])
247            }
248
249            Op::Mean { input } => {
250                let a = get(*input)?;
251                let sum: f32 = a.iter().sum();
252                Ok(vec![sum / a.len() as f32])
253            }
254
255            Op::SumAxis { input, axis, keepdim } => {
256                // Simplified: just sum all for now
257                let a = get(*input)?;
258                let input_node = self.graph.node(*input);
259                let input_shape = input_node.shape.dims();
260
261                reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)
262            }
263
264            Op::MeanAxis { input, axis, keepdim } => {
265                let a = get(*input)?;
266                let input_node = self.graph.node(*input);
267                let input_shape = input_node.shape.dims();
268                let axis_size = input_shape[normalize_axis(*axis, input_shape.len())];
269
270                let sum = reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)?;
271                Ok(sum.iter().map(|x| x / axis_size as f32).collect())
272            }
273
274            Op::MaxAxis { input, axis, keepdim } => {
275                let a = get(*input)?;
276                let input_node = self.graph.node(*input);
277                let input_shape = input_node.shape.dims();
278
279                reduce_axis(a, input_shape, *axis, *keepdim, f32::max, f32::NEG_INFINITY)
280            }
281
282            // Shape ops - for interpreter, just pass through
283            Op::Reshape { input, .. } |
284            Op::Transpose { input, .. } |
285            Op::Squeeze { input, .. } |
286            Op::Unsqueeze { input, .. } |
287            Op::Broadcast { input, .. } |
288            Op::Contiguous { input } => {
289                Ok(get(*input)?.clone())
290            }
291
292            // Matrix multiplication
293            Op::MatMul { lhs, rhs } => {
294                let a = get(*lhs)?;
295                let b = get(*rhs)?;
296                let lhs_node = self.graph.node(*lhs);
297                let rhs_node = self.graph.node(*rhs);
298
299                let lhs_shape = lhs_node.shape.dims();
300                let rhs_shape = rhs_node.shape.dims();
301
302                matmul_impl(a, b, lhs_shape, rhs_shape)
303            }
304
305            // Comparisons
306            Op::Gt { lhs, rhs } => {
307                let a = get(*lhs)?;
308                let b = get(*rhs)?;
309                Ok(a.iter().zip(b.iter()).map(|(x, y)| if x > y { 1.0 } else { 0.0 }).collect())
310            }
311
312            Op::Lt { lhs, rhs } => {
313                let a = get(*lhs)?;
314                let b = get(*rhs)?;
315                Ok(a.iter().zip(b.iter()).map(|(x, y)| if x < y { 1.0 } else { 0.0 }).collect())
316            }
317
318            Op::Eq { lhs, rhs } => {
319                let a = get(*lhs)?;
320                let b = get(*rhs)?;
321                Ok(a.iter().zip(b.iter()).map(|(x, y)| if (x - y).abs() < f32::EPSILON { 1.0 } else { 0.0 }).collect())
322            }
323
324            Op::Where { condition, x, y } => {
325                let cond = get(*condition)?;
326                let a = get(*x)?;
327                let b = get(*y)?;
328                Ok(cond.iter().zip(a.iter().zip(b.iter())).map(|(c, (a, b))| {
329                    if *c != 0.0 { *a } else { *b }
330                }).collect())
331            }
332
333            Op::Cast { input, .. } => {
334                // For f32, just pass through
335                Ok(get(*input)?.clone())
336            }
337        }
338    }
339}
340
341fn normalize_axis(axis: i32, ndim: usize) -> usize {
342    if axis < 0 {
343        (ndim as i32 + axis) as usize
344    } else {
345        axis as usize
346    }
347}
348
349fn reduce_axis(
350    data: &[f32],
351    shape: &[usize],
352    axis: i32,
353    keepdim: bool,
354    op: fn(f32, f32) -> f32,
355    init: f32,
356) -> JitResult<Vec<f32>> {
357    let axis = normalize_axis(axis, shape.len());
358
359    // Compute strides
360    let mut strides = vec![1usize; shape.len()];
361    for i in (0..shape.len() - 1).rev() {
362        strides[i] = strides[i + 1] * shape[i + 1];
363    }
364
365    // Compute output shape
366    let mut output_shape: Vec<usize> = shape.to_vec();
367    if keepdim {
368        output_shape[axis] = 1;
369    } else {
370        output_shape.remove(axis);
371    }
372
373    let output_numel: usize = output_shape.iter().product();
374    let mut result = vec![init; output_numel];
375
376    // Reduce
377    for i in 0..data.len() {
378        // Convert linear index to multi-index
379        let mut multi_idx = vec![0usize; shape.len()];
380        let mut idx = i;
381        for d in 0..shape.len() {
382            multi_idx[d] = idx / strides[d];
383            idx %= strides[d];
384        }
385
386        // Compute output index (skip axis)
387        let mut out_idx = 0;
388        let mut out_stride = 1;
389        for d in (0..shape.len()).rev() {
390            if d == axis {
391                continue;
392            }
393            out_idx += multi_idx[d] * out_stride;
394            let out_dim = if d > axis && !keepdim { d - 1 } else { d };
395            if out_dim + 1 < output_shape.len() {
396                out_stride *= output_shape[out_dim + 1];
397            }
398        }
399
400        if keepdim {
401            out_idx = 0;
402            out_stride = 1;
403            for d in (0..output_shape.len()).rev() {
404                if d == axis {
405                    out_stride *= output_shape[d];
406                    continue;
407                }
408                out_idx += multi_idx[d] * out_stride;
409                if d > 0 {
410                    out_stride *= output_shape[d - 1];
411                }
412            }
413            // Recalculate properly
414            out_idx = 0;
415            let mut temp_strides = vec![1usize; output_shape.len()];
416            for d in (0..output_shape.len() - 1).rev() {
417                temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
418            }
419            for d in 0..output_shape.len() {
420                let dim_idx = if d == axis { 0 } else { multi_idx[d] };
421                out_idx += dim_idx * temp_strides[d];
422            }
423        } else {
424            out_idx = 0;
425            let mut temp_strides = vec![1usize; output_shape.len()];
426            if !output_shape.is_empty() {
427                for d in (0..output_shape.len() - 1).rev() {
428                    temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
429                }
430            }
431            let mut out_d = 0;
432            for d in 0..shape.len() {
433                if d == axis {
434                    continue;
435                }
436                if out_d < temp_strides.len() {
437                    out_idx += multi_idx[d] * temp_strides[out_d];
438                }
439                out_d += 1;
440            }
441        }
442
443        if out_idx < result.len() {
444            result[out_idx] = op(result[out_idx], data[i]);
445        }
446    }
447
448    Ok(result)
449}
450
451fn matmul_impl(a: &[f32], b: &[f32], a_shape: &[usize], b_shape: &[usize]) -> JitResult<Vec<f32>> {
452    // Simple 2D matmul
453    if a_shape.len() != 2 || b_shape.len() != 2 {
454        return Err(JitError::UnsupportedOp("Only 2D matmul supported in interpreter".to_string()));
455    }
456
457    let m = a_shape[0];
458    let k = a_shape[1];
459    let n = b_shape[1];
460
461    if k != b_shape[0] {
462        return Err(JitError::ShapeMismatch {
463            expected: vec![k],
464            found: vec![b_shape[0]],
465        });
466    }
467
468    let mut result = vec![0.0f32; m * n];
469
470    for i in 0..m {
471        for j in 0..n {
472            let mut sum = 0.0;
473            for p in 0..k {
474                sum += a[i * k + p] * b[p * n + j];
475            }
476            result[i * n + j] = sum;
477        }
478    }
479
480    Ok(result)
481}
482
483/// JIT compiler.
484pub struct JitCompiler {
485    optimizer: Optimizer,
486    cache: FunctionCache,
487    use_native: bool,
488}
489
490impl JitCompiler {
491    /// Creates a new JIT compiler.
492    pub fn new() -> Self {
493        Self {
494            optimizer: Optimizer::default_passes(),
495            cache: FunctionCache::default_size(),
496            use_native: false, // Fallback to interpreter for now
497        }
498    }
499
500    /// Creates a compiler with custom optimizer.
501    pub fn with_optimizer(optimizer: Optimizer) -> Self {
502        Self {
503            optimizer,
504            cache: FunctionCache::default_size(),
505            use_native: false,
506        }
507    }
508
509    /// Compiles a graph into an executable function.
510    pub fn compile(&self, graph: &Graph) -> JitResult<CompiledFunction> {
511        // Check cache
512        let cache_key = FunctionCache::hash_graph(graph);
513        if let Some(cached) = self.cache.get(cache_key) {
514            return Ok(cached);
515        }
516
517        // Validate graph
518        graph.validate().map_err(JitError::InvalidGraph)?;
519
520        // Optimize
521        let optimized = self.optimizer.optimize(graph.clone());
522
523        // Generate code
524        let func = if self.use_native {
525            self.compile_native(&optimized)?
526        } else {
527            self.compile_interpreted(&optimized)
528        };
529
530        // Cache result
531        self.cache.insert(cache_key, func.clone());
532
533        Ok(func)
534    }
535
536    fn compile_interpreted(&self, graph: &Graph) -> CompiledFunction {
537        CompiledFunction {
538            graph: Arc::new(graph.clone()),
539            kind: CompiledKind::Interpreted,
540        }
541    }
542
543    fn compile_native(&self, graph: &Graph) -> JitResult<CompiledFunction> {
544        use cranelift::prelude::*;
545        use cranelift_jit::{JITBuilder, JITModule};
546        use cranelift_module::{Linkage, Module};
547
548        // Initialize Cranelift JIT module
549        let mut flag_builder = settings::builder();
550        flag_builder.set("use_colocated_libcalls", "false").unwrap();
551        flag_builder.set("is_pic", "false").unwrap();
552        let isa_builder = cranelift_native::builder()
553            .map_err(|e| JitError::CompilationFailed(format!("Failed to get native ISA: {}", e)))?;
554        let isa = isa_builder
555            .finish(settings::Flags::new(flag_builder))
556            .map_err(|e| JitError::CompilationFailed(format!("Failed to build ISA: {}", e)))?;
557
558        let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
559        let mut module = JITModule::new(builder);
560
561        // Create function signature: fn(inputs: *const f32, outputs: *mut f32)
562        let mut sig = module.make_signature();
563        sig.params.push(AbiParam::new(types::I64)); // input ptr
564        sig.params.push(AbiParam::new(types::I64)); // output ptr
565
566        let func_id = module
567            .declare_function("jit_kernel", Linkage::Export, &sig)
568            .map_err(|e| JitError::CompilationFailed(format!("Failed to declare function: {}", e)))?;
569
570        let mut ctx = module.make_context();
571        ctx.func.signature = sig;
572
573        // Build function body
574        let mut builder_ctx = FunctionBuilderContext::new();
575        {
576            let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_ctx);
577            let entry_block = builder.create_block();
578            builder.append_block_params_for_function_params(entry_block);
579            builder.switch_to_block(entry_block);
580            builder.seal_block(entry_block);
581
582            let input_ptr = builder.block_params(entry_block)[0];
583            let output_ptr = builder.block_params(entry_block)[1];
584
585            // Generate code for each operation in the graph
586            let mut values: Vec<Option<Value>> = vec![None; graph.len()];
587
588            for node in graph.nodes() {
589                let result = self.codegen_node(&mut builder, node, &values, input_ptr)?;
590                values[node.id.index()] = Some(result);
591            }
592
593            // Store output
594            if let Some((_, output_id)) = graph.outputs().iter().next() {
595                let output_node = graph.node(*output_id);
596                if let Op::Output { input, .. } = &output_node.op {
597                    if let Some(val) = values[input.index()] {
598                        builder.ins().store(MemFlags::new(), val, output_ptr, 0);
599                    }
600                }
601            }
602
603            builder.ins().return_(&[]);
604            builder.finalize();
605        }
606
607        // Compile the function
608        module
609            .define_function(func_id, &mut ctx)
610            .map_err(|e| JitError::CompilationFailed(format!("Failed to define function: {}", e)))?;
611        module.clear_context(&mut ctx);
612        module
613            .finalize_definitions()
614            .map_err(|e| JitError::CompilationFailed(format!("Failed to finalize: {:?}", e)))?;
615
616        let code_ptr = module.get_finalized_function(func_id);
617        let code_size = 0; // JITModule manages memory
618
619        // Leak the module to keep the code alive
620        std::mem::forget(module);
621
622        Ok(CompiledFunction {
623            graph: Arc::new(graph.clone()),
624            kind: CompiledKind::Native {
625                code_ptr: code_ptr as *const u8,
626                code_size,
627            },
628        })
629    }
630
631    fn codegen_node(
632        &self,
633        builder: &mut cranelift::prelude::FunctionBuilder,
634        node: &Node,
635        values: &[Option<cranelift::prelude::Value>],
636        input_ptr: cranelift::prelude::Value,
637    ) -> JitResult<cranelift::prelude::Value> {
638        use cranelift::prelude::*;
639
640        let get = |id: NodeId| -> JitResult<Value> {
641            values[id.index()]
642                .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not compiled", id)))
643        };
644
645        match &node.op {
646            Op::Input { name, .. } => {
647                // Load from input pointer at appropriate offset
648                let offset = self.get_input_offset(name);
649                Ok(builder.ins().load(types::F32, MemFlags::new(), input_ptr, offset))
650            }
651
652            Op::Output { input, .. } => get(*input),
653
654            Op::Constant { value } => Ok(builder.ins().f32const(*value as f32)),
655
656            Op::Add { lhs, rhs } => {
657                let a = get(*lhs)?;
658                let b = get(*rhs)?;
659                Ok(builder.ins().fadd(a, b))
660            }
661
662            Op::Sub { lhs, rhs } => {
663                let a = get(*lhs)?;
664                let b = get(*rhs)?;
665                Ok(builder.ins().fsub(a, b))
666            }
667
668            Op::Mul { lhs, rhs } => {
669                let a = get(*lhs)?;
670                let b = get(*rhs)?;
671                Ok(builder.ins().fmul(a, b))
672            }
673
674            Op::Div { lhs, rhs } => {
675                let a = get(*lhs)?;
676                let b = get(*rhs)?;
677                Ok(builder.ins().fdiv(a, b))
678            }
679
680            Op::Neg { input } => {
681                let a = get(*input)?;
682                Ok(builder.ins().fneg(a))
683            }
684
685            Op::Abs { input } => {
686                let a = get(*input)?;
687                Ok(builder.ins().fabs(a))
688            }
689
690            Op::Sqrt { input } => {
691                let a = get(*input)?;
692                Ok(builder.ins().sqrt(a))
693            }
694
695            Op::AddScalar { input, scalar } => {
696                let a = get(*input)?;
697                let s = builder.ins().f32const(*scalar as f32);
698                Ok(builder.ins().fadd(a, s))
699            }
700
701            Op::MulScalar { input, scalar } => {
702                let a = get(*input)?;
703                let s = builder.ins().f32const(*scalar as f32);
704                Ok(builder.ins().fmul(a, s))
705            }
706
707            // For operations not easily supported by Cranelift scalars,
708            // fall back to interpreted execution for the whole graph
709            _ => Err(JitError::UnsupportedOp(format!(
710                "Operation {:?} not supported in native codegen, using interpreter",
711                node.op
712            ))),
713        }
714    }
715
716    fn get_input_offset(&self, _name: &str) -> i32 {
717        // Simple offset calculation - in practice would use a mapping
718        0
719    }
720
721    /// Returns cache statistics.
722    pub fn cache_stats(&self) -> crate::cache::CacheStats {
723        self.cache.stats()
724    }
725
726    /// Clears the compilation cache.
727    pub fn clear_cache(&self) {
728        self.cache.clear();
729    }
730}
731
732impl Default for JitCompiler {
733    fn default() -> Self {
734        Self::new()
735    }
736}
737
738#[cfg(test)]
739mod tests {
740    use super::*;
741    use crate::trace::trace;
742
743    #[test]
744    fn test_compile_simple() {
745        let graph = trace(|tracer| {
746            let a = tracer.input("a", &[4]);
747            let b = tracer.input("b", &[4]);
748            let c = a.add(&b);
749            tracer.output("result", c)
750        });
751
752        let compiler = JitCompiler::new();
753        let func = compiler.compile(&graph).unwrap();
754
755        let a = [1.0, 2.0, 3.0, 4.0];
756        let b = [5.0, 6.0, 7.0, 8.0];
757        let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
758
759        assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
760    }
761
762    #[test]
763    fn test_compile_chain() {
764        let graph = trace(|tracer| {
765            let x = tracer.input("x", &[4]);
766            let y = x.relu().mul_scalar(2.0).add_scalar(1.0);
767            tracer.output("y", y)
768        });
769
770        let compiler = JitCompiler::new();
771        let func = compiler.compile(&graph).unwrap();
772
773        let x = [-1.0, 0.0, 1.0, 2.0];
774        let result = func.run(&[("x", &x)]).unwrap();
775
776        // relu([-1, 0, 1, 2]) = [0, 0, 1, 2]
777        // * 2 = [0, 0, 2, 4]
778        // + 1 = [1, 1, 3, 5]
779        assert_eq!(result, vec![1.0, 1.0, 3.0, 5.0]);
780    }
781
782    #[test]
783    fn test_compile_activations() {
784        let graph = trace(|tracer| {
785            let x = tracer.input("x", &[3]);
786            let y = x.sigmoid();
787            tracer.output("y", y)
788        });
789
790        let compiler = JitCompiler::new();
791        let func = compiler.compile(&graph).unwrap();
792
793        let x = [0.0, 1.0, -1.0];
794        let result = func.run(&[("x", &x)]).unwrap();
795
796        // sigmoid(0) = 0.5
797        assert!((result[0] - 0.5).abs() < 0.01);
798        // sigmoid(1) ≈ 0.731
799        assert!((result[1] - 0.731).abs() < 0.01);
800    }
801
802    #[test]
803    fn test_compile_matmul() {
804        let graph = trace(|tracer| {
805            let a = tracer.input("a", &[2, 3]);
806            let b = tracer.input("b", &[3, 2]);
807            let c = a.matmul(&b);
808            tracer.output("c", c)
809        });
810
811        let compiler = JitCompiler::new();
812        let func = compiler.compile(&graph).unwrap();
813
814        // Identity-like matrices
815        let a = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]; // 2x3
816        let b = [1.0, 0.0, 0.0, 1.0, 0.0, 0.0]; // 3x2
817        let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
818
819        assert_eq!(result.len(), 4); // 2x2
820    }
821
822    #[test]
823    fn test_caching() {
824        let graph = trace(|tracer| {
825            let x = tracer.input("x", &[4]);
826            tracer.output("y", x.relu())
827        });
828
829        let compiler = JitCompiler::new();
830        assert_eq!(compiler.cache_stats().entries, 0);
831
832        let _ = compiler.compile(&graph).unwrap();
833        assert_eq!(compiler.cache_stats().entries, 1);
834
835        // Second compile should use cache
836        let _ = compiler.compile(&graph).unwrap();
837        assert_eq!(compiler.cache_stats().entries, 1);
838    }
839}