1use std::sync::Arc;
47
48use crate::cache::FunctionCache;
49use crate::error::{JitError, JitResult};
50use crate::ir::{Graph, Node, NodeId, Op};
51use crate::optimize::Optimizer;
52
53#[derive(Clone)]
59pub struct CompiledFunction {
60 graph: Arc<Graph>,
62 kind: CompiledKind,
64}
65
66#[derive(Clone)]
67enum CompiledKind {
68 Interpreted,
70 Native {
72 code_ptr: *const u8,
74 code_size: usize,
76 },
77}
78
79unsafe impl Send for CompiledKind {}
81unsafe impl Sync for CompiledKind {}
82
83impl CompiledFunction {
84 pub fn placeholder() -> Self {
86 Self {
87 graph: Arc::new(Graph::new()),
88 kind: CompiledKind::Interpreted,
89 }
90 }
91
92 pub fn graph(&self) -> &Graph {
94 &self.graph
95 }
96
97 pub fn run(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
103 match &self.kind {
104 CompiledKind::Interpreted => self.run_interpreted(inputs),
105 CompiledKind::Native {
106 code_ptr,
107 code_size,
108 } => {
109 unsafe {
112 let func: extern "C" fn(*const f32, *mut f32) = std::mem::transmute(code_ptr);
113 let flat_inputs: Vec<f32> =
114 inputs.iter().flat_map(|(_, d)| d.iter().copied()).collect();
115 let output_size: usize = self
117 .graph
118 .outputs()
119 .values()
120 .map(|id| self.graph.node(*id).shape.numel())
121 .sum();
122 let mut output = vec![0.0f32; output_size];
123 func(flat_inputs.as_ptr(), output.as_mut_ptr());
124 let _ = code_size; Ok(output)
126 }
127 }
128 }
129 }
130
131 fn run_interpreted(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
137 let mut values: Vec<Option<Vec<f32>>> = vec![None; self.graph.len()];
138
139 for (name, data) in inputs {
141 if let Some(id) = self.graph.input(name) {
142 values[id.index()] = Some(data.to_vec());
143 } else {
144 return Err(JitError::InputNotFound(name.to_string()));
145 }
146 }
147
148 for node in self.graph.nodes() {
150 let result = self.eval_node(node, &values)?;
151 values[node.id.index()] = Some(result);
152 }
153
154 if let Some((_, output_id)) = self.graph.outputs().iter().next() {
156 let output_node = self.graph.node(*output_id);
157 if let Op::Output { input, .. } = &output_node.op {
158 return Ok(values[input.index()].clone().unwrap_or_default());
159 }
160 }
161
162 Err(JitError::OutputNotFound("no output".to_string()))
163 }
164
165 fn eval_node(&self, node: &Node, values: &[Option<Vec<f32>>]) -> JitResult<Vec<f32>> {
166 let get = |id: NodeId| -> JitResult<&Vec<f32>> {
167 values[id.index()]
168 .as_ref()
169 .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not computed", id)))
170 };
171
172 match &node.op {
173 Op::Input { .. } => {
174 Ok(values[node.id.index()].clone().unwrap_or_default())
176 }
177
178 Op::Output { input, .. } => Ok(get(*input)?.clone()),
179
180 Op::Constant { value } => {
181 let numel = node.shape.numel();
182 Ok(vec![*value as f32; numel])
183 }
184
185 Op::Add { lhs, rhs } => {
187 let a = get(*lhs)?;
188 let b = get(*rhs)?;
189 Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
190 }
191
192 Op::Sub { lhs, rhs } => {
193 let a = get(*lhs)?;
194 let b = get(*rhs)?;
195 Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
196 }
197
198 Op::Mul { lhs, rhs } => {
199 let a = get(*lhs)?;
200 let b = get(*rhs)?;
201 Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).collect())
202 }
203
204 Op::Div { lhs, rhs } => {
205 let a = get(*lhs)?;
206 let b = get(*rhs)?;
207 Ok(a.iter().zip(b.iter()).map(|(x, y)| x / y).collect())
208 }
209
210 Op::Pow { base, exp } => {
211 let a = get(*base)?;
212 let b = get(*exp)?;
213 Ok(a.iter().zip(b.iter()).map(|(x, y)| x.powf(*y)).collect())
214 }
215
216 Op::Max { lhs, rhs } => {
217 let a = get(*lhs)?;
218 let b = get(*rhs)?;
219 Ok(a.iter().zip(b.iter()).map(|(x, y)| x.max(*y)).collect())
220 }
221
222 Op::Min { lhs, rhs } => {
223 let a = get(*lhs)?;
224 let b = get(*rhs)?;
225 Ok(a.iter().zip(b.iter()).map(|(x, y)| x.min(*y)).collect())
226 }
227
228 Op::AddScalar { input, scalar } => {
230 let a = get(*input)?;
231 Ok(a.iter().map(|x| x + *scalar as f32).collect())
232 }
233
234 Op::MulScalar { input, scalar } => {
235 let a = get(*input)?;
236 Ok(a.iter().map(|x| x * *scalar as f32).collect())
237 }
238
239 Op::Neg { input } => {
241 let a = get(*input)?;
242 Ok(a.iter().map(|x| -x).collect())
243 }
244
245 Op::Abs { input } => {
246 let a = get(*input)?;
247 Ok(a.iter().map(|x| x.abs()).collect())
248 }
249
250 Op::Sqrt { input } => {
251 let a = get(*input)?;
252 Ok(a.iter().map(|x| x.sqrt()).collect())
253 }
254
255 Op::Exp { input } => {
256 let a = get(*input)?;
257 Ok(a.iter().map(|x| x.exp()).collect())
258 }
259
260 Op::Log { input } => {
261 let a = get(*input)?;
262 Ok(a.iter().map(|x| x.ln()).collect())
263 }
264
265 Op::Sin { input } => {
266 let a = get(*input)?;
267 Ok(a.iter().map(|x| x.sin()).collect())
268 }
269
270 Op::Cos { input } => {
271 let a = get(*input)?;
272 Ok(a.iter().map(|x| x.cos()).collect())
273 }
274
275 Op::Tanh { input } => {
276 let a = get(*input)?;
277 Ok(a.iter().map(|x| x.tanh()).collect())
278 }
279
280 Op::Relu { input } => {
282 let a = get(*input)?;
283 Ok(a.iter().map(|x| x.max(0.0)).collect())
284 }
285
286 Op::Sigmoid { input } => {
287 let a = get(*input)?;
288 Ok(a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect())
289 }
290
291 Op::Gelu { input } => {
292 let a = get(*input)?;
293 const SQRT_2_OVER_PI: f32 = 0.797_884_6;
294 Ok(a.iter()
295 .map(|x| 0.5 * x * (1.0 + (SQRT_2_OVER_PI * (x + 0.044715 * x.powi(3))).tanh()))
296 .collect())
297 }
298
299 Op::Silu { input } => {
300 let a = get(*input)?;
301 Ok(a.iter().map(|x| x / (1.0 + (-x).exp())).collect())
302 }
303
304 Op::Sum { input } => {
306 let a = get(*input)?;
307 Ok(vec![a.iter().sum()])
308 }
309
310 Op::Mean { input } => {
311 let a = get(*input)?;
312 let sum: f32 = a.iter().sum();
313 Ok(vec![sum / a.len() as f32])
314 }
315
316 Op::SumAxis {
317 input,
318 axis,
319 keepdim,
320 } => {
321 let a = get(*input)?;
323 let input_node = self.graph.node(*input);
324 let input_shape = input_node.shape.dims();
325
326 reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)
327 }
328
329 Op::MeanAxis {
330 input,
331 axis,
332 keepdim,
333 } => {
334 let a = get(*input)?;
335 let input_node = self.graph.node(*input);
336 let input_shape = input_node.shape.dims();
337 let axis_size = input_shape[normalize_axis(*axis, input_shape.len())];
338
339 let sum = reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)?;
340 Ok(sum.iter().map(|x| x / axis_size as f32).collect())
341 }
342
343 Op::MaxAxis {
344 input,
345 axis,
346 keepdim,
347 } => {
348 let a = get(*input)?;
349 let input_node = self.graph.node(*input);
350 let input_shape = input_node.shape.dims();
351
352 reduce_axis(a, input_shape, *axis, *keepdim, f32::max, f32::NEG_INFINITY)
353 }
354
355 Op::Reshape { input, .. }
357 | Op::Squeeze { input, .. }
358 | Op::Unsqueeze { input, .. }
359 | Op::Broadcast { input, .. }
360 | Op::Contiguous { input } => Ok(get(*input)?.clone()),
361
362 Op::Transpose { input, dim0, dim1 } => {
364 let a = get(*input)?;
365 let input_shape = &self.graph.node(*input).shape;
366 let ndim = input_shape.dims().len();
367 if ndim < 2 || *dim0 >= ndim || *dim1 >= ndim || dim0 == dim1 {
368 return Ok(a.clone());
369 }
370 let dims = input_shape.dims();
372 let mut perm: Vec<usize> = (0..ndim).collect();
373 perm.swap(*dim0, *dim1);
374 let new_shape: Vec<usize> = perm.iter().map(|&d| dims[d]).collect();
375 let numel: usize = dims.iter().product();
376 let mut result = vec![0.0f32; numel];
377
378 let mut in_strides = vec![1usize; ndim];
380 for d in (0..ndim - 1).rev() {
381 in_strides[d] = in_strides[d + 1] * dims[d + 1];
382 }
383 let mut out_strides = vec![1usize; ndim];
385 for d in (0..ndim - 1).rev() {
386 out_strides[d] = out_strides[d + 1] * new_shape[d + 1];
387 }
388
389 #[allow(clippy::needless_range_loop)]
390 for flat in 0..numel {
391 let mut remaining = flat;
393 let mut out_idx = vec![0usize; ndim];
394 for d in 0..ndim {
395 out_idx[d] = remaining / out_strides[d];
396 remaining %= out_strides[d];
397 }
398 let mut in_flat = 0;
400 for d in 0..ndim {
401 in_flat += out_idx[d] * in_strides[perm[d]];
402 }
403 result[flat] = a[in_flat];
404 }
405 Ok(result)
406 }
407
408 Op::MatMul { lhs, rhs } => {
410 let a = get(*lhs)?;
411 let b = get(*rhs)?;
412 let lhs_node = self.graph.node(*lhs);
413 let rhs_node = self.graph.node(*rhs);
414
415 let lhs_shape = lhs_node.shape.dims();
416 let rhs_shape = rhs_node.shape.dims();
417
418 matmul_impl(a, b, lhs_shape, rhs_shape)
419 }
420
421 Op::Gt { lhs, rhs } => {
423 let a = get(*lhs)?;
424 let b = get(*rhs)?;
425 Ok(a.iter()
426 .zip(b.iter())
427 .map(|(x, y)| if x > y { 1.0 } else { 0.0 })
428 .collect())
429 }
430
431 Op::Lt { lhs, rhs } => {
432 let a = get(*lhs)?;
433 let b = get(*rhs)?;
434 Ok(a.iter()
435 .zip(b.iter())
436 .map(|(x, y)| if x < y { 1.0 } else { 0.0 })
437 .collect())
438 }
439
440 Op::Eq { lhs, rhs } => {
441 let a = get(*lhs)?;
442 let b = get(*rhs)?;
443 Ok(a.iter()
444 .zip(b.iter())
445 .map(|(x, y)| {
446 if (x - y).abs() < f32::EPSILON {
447 1.0
448 } else {
449 0.0
450 }
451 })
452 .collect())
453 }
454
455 Op::Where { condition, x, y } => {
456 let cond = get(*condition)?;
457 let a = get(*x)?;
458 let b = get(*y)?;
459 Ok(cond
460 .iter()
461 .zip(a.iter().zip(b.iter()))
462 .map(|(c, (a, b))| if *c != 0.0 { *a } else { *b })
463 .collect())
464 }
465
466 Op::Cast { input, .. } => {
467 Ok(get(*input)?.clone())
469 }
470 }
471 }
472}
473
474fn normalize_axis(axis: i32, ndim: usize) -> usize {
479 if axis < 0 {
480 (ndim as i32 + axis) as usize
481 } else {
482 axis as usize
483 }
484}
485
486fn reduce_axis(
487 data: &[f32],
488 shape: &[usize],
489 axis: i32,
490 keepdim: bool,
491 op: fn(f32, f32) -> f32,
492 init: f32,
493) -> JitResult<Vec<f32>> {
494 let axis = normalize_axis(axis, shape.len());
495
496 let mut strides = vec![1usize; shape.len()];
498 for i in (0..shape.len() - 1).rev() {
499 strides[i] = strides[i + 1] * shape[i + 1];
500 }
501
502 let mut output_shape: Vec<usize> = shape.to_vec();
504 if keepdim {
505 output_shape[axis] = 1;
506 } else {
507 output_shape.remove(axis);
508 }
509
510 let output_numel: usize = output_shape.iter().product();
511 let mut result = vec![init; output_numel];
512
513 for (i, &val) in data.iter().enumerate() {
515 let mut multi_idx = vec![0usize; shape.len()];
517 let mut idx = i;
518 for (d, &st) in strides.iter().enumerate() {
519 multi_idx[d] = idx / st;
520 idx %= st;
521 }
522
523 let out_idx = if keepdim {
525 let mut out_idx = 0;
526 let mut temp_strides = vec![1usize; output_shape.len()];
527 for d in (0..output_shape.len() - 1).rev() {
528 temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
529 }
530 for d in 0..output_shape.len() {
531 let dim_idx = if d == axis { 0 } else { multi_idx[d] };
532 out_idx += dim_idx * temp_strides[d];
533 }
534 out_idx
535 } else {
536 let mut out_idx = 0;
537 let mut temp_strides = vec![1usize; output_shape.len()];
538 if !output_shape.is_empty() {
539 for d in (0..output_shape.len() - 1).rev() {
540 temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
541 }
542 }
543 let mut out_d = 0;
544 for (d, &mi) in multi_idx.iter().enumerate().take(shape.len()) {
545 if d == axis {
546 continue;
547 }
548 if out_d < temp_strides.len() {
549 out_idx += mi * temp_strides[out_d];
550 }
551 out_d += 1;
552 }
553 out_idx
554 };
555
556 if out_idx < result.len() {
557 result[out_idx] = op(result[out_idx], val);
558 }
559 }
560
561 Ok(result)
562}
563
564fn matmul_impl(a: &[f32], b: &[f32], a_shape: &[usize], b_shape: &[usize]) -> JitResult<Vec<f32>> {
565 if a_shape.len() != 2 || b_shape.len() != 2 {
567 return Err(JitError::UnsupportedOp(
568 "Only 2D matmul supported in interpreter".to_string(),
569 ));
570 }
571
572 let m = a_shape[0];
573 let k = a_shape[1];
574 let n = b_shape[1];
575
576 if k != b_shape[0] {
577 return Err(JitError::ShapeMismatch {
578 expected: vec![k],
579 found: vec![b_shape[0]],
580 });
581 }
582
583 let mut result = vec![0.0f32; m * n];
584
585 for i in 0..m {
586 for j in 0..n {
587 let mut sum = 0.0;
588 for p in 0..k {
589 sum += a[i * k + p] * b[p * n + j];
590 }
591 result[i * n + j] = sum;
592 }
593 }
594
595 Ok(result)
596}
597
598pub struct JitCompiler {
604 optimizer: Optimizer,
605 cache: FunctionCache,
606 use_native: bool,
607}
608
609impl JitCompiler {
610 pub fn new() -> Self {
612 Self {
613 optimizer: Optimizer::default_passes(),
614 cache: FunctionCache::default_size(),
615 use_native: false, }
617 }
618
619 pub fn with_optimizer(optimizer: Optimizer) -> Self {
621 Self {
622 optimizer,
623 cache: FunctionCache::default_size(),
624 use_native: false,
625 }
626 }
627
628 pub fn enable_native(&mut self, enable: bool) {
633 self.use_native = enable;
634 }
635
636 pub fn compile(&self, graph: &Graph) -> JitResult<CompiledFunction> {
642 let cache_key = FunctionCache::hash_graph(graph);
644 if let Some(cached) = self.cache.get(cache_key) {
645 return Ok(cached);
646 }
647
648 graph.validate().map_err(JitError::InvalidGraph)?;
650
651 let optimized = self.optimizer.optimize(graph.clone());
653
654 let func = if self.use_native {
656 self.compile_native(&optimized)?
657 } else {
658 self.compile_interpreted(&optimized)
659 };
660
661 self.cache.insert(cache_key, func.clone());
663
664 Ok(func)
665 }
666
667 fn compile_interpreted(&self, graph: &Graph) -> CompiledFunction {
668 CompiledFunction {
669 graph: Arc::new(graph.clone()),
670 kind: CompiledKind::Interpreted,
671 }
672 }
673
674 fn compile_native(&self, graph: &Graph) -> JitResult<CompiledFunction> {
679 use cranelift::prelude::*;
680 use cranelift_jit::{JITBuilder, JITModule};
681 use cranelift_module::{Linkage, Module};
682
683 let mut flag_builder = settings::builder();
685 flag_builder.set("use_colocated_libcalls", "false").unwrap();
686 flag_builder.set("is_pic", "false").unwrap();
687 let isa_builder = cranelift_native::builder()
688 .map_err(|e| JitError::CompilationFailed(format!("Failed to get native ISA: {}", e)))?;
689 let isa = isa_builder
690 .finish(settings::Flags::new(flag_builder))
691 .map_err(|e| JitError::CompilationFailed(format!("Failed to build ISA: {}", e)))?;
692
693 let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
694 let mut module = JITModule::new(builder);
695
696 let mut sig = module.make_signature();
698 sig.params.push(AbiParam::new(types::I64)); sig.params.push(AbiParam::new(types::I64)); let func_id = module
702 .declare_function("jit_kernel", Linkage::Export, &sig)
703 .map_err(|e| {
704 JitError::CompilationFailed(format!("Failed to declare function: {}", e))
705 })?;
706
707 let mut ctx = module.make_context();
708 ctx.func.signature = sig;
709
710 let mut builder_ctx = FunctionBuilderContext::new();
712 {
713 let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_ctx);
714 let entry_block = builder.create_block();
715 builder.append_block_params_for_function_params(entry_block);
716 builder.switch_to_block(entry_block);
717 builder.seal_block(entry_block);
718
719 let input_ptr = builder.block_params(entry_block)[0];
720 let output_ptr = builder.block_params(entry_block)[1];
721
722 let mut values: Vec<Option<Value>> = vec![None; graph.len()];
724
725 for node in graph.nodes() {
726 let result = self.codegen_node(&mut builder, node, &values, input_ptr)?;
727 values[node.id.index()] = Some(result);
728 }
729
730 if let Some((_, output_id)) = graph.outputs().iter().next() {
732 let output_node = graph.node(*output_id);
733 if let Op::Output { input, .. } = &output_node.op {
734 if let Some(val) = values[input.index()] {
735 builder.ins().store(MemFlags::new(), val, output_ptr, 0);
736 }
737 }
738 }
739
740 builder.ins().return_(&[]);
741 builder.finalize();
742 }
743
744 module.define_function(func_id, &mut ctx).map_err(|e| {
746 JitError::CompilationFailed(format!("Failed to define function: {}", e))
747 })?;
748 module.clear_context(&mut ctx);
749 module
750 .finalize_definitions()
751 .map_err(|e| JitError::CompilationFailed(format!("Failed to finalize: {:?}", e)))?;
752
753 let code_ptr = module.get_finalized_function(func_id);
754 let code_size = 0; std::mem::forget(module);
758
759 Ok(CompiledFunction {
760 graph: Arc::new(graph.clone()),
761 kind: CompiledKind::Native {
762 code_ptr,
763 code_size,
764 },
765 })
766 }
767
768 fn codegen_node(
769 &self,
770 builder: &mut cranelift::prelude::FunctionBuilder,
771 node: &Node,
772 values: &[Option<cranelift::prelude::Value>],
773 input_ptr: cranelift::prelude::Value,
774 ) -> JitResult<cranelift::prelude::Value> {
775 use cranelift::prelude::*;
776
777 let get = |id: NodeId| -> JitResult<Value> {
778 values[id.index()]
779 .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not compiled", id)))
780 };
781
782 match &node.op {
783 Op::Input { name, .. } => {
784 let offset = self.get_input_offset(name);
786 Ok(builder
787 .ins()
788 .load(types::F32, MemFlags::new(), input_ptr, offset))
789 }
790
791 Op::Output { input, .. } => get(*input),
792
793 Op::Constant { value } => Ok(builder.ins().f32const(*value as f32)),
794
795 Op::Add { lhs, rhs } => {
796 let a = get(*lhs)?;
797 let b = get(*rhs)?;
798 Ok(builder.ins().fadd(a, b))
799 }
800
801 Op::Sub { lhs, rhs } => {
802 let a = get(*lhs)?;
803 let b = get(*rhs)?;
804 Ok(builder.ins().fsub(a, b))
805 }
806
807 Op::Mul { lhs, rhs } => {
808 let a = get(*lhs)?;
809 let b = get(*rhs)?;
810 Ok(builder.ins().fmul(a, b))
811 }
812
813 Op::Div { lhs, rhs } => {
814 let a = get(*lhs)?;
815 let b = get(*rhs)?;
816 Ok(builder.ins().fdiv(a, b))
817 }
818
819 Op::Neg { input } => {
820 let a = get(*input)?;
821 Ok(builder.ins().fneg(a))
822 }
823
824 Op::Abs { input } => {
825 let a = get(*input)?;
826 Ok(builder.ins().fabs(a))
827 }
828
829 Op::Sqrt { input } => {
830 let a = get(*input)?;
831 Ok(builder.ins().sqrt(a))
832 }
833
834 Op::AddScalar { input, scalar } => {
835 let a = get(*input)?;
836 let s = builder.ins().f32const(*scalar as f32);
837 Ok(builder.ins().fadd(a, s))
838 }
839
840 Op::MulScalar { input, scalar } => {
841 let a = get(*input)?;
842 let s = builder.ins().f32const(*scalar as f32);
843 Ok(builder.ins().fmul(a, s))
844 }
845
846 _ => Err(JitError::UnsupportedOp(format!(
849 "Operation {:?} not supported in native codegen, using interpreter",
850 node.op
851 ))),
852 }
853 }
854
855 fn get_input_offset(&self, _name: &str) -> i32 {
856 0
861 }
862
863 pub fn cache_stats(&self) -> crate::cache::CacheStats {
869 self.cache.stats()
870 }
871
872 pub fn clear_cache(&self) {
874 self.cache.clear();
875 }
876}
877
878impl Default for JitCompiler {
879 fn default() -> Self {
880 Self::new()
881 }
882}
883
884#[cfg(test)]
889mod tests {
890 use super::*;
891 use crate::trace::trace;
892
893 #[test]
894 fn test_compile_simple() {
895 let graph = trace(|tracer| {
896 let a = tracer.input("a", &[4]);
897 let b = tracer.input("b", &[4]);
898 let c = a.add(&b);
899 tracer.output("result", c)
900 });
901
902 let compiler = JitCompiler::new();
903 let func = compiler.compile(&graph).unwrap();
904
905 let a = [1.0, 2.0, 3.0, 4.0];
906 let b = [5.0, 6.0, 7.0, 8.0];
907 let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
908
909 assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
910 }
911
912 #[test]
913 fn test_compile_chain() {
914 let graph = trace(|tracer| {
915 let x = tracer.input("x", &[4]);
916 let y = x.relu().mul_scalar(2.0).add_scalar(1.0);
917 tracer.output("y", y)
918 });
919
920 let compiler = JitCompiler::new();
921 let func = compiler.compile(&graph).unwrap();
922
923 let x = [-1.0, 0.0, 1.0, 2.0];
924 let result = func.run(&[("x", &x)]).unwrap();
925
926 assert_eq!(result, vec![1.0, 1.0, 3.0, 5.0]);
930 }
931
932 #[test]
933 fn test_compile_activations() {
934 let graph = trace(|tracer| {
935 let x = tracer.input("x", &[3]);
936 let y = x.sigmoid();
937 tracer.output("y", y)
938 });
939
940 let compiler = JitCompiler::new();
941 let func = compiler.compile(&graph).unwrap();
942
943 let x = [0.0, 1.0, -1.0];
944 let result = func.run(&[("x", &x)]).unwrap();
945
946 assert!((result[0] - 0.5).abs() < 0.01);
948 assert!((result[1] - 0.731).abs() < 0.01);
950 }
951
952 #[test]
953 fn test_compile_matmul() {
954 let graph = trace(|tracer| {
955 let a = tracer.input("a", &[2, 3]);
956 let b = tracer.input("b", &[3, 2]);
957 let c = a.matmul(&b);
958 tracer.output("c", c)
959 });
960
961 let compiler = JitCompiler::new();
962 let func = compiler.compile(&graph).unwrap();
963
964 let a = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let b = [1.0, 0.0, 0.0, 1.0, 0.0, 0.0]; let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
968
969 assert_eq!(result.len(), 4); }
971
972 #[test]
973 fn test_caching() {
974 let graph = trace(|tracer| {
975 let x = tracer.input("x", &[4]);
976 tracer.output("y", x.relu())
977 });
978
979 let compiler = JitCompiler::new();
980 assert_eq!(compiler.cache_stats().entries, 0);
981
982 let _ = compiler.compile(&graph).unwrap();
983 assert_eq!(compiler.cache_stats().entries, 1);
984
985 let _ = compiler.compile(&graph).unwrap();
987 assert_eq!(compiler.cache_stats().entries, 1);
988 }
989}