1use std::sync::Arc;
18
19use crate::cache::FunctionCache;
20use crate::error::{JitError, JitResult};
21use crate::ir::{Graph, Node, NodeId, Op};
22use crate::optimize::Optimizer;
23
24#[derive(Clone)]
26pub struct CompiledFunction {
27 graph: Arc<Graph>,
29 kind: CompiledKind,
31}
32
33#[derive(Clone)]
34enum CompiledKind {
35 Interpreted,
37 #[allow(dead_code)]
39 Native {
40 code_ptr: *const u8,
42 code_size: usize,
44 },
45}
46
47unsafe impl Send for CompiledKind {}
49unsafe impl Sync for CompiledKind {}
50
51impl CompiledFunction {
52 pub fn placeholder() -> Self {
54 Self {
55 graph: Arc::new(Graph::new()),
56 kind: CompiledKind::Interpreted,
57 }
58 }
59
60 pub fn graph(&self) -> &Graph {
62 &self.graph
63 }
64
65 pub fn run(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
67 match &self.kind {
68 CompiledKind::Interpreted => self.run_interpreted(inputs),
69 CompiledKind::Native {
70 code_ptr,
71 code_size,
72 } => {
73 unsafe {
76 let func: extern "C" fn(*const f32, *mut f32) = std::mem::transmute(code_ptr);
77 let flat_inputs: Vec<f32> =
78 inputs.iter().flat_map(|(_, d)| d.iter().copied()).collect();
79 let mut output = vec![0.0f32; self.graph.outputs().len() * 1024]; func(flat_inputs.as_ptr(), output.as_mut_ptr());
81 let _ = code_size; Ok(output)
83 }
84 }
85 }
86 }
87
88 fn run_interpreted(&self, inputs: &[(&str, &[f32])]) -> JitResult<Vec<f32>> {
90 let mut values: Vec<Option<Vec<f32>>> = vec![None; self.graph.len()];
91
92 for (name, data) in inputs {
94 if let Some(id) = self.graph.input(name) {
95 values[id.index()] = Some(data.to_vec());
96 } else {
97 return Err(JitError::InputNotFound(name.to_string()));
98 }
99 }
100
101 for node in self.graph.nodes() {
103 let result = self.eval_node(node, &values)?;
104 values[node.id.index()] = Some(result);
105 }
106
107 if let Some((_, output_id)) = self.graph.outputs().iter().next() {
109 let output_node = self.graph.node(*output_id);
110 if let Op::Output { input, .. } = &output_node.op {
111 return Ok(values[input.index()].clone().unwrap_or_default());
112 }
113 }
114
115 Err(JitError::OutputNotFound("no output".to_string()))
116 }
117
118 fn eval_node(&self, node: &Node, values: &[Option<Vec<f32>>]) -> JitResult<Vec<f32>> {
119 let get = |id: NodeId| -> JitResult<&Vec<f32>> {
120 values[id.index()]
121 .as_ref()
122 .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not computed", id)))
123 };
124
125 match &node.op {
126 Op::Input { .. } => {
127 Ok(values[node.id.index()].clone().unwrap_or_default())
129 }
130
131 Op::Output { input, .. } => Ok(get(*input)?.clone()),
132
133 Op::Constant { value } => {
134 let numel = node.shape.numel();
135 Ok(vec![*value as f32; numel])
136 }
137
138 Op::Add { lhs, rhs } => {
140 let a = get(*lhs)?;
141 let b = get(*rhs)?;
142 Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
143 }
144
145 Op::Sub { lhs, rhs } => {
146 let a = get(*lhs)?;
147 let b = get(*rhs)?;
148 Ok(a.iter().zip(b.iter()).map(|(x, y)| x - y).collect())
149 }
150
151 Op::Mul { lhs, rhs } => {
152 let a = get(*lhs)?;
153 let b = get(*rhs)?;
154 Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).collect())
155 }
156
157 Op::Div { lhs, rhs } => {
158 let a = get(*lhs)?;
159 let b = get(*rhs)?;
160 Ok(a.iter().zip(b.iter()).map(|(x, y)| x / y).collect())
161 }
162
163 Op::Pow { base, exp } => {
164 let a = get(*base)?;
165 let b = get(*exp)?;
166 Ok(a.iter().zip(b.iter()).map(|(x, y)| x.powf(*y)).collect())
167 }
168
169 Op::Max { lhs, rhs } => {
170 let a = get(*lhs)?;
171 let b = get(*rhs)?;
172 Ok(a.iter().zip(b.iter()).map(|(x, y)| x.max(*y)).collect())
173 }
174
175 Op::Min { lhs, rhs } => {
176 let a = get(*lhs)?;
177 let b = get(*rhs)?;
178 Ok(a.iter().zip(b.iter()).map(|(x, y)| x.min(*y)).collect())
179 }
180
181 Op::AddScalar { input, scalar } => {
183 let a = get(*input)?;
184 Ok(a.iter().map(|x| x + *scalar as f32).collect())
185 }
186
187 Op::MulScalar { input, scalar } => {
188 let a = get(*input)?;
189 Ok(a.iter().map(|x| x * *scalar as f32).collect())
190 }
191
192 Op::Neg { input } => {
194 let a = get(*input)?;
195 Ok(a.iter().map(|x| -x).collect())
196 }
197
198 Op::Abs { input } => {
199 let a = get(*input)?;
200 Ok(a.iter().map(|x| x.abs()).collect())
201 }
202
203 Op::Sqrt { input } => {
204 let a = get(*input)?;
205 Ok(a.iter().map(|x| x.sqrt()).collect())
206 }
207
208 Op::Exp { input } => {
209 let a = get(*input)?;
210 Ok(a.iter().map(|x| x.exp()).collect())
211 }
212
213 Op::Log { input } => {
214 let a = get(*input)?;
215 Ok(a.iter().map(|x| x.ln()).collect())
216 }
217
218 Op::Sin { input } => {
219 let a = get(*input)?;
220 Ok(a.iter().map(|x| x.sin()).collect())
221 }
222
223 Op::Cos { input } => {
224 let a = get(*input)?;
225 Ok(a.iter().map(|x| x.cos()).collect())
226 }
227
228 Op::Tanh { input } => {
229 let a = get(*input)?;
230 Ok(a.iter().map(|x| x.tanh()).collect())
231 }
232
233 Op::Relu { input } => {
235 let a = get(*input)?;
236 Ok(a.iter().map(|x| x.max(0.0)).collect())
237 }
238
239 Op::Sigmoid { input } => {
240 let a = get(*input)?;
241 Ok(a.iter().map(|x| 1.0 / (1.0 + (-x).exp())).collect())
242 }
243
244 Op::Gelu { input } => {
245 let a = get(*input)?;
246 const SQRT_2_OVER_PI: f32 = 0.797_884_6;
247 Ok(a.iter()
248 .map(|x| 0.5 * x * (1.0 + (SQRT_2_OVER_PI * (x + 0.044715 * x.powi(3))).tanh()))
249 .collect())
250 }
251
252 Op::Silu { input } => {
253 let a = get(*input)?;
254 Ok(a.iter().map(|x| x / (1.0 + (-x).exp())).collect())
255 }
256
257 Op::Sum { input } => {
259 let a = get(*input)?;
260 Ok(vec![a.iter().sum()])
261 }
262
263 Op::Mean { input } => {
264 let a = get(*input)?;
265 let sum: f32 = a.iter().sum();
266 Ok(vec![sum / a.len() as f32])
267 }
268
269 Op::SumAxis {
270 input,
271 axis,
272 keepdim,
273 } => {
274 let a = get(*input)?;
276 let input_node = self.graph.node(*input);
277 let input_shape = input_node.shape.dims();
278
279 reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)
280 }
281
282 Op::MeanAxis {
283 input,
284 axis,
285 keepdim,
286 } => {
287 let a = get(*input)?;
288 let input_node = self.graph.node(*input);
289 let input_shape = input_node.shape.dims();
290 let axis_size = input_shape[normalize_axis(*axis, input_shape.len())];
291
292 let sum = reduce_axis(a, input_shape, *axis, *keepdim, |x, y| x + y, 0.0)?;
293 Ok(sum.iter().map(|x| x / axis_size as f32).collect())
294 }
295
296 Op::MaxAxis {
297 input,
298 axis,
299 keepdim,
300 } => {
301 let a = get(*input)?;
302 let input_node = self.graph.node(*input);
303 let input_shape = input_node.shape.dims();
304
305 reduce_axis(a, input_shape, *axis, *keepdim, f32::max, f32::NEG_INFINITY)
306 }
307
308 Op::Reshape { input, .. }
310 | Op::Transpose { input, .. }
311 | Op::Squeeze { input, .. }
312 | Op::Unsqueeze { input, .. }
313 | Op::Broadcast { input, .. }
314 | Op::Contiguous { input } => Ok(get(*input)?.clone()),
315
316 Op::MatMul { lhs, rhs } => {
318 let a = get(*lhs)?;
319 let b = get(*rhs)?;
320 let lhs_node = self.graph.node(*lhs);
321 let rhs_node = self.graph.node(*rhs);
322
323 let lhs_shape = lhs_node.shape.dims();
324 let rhs_shape = rhs_node.shape.dims();
325
326 matmul_impl(a, b, lhs_shape, rhs_shape)
327 }
328
329 Op::Gt { lhs, rhs } => {
331 let a = get(*lhs)?;
332 let b = get(*rhs)?;
333 Ok(a.iter()
334 .zip(b.iter())
335 .map(|(x, y)| if x > y { 1.0 } else { 0.0 })
336 .collect())
337 }
338
339 Op::Lt { lhs, rhs } => {
340 let a = get(*lhs)?;
341 let b = get(*rhs)?;
342 Ok(a.iter()
343 .zip(b.iter())
344 .map(|(x, y)| if x < y { 1.0 } else { 0.0 })
345 .collect())
346 }
347
348 Op::Eq { lhs, rhs } => {
349 let a = get(*lhs)?;
350 let b = get(*rhs)?;
351 Ok(a.iter()
352 .zip(b.iter())
353 .map(|(x, y)| {
354 if (x - y).abs() < f32::EPSILON {
355 1.0
356 } else {
357 0.0
358 }
359 })
360 .collect())
361 }
362
363 Op::Where { condition, x, y } => {
364 let cond = get(*condition)?;
365 let a = get(*x)?;
366 let b = get(*y)?;
367 Ok(cond
368 .iter()
369 .zip(a.iter().zip(b.iter()))
370 .map(|(c, (a, b))| if *c != 0.0 { *a } else { *b })
371 .collect())
372 }
373
374 Op::Cast { input, .. } => {
375 Ok(get(*input)?.clone())
377 }
378 }
379 }
380}
381
382fn normalize_axis(axis: i32, ndim: usize) -> usize {
383 if axis < 0 {
384 (ndim as i32 + axis) as usize
385 } else {
386 axis as usize
387 }
388}
389
390fn reduce_axis(
391 data: &[f32],
392 shape: &[usize],
393 axis: i32,
394 keepdim: bool,
395 op: fn(f32, f32) -> f32,
396 init: f32,
397) -> JitResult<Vec<f32>> {
398 let axis = normalize_axis(axis, shape.len());
399
400 let mut strides = vec![1usize; shape.len()];
402 for i in (0..shape.len() - 1).rev() {
403 strides[i] = strides[i + 1] * shape[i + 1];
404 }
405
406 let mut output_shape: Vec<usize> = shape.to_vec();
408 if keepdim {
409 output_shape[axis] = 1;
410 } else {
411 output_shape.remove(axis);
412 }
413
414 let output_numel: usize = output_shape.iter().product();
415 let mut result = vec![init; output_numel];
416
417 for (i, &val) in data.iter().enumerate() {
419 let mut multi_idx = vec![0usize; shape.len()];
421 let mut idx = i;
422 for (d, &st) in strides.iter().enumerate() {
423 multi_idx[d] = idx / st;
424 idx %= st;
425 }
426
427 let out_idx = if keepdim {
429 let mut out_idx = 0;
430 let mut temp_strides = vec![1usize; output_shape.len()];
431 for d in (0..output_shape.len() - 1).rev() {
432 temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
433 }
434 for d in 0..output_shape.len() {
435 let dim_idx = if d == axis { 0 } else { multi_idx[d] };
436 out_idx += dim_idx * temp_strides[d];
437 }
438 out_idx
439 } else {
440 let mut out_idx = 0;
441 let mut temp_strides = vec![1usize; output_shape.len()];
442 if !output_shape.is_empty() {
443 for d in (0..output_shape.len() - 1).rev() {
444 temp_strides[d] = temp_strides[d + 1] * output_shape[d + 1];
445 }
446 }
447 let mut out_d = 0;
448 for (d, &mi) in multi_idx.iter().enumerate().take(shape.len()) {
449 if d == axis {
450 continue;
451 }
452 if out_d < temp_strides.len() {
453 out_idx += mi * temp_strides[out_d];
454 }
455 out_d += 1;
456 }
457 out_idx
458 };
459
460 if out_idx < result.len() {
461 result[out_idx] = op(result[out_idx], val);
462 }
463 }
464
465 Ok(result)
466}
467
468fn matmul_impl(a: &[f32], b: &[f32], a_shape: &[usize], b_shape: &[usize]) -> JitResult<Vec<f32>> {
469 if a_shape.len() != 2 || b_shape.len() != 2 {
471 return Err(JitError::UnsupportedOp(
472 "Only 2D matmul supported in interpreter".to_string(),
473 ));
474 }
475
476 let m = a_shape[0];
477 let k = a_shape[1];
478 let n = b_shape[1];
479
480 if k != b_shape[0] {
481 return Err(JitError::ShapeMismatch {
482 expected: vec![k],
483 found: vec![b_shape[0]],
484 });
485 }
486
487 let mut result = vec![0.0f32; m * n];
488
489 for i in 0..m {
490 for j in 0..n {
491 let mut sum = 0.0;
492 for p in 0..k {
493 sum += a[i * k + p] * b[p * n + j];
494 }
495 result[i * n + j] = sum;
496 }
497 }
498
499 Ok(result)
500}
501
502pub struct JitCompiler {
504 optimizer: Optimizer,
505 cache: FunctionCache,
506 use_native: bool,
507}
508
509impl JitCompiler {
510 pub fn new() -> Self {
512 Self {
513 optimizer: Optimizer::default_passes(),
514 cache: FunctionCache::default_size(),
515 use_native: false, }
517 }
518
519 pub fn with_optimizer(optimizer: Optimizer) -> Self {
521 Self {
522 optimizer,
523 cache: FunctionCache::default_size(),
524 use_native: false,
525 }
526 }
527
528 pub fn compile(&self, graph: &Graph) -> JitResult<CompiledFunction> {
530 let cache_key = FunctionCache::hash_graph(graph);
532 if let Some(cached) = self.cache.get(cache_key) {
533 return Ok(cached);
534 }
535
536 graph.validate().map_err(JitError::InvalidGraph)?;
538
539 let optimized = self.optimizer.optimize(graph.clone());
541
542 let func = if self.use_native {
544 self.compile_native(&optimized)?
545 } else {
546 self.compile_interpreted(&optimized)
547 };
548
549 self.cache.insert(cache_key, func.clone());
551
552 Ok(func)
553 }
554
555 fn compile_interpreted(&self, graph: &Graph) -> CompiledFunction {
556 CompiledFunction {
557 graph: Arc::new(graph.clone()),
558 kind: CompiledKind::Interpreted,
559 }
560 }
561
562 fn compile_native(&self, graph: &Graph) -> JitResult<CompiledFunction> {
563 use cranelift::prelude::*;
564 use cranelift_jit::{JITBuilder, JITModule};
565 use cranelift_module::{Linkage, Module};
566
567 let mut flag_builder = settings::builder();
569 flag_builder.set("use_colocated_libcalls", "false").unwrap();
570 flag_builder.set("is_pic", "false").unwrap();
571 let isa_builder = cranelift_native::builder()
572 .map_err(|e| JitError::CompilationFailed(format!("Failed to get native ISA: {}", e)))?;
573 let isa = isa_builder
574 .finish(settings::Flags::new(flag_builder))
575 .map_err(|e| JitError::CompilationFailed(format!("Failed to build ISA: {}", e)))?;
576
577 let builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
578 let mut module = JITModule::new(builder);
579
580 let mut sig = module.make_signature();
582 sig.params.push(AbiParam::new(types::I64)); sig.params.push(AbiParam::new(types::I64)); let func_id = module
586 .declare_function("jit_kernel", Linkage::Export, &sig)
587 .map_err(|e| {
588 JitError::CompilationFailed(format!("Failed to declare function: {}", e))
589 })?;
590
591 let mut ctx = module.make_context();
592 ctx.func.signature = sig;
593
594 let mut builder_ctx = FunctionBuilderContext::new();
596 {
597 let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_ctx);
598 let entry_block = builder.create_block();
599 builder.append_block_params_for_function_params(entry_block);
600 builder.switch_to_block(entry_block);
601 builder.seal_block(entry_block);
602
603 let input_ptr = builder.block_params(entry_block)[0];
604 let output_ptr = builder.block_params(entry_block)[1];
605
606 let mut values: Vec<Option<Value>> = vec![None; graph.len()];
608
609 for node in graph.nodes() {
610 let result = self.codegen_node(&mut builder, node, &values, input_ptr)?;
611 values[node.id.index()] = Some(result);
612 }
613
614 if let Some((_, output_id)) = graph.outputs().iter().next() {
616 let output_node = graph.node(*output_id);
617 if let Op::Output { input, .. } = &output_node.op {
618 if let Some(val) = values[input.index()] {
619 builder.ins().store(MemFlags::new(), val, output_ptr, 0);
620 }
621 }
622 }
623
624 builder.ins().return_(&[]);
625 builder.finalize();
626 }
627
628 module.define_function(func_id, &mut ctx).map_err(|e| {
630 JitError::CompilationFailed(format!("Failed to define function: {}", e))
631 })?;
632 module.clear_context(&mut ctx);
633 module
634 .finalize_definitions()
635 .map_err(|e| JitError::CompilationFailed(format!("Failed to finalize: {:?}", e)))?;
636
637 let code_ptr = module.get_finalized_function(func_id);
638 let code_size = 0; std::mem::forget(module);
642
643 Ok(CompiledFunction {
644 graph: Arc::new(graph.clone()),
645 kind: CompiledKind::Native {
646 code_ptr,
647 code_size,
648 },
649 })
650 }
651
652 fn codegen_node(
653 &self,
654 builder: &mut cranelift::prelude::FunctionBuilder,
655 node: &Node,
656 values: &[Option<cranelift::prelude::Value>],
657 input_ptr: cranelift::prelude::Value,
658 ) -> JitResult<cranelift::prelude::Value> {
659 use cranelift::prelude::*;
660
661 let get = |id: NodeId| -> JitResult<Value> {
662 values[id.index()]
663 .ok_or_else(|| JitError::RuntimeError(format!("Node {:?} not compiled", id)))
664 };
665
666 match &node.op {
667 Op::Input { name, .. } => {
668 let offset = self.get_input_offset(name);
670 Ok(builder
671 .ins()
672 .load(types::F32, MemFlags::new(), input_ptr, offset))
673 }
674
675 Op::Output { input, .. } => get(*input),
676
677 Op::Constant { value } => Ok(builder.ins().f32const(*value as f32)),
678
679 Op::Add { lhs, rhs } => {
680 let a = get(*lhs)?;
681 let b = get(*rhs)?;
682 Ok(builder.ins().fadd(a, b))
683 }
684
685 Op::Sub { lhs, rhs } => {
686 let a = get(*lhs)?;
687 let b = get(*rhs)?;
688 Ok(builder.ins().fsub(a, b))
689 }
690
691 Op::Mul { lhs, rhs } => {
692 let a = get(*lhs)?;
693 let b = get(*rhs)?;
694 Ok(builder.ins().fmul(a, b))
695 }
696
697 Op::Div { lhs, rhs } => {
698 let a = get(*lhs)?;
699 let b = get(*rhs)?;
700 Ok(builder.ins().fdiv(a, b))
701 }
702
703 Op::Neg { input } => {
704 let a = get(*input)?;
705 Ok(builder.ins().fneg(a))
706 }
707
708 Op::Abs { input } => {
709 let a = get(*input)?;
710 Ok(builder.ins().fabs(a))
711 }
712
713 Op::Sqrt { input } => {
714 let a = get(*input)?;
715 Ok(builder.ins().sqrt(a))
716 }
717
718 Op::AddScalar { input, scalar } => {
719 let a = get(*input)?;
720 let s = builder.ins().f32const(*scalar as f32);
721 Ok(builder.ins().fadd(a, s))
722 }
723
724 Op::MulScalar { input, scalar } => {
725 let a = get(*input)?;
726 let s = builder.ins().f32const(*scalar as f32);
727 Ok(builder.ins().fmul(a, s))
728 }
729
730 _ => Err(JitError::UnsupportedOp(format!(
733 "Operation {:?} not supported in native codegen, using interpreter",
734 node.op
735 ))),
736 }
737 }
738
739 fn get_input_offset(&self, _name: &str) -> i32 {
740 0
742 }
743
744 pub fn cache_stats(&self) -> crate::cache::CacheStats {
746 self.cache.stats()
747 }
748
749 pub fn clear_cache(&self) {
751 self.cache.clear();
752 }
753}
754
755impl Default for JitCompiler {
756 fn default() -> Self {
757 Self::new()
758 }
759}
760
761#[cfg(test)]
762mod tests {
763 use super::*;
764 use crate::trace::trace;
765
766 #[test]
767 fn test_compile_simple() {
768 let graph = trace(|tracer| {
769 let a = tracer.input("a", &[4]);
770 let b = tracer.input("b", &[4]);
771 let c = a.add(&b);
772 tracer.output("result", c)
773 });
774
775 let compiler = JitCompiler::new();
776 let func = compiler.compile(&graph).unwrap();
777
778 let a = [1.0, 2.0, 3.0, 4.0];
779 let b = [5.0, 6.0, 7.0, 8.0];
780 let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
781
782 assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
783 }
784
785 #[test]
786 fn test_compile_chain() {
787 let graph = trace(|tracer| {
788 let x = tracer.input("x", &[4]);
789 let y = x.relu().mul_scalar(2.0).add_scalar(1.0);
790 tracer.output("y", y)
791 });
792
793 let compiler = JitCompiler::new();
794 let func = compiler.compile(&graph).unwrap();
795
796 let x = [-1.0, 0.0, 1.0, 2.0];
797 let result = func.run(&[("x", &x)]).unwrap();
798
799 assert_eq!(result, vec![1.0, 1.0, 3.0, 5.0]);
803 }
804
805 #[test]
806 fn test_compile_activations() {
807 let graph = trace(|tracer| {
808 let x = tracer.input("x", &[3]);
809 let y = x.sigmoid();
810 tracer.output("y", y)
811 });
812
813 let compiler = JitCompiler::new();
814 let func = compiler.compile(&graph).unwrap();
815
816 let x = [0.0, 1.0, -1.0];
817 let result = func.run(&[("x", &x)]).unwrap();
818
819 assert!((result[0] - 0.5).abs() < 0.01);
821 assert!((result[1] - 0.731).abs() < 0.01);
823 }
824
825 #[test]
826 fn test_compile_matmul() {
827 let graph = trace(|tracer| {
828 let a = tracer.input("a", &[2, 3]);
829 let b = tracer.input("b", &[3, 2]);
830 let c = a.matmul(&b);
831 tracer.output("c", c)
832 });
833
834 let compiler = JitCompiler::new();
835 let func = compiler.compile(&graph).unwrap();
836
837 let a = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]; let b = [1.0, 0.0, 0.0, 1.0, 0.0, 0.0]; let result = func.run(&[("a", &a), ("b", &b)]).unwrap();
841
842 assert_eq!(result.len(), 4); }
844
845 #[test]
846 fn test_caching() {
847 let graph = trace(|tracer| {
848 let x = tracer.input("x", &[4]);
849 tracer.output("y", x.relu())
850 });
851
852 let compiler = JitCompiler::new();
853 assert_eq!(compiler.cache_stats().entries, 0);
854
855 let _ = compiler.compile(&graph).unwrap();
856 assert_eq!(compiler.cache_stats().entries, 1);
857
858 let _ = compiler.compile(&graph).unwrap();
860 assert_eq!(compiler.cache_stats().entries, 1);
861 }
862}