Skip to main content

axonml_jit/
codegen.rs

1//! Code Generation
2//!
3//! Generates native code from computation graphs using Cranelift.
4
5use std::sync::Arc;
6
7use crate::cache::FunctionCache;
8use crate::error::{JitError, JitResult};
9use crate::ir::{Graph, Node, NodeId, Op};
10use crate::optimize::Optimizer;
11
12/// A compiled function ready for execution.
13#[derive(Clone)]
14pub struct CompiledFunction {
15    /// The original graph.
16    graph: Arc<Graph>,
17    /// Function kind.
18    kind: CompiledKind,
19}
20
21#[derive(Clone)]
22enum CompiledKind {
23    /// Interpreted execution (fallback).
24    Interpreted,
25    /// Native code (future: Cranelift JIT).
26    #[allow(dead_code)]
27    Native {
28        /// Pointer to compiled code.
29        code_ptr: *const u8,
30        /// Code size.
31        code_size: usize,
32    },
33}
34
35// Safety: The native code pointer is never dereferenced without proper synchronization
36unsafe impl Send for CompiledKind {}
37unsafe impl Sync for CompiledKind {}
38
39impl CompiledFunction {
40    /// Creates a placeholder compiled function (for testing).
41    pub fn placeholder() -> Self {
42        Self {
43            graph: Arc::new(Graph::new()),
44            kind: CompiledKind::Interpreted,
45        }
46    }
47
48    /// Returns the graph.
49    pub fn graph(&self) -> &Graph {
50        &self.graph
51    }
52
53    /// Executes the compiled function with the given inputs.
54    pub fn run(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
55        match &self.kind {
56            CompiledKind::Interpreted => self.run_interpreted(inputs),
57            CompiledKind::Native {
58                code_ptr,
59                code_size,
60            } => {
61                // Native execution via function pointer call
62                // Safety: code_ptr points to valid compiled code from Cranelift
63                unsafe {
64                    let func: extern "C" fn(*const f32, *mut f32) = std::mem::transmute(code_ptr);
65                    let flat_inputs: Vec<f32> =
66                        inputs.iter().flat_map(|(_, d)| d.iter().copied()).collect();
67                    let mut output = vec![0.0f32; self.graph.outputs().len() * 1024]; // Max output size
68                    func(flat_inputs.as_ptr(), output.as_mut_ptr());
69                    let _ = code_size; // Used for memory management
70                    Ok(output)
71                }
72            }
73        }
74    }
75
76    /// Interpreted execution.
77    fn run_interpreted(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
78        let mut values: Vec<Option<Vec<f32>>> = vec![None; self.graph.len()];
79
80        // Set input values
81        for (name, data) in inputs {
82            if let Some(id) = self.graph.input(name) {
83                values[id.index()] = Some(data.to_vec());
84            } else {
85                return Err(JitError::InputNotFound(name.to_string()));
86            }
87        }
88
89        // Execute in topological order
90        for node in self.graph.nodes() {
91            let result = self.eval_node(node, &values)?;
92            values[node.id.index()] = Some(result);
93        }
94
95        // Get output value
96        if let Some((_, output_id)) = self.graph.outputs().iter().next() {
97            let output_node = self.graph.node(*output_id);
98            if let Op::Output { input, .. } = &output_node.op {
99                return Ok(values[input.index()].clone().unwrap_or_default());
100            }
101        }
102
103        Err(JitError::OutputNotFound("no output".to_string()))
104    }
105
106    fn eval_node(&self, node: &Node, values: &[Option<Vec<f32>>]) -> JitResult<Vec<f32>> {
107        let get = |id: NodeId| -> JitResult<&Vec<f32>> {
108            values[id.index()]
109                .as_ref()
110                .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not computed", id)))
111        };
112
113        match &node.op {
114            Op::Input { .. } => {
115                // Already set
116                Ok(values[node.id.index()].clone().unwrap_or_default())
117            }
118
119            Op::Output { input, .. } => Ok(get(*input)?.clone()),
120
121            Op::Constant { value } => {
122                let numel = node.shape.numel();
123                Ok(vec![*value as f32; numel])
124            }
125
126            // Binary ops
127            Op::Add { lhs, rhs } => {
128                let a = get(*lhs)?;
129                let b = get(*rhs)?;
130                Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
131            }
132
133            Op::Sub { lhs, rhs } => {
134                let a = get(*lhs)?;
135                let b = get(*rhs)?;
136                Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
137            }
138
139            Op::Mul { lhs, rhs } => {
140                let a = get(*lhs)?;
141                let b = get(*rhs)?;
142                Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).collect())
143            }
144
145            Op::Div { lhs, rhs } => {
146                let a = get(*lhs)?;
147                let b = get(*rhs)?;
148                Ok(a.iter().zip(b.iter()).map(|(x, y)| x / y).collect())
149            }
150
151            Op::Pow { base, exp } => {
152                let a = get(*base)?;
153                let b = get(*exp)?;
154                Ok(a.iter().zip(b.iter()).map(|(x, y)| x.powf(*y)).collect())
155            }
156
157            Op::Max { lhs, rhs } => {
158                let a = get(*lhs)?;
159                let b = get(*rhs)?;
160                Ok(a.iter().zip(b.iter()).map(|(x, y)| x.max(*y)).collect())
161            }
162
163            Op::Min { lhs, rhs } => {
164                let a = get(*lhs)?;
165                let b = get(*rhs)?;
166                Ok(a.iter().zip(b.iter()).map(|(x, y)| x.min(*y)).collect())
167            }
168
169            // Scalar ops
170            Op::AddScalar { input, scalar } => {
171                let a = get(*input)?;
172                Ok(a.iter().map(|x| x + *scalar as f32).collect())
173            }
174
175            Op::MulScalar { input, scalar } => {
176                let a = get(*input)?;
177                Ok(a.iter().map(|x| x * *scalar as f32).collect())
178            }
179
180            // Unary ops
181            Op::Neg { input } => {
182                let a = get(*input)?;
183                Ok(a.iter().map(|x| -x).collect())
184            }
185
186            Op::Abs { input } => {
187                let a = get(*input)?;
188                Ok(a.iter().map(|x| x.abs()).collect())
189            }
190
191            Op::Sqrt { input } => {
192                let a = get(*input)?;
193                Ok(a.iter().map(|x| x.sqrt()).collect())
194            }
195
196            Op::Exp { input } => {
197                let a = get(*input)?;
198                Ok(a.iter().map(|x| x.exp()).collect())
199            }
200
201            Op::Log { input } => {
202                let a = get(*input)?;
203                Ok(a.iter().map(|x| x.ln()).collect())
204            }
205
206            Op::Sin { input } => {
207                let a = get(*input)?;
208                Ok(a.iter().map(|x| x.sin()).collect())
209            }
210
211            Op::Cos { input } => {
212                let a = get(*input)?;
213                Ok(a.iter().map(|x| x.cos()).collect())
214            }
215
216            Op::Tanh { input } => {
217                let a = get(*input)?;
218                Ok(a.iter().map(|x| x.tanh()).collect())
219            }
220
221            // Activations
222            Op::Relu { input } => {
223                let a = get(*input)?;
224                Ok(a.iter().map(|x| x.max(0.0)).collect())
225            }
226
227            Op::Sigmoid { input } => {
228                let a = get(*input)?;
229                Ok(a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect())
230            }
231
232            Op::Gelu { input } => {
233                let a = get(*input)?;
234                const SQRT_2_OVER_PI: f32 = 0.7978845608;
235                Ok(a.iter()
236                    .map(|x| 0.5 * x * (1.0 + (SQRT_2_OVER_PI * (x + 0.044715 * x.powi(3))).tanh()))
237                    .collect())
238            }
239
240            Op::Silu { input } => {
241                let a = get(*input)?;
242                Ok(a.iter().map(|x| x / (1.0 + (-x).exp())).collect())
243            }
244
245            // Reductions
246            Op::Sum { input } => {
247                let a = get(*input)?;
248                Ok(vec![a.iter().sum()])
249            }
250
251            Op::Mean { input } => {
252                let a = get(*input)?;
253                let sum: f32 = a.iter().sum();
254                Ok(vec![sum / a.len() as f32])
255            }
256
257            Op::SumAxis {
258                input,
259                axis,
260                keepdim,
261            } => {
262                // Simplified: just sum all for now
263                let a = get(*input)?;
264                let input_node = self.graph.node(*input);
265                let input_shape = input_node.shape.dims();
266
267                reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)
268            }
269
270            Op::MeanAxis {
271                input,
272                axis,
273                keepdim,
274            } => {
275                let a = get(*input)?;
276                let input_node = self.graph.node(*input);
277                let input_shape = input_node.shape.dims();
278                let axis_size = input_shape[normalize_axis(*axis, input_shape.len())];
279
280                let sum = reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)?;
281                Ok(sum.iter().map(|x| x / axis_size as f32).collect())
282            }
283
284            Op::MaxAxis {
285                input,
286                axis,
287                keepdim,
288            } => {
289                let a = get(*input)?;
290                let input_node = self.graph.node(*input);
291                let input_shape = input_node.shape.dims();
292
293                reduce_axis(a, input_shape, *axis, *keepdim, f32::max, f32::NEG_INFINITY)
294            }
295
296            // Shape ops - for interpreter, just pass through
297            Op::Reshape { input, .. }
298            | Op::Transpose { input, .. }
299            | Op::Squeeze { input, .. }
300            | Op::Unsqueeze { input, .. }
301            | Op::Broadcast { input, .. }
302            | Op::Contiguous { input } => Ok(get(*input)?.clone()),
303
304            // Matrix multiplication
305            Op::MatMul { lhs, rhs } => {
306                let a = get(*lhs)?;
307                let b = get(*rhs)?;
308                let lhs_node = self.graph.node(*lhs);
309                let rhs_node = self.graph.node(*rhs);
310
311                let lhs_shape = lhs_node.shape.dims();
312                let rhs_shape = rhs_node.shape.dims();
313
314                matmul_impl(a, b, lhs_shape, rhs_shape)
315            }
316
317            // Comparisons
318            Op::Gt { lhs, rhs } => {
319                let a = get(*lhs)?;
320                let b = get(*rhs)?;
321                Ok(a.iter()
322                    .zip(b.iter())
323                    .map(|(x, y)| if x > y { 1.0 } else { 0.0 })
324                    .collect())
325            }
326
327            Op::Lt { lhs, rhs } => {
328                let a = get(*lhs)?;
329                let b = get(*rhs)?;
330                Ok(a.iter()
331                    .zip(b.iter())
332                    .map(|(x, y)| if x < y { 1.0 } else { 0.0 })
333                    .collect())
334            }
335
336            Op::Eq { lhs, rhs } => {
337                let a = get(*lhs)?;
338                let b = get(*rhs)?;
339                Ok(a.iter()
340                    .zip(b.iter())
341                    .map(|(x, y)| {
342                        if (x - y).abs() < f32::EPSILON {
343                            1.0
344                        } else {
345                            0.0
346                        }
347                    })
348                    .collect())
349            }
350
351            Op::Where { condition, x, y } => {
352                let cond = get(*condition)?;
353                let a = get(*x)?;
354                let b = get(*y)?;
355                Ok(cond
356                    .iter()
357                    .zip(a.iter().zip(b.iter()))
358                    .map(|(c, (a, b))| if *c != 0.0 { *a } else { *b })
359                    .collect())
360            }
361
362            Op::Cast { input, .. } => {
363                // For f32, just pass through
364                Ok(get(*input)?.clone())
365            }
366        }
367    }
368}
369
370fn normalize_axis(axis: i32, ndim: usize) -> usize {
371    if axis < 0 {
372        (ndim as i32 + axis) as usize
373    } else {
374        axis as usize
375    }
376}
377
378fn reduce_axis(
379    data: &[f32],
380    shape: &[usize],
381    axis: i32,
382    keepdim: bool,
383    op: fn(f32, f32) -> f32,
384    init: f32,
385) -> JitResult<Vec<f32>> {
386    let axis = normalize_axis(axis, shape.len());
387
388    // Compute strides
389    let mut strides = vec![1usize; shape.len()];
390    for i in (0..shape.len() - 1).rev() {
391        strides[i] = strides[i + 1] * shape[i + 1];
392    }
393
394    // Compute output shape
395    let mut output_shape: Vec<usize> = shape.to_vec();
396    if keepdim {
397        output_shape[axis] = 1;
398    } else {
399        output_shape.remove(axis);
400    }
401
402    let output_numel: usize = output_shape.iter().product();
403    let mut result = vec![init; output_numel];
404
405    // Reduce
406    for i in 0..data.len() {
407        // Convert linear index to multi-index
408        let mut multi_idx = vec![0usize; shape.len()];
409        let mut idx = i;
410        for d in 0..shape.len() {
411            multi_idx[d] = idx / strides[d];
412            idx %= strides[d];
413        }
414
415        // Compute output index
416        let out_idx = if keepdim {
417            let mut out_idx = 0;
418            let mut temp_strides = vec![1usize; output_shape.len()];
419            for d in (0..output_shape.len() - 1).rev() {
420                temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
421            }
422            for d in 0..output_shape.len() {
423                let dim_idx = if d == axis { 0 } else { multi_idx[d] };
424                out_idx += dim_idx * temp_strides[d];
425            }
426            out_idx
427        } else {
428            let mut out_idx = 0;
429            let mut temp_strides = vec![1usize; output_shape.len()];
430            if !output_shape.is_empty() {
431                for d in (0..output_shape.len() - 1).rev() {
432                    temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
433                }
434            }
435            let mut out_d = 0;
436            for d in 0..shape.len() {
437                if d == axis {
438                    continue;
439                }
440                if out_d < temp_strides.len() {
441                    out_idx += multi_idx[d] * temp_strides[out_d];
442                }
443                out_d += 1;
444            }
445            out_idx
446        };
447
448        if out_idx < result.len() {
449            result[out_idx] = op(result[out_idx], data[i]);
450        }
451    }
452
453    Ok(result)
454}
455
456fn matmul_impl(a: &[f32], b: &[f32], a_shape: &[usize], b_shape: &[usize]) -> JitResult<Vec<f32>> {
457    // Simple 2D matmul
458    if a_shape.len() != 2 || b_shape.len() != 2 {
459        return Err(JitError::UnsupportedOp(
460            "Only 2D matmul supported in interpreter".to_string(),
461        ));
462    }
463
464    let m = a_shape[0];
465    let k = a_shape[1];
466    let n = b_shape[1];
467
468    if k != b_shape[0] {
469        return Err(JitError::ShapeMismatch {
470            expected: vec![k],
471            found: vec![b_shape[0]],
472        });
473    }
474
475    let mut result = vec![0.0f32; m * n];
476
477    for i in 0..m {
478        for j in 0..n {
479            let mut sum = 0.0;
480            for p in 0..k {
481                sum += a[i * k + p] * b[p * n + j];
482            }
483            result[i * n + j] = sum;
484        }
485    }
486
487    Ok(result)
488}
489
490/// JIT compiler.
491pub struct JitCompiler {
492    optimizer: Optimizer,
493    cache: FunctionCache,
494    use_native: bool,
495}
496
497impl JitCompiler {
498    /// Creates a new JIT compiler.
499    pub fn new() -> Self {
500        Self {
501            optimizer: Optimizer::default_passes(),
502            cache: FunctionCache::default_size(),
503            use_native: false, // Fallback to interpreter for now
504        }
505    }
506
507    /// Creates a compiler with custom optimizer.
508    pub fn with_optimizer(optimizer: Optimizer) -> Self {
509        Self {
510            optimizer,
511            cache: FunctionCache::default_size(),
512            use_native: false,
513        }
514    }
515
516    /// Compiles a graph into an executable function.
517    pub fn compile(&self, graph: &Graph) -> JitResult<CompiledFunction> {
518        // Check cache
519        let cache_key = FunctionCache::hash_graph(graph);
520        if let Some(cached) = self.cache.get(cache_key) {
521            return Ok(cached);
522        }
523
524        // Validate graph
525        graph.validate().map_err(JitError::InvalidGraph)?;
526
527        // Optimize
528        let optimized = self.optimizer.optimize(graph.clone());
529
530        // Generate code
531        let func = if self.use_native {
532            self.compile_native(&optimized)?
533        } else {
534            self.compile_interpreted(&optimized)
535        };
536
537        // Cache result
538        self.cache.insert(cache_key, func.clone());
539
540        Ok(func)
541    }
542
543    fn compile_interpreted(&self, graph: &Graph) -> CompiledFunction {
544        CompiledFunction {
545            graph: Arc::new(graph.clone()),
546            kind: CompiledKind::Interpreted,
547        }
548    }
549
550    fn compile_native(&self, graph: &Graph) -> JitResult<CompiledFunction> {
551        use cranelift::prelude::*;
552        use cranelift_jit::{JITBuilder, JITModule};
553        use cranelift_module::{Linkage, Module};
554
555        // Initialize Cranelift JIT module
556        let mut flag_builder = settings::builder();
557        flag_builder.set("use_colocated_libcalls", "false").unwrap();
558        flag_builder.set("is_pic", "false").unwrap();
559        let isa_builder = cranelift_native::builder()
560            .map_err(|e| JitError::CompilationFailed(format!("Failed to get native ISA: {}", e)))?;
561        let isa = isa_builder
562            .finish(settings::Flags::new(flag_builder))
563            .map_err(|e| JitError::CompilationFailed(format!("Failed to build ISA: {}", e)))?;
564
565        let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
566        let mut module = JITModule::new(builder);
567
568        // Create function signature: fn(inputs: *const f32, outputs: *mut f32)
569        let mut sig = module.make_signature();
570        sig.params.push(AbiParam::new(types::I64)); // input ptr
571        sig.params.push(AbiParam::new(types::I64)); // output ptr
572
573        let func_id = module
574            .declare_function("jit_kernel", Linkage::Export, &sig)
575            .map_err(|e| {
576                JitError::CompilationFailed(format!("Failed to declare function: {}", e))
577            })?;
578
579        let mut ctx = module.make_context();
580        ctx.func.signature = sig;
581
582        // Build function body
583        let mut builder_ctx = FunctionBuilderContext::new();
584        {
585            let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_ctx);
586            let entry_block = builder.create_block();
587            builder.append_block_params_for_function_params(entry_block);
588            builder.switch_to_block(entry_block);
589            builder.seal_block(entry_block);
590
591            let input_ptr = builder.block_params(entry_block)[0];
592            let output_ptr = builder.block_params(entry_block)[1];
593
594            // Generate code for each operation in the graph
595            let mut values: Vec<Option<Value>> = vec![None; graph.len()];
596
597            for node in graph.nodes() {
598                let result = self.codegen_node(&mut builder, node, &values, input_ptr)?;
599                values[node.id.index()] = Some(result);
600            }
601
602            // Store output
603            if let Some((_, output_id)) = graph.outputs().iter().next() {
604                let output_node = graph.node(*output_id);
605                if let Op::Output { input, .. } = &output_node.op {
606                    if let Some(val) = values[input.index()] {
607                        builder.ins().store(MemFlags::new(), val, output_ptr, 0);
608                    }
609                }
610            }
611
612            builder.ins().return_(&[]);
613            builder.finalize();
614        }
615
616        // Compile the function
617        module.define_function(func_id, &mut ctx).map_err(|e| {
618            JitError::CompilationFailed(format!("Failed to define function: {}", e))
619        })?;
620        module.clear_context(&mut ctx);
621        module
622            .finalize_definitions()
623            .map_err(|e| JitError::CompilationFailed(format!("Failed to finalize: {:?}", e)))?;
624
625        let code_ptr = module.get_finalized_function(func_id);
626        let code_size = 0; // JITModule manages memory
627
628        // Leak the module to keep the code alive
629        std::mem::forget(module);
630
631        Ok(CompiledFunction {
632            graph: Arc::new(graph.clone()),
633            kind: CompiledKind::Native {
634                code_ptr: code_ptr as *const u8,
635                code_size,
636            },
637        })
638    }
639
640    fn codegen_node(
641        &self,
642        builder: &mut cranelift::prelude::FunctionBuilder,
643        node: &Node,
644        values: &[Option<cranelift::prelude::Value>],
645        input_ptr: cranelift::prelude::Value,
646    ) -> JitResult<cranelift::prelude::Value> {
647        use cranelift::prelude::*;
648
649        let get = |id: NodeId| -> JitResult<Value> {
650            values[id.index()]
651                .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not compiled", id)))
652        };
653
654        match &node.op {
655            Op::Input { name, .. } => {
656                // Load from input pointer at appropriate offset
657                let offset = self.get_input_offset(name);
658                Ok(builder
659                    .ins()
660                    .load(types::F32, MemFlags::new(), input_ptr, offset))
661            }
662
663            Op::Output { input, .. } => get(*input),
664
665            Op::Constant { value } => Ok(builder.ins().f32const(*value as f32)),
666
667            Op::Add { lhs, rhs } => {
668                let a = get(*lhs)?;
669                let b = get(*rhs)?;
670                Ok(builder.ins().fadd(a, b))
671            }
672
673            Op::Sub { lhs, rhs } => {
674                let a = get(*lhs)?;
675                let b = get(*rhs)?;
676                Ok(builder.ins().fsub(a, b))
677            }
678
679            Op::Mul { lhs, rhs } => {
680                let a = get(*lhs)?;
681                let b = get(*rhs)?;
682                Ok(builder.ins().fmul(a, b))
683            }
684
685            Op::Div { lhs, rhs } => {
686                let a = get(*lhs)?;
687                let b = get(*rhs)?;
688                Ok(builder.ins().fdiv(a, b))
689            }
690
691            Op::Neg { input } => {
692                let a = get(*input)?;
693                Ok(builder.ins().fneg(a))
694            }
695
696            Op::Abs { input } => {
697                let a = get(*input)?;
698                Ok(builder.ins().fabs(a))
699            }
700
701            Op::Sqrt { input } => {
702                let a = get(*input)?;
703                Ok(builder.ins().sqrt(a))
704            }
705
706            Op::AddScalar { input, scalar } => {
707                let a = get(*input)?;
708                let s = builder.ins().f32const(*scalar as f32);
709                Ok(builder.ins().fadd(a, s))
710            }
711
712            Op::MulScalar { input, scalar } => {
713                let a = get(*input)?;
714                let s = builder.ins().f32const(*scalar as f32);
715                Ok(builder.ins().fmul(a, s))
716            }
717
718            // For operations not easily supported by Cranelift scalars,
719            // fall back to interpreted execution for the whole graph
720            _ => Err(JitError::UnsupportedOp(format!(
721                "Operation {:?} not supported in native codegen, using interpreter",
722                node.op
723            ))),
724        }
725    }
726
727    fn get_input_offset(&self, _name: &str) -> i32 {
728        // Simple offset calculation - in practice would use a mapping
729        0
730    }
731
732    /// Returns cache statistics.
733    pub fn cache_stats(&self) -> crate::cache::CacheStats {
734        self.cache.stats()
735    }
736
737    /// Clears the compilation cache.
738    pub fn clear_cache(&self) {
739        self.cache.clear();
740    }
741}
742
743impl Default for JitCompiler {
744    fn default() -> Self {
745        Self::new()
746    }
747}
748
749#[cfg(test)]
750mod tests {
751    use super::*;
752    use crate::trace::trace;
753
754    #[test]
755    fn test_compile_simple() {
756        let graph = trace(|tracer| {
757            let a = tracer.input("a", &[4]);
758            let b = tracer.input("b", &[4]);
759            let c = a.add(&b);
760            tracer.output("result", c)
761        });
762
763        let compiler = JitCompiler::new();
764        let func = compiler.compile(&graph).unwrap();
765
766        let a = [1.0, 2.0, 3.0, 4.0];
767        let b = [5.0, 6.0, 7.0, 8.0];
768        let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
769
770        assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
771    }
772
773    #[test]
774    fn test_compile_chain() {
775        let graph = trace(|tracer| {
776            let x = tracer.input("x", &[4]);
777            let y = x.relu().mul_scalar(2.0).add_scalar(1.0);
778            tracer.output("y", y)
779        });
780
781        let compiler = JitCompiler::new();
782        let func = compiler.compile(&graph).unwrap();
783
784        let x = [-1.0, 0.0, 1.0, 2.0];
785        let result = func.run(&[("x", &x)]).unwrap();
786
787        // relu([-1, 0, 1, 2]) = [0, 0, 1, 2]
788        // * 2 = [0, 0, 2, 4]
789        // + 1 = [1, 1, 3, 5]
790        assert_eq!(result, vec![1.0, 1.0, 3.0, 5.0]);
791    }
792
793    #[test]
794    fn test_compile_activations() {
795        let graph = trace(|tracer| {
796            let x = tracer.input("x", &[3]);
797            let y = x.sigmoid();
798            tracer.output("y", y)
799        });
800
801        let compiler = JitCompiler::new();
802        let func = compiler.compile(&graph).unwrap();
803
804        let x = [0.0, 1.0, -1.0];
805        let result = func.run(&[("x", &x)]).unwrap();
806
807        // sigmoid(0) = 0.5
808        assert!((result[0] - 0.5).abs() < 0.01);
809        // sigmoid(1) ≈ 0.731
810        assert!((result[1] - 0.731).abs() < 0.01);
811    }
812
813    #[test]
814    fn test_compile_matmul() {
815        let graph = trace(|tracer| {
816            let a = tracer.input("a", &[2, 3]);
817            let b = tracer.input("b", &[3, 2]);
818            let c = a.matmul(&b);
819            tracer.output("c", c)
820        });
821
822        let compiler = JitCompiler::new();
823        let func = compiler.compile(&graph).unwrap();
824
825        // Identity-like matrices
826        let a = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]; // 2x3
827        let b = [1.0, 0.0, 0.0, 1.0, 0.0, 0.0]; // 3x2
828        let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
829
830        assert_eq!(result.len(), 4); // 2x2
831    }
832
833    #[test]
834    fn test_caching() {
835        let graph = trace(|tracer| {
836            let x = tracer.input("x", &[4]);
837            tracer.output("y", x.relu())
838        });
839
840        let compiler = JitCompiler::new();
841        assert_eq!(compiler.cache_stats().entries, 0);
842
843        let _ = compiler.compile(&graph).unwrap();
844        assert_eq!(compiler.cache_stats().entries, 1);
845
846        // Second compile should use cache
847        let _ = compiler.compile(&graph).unwrap();
848        assert_eq!(compiler.cache_stats().entries, 1);
849    }
850}