hugr_llvm/extension/
conversions.rs

1use anyhow::{Result, anyhow, bail, ensure};
2
3use hugr_core::{
4    HugrView, Node,
5    extension::{
6        prelude::{ConstError, bool_t, sum_with_error},
7        simple_op::MakeExtensionOp,
8    },
9    ops::{DataflowOpTrait as _, constant::Value, custom::ExtensionOp},
10    std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES},
11    types::{TypeEnum, TypeRow},
12};
13
14use inkwell::{FloatPredicate, IntPredicate, types::IntType, values::BasicValue};
15
16use crate::{
17    custom::{CodegenExtension, CodegenExtsBuilder},
18    emit::{
19        EmitOpArgs,
20        func::EmitFuncContext,
21        ops::{emit_custom_unary_op, emit_value},
22    },
23    extension::int::get_width_arg,
24    sum::LLVMSumValue,
25    types::HugrType,
26};
27
28/// Returns the largest and smallest values that can be represented by an
29/// integer of the given `width`.
30///
31/// The elements of the tuple are:
32///  - The most negative signed integer
33///  - The most positive signed integer
34///  - The largest unsigned integer
35#[must_use]
36pub fn int_type_bounds(width: u32) -> (i64, i64, u64) {
37    assert!(width <= 64);
38    (
39        i64::MIN >> (64 - width),
40        i64::MAX >> (64 - width),
41        u64::MAX >> (64 - width),
42    )
43}
44
45fn build_trunc_op<'c, H: HugrView<Node = Node>>(
46    context: &mut EmitFuncContext<'c, '_, H>,
47    signed: bool,
48    log_width: u64,
49    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
50) -> Result<()> {
51    let hugr_int_ty = INT_TYPES[log_width as usize].clone();
52    let hugr_sum_ty = sum_with_error(vec![hugr_int_ty.clone()]);
53    // TODO: it would be nice to get this info out of `ops.node()`, this would
54    // require adding appropriate methods to `ConvertOpDef`. In the meantime, we
55    // assert that the output types are as we expect.
56    debug_assert_eq!(
57        TypeRow::from(vec![HugrType::from(hugr_sum_ty.clone())]),
58        args.node().signature().output
59    );
60
61    let Some(int_ty) = IntType::try_from(context.llvm_type(&hugr_int_ty)?).ok() else {
62        bail!("Expected `arithmetic.int` to lower to an llvm integer")
63    };
64
65    let sum_ty = context.llvm_sum_type(hugr_sum_ty)?;
66
67    let (width, (int_min_value_s, int_max_value_s, int_max_value_u)) = {
68        ensure!(
69            log_width <= 6,
70            "Expected log_width of output to be <= 6, found: {log_width}"
71        );
72        let width = 1 << log_width;
73        (width, int_type_bounds(width))
74    };
75
76    emit_custom_unary_op(context, args, |ctx, arg, _| {
77        // We have to check if the conversion will work, so we
78        // make the maximum int and convert to a float, then compare
79        // with the function input.
80        let flt_max = ctx.iw_context().f64_type().const_float(if signed {
81            int_max_value_s as f64
82        } else {
83            int_max_value_u as f64
84        });
85
86        let within_upper_bound = ctx.builder().build_float_compare(
87            FloatPredicate::OLT,
88            arg.into_float_value(),
89            flt_max,
90            "within_upper_bound",
91        )?;
92
93        let flt_min = ctx.iw_context().f64_type().const_float(if signed {
94            int_min_value_s as f64
95        } else {
96            0.0
97        });
98
99        let within_lower_bound = ctx.builder().build_float_compare(
100            FloatPredicate::OLE,
101            flt_min,
102            arg.into_float_value(),
103            "within_lower_bound",
104        )?;
105
106        // N.B. If the float value is NaN, we will never succeed.
107        let success = ctx
108            .builder()
109            .build_and(within_upper_bound, within_lower_bound, "success")
110            .unwrap();
111
112        // Perform the conversion unconditionally, which will result
113        // in a poison value if the input was too large. We will
114        // decide whether we return it based on the result of our
115        // earlier check.
116        let trunc_result = if signed {
117            ctx.builder()
118                .build_float_to_signed_int(arg.into_float_value(), int_ty, "trunc_result")
119        } else {
120            ctx.builder().build_float_to_unsigned_int(
121                arg.into_float_value(),
122                int_ty,
123                "trunc_result",
124            )
125        }?
126        .as_basic_value_enum();
127
128        let err_msg = Value::extension(ConstError::new(
129            2,
130            format!("Float value too big to convert to int of given width ({width})"),
131        ));
132
133        let err_val = emit_value(ctx, &err_msg)?;
134        let failure = sum_ty.build_tag(ctx.builder(), 0, vec![err_val]).unwrap();
135        let trunc_result = sum_ty
136            .build_tag(ctx.builder(), 1, vec![trunc_result])
137            .unwrap();
138
139        let final_result = ctx
140            .builder()
141            .build_select(success, trunc_result, failure, "")
142            .unwrap();
143        Ok(vec![final_result])
144    })
145}
146
147fn emit_conversion_op<'c, H: HugrView<Node = Node>>(
148    context: &mut EmitFuncContext<'c, '_, H>,
149    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
150    conversion_op: ConvertOpDef,
151) -> Result<()> {
152    match conversion_op {
153        ConvertOpDef::trunc_u | ConvertOpDef::trunc_s => {
154            let signed = conversion_op == ConvertOpDef::trunc_s;
155            let log_width = get_width_arg(&args, &conversion_op)?;
156            build_trunc_op(context, signed, log_width, args)
157        }
158
159        ConvertOpDef::convert_u => emit_custom_unary_op(context, args, |ctx, arg, out_tys| {
160            let out_ty = out_tys.last().unwrap();
161            Ok(vec![
162                ctx.builder()
163                    .build_unsigned_int_to_float(
164                        arg.into_int_value(),
165                        out_ty.into_float_type(),
166                        "",
167                    )?
168                    .as_basic_value_enum(),
169            ])
170        }),
171
172        ConvertOpDef::convert_s => emit_custom_unary_op(context, args, |ctx, arg, out_tys| {
173            let out_ty = out_tys.last().unwrap();
174            Ok(vec![
175                ctx.builder()
176                    .build_signed_int_to_float(arg.into_int_value(), out_ty.into_float_type(), "")?
177                    .as_basic_value_enum(),
178            ])
179        }),
180        // These ops convert between hugr's `USIZE` and u64. The former is
181        // implementation-dependent and we define them to be the same.
182        // Hence our implementation is a noop.
183        ConvertOpDef::itousize | ConvertOpDef::ifromusize => {
184            emit_custom_unary_op(context, args, |_, arg, _| Ok(vec![arg]))
185        }
186        ConvertOpDef::itobool | ConvertOpDef::ifrombool => {
187            assert!(conversion_op.type_args().is_empty()); // Always 1-bit int <-> bool
188            let i0_ty = context
189                .typing_session()
190                .llvm_type(&INT_TYPES[0])?
191                .into_int_type();
192            let sum_ty = context
193                .typing_session()
194                .llvm_sum_type(match bool_t().as_type_enum() {
195                    TypeEnum::Sum(st) => st.clone(),
196                    _ => panic!("Hugr prelude bool_t() not a Sum"),
197                })?;
198
199            emit_custom_unary_op(context, args, |ctx, arg, _| {
200                let res = if conversion_op == ConvertOpDef::itobool {
201                    let is1 = ctx.builder().build_int_compare(
202                        IntPredicate::EQ,
203                        arg.into_int_value(),
204                        i0_ty.const_int(1, false),
205                        "eq1",
206                    )?;
207                    let sum_f = sum_ty.build_tag(ctx.builder(), 0, vec![])?;
208                    let sum_t = sum_ty.build_tag(ctx.builder(), 1, vec![])?;
209                    ctx.builder().build_select(is1, sum_t, sum_f, "")?
210                } else {
211                    let tag_ty = sum_ty.tag_type();
212                    let tag = LLVMSumValue::try_new(arg, sum_ty)?.build_get_tag(ctx.builder())?;
213                    let is_true = ctx.builder().build_int_compare(
214                        IntPredicate::EQ,
215                        tag,
216                        tag_ty.const_int(1, false),
217                        "",
218                    )?;
219                    ctx.builder().build_select(
220                        is_true,
221                        i0_ty.const_int(1, false),
222                        i0_ty.const_int(0, false),
223                        "",
224                    )?
225                };
226                Ok(vec![res])
227            })
228        }
229        ConvertOpDef::bytecast_int64_to_float64 => {
230            emit_custom_unary_op(context, args, |ctx, arg, outs| {
231                let [out] = outs.try_into()?;
232                Ok(vec![ctx.builder().build_bit_cast(arg, out, "")?])
233            })
234        }
235        ConvertOpDef::bytecast_float64_to_int64 => {
236            emit_custom_unary_op(context, args, |ctx, arg, outs| {
237                let [out] = outs.try_into()?;
238                Ok(vec![ctx.builder().build_bit_cast(arg, out, "")?])
239            })
240        }
241        _ => Err(anyhow!(
242            "Conversion op not implemented: {:?}",
243            args.node().as_ref()
244        )),
245    }
246}
247
248#[derive(Clone, Debug)]
249pub struct ConversionExtension;
250
251impl CodegenExtension for ConversionExtension {
252    fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
253        self,
254        builder: CodegenExtsBuilder<'a, H>,
255    ) -> CodegenExtsBuilder<'a, H>
256    where
257        Self: 'a,
258    {
259        builder.simple_extension_op(emit_conversion_op)
260    }
261}
262
263impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
264    #[must_use]
265    pub fn add_conversion_extensions(self) -> Self {
266        self.add_extension(ConversionExtension)
267    }
268}
269
270#[cfg(test)]
271mod test {
272
273    use super::*;
274
275    use crate::check_emission;
276    use crate::emit::test::{DFGW, SimpleHugrConfig};
277    use crate::test::{TestContext, exec_ctx, llvm_ctx};
278    use hugr_core::builder::{DataflowHugr, SubContainer};
279    use hugr_core::std_extensions::STD_REG;
280    use hugr_core::std_extensions::arithmetic::float_types::ConstF64;
281    use hugr_core::std_extensions::arithmetic::int_types::ConstInt;
282    use hugr_core::{
283        Hugr,
284        builder::{Dataflow, DataflowSubContainer},
285        extension::prelude::{ConstUsize, PRELUDE_REGISTRY, usize_t},
286        std_extensions::arithmetic::{
287            conversions::{ConvertOpDef, EXTENSION},
288            float_types::float64_type,
289            int_types::INT_TYPES,
290        },
291        types::Type,
292    };
293    use rstest::rstest;
294
295    fn test_conversion_op(
296        name: impl AsRef<str>,
297        in_type: Type,
298        out_type: Type,
299        int_width: u8,
300    ) -> Hugr {
301        SimpleHugrConfig::new()
302            .with_ins(vec![in_type.clone()])
303            .with_outs(vec![out_type.clone()])
304            .with_extensions(STD_REG.clone())
305            .finish(|mut hugr_builder| {
306                let [in1] = hugr_builder.input_wires_arr();
307                let ext_op = EXTENSION
308                    .instantiate_extension_op(name.as_ref(), [u64::from(int_width).into()])
309                    .unwrap();
310                let outputs = hugr_builder
311                    .add_dataflow_op(ext_op, [in1])
312                    .unwrap()
313                    .outputs();
314                hugr_builder.finish_hugr_with_outputs(outputs).unwrap()
315            })
316    }
317
318    #[rstest]
319    #[case("convert_u", 4)]
320    #[case("convert_s", 5)]
321    fn test_convert(mut llvm_ctx: TestContext, #[case] op_name: &str, #[case] log_width: u8) -> () {
322        llvm_ctx.add_extensions(|ceb| {
323            ceb.add_default_int_extensions()
324                .add_float_extensions()
325                .add_conversion_extensions()
326        });
327        let in_ty = INT_TYPES[log_width as usize].clone();
328        let out_ty = float64_type();
329        let hugr = test_conversion_op(op_name, in_ty, out_ty, log_width);
330        check_emission!(op_name, hugr, llvm_ctx);
331    }
332
333    #[rstest]
334    #[case("trunc_u", 6)]
335    #[case("trunc_s", 5)]
336    fn test_truncation(
337        mut llvm_ctx: TestContext,
338        #[case] op_name: &str,
339        #[case] log_width: u8,
340    ) -> () {
341        llvm_ctx.add_extensions(|builder| {
342            builder
343                .add_default_int_extensions()
344                .add_float_extensions()
345                .add_conversion_extensions()
346                .add_default_prelude_extensions()
347        });
348        let in_ty = float64_type();
349        let out_ty = sum_with_error(INT_TYPES[log_width as usize].clone());
350        let hugr = test_conversion_op(op_name, in_ty, out_ty.into(), log_width);
351        check_emission!(op_name, hugr, llvm_ctx);
352    }
353
354    #[rstest]
355    #[case("itobool", true)]
356    #[case("ifrombool", false)]
357    fn test_intbool_emit(
358        mut llvm_ctx: TestContext,
359        #[case] op_name: &str,
360        #[case] input_int: bool,
361    ) {
362        let mut tys = [INT_TYPES[0].clone(), bool_t()];
363        if !input_int {
364            tys.reverse();
365        }
366        let [in_t, out_t] = tys;
367        llvm_ctx.add_extensions(|builder| {
368            builder
369                .add_default_int_extensions()
370                .add_float_extensions()
371                .add_conversion_extensions()
372        });
373        let hugr = SimpleHugrConfig::new()
374            .with_ins(vec![in_t])
375            .with_outs(vec![out_t])
376            .with_extensions(STD_REG.to_owned())
377            .finish(|mut hugr_builder| {
378                let [in1] = hugr_builder.input_wires_arr();
379                let ext_op = EXTENSION.instantiate_extension_op(op_name, []).unwrap();
380                let [out1] = hugr_builder
381                    .add_dataflow_op(ext_op, [in1])
382                    .unwrap()
383                    .outputs_arr();
384                hugr_builder.finish_hugr_with_outputs([out1]).unwrap()
385            });
386        check_emission!(op_name, hugr, llvm_ctx);
387    }
388
389    #[rstest]
390    fn my_test_exec(mut exec_ctx: TestContext) {
391        let hugr = SimpleHugrConfig::new()
392            .with_outs(usize_t())
393            .with_extensions(PRELUDE_REGISTRY.to_owned())
394            .finish(|mut builder: DFGW| {
395                let konst = builder.add_load_value(ConstUsize::new(42));
396                builder.finish_hugr_with_outputs([konst]).unwrap()
397            });
398        exec_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions);
399        assert_eq!(42, exec_ctx.exec_hugr_u64(hugr, "main"));
400    }
401
402    #[rstest]
403    #[case(0)]
404    #[case(42)]
405    #[case(18_446_744_073_709_551_615)]
406    fn usize_roundtrip(mut exec_ctx: TestContext, #[case] val: u64) -> () {
407        let hugr = SimpleHugrConfig::new()
408            .with_outs(usize_t())
409            .with_extensions(STD_REG.clone())
410            .finish(|mut builder: DFGW| {
411                let k = builder.add_load_value(ConstUsize::new(val));
412                let [int] = builder
413                    .add_dataflow_op(ConvertOpDef::ifromusize.without_log_width(), [k])
414                    .unwrap()
415                    .outputs_arr();
416                let [usize_] = builder
417                    .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [int])
418                    .unwrap()
419                    .outputs_arr();
420                builder.finish_hugr_with_outputs([usize_]).unwrap()
421            });
422        exec_ctx.add_extensions(|builder| {
423            builder
424                .add_default_int_extensions()
425                .add_conversion_extensions()
426                .add_default_prelude_extensions()
427        });
428        assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main"));
429    }
430
431    fn roundtrip_hugr(val: u64, signed: bool) -> Hugr {
432        let int64 = INT_TYPES[6].clone();
433        SimpleHugrConfig::new()
434            .with_outs(usize_t())
435            .with_extensions(STD_REG.clone())
436            .finish(|mut builder| {
437                let k = builder.add_load_value(ConstUsize::new(val));
438                let [int] = builder
439                    .add_dataflow_op(ConvertOpDef::ifromusize.without_log_width(), [k])
440                    .unwrap()
441                    .outputs_arr();
442                let [flt] = {
443                    let op = if signed {
444                        ConvertOpDef::convert_s.with_log_width(6)
445                    } else {
446                        ConvertOpDef::convert_u.with_log_width(6)
447                    };
448                    builder.add_dataflow_op(op, [int]).unwrap().outputs_arr()
449                };
450
451                let [int_or_err] = {
452                    let op = if signed {
453                        ConvertOpDef::trunc_s.with_log_width(6)
454                    } else {
455                        ConvertOpDef::trunc_u.with_log_width(6)
456                    };
457                    builder.add_dataflow_op(op, [flt]).unwrap().outputs_arr()
458                };
459                let sum_ty = sum_with_error(int64.clone());
460                let variants = (0..sum_ty.num_variants())
461                    .map(|i| sum_ty.get_variant(i).unwrap().clone().try_into().unwrap());
462                let mut cond_b = builder
463                    .conditional_builder((variants, int_or_err), [], vec![int64].into())
464                    .unwrap();
465                let win_case = cond_b.case_builder(1).unwrap();
466                let [win_in] = win_case.input_wires_arr();
467                win_case.finish_with_outputs([win_in]).unwrap();
468                let mut lose_case = cond_b.case_builder(0).unwrap();
469                let const_999 = lose_case.add_load_value(ConstUsize::new(999));
470                let [const_999] = lose_case
471                    .add_dataflow_op(ConvertOpDef::ifromusize.without_log_width(), [const_999])
472                    .unwrap()
473                    .outputs_arr();
474                lose_case.finish_with_outputs([const_999]).unwrap();
475
476                let cond = cond_b.finish_sub_container().unwrap();
477
478                let [cond_result] = cond.outputs_arr();
479
480                let [usize_] = builder
481                    .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [cond_result])
482                    .unwrap()
483                    .outputs_arr();
484                builder.finish_hugr_with_outputs([usize_]).unwrap()
485            })
486    }
487
488    fn add_extensions(ctx: &mut TestContext) {
489        ctx.add_extensions(|builder| {
490            builder
491                .add_conversion_extensions()
492                .add_default_prelude_extensions()
493                .add_float_extensions()
494                .add_default_int_extensions()
495        });
496    }
497
498    #[rstest]
499    // Exact roundtrip conversion is defined on values up to 2**53 for f64.
500    #[case(0)]
501    #[case(3)]
502    #[case(255)]
503    #[case(4294967295)]
504    #[case(42)]
505    #[case(18_000_000_000_000_000_000)]
506    fn roundtrip_unsigned(mut exec_ctx: TestContext, #[case] val: u64) {
507        add_extensions(&mut exec_ctx);
508        let hugr = roundtrip_hugr(val, false);
509        assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main"));
510    }
511
512    #[rstest]
513    // Exact roundtrip conversion is defined on values up to 2**53 for f64.
514    #[case(0)]
515    #[case(3)]
516    #[case(255)]
517    #[case(4294967295)]
518    #[case(42)]
519    #[case(-9_000_000_000_000_000_000)]
520    fn roundtrip_signed(mut exec_ctx: TestContext, #[case] val: i64) {
521        add_extensions(&mut exec_ctx);
522        let hugr = roundtrip_hugr(val as u64, true);
523        assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main") as i64);
524    }
525
526    // For unisgined ints larger than (1 << 54) - 1, f64s do not have enough
527    // precision to exactly roundtrip the int.
528    // The exact behaviour of the round-trip is is platform-dependent.
529    #[rstest]
530    #[case(u64::MAX)]
531    #[case(u64::MAX - 1)]
532    #[case(u64::MAX - (1 << 1))]
533    #[case(u64::MAX - (1 << 2))]
534    #[case(u64::MAX - (1 << 3))]
535    #[case(u64::MAX - (1 << 4))]
536    #[case(u64::MAX - (1 << 5))]
537    #[case(u64::MAX - (1 << 6))]
538    #[case(u64::MAX - (1 << 7))]
539    #[case(u64::MAX - (1 << 8))]
540    #[case(u64::MAX - (1 << 9))]
541    #[case(u64::MAX - (1 << 10))]
542    #[case(u64::MAX - (1 << 11))]
543    fn approx_roundtrip_unsigned(mut exec_ctx: TestContext, #[case] val: u64) {
544        add_extensions(&mut exec_ctx);
545
546        let hugr = roundtrip_hugr(val, false);
547        let result = exec_ctx.exec_hugr_u64(hugr, "main");
548        let (v_r_max, v_r_min) = (val.max(result), val.min(result));
549        // If val is too large the `trunc_u` op in `hugr` will return None.
550        // In this case the hugr returns the magic number `999`.
551        assert!(result == 999 || (v_r_max - v_r_min) < 1 << 10);
552    }
553
554    #[rstest]
555    #[case(i64::MAX)]
556    #[case(i64::MAX - 1)]
557    #[case(i64::MAX - (1 << 1))]
558    #[case(i64::MAX - (1 << 2))]
559    #[case(i64::MAX - (1 << 3))]
560    #[case(i64::MAX - (1 << 4))]
561    #[case(i64::MAX - (1 << 5))]
562    #[case(i64::MAX - (1 << 6))]
563    #[case(i64::MAX - (1 << 7))]
564    #[case(i64::MAX - (1 << 8))]
565    #[case(i64::MAX - (1 << 9))]
566    #[case(i64::MAX - (1 << 10))]
567    #[case(i64::MAX - (1 << 11))]
568    #[case(i64::MIN)]
569    #[case(i64::MIN + 1)]
570    #[case(i64::MIN + (1 << 1))]
571    #[case(i64::MIN + (1 << 2))]
572    #[case(i64::MIN + (1 << 3))]
573    #[case(i64::MIN + (1 << 4))]
574    #[case(i64::MIN + (1 << 5))]
575    #[case(i64::MIN + (1 << 6))]
576    #[case(i64::MIN + (1 << 7))]
577    #[case(i64::MIN + (1 << 8))]
578    #[case(i64::MIN + (1 << 9))]
579    #[case(i64::MIN + (1 << 10))]
580    #[case(i64::MIN + (1 << 11))]
581    fn approx_roundtrip_signed(mut exec_ctx: TestContext, #[case] val: i64) {
582        add_extensions(&mut exec_ctx);
583
584        let hugr = roundtrip_hugr(val as u64, true);
585        let result = exec_ctx.exec_hugr_u64(hugr, "main") as i64;
586        // If val.abs() is too large the `trunc_s` op in `hugr` will return None.
587        // In this case the hugr returns the magic number `999`.
588        assert!(result == 999 || (val - result).abs() < 1 << 10);
589    }
590
591    #[rstest]
592    fn itobool_cond(mut exec_ctx: TestContext, #[values(0, 1)] i: u64) {
593        use hugr_core::type_row;
594
595        let hugr = SimpleHugrConfig::new()
596            .with_outs(vec![usize_t()])
597            .with_extensions(STD_REG.to_owned())
598            .finish(|mut builder| {
599                let i = builder.add_load_value(ConstInt::new_u(0, i).unwrap());
600                let ext_op = EXTENSION.instantiate_extension_op("itobool", []).unwrap();
601                let [b] = builder.add_dataflow_op(ext_op, [i]).unwrap().outputs_arr();
602                let mut cond = builder
603                    .conditional_builder(
604                        ([type_row![], type_row![]], b),
605                        [],
606                        vec![usize_t()].into(),
607                    )
608                    .unwrap();
609                let mut case_false = cond.case_builder(0).unwrap();
610                let false_result = case_false.add_load_value(ConstUsize::new(1));
611                case_false.finish_with_outputs([false_result]).unwrap();
612                let mut case_true = cond.case_builder(1).unwrap();
613                let true_result = case_true.add_load_value(ConstUsize::new(6));
614                case_true.finish_with_outputs([true_result]).unwrap();
615                let res = cond.finish_sub_container().unwrap();
616                builder.finish_hugr_with_outputs(res.outputs()).unwrap()
617            });
618        exec_ctx.add_extensions(|builder| {
619            builder
620                .add_conversion_extensions()
621                .add_default_prelude_extensions()
622                .add_default_int_extensions()
623        });
624        assert_eq!(i * 5 + 1, exec_ctx.exec_hugr_u64(hugr, "main"));
625    }
626
627    #[rstest]
628    fn itobool_roundtrip(mut exec_ctx: TestContext, #[values(0, 1)] i: u64) {
629        let hugr = SimpleHugrConfig::new()
630            .with_outs(vec![INT_TYPES[0].clone()])
631            .with_extensions(STD_REG.to_owned())
632            .finish(|mut builder| {
633                let i = builder.add_load_value(ConstInt::new_u(0, i).unwrap());
634                let i2b = EXTENSION.instantiate_extension_op("itobool", []).unwrap();
635                let [b] = builder.add_dataflow_op(i2b, [i]).unwrap().outputs_arr();
636                let b2i = EXTENSION.instantiate_extension_op("ifrombool", []).unwrap();
637                let [i] = builder.add_dataflow_op(b2i, [b]).unwrap().outputs_arr();
638                builder.finish_hugr_with_outputs([i]).unwrap()
639            });
640        exec_ctx.add_extensions(|builder| {
641            builder
642                .add_conversion_extensions()
643                .add_default_prelude_extensions()
644                .add_default_int_extensions()
645        });
646        assert_eq!(i, exec_ctx.exec_hugr_u64(hugr, "main"));
647    }
648
649    #[rstest]
650    #[case(42.0)]
651    #[case(f64::INFINITY)]
652    #[case(f64::NEG_INFINITY)]
653    #[case(f64::NAN)]
654    #[case(-0.0f64)]
655    #[case(0.0f64)]
656    fn bytecast_int64_to_float64(mut exec_ctx: TestContext, #[case] f: f64) {
657        let hugr = SimpleHugrConfig::new()
658            .with_outs(float64_type())
659            .with_extensions(STD_REG.to_owned())
660            .finish(|mut builder| {
661                let i = builder.add_load_value(ConstInt::new_u(6, f.to_bits()).unwrap());
662                let i2f = EXTENSION
663                    .instantiate_extension_op("bytecast_int64_to_float64", [])
664                    .unwrap();
665                let [f] = builder.add_dataflow_op(i2f, [i]).unwrap().outputs_arr();
666                builder.finish_hugr_with_outputs([f]).unwrap()
667            });
668        exec_ctx.add_extensions(|builder| {
669            builder
670                .add_conversion_extensions()
671                .add_default_prelude_extensions()
672                .add_default_int_extensions()
673                .add_float_extensions()
674        });
675        let hugr_f = exec_ctx.exec_hugr_f64(hugr, "main");
676        assert!((f.is_nan() && hugr_f.is_nan()) || f == hugr_f);
677    }
678
679    #[rstest]
680    #[case(42.0)]
681    #[case(-0.0f64)]
682    #[case(0.0f64)]
683    fn bytecast_float64_to_int64(mut exec_ctx: TestContext, #[case] f: f64) {
684        let hugr = SimpleHugrConfig::new()
685            .with_outs(INT_TYPES[6].clone())
686            .with_extensions(STD_REG.to_owned())
687            .finish(|mut builder| {
688                let f = builder.add_load_value(ConstF64::new(f));
689                let f2i = EXTENSION
690                    .instantiate_extension_op("bytecast_float64_to_int64", [])
691                    .unwrap();
692                let [i] = builder.add_dataflow_op(f2i, [f]).unwrap().outputs_arr();
693                builder.finish_hugr_with_outputs([i]).unwrap()
694            });
695        exec_ctx.add_extensions(|builder| {
696            builder
697                .add_conversion_extensions()
698                .add_default_prelude_extensions()
699                .add_default_int_extensions()
700                .add_float_extensions()
701        });
702        assert_eq!(f.to_bits(), exec_ctx.exec_hugr_u64(hugr, "main"));
703    }
704}