1use std::sync::Arc;
6
7use crate::ir::{Graph, Node, NodeId, Op};
8use crate::error::{JitError, JitResult};
9use crate::cache::FunctionCache;
10use crate::optimize::Optimizer;
11
12#[derive(Clone)]
14pub struct CompiledFunction {
15 graph: Arc<Graph>,
17 kind: CompiledKind,
19}
20
21#[derive(Clone)]
22enum CompiledKind {
23 Interpreted,
25 #[allow(dead_code)]
27 Native {
28 code_ptr: *const u8,
30 code_size: usize,
32 },
33}
34
35unsafe impl Send for CompiledKind {}
37unsafe impl Sync for CompiledKind {}
38
39impl CompiledFunction {
40 pub fn placeholder() -> Self {
42 Self {
43 graph: Arc::new(Graph::new()),
44 kind: CompiledKind::Interpreted,
45 }
46 }
47
48 pub fn graph(&self) -> &Graph {
50 &self.graph
51 }
52
53 pub fn run(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
55 match &self.kind {
56 CompiledKind::Interpreted => self.run_interpreted(inputs),
57 CompiledKind::Native { code_ptr, code_size } => {
58 unsafe {
61 let func: extern "C" fn(*const f32, *mut f32) = std::mem::transmute(code_ptr);
62 let flat_inputs: Vec<f32> = inputs.iter().flat_map(|(_, d)| d.iter().copied()).collect();
63 let mut output = vec![0.0f32; self.graph.outputs().len() * 1024]; func(flat_inputs.as_ptr(), output.as_mut_ptr());
65 let _ = code_size; Ok(output)
67 }
68 }
69 }
70 }
71
72 fn run_interpreted(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
74 let mut values: Vec<Option<Vec<f32>>> = vec![None; self.graph.len()];
75
76 for (name, data) in inputs {
78 if let Some(id) = self.graph.input(name) {
79 values[id.index()] = Some(data.to_vec());
80 } else {
81 return Err(JitError::InputNotFound(name.to_string()));
82 }
83 }
84
85 for node in self.graph.nodes() {
87 let result = self.eval_node(node, &values)?;
88 values[node.id.index()] = Some(result);
89 }
90
91 if let Some((_, output_id)) = self.graph.outputs().iter().next() {
93 let output_node = self.graph.node(*output_id);
94 if let Op::Output { input, .. } = &output_node.op {
95 return Ok(values[input.index()].clone().unwrap_or_default());
96 }
97 }
98
99 Err(JitError::OutputNotFound("no output".to_string()))
100 }
101
102 fn eval_node(&self, node: &Node, values: &[Option<Vec<f32>>]) -> JitResult<Vec<f32>> {
103 let get = |id: NodeId| -> JitResult<&Vec<f32>> {
104 values[id.index()]
105 .as_ref()
106 .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not computed", id)))
107 };
108
109 match &node.op {
110 Op::Input { .. } => {
111 Ok(values[node.id.index()].clone().unwrap_or_default())
113 }
114
115 Op::Output { input, .. } => {
116 Ok(get(*input)?.clone())
117 }
118
119 Op::Constant { value } => {
120 let numel = node.shape.numel();
121 Ok(vec![*value as f32; numel])
122 }
123
124 Op::Add { lhs, rhs } => {
126 let a = get(*lhs)?;
127 let b = get(*rhs)?;
128 Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
129 }
130
131 Op::Sub { lhs, rhs } => {
132 let a = get(*lhs)?;
133 let b = get(*rhs)?;
134 Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
135 }
136
137 Op::Mul { lhs, rhs } => {
138 let a = get(*lhs)?;
139 let b = get(*rhs)?;
140 Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).collect())
141 }
142
143 Op::Div { lhs, rhs } => {
144 let a = get(*lhs)?;
145 let b = get(*rhs)?;
146 Ok(a.iter().zip(b.iter()).map(|(x, y)| x / y).collect())
147 }
148
149 Op::Pow { base, exp } => {
150 let a = get(*base)?;
151 let b = get(*exp)?;
152 Ok(a.iter().zip(b.iter()).map(|(x, y)| x.powf(*y)).collect())
153 }
154
155 Op::Max { lhs, rhs } => {
156 let a = get(*lhs)?;
157 let b = get(*rhs)?;
158 Ok(a.iter().zip(b.iter()).map(|(x, y)| x.max(*y)).collect())
159 }
160
161 Op::Min { lhs, rhs } => {
162 let a = get(*lhs)?;
163 let b = get(*rhs)?;
164 Ok(a.iter().zip(b.iter()).map(|(x, y)| x.min(*y)).collect())
165 }
166
167 Op::AddScalar { input, scalar } => {
169 let a = get(*input)?;
170 Ok(a.iter().map(|x| x + *scalar as f32).collect())
171 }
172
173 Op::MulScalar { input, scalar } => {
174 let a = get(*input)?;
175 Ok(a.iter().map(|x| x * *scalar as f32).collect())
176 }
177
178 Op::Neg { input } => {
180 let a = get(*input)?;
181 Ok(a.iter().map(|x| -x).collect())
182 }
183
184 Op::Abs { input } => {
185 let a = get(*input)?;
186 Ok(a.iter().map(|x| x.abs()).collect())
187 }
188
189 Op::Sqrt { input } => {
190 let a = get(*input)?;
191 Ok(a.iter().map(|x| x.sqrt()).collect())
192 }
193
194 Op::Exp { input } => {
195 let a = get(*input)?;
196 Ok(a.iter().map(|x| x.exp()).collect())
197 }
198
199 Op::Log { input } => {
200 let a = get(*input)?;
201 Ok(a.iter().map(|x| x.ln()).collect())
202 }
203
204 Op::Sin { input } => {
205 let a = get(*input)?;
206 Ok(a.iter().map(|x| x.sin()).collect())
207 }
208
209 Op::Cos { input } => {
210 let a = get(*input)?;
211 Ok(a.iter().map(|x| x.cos()).collect())
212 }
213
214 Op::Tanh { input } => {
215 let a = get(*input)?;
216 Ok(a.iter().map(|x| x.tanh()).collect())
217 }
218
219 Op::Relu { input } => {
221 let a = get(*input)?;
222 Ok(a.iter().map(|x| x.max(0.0)).collect())
223 }
224
225 Op::Sigmoid { input } => {
226 let a = get(*input)?;
227 Ok(a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect())
228 }
229
230 Op::Gelu { input } => {
231 let a = get(*input)?;
232 const SQRT_2_OVER_PI: f32 = 0.7978845608;
233 Ok(a.iter().map(|x| {
234 0.5 * x * (1.0 + (SQRT_2_OVER_PI * (x + 0.044715 * x.powi(3))).tanh())
235 }).collect())
236 }
237
238 Op::Silu { input } => {
239 let a = get(*input)?;
240 Ok(a.iter().map(|x| x / (1.0 + (-x).exp())).collect())
241 }
242
243 Op::Sum { input } => {
245 let a = get(*input)?;
246 Ok(vec![a.iter().sum()])
247 }
248
249 Op::Mean { input } => {
250 let a = get(*input)?;
251 let sum: f32 = a.iter().sum();
252 Ok(vec![sum / a.len() as f32])
253 }
254
255 Op::SumAxis { input, axis, keepdim } => {
256 let a = get(*input)?;
258 let input_node = self.graph.node(*input);
259 let input_shape = input_node.shape.dims();
260
261 reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)
262 }
263
264 Op::MeanAxis { input, axis, keepdim } => {
265 let a = get(*input)?;
266 let input_node = self.graph.node(*input);
267 let input_shape = input_node.shape.dims();
268 let axis_size = input_shape[normalize_axis(*axis, input_shape.len())];
269
270 let sum = reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)?;
271 Ok(sum.iter().map(|x| x / axis_size as f32).collect())
272 }
273
274 Op::MaxAxis { input, axis, keepdim } => {
275 let a = get(*input)?;
276 let input_node = self.graph.node(*input);
277 let input_shape = input_node.shape.dims();
278
279 reduce_axis(a, input_shape, *axis, *keepdim, f32::max, f32::NEG_INFINITY)
280 }
281
282 Op::Reshape { input, .. } |
284 Op::Transpose { input, .. } |
285 Op::Squeeze { input, .. } |
286 Op::Unsqueeze { input, .. } |
287 Op::Broadcast { input, .. } |
288 Op::Contiguous { input } => {
289 Ok(get(*input)?.clone())
290 }
291
292 Op::MatMul { lhs, rhs } => {
294 let a = get(*lhs)?;
295 let b = get(*rhs)?;
296 let lhs_node = self.graph.node(*lhs);
297 let rhs_node = self.graph.node(*rhs);
298
299 let lhs_shape = lhs_node.shape.dims();
300 let rhs_shape = rhs_node.shape.dims();
301
302 matmul_impl(a, b, lhs_shape, rhs_shape)
303 }
304
305 Op::Gt { lhs, rhs } => {
307 let a = get(*lhs)?;
308 let b = get(*rhs)?;
309 Ok(a.iter().zip(b.iter()).map(|(x, y)| if x > y { 1.0 } else { 0.0 }).collect())
310 }
311
312 Op::Lt { lhs, rhs } => {
313 let a = get(*lhs)?;
314 let b = get(*rhs)?;
315 Ok(a.iter().zip(b.iter()).map(|(x, y)| if x < y { 1.0 } else { 0.0 }).collect())
316 }
317
318 Op::Eq { lhs, rhs } => {
319 let a = get(*lhs)?;
320 let b = get(*rhs)?;
321 Ok(a.iter().zip(b.iter()).map(|(x, y)| if (x - y).abs() < f32::EPSILON { 1.0 } else { 0.0 }).collect())
322 }
323
324 Op::Where { condition, x, y } => {
325 let cond = get(*condition)?;
326 let a = get(*x)?;
327 let b = get(*y)?;
328 Ok(cond.iter().zip(a.iter().zip(b.iter())).map(|(c, (a, b))| {
329 if *c != 0.0 { *a } else { *b }
330 }).collect())
331 }
332
333 Op::Cast { input, .. } => {
334 Ok(get(*input)?.clone())
336 }
337 }
338 }
339}
340
341fn normalize_axis(axis: i32, ndim: usize) -> usize {
342 if axis < 0 {
343 (ndim as i32 + axis) as usize
344 } else {
345 axis as usize
346 }
347}
348
349fn reduce_axis(
350 data: &[f32],
351 shape: &[usize],
352 axis: i32,
353 keepdim: bool,
354 op: fn(f32, f32) -> f32,
355 init: f32,
356) -> JitResult<Vec<f32>> {
357 let axis = normalize_axis(axis, shape.len());
358
359 let mut strides = vec![1usize; shape.len()];
361 for i in (0..shape.len() - 1).rev() {
362 strides[i] = strides[i + 1] * shape[i + 1];
363 }
364
365 let mut output_shape: Vec<usize> = shape.to_vec();
367 if keepdim {
368 output_shape[axis] = 1;
369 } else {
370 output_shape.remove(axis);
371 }
372
373 let output_numel: usize = output_shape.iter().product();
374 let mut result = vec![init; output_numel];
375
376 for i in 0..data.len() {
378 let mut multi_idx = vec![0usize; shape.len()];
380 let mut idx = i;
381 for d in 0..shape.len() {
382 multi_idx[d] = idx / strides[d];
383 idx %= strides[d];
384 }
385
386 let mut out_idx = 0;
388 let mut out_stride = 1;
389 for d in (0..shape.len()).rev() {
390 if d == axis {
391 continue;
392 }
393 out_idx += multi_idx[d] * out_stride;
394 let out_dim = if d > axis && !keepdim { d - 1 } else { d };
395 if out_dim + 1 < output_shape.len() {
396 out_stride *= output_shape[out_dim + 1];
397 }
398 }
399
400 if keepdim {
401 out_idx = 0;
402 out_stride = 1;
403 for d in (0..output_shape.len()).rev() {
404 if d == axis {
405 out_stride *= output_shape[d];
406 continue;
407 }
408 out_idx += multi_idx[d] * out_stride;
409 if d > 0 {
410 out_stride *= output_shape[d - 1];
411 }
412 }
413 out_idx = 0;
415 let mut temp_strides = vec![1usize; output_shape.len()];
416 for d in (0..output_shape.len() - 1).rev() {
417 temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
418 }
419 for d in 0..output_shape.len() {
420 let dim_idx = if d == axis { 0 } else { multi_idx[d] };
421 out_idx += dim_idx * temp_strides[d];
422 }
423 } else {
424 out_idx = 0;
425 let mut temp_strides = vec![1usize; output_shape.len()];
426 if !output_shape.is_empty() {
427 for d in (0..output_shape.len() - 1).rev() {
428 temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
429 }
430 }
431 let mut out_d = 0;
432 for d in 0..shape.len() {
433 if d == axis {
434 continue;
435 }
436 if out_d < temp_strides.len() {
437 out_idx += multi_idx[d] * temp_strides[out_d];
438 }
439 out_d += 1;
440 }
441 }
442
443 if out_idx < result.len() {
444 result[out_idx] = op(result[out_idx], data[i]);
445 }
446 }
447
448 Ok(result)
449}
450
451fn matmul_impl(a: &[f32], b: &[f32], a_shape: &[usize], b_shape: &[usize]) -> JitResult<Vec<f32>> {
452 if a_shape.len() != 2 || b_shape.len() != 2 {
454 return Err(JitError::UnsupportedOp("Only 2D matmul supported in interpreter".to_string()));
455 }
456
457 let m = a_shape[0];
458 let k = a_shape[1];
459 let n = b_shape[1];
460
461 if k != b_shape[0] {
462 return Err(JitError::ShapeMismatch {
463 expected: vec![k],
464 found: vec![b_shape[0]],
465 });
466 }
467
468 let mut result = vec![0.0f32; m * n];
469
470 for i in 0..m {
471 for j in 0..n {
472 let mut sum = 0.0;
473 for p in 0..k {
474 sum += a[i * k + p] * b[p * n + j];
475 }
476 result[i * n + j] = sum;
477 }
478 }
479
480 Ok(result)
481}
482
483pub struct JitCompiler {
485 optimizer: Optimizer,
486 cache: FunctionCache,
487 use_native: bool,
488}
489
490impl JitCompiler {
491 pub fn new() -> Self {
493 Self {
494 optimizer: Optimizer::default_passes(),
495 cache: FunctionCache::default_size(),
496 use_native: false, }
498 }
499
500 pub fn with_optimizer(optimizer: Optimizer) -> Self {
502 Self {
503 optimizer,
504 cache: FunctionCache::default_size(),
505 use_native: false,
506 }
507 }
508
509 pub fn compile(&self, graph: &Graph) -> JitResult<CompiledFunction> {
511 let cache_key = FunctionCache::hash_graph(graph);
513 if let Some(cached) = self.cache.get(cache_key) {
514 return Ok(cached);
515 }
516
517 graph.validate().map_err(JitError::InvalidGraph)?;
519
520 let optimized = self.optimizer.optimize(graph.clone());
522
523 let func = if self.use_native {
525 self.compile_native(&optimized)?
526 } else {
527 self.compile_interpreted(&optimized)
528 };
529
530 self.cache.insert(cache_key, func.clone());
532
533 Ok(func)
534 }
535
536 fn compile_interpreted(&self, graph: &Graph) -> CompiledFunction {
537 CompiledFunction {
538 graph: Arc::new(graph.clone()),
539 kind: CompiledKind::Interpreted,
540 }
541 }
542
543 fn compile_native(&self, graph: &Graph) -> JitResult<CompiledFunction> {
544 use cranelift::prelude::*;
545 use cranelift_jit::{JITBuilder, JITModule};
546 use cranelift_module::{Linkage, Module};
547
548 let mut flag_builder = settings::builder();
550 flag_builder.set("use_colocated_libcalls", "false").unwrap();
551 flag_builder.set("is_pic", "false").unwrap();
552 let isa_builder = cranelift_native::builder()
553 .map_err(|e| JitError::CompilationFailed(format!("Failed to get native ISA: {}", e)))?;
554 let isa = isa_builder
555 .finish(settings::Flags::new(flag_builder))
556 .map_err(|e| JitError::CompilationFailed(format!("Failed to build ISA: {}", e)))?;
557
558 let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
559 let mut module = JITModule::new(builder);
560
561 let mut sig = module.make_signature();
563 sig.params.push(AbiParam::new(types::I64)); sig.params.push(AbiParam::new(types::I64)); let func_id = module
567 .declare_function("jit_kernel", Linkage::Export, &sig)
568 .map_err(|e| JitError::CompilationFailed(format!("Failed to declare function: {}", e)))?;
569
570 let mut ctx = module.make_context();
571 ctx.func.signature = sig;
572
573 let mut builder_ctx = FunctionBuilderContext::new();
575 {
576 let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_ctx);
577 let entry_block = builder.create_block();
578 builder.append_block_params_for_function_params(entry_block);
579 builder.switch_to_block(entry_block);
580 builder.seal_block(entry_block);
581
582 let input_ptr = builder.block_params(entry_block)[0];
583 let output_ptr = builder.block_params(entry_block)[1];
584
585 let mut values: Vec<Option<Value>> = vec![None; graph.len()];
587
588 for node in graph.nodes() {
589 let result = self.codegen_node(&mut builder, node, &values, input_ptr)?;
590 values[node.id.index()] = Some(result);
591 }
592
593 if let Some((_, output_id)) = graph.outputs().iter().next() {
595 let output_node = graph.node(*output_id);
596 if let Op::Output { input, .. } = &output_node.op {
597 if let Some(val) = values[input.index()] {
598 builder.ins().store(MemFlags::new(), val, output_ptr, 0);
599 }
600 }
601 }
602
603 builder.ins().return_(&[]);
604 builder.finalize();
605 }
606
607 module
609 .define_function(func_id, &mut ctx)
610 .map_err(|e| JitError::CompilationFailed(format!("Failed to define function: {}", e)))?;
611 module.clear_context(&mut ctx);
612 module
613 .finalize_definitions()
614 .map_err(|e| JitError::CompilationFailed(format!("Failed to finalize: {:?}", e)))?;
615
616 let code_ptr = module.get_finalized_function(func_id);
617 let code_size = 0; std::mem::forget(module);
621
622 Ok(CompiledFunction {
623 graph: Arc::new(graph.clone()),
624 kind: CompiledKind::Native {
625 code_ptr: code_ptr as *const u8,
626 code_size,
627 },
628 })
629 }
630
631 fn codegen_node(
632 &self,
633 builder: &mut cranelift::prelude::FunctionBuilder,
634 node: &Node,
635 values: &[Option<cranelift::prelude::Value>],
636 input_ptr: cranelift::prelude::Value,
637 ) -> JitResult<cranelift::prelude::Value> {
638 use cranelift::prelude::*;
639
640 let get = |id: NodeId| -> JitResult<Value> {
641 values[id.index()]
642 .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not compiled", id)))
643 };
644
645 match &node.op {
646 Op::Input { name, .. } => {
647 let offset = self.get_input_offset(name);
649 Ok(builder.ins().load(types::F32, MemFlags::new(), input_ptr, offset))
650 }
651
652 Op::Output { input, .. } => get(*input),
653
654 Op::Constant { value } => Ok(builder.ins().f32const(*value as f32)),
655
656 Op::Add { lhs, rhs } => {
657 let a = get(*lhs)?;
658 let b = get(*rhs)?;
659 Ok(builder.ins().fadd(a, b))
660 }
661
662 Op::Sub { lhs, rhs } => {
663 let a = get(*lhs)?;
664 let b = get(*rhs)?;
665 Ok(builder.ins().fsub(a, b))
666 }
667
668 Op::Mul { lhs, rhs } => {
669 let a = get(*lhs)?;
670 let b = get(*rhs)?;
671 Ok(builder.ins().fmul(a, b))
672 }
673
674 Op::Div { lhs, rhs } => {
675 let a = get(*lhs)?;
676 let b = get(*rhs)?;
677 Ok(builder.ins().fdiv(a, b))
678 }
679
680 Op::Neg { input } => {
681 let a = get(*input)?;
682 Ok(builder.ins().fneg(a))
683 }
684
685 Op::Abs { input } => {
686 let a = get(*input)?;
687 Ok(builder.ins().fabs(a))
688 }
689
690 Op::Sqrt { input } => {
691 let a = get(*input)?;
692 Ok(builder.ins().sqrt(a))
693 }
694
695 Op::AddScalar { input, scalar } => {
696 let a = get(*input)?;
697 let s = builder.ins().f32const(*scalar as f32);
698 Ok(builder.ins().fadd(a, s))
699 }
700
701 Op::MulScalar { input, scalar } => {
702 let a = get(*input)?;
703 let s = builder.ins().f32const(*scalar as f32);
704 Ok(builder.ins().fmul(a, s))
705 }
706
707 _ => Err(JitError::UnsupportedOp(format!(
710 "Operation {:?} not supported in native codegen, using interpreter",
711 node.op
712 ))),
713 }
714 }
715
716 fn get_input_offset(&self, _name: &str) -> i32 {
717 0
719 }
720
721 pub fn cache_stats(&self) -> crate::cache::CacheStats {
723 self.cache.stats()
724 }
725
726 pub fn clear_cache(&self) {
728 self.cache.clear();
729 }
730}
731
732impl Default for JitCompiler {
733 fn default() -> Self {
734 Self::new()
735 }
736}
737
738#[cfg(test)]
739mod tests {
740 use super::*;
741 use crate::trace::trace;
742
743 #[test]
744 fn test_compile_simple() {
745 let graph = trace(|tracer| {
746 let a = tracer.input("a", &[4]);
747 let b = tracer.input("b", &[4]);
748 let c = a.add(&b);
749 tracer.output("result", c)
750 });
751
752 let compiler = JitCompiler::new();
753 let func = compiler.compile(&graph).unwrap();
754
755 let a = [1.0, 2.0, 3.0, 4.0];
756 let b = [5.0, 6.0, 7.0, 8.0];
757 let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
758
759 assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
760 }
761
762 #[test]
763 fn test_compile_chain() {
764 let graph = trace(|tracer| {
765 let x = tracer.input("x", &[4]);
766 let y = x.relu().mul_scalar(2.0).add_scalar(1.0);
767 tracer.output("y", y)
768 });
769
770 let compiler = JitCompiler::new();
771 let func = compiler.compile(&graph).unwrap();
772
773 let x = [-1.0, 0.0, 1.0, 2.0];
774 let result = func.run(&[("x", &x)]).unwrap();
775
776 assert_eq!(result, vec![1.0, 1.0, 3.0, 5.0]);
780 }
781
782 #[test]
783 fn test_compile_activations() {
784 let graph = trace(|tracer| {
785 let x = tracer.input("x", &[3]);
786 let y = x.sigmoid();
787 tracer.output("y", y)
788 });
789
790 let compiler = JitCompiler::new();
791 let func = compiler.compile(&graph).unwrap();
792
793 let x = [0.0, 1.0, -1.0];
794 let result = func.run(&[("x", &x)]).unwrap();
795
796 assert!((result[0] - 0.5).abs() < 0.01);
798 assert!((result[1] - 0.731).abs() < 0.01);
800 }
801
802 #[test]
803 fn test_compile_matmul() {
804 let graph = trace(|tracer| {
805 let a = tracer.input("a", &[2, 3]);
806 let b = tracer.input("b", &[3, 2]);
807 let c = a.matmul(&b);
808 tracer.output("c", c)
809 });
810
811 let compiler = JitCompiler::new();
812 let func = compiler.compile(&graph).unwrap();
813
814 let a = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let b = [1.0, 0.0, 0.0, 1.0, 0.0, 0.0]; let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
818
819 assert_eq!(result.len(), 4); }
821
822 #[test]
823 fn test_caching() {
824 let graph = trace(|tracer| {
825 let x = tracer.input("x", &[4]);
826 tracer.output("y", x.relu())
827 });
828
829 let compiler = JitCompiler::new();
830 assert_eq!(compiler.cache_stats().entries, 0);
831
832 let _ = compiler.compile(&graph).unwrap();
833 assert_eq!(compiler.cache_stats().entries, 1);
834
835 let _ = compiler.compile(&graph).unwrap();
837 assert_eq!(compiler.cache_stats().entries, 1);
838 }
839}