1use std::sync::Arc;
6
7use crate::cache::FunctionCache;
8use crate::error::{JitError, JitResult};
9use crate::ir::{Graph, Node, NodeId, Op};
10use crate::optimize::Optimizer;
11
12#[derive(Clone)]
14pub struct CompiledFunction {
15 graph: Arc<Graph>,
17 kind: CompiledKind,
19}
20
21#[derive(Clone)]
22enum CompiledKind {
23 Interpreted,
25 #[allow(dead_code)]
27 Native {
28 code_ptr: *const u8,
30 code_size: usize,
32 },
33}
34
35unsafe impl Send for CompiledKind {}
37unsafe impl Sync for CompiledKind {}
38
39impl CompiledFunction {
40 pub fn placeholder() -> Self {
42 Self {
43 graph: Arc::new(Graph::new()),
44 kind: CompiledKind::Interpreted,
45 }
46 }
47
48 pub fn graph(&self) -> &Graph {
50 &self.graph
51 }
52
53 pub fn run(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
55 match &self.kind {
56 CompiledKind::Interpreted => self.run_interpreted(inputs),
57 CompiledKind::Native {
58 code_ptr,
59 code_size,
60 } => {
61 unsafe {
64 let func: extern "C" fn(*const f32, *mut f32) = std::mem::transmute(code_ptr);
65 let flat_inputs: Vec<f32> =
66 inputs.iter().flat_map(|(_, d)| d.iter().copied()).collect();
67 let mut output = vec![0.0f32; self.graph.outputs().len() * 1024]; func(flat_inputs.as_ptr(), output.as_mut_ptr());
69 let _ = code_size; Ok(output)
71 }
72 }
73 }
74 }
75
76 fn run_interpreted(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
78 let mut values: Vec<Option<Vec<f32>>> = vec![None; self.graph.len()];
79
80 for (name, data) in inputs {
82 if let Some(id) = self.graph.input(name) {
83 values[id.index()] = Some(data.to_vec());
84 } else {
85 return Err(JitError::InputNotFound(name.to_string()));
86 }
87 }
88
89 for node in self.graph.nodes() {
91 let result = self.eval_node(node, &values)?;
92 values[node.id.index()] = Some(result);
93 }
94
95 if let Some((_, output_id)) = self.graph.outputs().iter().next() {
97 let output_node = self.graph.node(*output_id);
98 if let Op::Output { input, .. } = &output_node.op {
99 return Ok(values[input.index()].clone().unwrap_or_default());
100 }
101 }
102
103 Err(JitError::OutputNotFound("no output".to_string()))
104 }
105
106 fn eval_node(&self, node: &Node, values: &[Option<Vec<f32>>]) -> JitResult<Vec<f32>> {
107 let get = |id: NodeId| -> JitResult<&Vec<f32>> {
108 values[id.index()]
109 .as_ref()
110 .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not computed", id)))
111 };
112
113 match &node.op {
114 Op::Input { .. } => {
115 Ok(values[node.id.index()].clone().unwrap_or_default())
117 }
118
119 Op::Output { input, .. } => Ok(get(*input)?.clone()),
120
121 Op::Constant { value } => {
122 let numel = node.shape.numel();
123 Ok(vec![*value as f32; numel])
124 }
125
126 Op::Add { lhs, rhs } => {
128 let a = get(*lhs)?;
129 let b = get(*rhs)?;
130 Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
131 }
132
133 Op::Sub { lhs, rhs } => {
134 let a = get(*lhs)?;
135 let b = get(*rhs)?;
136 Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
137 }
138
139 Op::Mul { lhs, rhs } => {
140 let a = get(*lhs)?;
141 let b = get(*rhs)?;
142 Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).collect())
143 }
144
145 Op::Div { lhs, rhs } => {
146 let a = get(*lhs)?;
147 let b = get(*rhs)?;
148 Ok(a.iter().zip(b.iter()).map(|(x, y)| x / y).collect())
149 }
150
151 Op::Pow { base, exp } => {
152 let a = get(*base)?;
153 let b = get(*exp)?;
154 Ok(a.iter().zip(b.iter()).map(|(x, y)| x.powf(*y)).collect())
155 }
156
157 Op::Max { lhs, rhs } => {
158 let a = get(*lhs)?;
159 let b = get(*rhs)?;
160 Ok(a.iter().zip(b.iter()).map(|(x, y)| x.max(*y)).collect())
161 }
162
163 Op::Min { lhs, rhs } => {
164 let a = get(*lhs)?;
165 let b = get(*rhs)?;
166 Ok(a.iter().zip(b.iter()).map(|(x, y)| x.min(*y)).collect())
167 }
168
169 Op::AddScalar { input, scalar } => {
171 let a = get(*input)?;
172 Ok(a.iter().map(|x| x + *scalar as f32).collect())
173 }
174
175 Op::MulScalar { input, scalar } => {
176 let a = get(*input)?;
177 Ok(a.iter().map(|x| x * *scalar as f32).collect())
178 }
179
180 Op::Neg { input } => {
182 let a = get(*input)?;
183 Ok(a.iter().map(|x| -x).collect())
184 }
185
186 Op::Abs { input } => {
187 let a = get(*input)?;
188 Ok(a.iter().map(|x| x.abs()).collect())
189 }
190
191 Op::Sqrt { input } => {
192 let a = get(*input)?;
193 Ok(a.iter().map(|x| x.sqrt()).collect())
194 }
195
196 Op::Exp { input } => {
197 let a = get(*input)?;
198 Ok(a.iter().map(|x| x.exp()).collect())
199 }
200
201 Op::Log { input } => {
202 let a = get(*input)?;
203 Ok(a.iter().map(|x| x.ln()).collect())
204 }
205
206 Op::Sin { input } => {
207 let a = get(*input)?;
208 Ok(a.iter().map(|x| x.sin()).collect())
209 }
210
211 Op::Cos { input } => {
212 let a = get(*input)?;
213 Ok(a.iter().map(|x| x.cos()).collect())
214 }
215
216 Op::Tanh { input } => {
217 let a = get(*input)?;
218 Ok(a.iter().map(|x| x.tanh()).collect())
219 }
220
221 Op::Relu { input } => {
223 let a = get(*input)?;
224 Ok(a.iter().map(|x| x.max(0.0)).collect())
225 }
226
227 Op::Sigmoid { input } => {
228 let a = get(*input)?;
229 Ok(a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect())
230 }
231
232 Op::Gelu { input } => {
233 let a = get(*input)?;
234 const SQRT_2_OVER_PI: f32 = 0.7978845608;
235 Ok(a.iter()
236 .map(|x| 0.5 * x * (1.0 + (SQRT_2_OVER_PI * (x + 0.044715 * x.powi(3))).tanh()))
237 .collect())
238 }
239
240 Op::Silu { input } => {
241 let a = get(*input)?;
242 Ok(a.iter().map(|x| x / (1.0 + (-x).exp())).collect())
243 }
244
245 Op::Sum { input } => {
247 let a = get(*input)?;
248 Ok(vec![a.iter().sum()])
249 }
250
251 Op::Mean { input } => {
252 let a = get(*input)?;
253 let sum: f32 = a.iter().sum();
254 Ok(vec![sum / a.len() as f32])
255 }
256
257 Op::SumAxis {
258 input,
259 axis,
260 keepdim,
261 } => {
262 let a = get(*input)?;
264 let input_node = self.graph.node(*input);
265 let input_shape = input_node.shape.dims();
266
267 reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)
268 }
269
270 Op::MeanAxis {
271 input,
272 axis,
273 keepdim,
274 } => {
275 let a = get(*input)?;
276 let input_node = self.graph.node(*input);
277 let input_shape = input_node.shape.dims();
278 let axis_size = input_shape[normalize_axis(*axis, input_shape.len())];
279
280 let sum = reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)?;
281 Ok(sum.iter().map(|x| x / axis_size as f32).collect())
282 }
283
284 Op::MaxAxis {
285 input,
286 axis,
287 keepdim,
288 } => {
289 let a = get(*input)?;
290 let input_node = self.graph.node(*input);
291 let input_shape = input_node.shape.dims();
292
293 reduce_axis(a, input_shape, *axis, *keepdim, f32::max, f32::NEG_INFINITY)
294 }
295
296 Op::Reshape { input, .. }
298 | Op::Transpose { input, .. }
299 | Op::Squeeze { input, .. }
300 | Op::Unsqueeze { input, .. }
301 | Op::Broadcast { input, .. }
302 | Op::Contiguous { input } => Ok(get(*input)?.clone()),
303
304 Op::MatMul { lhs, rhs } => {
306 let a = get(*lhs)?;
307 let b = get(*rhs)?;
308 let lhs_node = self.graph.node(*lhs);
309 let rhs_node = self.graph.node(*rhs);
310
311 let lhs_shape = lhs_node.shape.dims();
312 let rhs_shape = rhs_node.shape.dims();
313
314 matmul_impl(a, b, lhs_shape, rhs_shape)
315 }
316
317 Op::Gt { lhs, rhs } => {
319 let a = get(*lhs)?;
320 let b = get(*rhs)?;
321 Ok(a.iter()
322 .zip(b.iter())
323 .map(|(x, y)| if x > y { 1.0 } else { 0.0 })
324 .collect())
325 }
326
327 Op::Lt { lhs, rhs } => {
328 let a = get(*lhs)?;
329 let b = get(*rhs)?;
330 Ok(a.iter()
331 .zip(b.iter())
332 .map(|(x, y)| if x < y { 1.0 } else { 0.0 })
333 .collect())
334 }
335
336 Op::Eq { lhs, rhs } => {
337 let a = get(*lhs)?;
338 let b = get(*rhs)?;
339 Ok(a.iter()
340 .zip(b.iter())
341 .map(|(x, y)| {
342 if (x - y).abs() < f32::EPSILON {
343 1.0
344 } else {
345 0.0
346 }
347 })
348 .collect())
349 }
350
351 Op::Where { condition, x, y } => {
352 let cond = get(*condition)?;
353 let a = get(*x)?;
354 let b = get(*y)?;
355 Ok(cond
356 .iter()
357 .zip(a.iter().zip(b.iter()))
358 .map(|(c, (a, b))| if *c != 0.0 { *a } else { *b })
359 .collect())
360 }
361
362 Op::Cast { input, .. } => {
363 Ok(get(*input)?.clone())
365 }
366 }
367 }
368}
369
370fn normalize_axis(axis: i32, ndim: usize) -> usize {
371 if axis < 0 {
372 (ndim as i32 + axis) as usize
373 } else {
374 axis as usize
375 }
376}
377
378fn reduce_axis(
379 data: &[f32],
380 shape: &[usize],
381 axis: i32,
382 keepdim: bool,
383 op: fn(f32, f32) -> f32,
384 init: f32,
385) -> JitResult<Vec<f32>> {
386 let axis = normalize_axis(axis, shape.len());
387
388 let mut strides = vec![1usize; shape.len()];
390 for i in (0..shape.len() - 1).rev() {
391 strides[i] = strides[i + 1] * shape[i + 1];
392 }
393
394 let mut output_shape: Vec<usize> = shape.to_vec();
396 if keepdim {
397 output_shape[axis] = 1;
398 } else {
399 output_shape.remove(axis);
400 }
401
402 let output_numel: usize = output_shape.iter().product();
403 let mut result = vec![init; output_numel];
404
405 for i in 0..data.len() {
407 let mut multi_idx = vec![0usize; shape.len()];
409 let mut idx = i;
410 for d in 0..shape.len() {
411 multi_idx[d] = idx / strides[d];
412 idx %= strides[d];
413 }
414
415 let out_idx = if keepdim {
417 let mut out_idx = 0;
418 let mut temp_strides = vec![1usize; output_shape.len()];
419 for d in (0..output_shape.len() - 1).rev() {
420 temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
421 }
422 for d in 0..output_shape.len() {
423 let dim_idx = if d == axis { 0 } else { multi_idx[d] };
424 out_idx += dim_idx * temp_strides[d];
425 }
426 out_idx
427 } else {
428 let mut out_idx = 0;
429 let mut temp_strides = vec![1usize; output_shape.len()];
430 if !output_shape.is_empty() {
431 for d in (0..output_shape.len() - 1).rev() {
432 temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
433 }
434 }
435 let mut out_d = 0;
436 for d in 0..shape.len() {
437 if d == axis {
438 continue;
439 }
440 if out_d < temp_strides.len() {
441 out_idx += multi_idx[d] * temp_strides[out_d];
442 }
443 out_d += 1;
444 }
445 out_idx
446 };
447
448 if out_idx < result.len() {
449 result[out_idx] = op(result[out_idx], data[i]);
450 }
451 }
452
453 Ok(result)
454}
455
456fn matmul_impl(a: &[f32], b: &[f32], a_shape: &[usize], b_shape: &[usize]) -> JitResult<Vec<f32>> {
457 if a_shape.len() != 2 || b_shape.len() != 2 {
459 return Err(JitError::UnsupportedOp(
460 "Only 2D matmul supported in interpreter".to_string(),
461 ));
462 }
463
464 let m = a_shape[0];
465 let k = a_shape[1];
466 let n = b_shape[1];
467
468 if k != b_shape[0] {
469 return Err(JitError::ShapeMismatch {
470 expected: vec![k],
471 found: vec![b_shape[0]],
472 });
473 }
474
475 let mut result = vec![0.0f32; m * n];
476
477 for i in 0..m {
478 for j in 0..n {
479 let mut sum = 0.0;
480 for p in 0..k {
481 sum += a[i * k + p] * b[p * n + j];
482 }
483 result[i * n + j] = sum;
484 }
485 }
486
487 Ok(result)
488}
489
490pub struct JitCompiler {
492 optimizer: Optimizer,
493 cache: FunctionCache,
494 use_native: bool,
495}
496
497impl JitCompiler {
498 pub fn new() -> Self {
500 Self {
501 optimizer: Optimizer::default_passes(),
502 cache: FunctionCache::default_size(),
503 use_native: false, }
505 }
506
507 pub fn with_optimizer(optimizer: Optimizer) -> Self {
509 Self {
510 optimizer,
511 cache: FunctionCache::default_size(),
512 use_native: false,
513 }
514 }
515
516 pub fn compile(&self, graph: &Graph) -> JitResult<CompiledFunction> {
518 let cache_key = FunctionCache::hash_graph(graph);
520 if let Some(cached) = self.cache.get(cache_key) {
521 return Ok(cached);
522 }
523
524 graph.validate().map_err(JitError::InvalidGraph)?;
526
527 let optimized = self.optimizer.optimize(graph.clone());
529
530 let func = if self.use_native {
532 self.compile_native(&optimized)?
533 } else {
534 self.compile_interpreted(&optimized)
535 };
536
537 self.cache.insert(cache_key, func.clone());
539
540 Ok(func)
541 }
542
543 fn compile_interpreted(&self, graph: &Graph) -> CompiledFunction {
544 CompiledFunction {
545 graph: Arc::new(graph.clone()),
546 kind: CompiledKind::Interpreted,
547 }
548 }
549
550 fn compile_native(&self, graph: &Graph) -> JitResult<CompiledFunction> {
551 use cranelift::prelude::*;
552 use cranelift_jit::{JITBuilder, JITModule};
553 use cranelift_module::{Linkage, Module};
554
555 let mut flag_builder = settings::builder();
557 flag_builder.set("use_colocated_libcalls", "false").unwrap();
558 flag_builder.set("is_pic", "false").unwrap();
559 let isa_builder = cranelift_native::builder()
560 .map_err(|e| JitError::CompilationFailed(format!("Failed to get native ISA: {}", e)))?;
561 let isa = isa_builder
562 .finish(settings::Flags::new(flag_builder))
563 .map_err(|e| JitError::CompilationFailed(format!("Failed to build ISA: {}", e)))?;
564
565 let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
566 let mut module = JITModule::new(builder);
567
568 let mut sig = module.make_signature();
570 sig.params.push(AbiParam::new(types::I64)); sig.params.push(AbiParam::new(types::I64)); let func_id = module
574 .declare_function("jit_kernel", Linkage::Export, &sig)
575 .map_err(|e| {
576 JitError::CompilationFailed(format!("Failed to declare function: {}", e))
577 })?;
578
579 let mut ctx = module.make_context();
580 ctx.func.signature = sig;
581
582 let mut builder_ctx = FunctionBuilderContext::new();
584 {
585 let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_ctx);
586 let entry_block = builder.create_block();
587 builder.append_block_params_for_function_params(entry_block);
588 builder.switch_to_block(entry_block);
589 builder.seal_block(entry_block);
590
591 let input_ptr = builder.block_params(entry_block)[0];
592 let output_ptr = builder.block_params(entry_block)[1];
593
594 let mut values: Vec<Option<Value>> = vec![None; graph.len()];
596
597 for node in graph.nodes() {
598 let result = self.codegen_node(&mut builder, node, &values, input_ptr)?;
599 values[node.id.index()] = Some(result);
600 }
601
602 if let Some((_, output_id)) = graph.outputs().iter().next() {
604 let output_node = graph.node(*output_id);
605 if let Op::Output { input, .. } = &output_node.op {
606 if let Some(val) = values[input.index()] {
607 builder.ins().store(MemFlags::new(), val, output_ptr, 0);
608 }
609 }
610 }
611
612 builder.ins().return_(&[]);
613 builder.finalize();
614 }
615
616 module.define_function(func_id, &mut ctx).map_err(|e| {
618 JitError::CompilationFailed(format!("Failed to define function: {}", e))
619 })?;
620 module.clear_context(&mut ctx);
621 module
622 .finalize_definitions()
623 .map_err(|e| JitError::CompilationFailed(format!("Failed to finalize: {:?}", e)))?;
624
625 let code_ptr = module.get_finalized_function(func_id);
626 let code_size = 0; std::mem::forget(module);
630
631 Ok(CompiledFunction {
632 graph: Arc::new(graph.clone()),
633 kind: CompiledKind::Native {
634 code_ptr: code_ptr as *const u8,
635 code_size,
636 },
637 })
638 }
639
640 fn codegen_node(
641 &self,
642 builder: &mut cranelift::prelude::FunctionBuilder,
643 node: &Node,
644 values: &[Option<cranelift::prelude::Value>],
645 input_ptr: cranelift::prelude::Value,
646 ) -> JitResult<cranelift::prelude::Value> {
647 use cranelift::prelude::*;
648
649 let get = |id: NodeId| -> JitResult<Value> {
650 values[id.index()]
651 .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not compiled", id)))
652 };
653
654 match &node.op {
655 Op::Input { name, .. } => {
656 let offset = self.get_input_offset(name);
658 Ok(builder
659 .ins()
660 .load(types::F32, MemFlags::new(), input_ptr, offset))
661 }
662
663 Op::Output { input, .. } => get(*input),
664
665 Op::Constant { value } => Ok(builder.ins().f32const(*value as f32)),
666
667 Op::Add { lhs, rhs } => {
668 let a = get(*lhs)?;
669 let b = get(*rhs)?;
670 Ok(builder.ins().fadd(a, b))
671 }
672
673 Op::Sub { lhs, rhs } => {
674 let a = get(*lhs)?;
675 let b = get(*rhs)?;
676 Ok(builder.ins().fsub(a, b))
677 }
678
679 Op::Mul { lhs, rhs } => {
680 let a = get(*lhs)?;
681 let b = get(*rhs)?;
682 Ok(builder.ins().fmul(a, b))
683 }
684
685 Op::Div { lhs, rhs } => {
686 let a = get(*lhs)?;
687 let b = get(*rhs)?;
688 Ok(builder.ins().fdiv(a, b))
689 }
690
691 Op::Neg { input } => {
692 let a = get(*input)?;
693 Ok(builder.ins().fneg(a))
694 }
695
696 Op::Abs { input } => {
697 let a = get(*input)?;
698 Ok(builder.ins().fabs(a))
699 }
700
701 Op::Sqrt { input } => {
702 let a = get(*input)?;
703 Ok(builder.ins().sqrt(a))
704 }
705
706 Op::AddScalar { input, scalar } => {
707 let a = get(*input)?;
708 let s = builder.ins().f32const(*scalar as f32);
709 Ok(builder.ins().fadd(a, s))
710 }
711
712 Op::MulScalar { input, scalar } => {
713 let a = get(*input)?;
714 let s = builder.ins().f32const(*scalar as f32);
715 Ok(builder.ins().fmul(a, s))
716 }
717
718 _ => Err(JitError::UnsupportedOp(format!(
721 "Operation {:?} not supported in native codegen, using interpreter",
722 node.op
723 ))),
724 }
725 }
726
727 fn get_input_offset(&self, _name: &str) -> i32 {
728 0
730 }
731
732 pub fn cache_stats(&self) -> crate::cache::CacheStats {
734 self.cache.stats()
735 }
736
737 pub fn clear_cache(&self) {
739 self.cache.clear();
740 }
741}
742
743impl Default for JitCompiler {
744 fn default() -> Self {
745 Self::new()
746 }
747}
748
749#[cfg(test)]
750mod tests {
751 use super::*;
752 use crate::trace::trace;
753
754 #[test]
755 fn test_compile_simple() {
756 let graph = trace(|tracer| {
757 let a = tracer.input("a", &[4]);
758 let b = tracer.input("b", &[4]);
759 let c = a.add(&b);
760 tracer.output("result", c)
761 });
762
763 let compiler = JitCompiler::new();
764 let func = compiler.compile(&graph).unwrap();
765
766 let a = [1.0, 2.0, 3.0, 4.0];
767 let b = [5.0, 6.0, 7.0, 8.0];
768 let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
769
770 assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
771 }
772
773 #[test]
774 fn test_compile_chain() {
775 let graph = trace(|tracer| {
776 let x = tracer.input("x", &[4]);
777 let y = x.relu().mul_scalar(2.0).add_scalar(1.0);
778 tracer.output("y", y)
779 });
780
781 let compiler = JitCompiler::new();
782 let func = compiler.compile(&graph).unwrap();
783
784 let x = [-1.0, 0.0, 1.0, 2.0];
785 let result = func.run(&[("x", &x)]).unwrap();
786
787 assert_eq!(result, vec![1.0, 1.0, 3.0, 5.0]);
791 }
792
793 #[test]
794 fn test_compile_activations() {
795 let graph = trace(|tracer| {
796 let x = tracer.input("x", &[3]);
797 let y = x.sigmoid();
798 tracer.output("y", y)
799 });
800
801 let compiler = JitCompiler::new();
802 let func = compiler.compile(&graph).unwrap();
803
804 let x = [0.0, 1.0, -1.0];
805 let result = func.run(&[("x", &x)]).unwrap();
806
807 assert!((result[0] - 0.5).abs() < 0.01);
809 assert!((result[1] - 0.731).abs() < 0.01);
811 }
812
813 #[test]
814 fn test_compile_matmul() {
815 let graph = trace(|tracer| {
816 let a = tracer.input("a", &[2, 3]);
817 let b = tracer.input("b", &[3, 2]);
818 let c = a.matmul(&b);
819 tracer.output("c", c)
820 });
821
822 let compiler = JitCompiler::new();
823 let func = compiler.compile(&graph).unwrap();
824
825 let a = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let b = [1.0, 0.0, 0.0, 1.0, 0.0, 0.0]; let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
829
830 assert_eq!(result.len(), 4); }
832
833 #[test]
834 fn test_caching() {
835 let graph = trace(|tracer| {
836 let x = tracer.input("x", &[4]);
837 tracer.output("y", x.relu())
838 });
839
840 let compiler = JitCompiler::new();
841 assert_eq!(compiler.cache_stats().entries, 0);
842
843 let _ = compiler.compile(&graph).unwrap();
844 assert_eq!(compiler.cache_stats().entries, 1);
845
846 let _ = compiler.compile(&graph).unwrap();
848 assert_eq!(compiler.cache_stats().entries, 1);
849 }
850}