hugr_llvm/extension/
conversions.rs

1use anyhow::{anyhow, bail, ensure, Result};
2
3use hugr_core::{
4    extension::{
5        prelude::{bool_t, sum_with_error, ConstError},
6        simple_op::MakeExtensionOp,
7    },
8    ops::{constant::Value, custom::ExtensionOp, DataflowOpTrait as _},
9    std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES},
10    types::{TypeArg, TypeEnum, TypeRow},
11    HugrView, Node,
12};
13
14use inkwell::{types::IntType, values::BasicValue, FloatPredicate, IntPredicate};
15
16use crate::{
17    custom::{CodegenExtension, CodegenExtsBuilder},
18    emit::{
19        func::EmitFuncContext,
20        ops::{emit_custom_unary_op, emit_value},
21        EmitOpArgs,
22    },
23    sum::LLVMSumValue,
24    types::HugrType,
25};
26
27/// Returns the largest and smallest values that can be represented by an
28/// integer of the given `width`.
29///
30/// The elements of the tuple are:
31///  - The most negative signed integer
32///  - The most positive signed integer
33///  - The largest unsigned integer
34pub fn int_type_bounds(width: u32) -> (i64, i64, u64) {
35    assert!(width <= 64);
36    (
37        i64::MIN >> (64 - width),
38        i64::MAX >> (64 - width),
39        u64::MAX >> (64 - width),
40    )
41}
42
43fn build_trunc_op<'c, H: HugrView<Node = Node>>(
44    context: &mut EmitFuncContext<'c, '_, H>,
45    signed: bool,
46    log_width: u64,
47    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
48) -> Result<()> {
49    let hugr_int_ty = INT_TYPES[log_width as usize].clone();
50    let hugr_sum_ty = sum_with_error(vec![hugr_int_ty.clone()]);
51    // TODO: it would be nice to get this info out of `ops.node()`, this would
52    // require adding appropriate methods to `ConvertOpDef`. In the meantime, we
53    // assert that the output types are as we expect.
54    debug_assert_eq!(
55        TypeRow::from(vec![HugrType::from(hugr_sum_ty.clone())]),
56        args.node().signature().output
57    );
58
59    let Some(int_ty) = IntType::try_from(context.llvm_type(&hugr_int_ty)?).ok() else {
60        bail!("Expected `arithmetic.int` to lower to an llvm integer")
61    };
62
63    let sum_ty = context.llvm_sum_type(hugr_sum_ty)?;
64
65    let (width, (int_min_value_s, int_max_value_s, int_max_value_u)) = {
66        ensure!(
67            log_width <= 6,
68            "Expected log_width of output to be <= 6, found: {log_width}"
69        );
70        let width = 1 << log_width;
71        (width, int_type_bounds(width))
72    };
73
74    emit_custom_unary_op(context, args, |ctx, arg, _| {
75        // We have to check if the conversion will work, so we
76        // make the maximum int and convert to a float, then compare
77        // with the function input.
78        let flt_max = ctx.iw_context().f64_type().const_float(if signed {
79            int_max_value_s as f64
80        } else {
81            int_max_value_u as f64
82        });
83
84        let within_upper_bound = ctx.builder().build_float_compare(
85            FloatPredicate::OLT,
86            arg.into_float_value(),
87            flt_max,
88            "within_upper_bound",
89        )?;
90
91        let flt_min = ctx.iw_context().f64_type().const_float(if signed {
92            int_min_value_s as f64
93        } else {
94            0.0
95        });
96
97        let within_lower_bound = ctx.builder().build_float_compare(
98            FloatPredicate::OLE,
99            flt_min,
100            arg.into_float_value(),
101            "within_lower_bound",
102        )?;
103
104        // N.B. If the float value is NaN, we will never succeed.
105        let success = ctx
106            .builder()
107            .build_and(within_upper_bound, within_lower_bound, "success")
108            .unwrap();
109
110        // Perform the conversion unconditionally, which will result
111        // in a poison value if the input was too large. We will
112        // decide whether we return it based on the result of our
113        // earlier check.
114        let trunc_result = if signed {
115            ctx.builder()
116                .build_float_to_signed_int(arg.into_float_value(), int_ty, "trunc_result")
117        } else {
118            ctx.builder().build_float_to_unsigned_int(
119                arg.into_float_value(),
120                int_ty,
121                "trunc_result",
122            )
123        }?
124        .as_basic_value_enum();
125
126        let err_msg = Value::extension(ConstError::new(
127            2,
128            format!(
129                "Float value too big to convert to int of given width ({})",
130                width
131            ),
132        ));
133
134        let err_val = emit_value(ctx, &err_msg)?;
135        let failure = sum_ty.build_tag(ctx.builder(), 0, vec![err_val]).unwrap();
136        let trunc_result = sum_ty
137            .build_tag(ctx.builder(), 1, vec![trunc_result])
138            .unwrap();
139
140        let final_result = ctx
141            .builder()
142            .build_select(success, trunc_result, failure, "")
143            .unwrap();
144        Ok(vec![final_result])
145    })
146}
147
148fn emit_conversion_op<'c, H: HugrView<Node = Node>>(
149    context: &mut EmitFuncContext<'c, '_, H>,
150    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
151    conversion_op: ConvertOpDef,
152) -> Result<()> {
153    match conversion_op {
154        ConvertOpDef::trunc_u | ConvertOpDef::trunc_s => {
155            let signed = conversion_op == ConvertOpDef::trunc_s;
156            let Some(TypeArg::BoundedNat { n: log_width }) = args.node().args().last().cloned()
157            else {
158                panic!("This op should have one type arg only: the log-width of the int we're truncating to.: {:?}", conversion_op.type_args())
159            };
160
161            build_trunc_op(context, signed, log_width, args)
162        }
163
164        ConvertOpDef::convert_u => emit_custom_unary_op(context, args, |ctx, arg, out_tys| {
165            let out_ty = out_tys.last().unwrap();
166            Ok(vec![ctx
167                .builder()
168                .build_unsigned_int_to_float(arg.into_int_value(), out_ty.into_float_type(), "")?
169                .as_basic_value_enum()])
170        }),
171
172        ConvertOpDef::convert_s => emit_custom_unary_op(context, args, |ctx, arg, out_tys| {
173            let out_ty = out_tys.last().unwrap();
174            Ok(vec![ctx
175                .builder()
176                .build_signed_int_to_float(arg.into_int_value(), out_ty.into_float_type(), "")?
177                .as_basic_value_enum()])
178        }),
179        // These ops convert between hugr's `USIZE` and u64. The former is
180        // implementation-dependent and we define them to be the same.
181        // Hence our implementation is a noop.
182        ConvertOpDef::itousize | ConvertOpDef::ifromusize => {
183            emit_custom_unary_op(context, args, |_, arg, _| Ok(vec![arg]))
184        }
185        ConvertOpDef::itobool | ConvertOpDef::ifrombool => {
186            assert!(conversion_op.type_args().is_empty()); // Always 1-bit int <-> bool
187            let i0_ty = context
188                .typing_session()
189                .llvm_type(&INT_TYPES[0])?
190                .into_int_type();
191            let sum_ty = context
192                .typing_session()
193                .llvm_sum_type(match bool_t().as_type_enum() {
194                    TypeEnum::Sum(st) => st.clone(),
195                    _ => panic!("Hugr prelude bool_t() not a Sum"),
196                })?;
197
198            emit_custom_unary_op(context, args, |ctx, arg, _| {
199                let res = if conversion_op == ConvertOpDef::itobool {
200                    let is1 = ctx.builder().build_int_compare(
201                        IntPredicate::EQ,
202                        arg.into_int_value(),
203                        i0_ty.const_int(1, false),
204                        "eq1",
205                    )?;
206                    let sum_f = sum_ty.build_tag(ctx.builder(), 0, vec![])?;
207                    let sum_t = sum_ty.build_tag(ctx.builder(), 1, vec![])?;
208                    ctx.builder().build_select(is1, sum_t, sum_f, "")?
209                } else {
210                    let tag_ty = sum_ty.tag_type();
211                    let tag = LLVMSumValue::try_new(arg, sum_ty)?.build_get_tag(ctx.builder())?;
212                    let is_true = ctx.builder().build_int_compare(
213                        IntPredicate::EQ,
214                        tag,
215                        tag_ty.const_int(1, false),
216                        "",
217                    )?;
218                    ctx.builder().build_select(
219                        is_true,
220                        i0_ty.const_int(1, false),
221                        i0_ty.const_int(0, false),
222                        "",
223                    )?
224                };
225                Ok(vec![res])
226            })
227        }
228        ConvertOpDef::bytecast_int64_to_float64 => {
229            emit_custom_unary_op(context, args, |ctx, arg, outs| {
230                let [out] = outs.try_into()?;
231                Ok(vec![ctx.builder().build_bit_cast(arg, out, "")?])
232            })
233        }
234        ConvertOpDef::bytecast_float64_to_int64 => {
235            emit_custom_unary_op(context, args, |ctx, arg, outs| {
236                let [out] = outs.try_into()?;
237                Ok(vec![ctx.builder().build_bit_cast(arg, out, "")?])
238            })
239        }
240        _ => Err(anyhow!(
241            "Conversion op not implemented: {:?}",
242            args.node().as_ref()
243        )),
244    }
245}
246
247#[derive(Clone, Debug)]
248pub struct ConversionExtension;
249
250impl CodegenExtension for ConversionExtension {
251    fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
252        self,
253        builder: CodegenExtsBuilder<'a, H>,
254    ) -> CodegenExtsBuilder<'a, H>
255    where
256        Self: 'a,
257    {
258        builder.simple_extension_op(emit_conversion_op)
259    }
260}
261
262impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
263    pub fn add_conversion_extensions(self) -> Self {
264        self.add_extension(ConversionExtension)
265    }
266}
267
268#[cfg(test)]
269mod test {
270
271    use super::*;
272
273    use crate::check_emission;
274    use crate::emit::test::{SimpleHugrConfig, DFGW};
275    use crate::test::{exec_ctx, llvm_ctx, TestContext};
276    use hugr_core::builder::SubContainer;
277    use hugr_core::std_extensions::arithmetic::float_types::ConstF64;
278    use hugr_core::std_extensions::arithmetic::int_types::ConstInt;
279    use hugr_core::std_extensions::STD_REG;
280    use hugr_core::{
281        builder::{Dataflow, DataflowSubContainer},
282        extension::prelude::{usize_t, ConstUsize, PRELUDE_REGISTRY},
283        std_extensions::arithmetic::{
284            conversions::{ConvertOpDef, EXTENSION},
285            float_types::float64_type,
286            int_types::INT_TYPES,
287        },
288        types::Type,
289        Hugr,
290    };
291    use rstest::rstest;
292
293    fn test_conversion_op(
294        name: impl AsRef<str>,
295        in_type: Type,
296        out_type: Type,
297        int_width: u8,
298    ) -> Hugr {
299        SimpleHugrConfig::new()
300            .with_ins(vec![in_type.clone()])
301            .with_outs(vec![out_type.clone()])
302            .with_extensions(STD_REG.clone())
303            .finish(|mut hugr_builder| {
304                let [in1] = hugr_builder.input_wires_arr();
305                let ext_op = EXTENSION
306                    .instantiate_extension_op(name.as_ref(), [(int_width as u64).into()])
307                    .unwrap();
308                let outputs = hugr_builder
309                    .add_dataflow_op(ext_op, [in1])
310                    .unwrap()
311                    .outputs();
312                hugr_builder.finish_with_outputs(outputs).unwrap()
313            })
314    }
315
316    #[rstest]
317    #[case("convert_u", 4)]
318    #[case("convert_s", 5)]
319    fn test_convert(mut llvm_ctx: TestContext, #[case] op_name: &str, #[case] log_width: u8) -> () {
320        llvm_ctx.add_extensions(|ceb| {
321            ceb.add_int_extensions()
322                .add_float_extensions()
323                .add_conversion_extensions()
324        });
325        let in_ty = INT_TYPES[log_width as usize].clone();
326        let out_ty = float64_type();
327        let hugr = test_conversion_op(op_name, in_ty, out_ty, log_width);
328        check_emission!(op_name, hugr, llvm_ctx);
329    }
330
331    #[rstest]
332    #[case("trunc_u", 6)]
333    #[case("trunc_s", 5)]
334    fn test_truncation(
335        mut llvm_ctx: TestContext,
336        #[case] op_name: &str,
337        #[case] log_width: u8,
338    ) -> () {
339        llvm_ctx.add_extensions(|builder| {
340            builder
341                .add_int_extensions()
342                .add_float_extensions()
343                .add_conversion_extensions()
344                .add_default_prelude_extensions()
345        });
346        let in_ty = float64_type();
347        let out_ty = sum_with_error(INT_TYPES[log_width as usize].clone());
348        let hugr = test_conversion_op(op_name, in_ty, out_ty.into(), log_width);
349        check_emission!(op_name, hugr, llvm_ctx);
350    }
351
352    #[rstest]
353    #[case("itobool", true)]
354    #[case("ifrombool", false)]
355    fn test_intbool_emit(
356        mut llvm_ctx: TestContext,
357        #[case] op_name: &str,
358        #[case] input_int: bool,
359    ) {
360        let mut tys = [INT_TYPES[0].clone(), bool_t()];
361        if !input_int {
362            tys.reverse()
363        };
364        let [in_t, out_t] = tys;
365        llvm_ctx.add_extensions(|builder| {
366            builder
367                .add_int_extensions()
368                .add_float_extensions()
369                .add_conversion_extensions()
370        });
371        let hugr = SimpleHugrConfig::new()
372            .with_ins(vec![in_t])
373            .with_outs(vec![out_t])
374            .with_extensions(STD_REG.to_owned())
375            .finish(|mut hugr_builder| {
376                let [in1] = hugr_builder.input_wires_arr();
377                let ext_op = EXTENSION.instantiate_extension_op(op_name, []).unwrap();
378                let [out1] = hugr_builder
379                    .add_dataflow_op(ext_op, [in1])
380                    .unwrap()
381                    .outputs_arr();
382                hugr_builder.finish_with_outputs([out1]).unwrap()
383            });
384        check_emission!(op_name, hugr, llvm_ctx);
385    }
386
387    #[rstest]
388    fn my_test_exec(mut exec_ctx: TestContext) {
389        let hugr = SimpleHugrConfig::new()
390            .with_outs(usize_t())
391            .with_extensions(PRELUDE_REGISTRY.to_owned())
392            .finish(|mut builder: DFGW| {
393                let konst = builder.add_load_value(ConstUsize::new(42));
394                builder.finish_with_outputs([konst]).unwrap()
395            });
396        exec_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions);
397        assert_eq!(42, exec_ctx.exec_hugr_u64(hugr, "main"));
398    }
399
400    #[rstest]
401    #[case(0)]
402    #[case(42)]
403    #[case(18_446_744_073_709_551_615)]
404    fn usize_roundtrip(mut exec_ctx: TestContext, #[case] val: u64) -> () {
405        let hugr = SimpleHugrConfig::new()
406            .with_outs(usize_t())
407            .with_extensions(STD_REG.clone())
408            .finish(|mut builder: DFGW| {
409                let k = builder.add_load_value(ConstUsize::new(val));
410                let [int] = builder
411                    .add_dataflow_op(ConvertOpDef::ifromusize.without_log_width(), [k])
412                    .unwrap()
413                    .outputs_arr();
414                let [usize_] = builder
415                    .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [int])
416                    .unwrap()
417                    .outputs_arr();
418                builder.finish_with_outputs([usize_]).unwrap()
419            });
420        exec_ctx.add_extensions(|builder| {
421            builder
422                .add_int_extensions()
423                .add_conversion_extensions()
424                .add_default_prelude_extensions()
425        });
426        assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main"));
427    }
428
429    fn roundtrip_hugr(val: u64, signed: bool) -> Hugr {
430        let int64 = INT_TYPES[6].clone();
431        SimpleHugrConfig::new()
432            .with_outs(usize_t())
433            .with_extensions(STD_REG.clone())
434            .finish(|mut builder| {
435                let k = builder.add_load_value(ConstUsize::new(val));
436                let [int] = builder
437                    .add_dataflow_op(ConvertOpDef::ifromusize.without_log_width(), [k])
438                    .unwrap()
439                    .outputs_arr();
440                let [flt] = {
441                    let op = if signed {
442                        ConvertOpDef::convert_s.with_log_width(6)
443                    } else {
444                        ConvertOpDef::convert_u.with_log_width(6)
445                    };
446                    builder.add_dataflow_op(op, [int]).unwrap().outputs_arr()
447                };
448
449                let [int_or_err] = {
450                    let op = if signed {
451                        ConvertOpDef::trunc_s.with_log_width(6)
452                    } else {
453                        ConvertOpDef::trunc_u.with_log_width(6)
454                    };
455                    builder.add_dataflow_op(op, [flt]).unwrap().outputs_arr()
456                };
457                let sum_ty = sum_with_error(int64.clone());
458                let variants = (0..sum_ty.num_variants())
459                    .map(|i| sum_ty.get_variant(i).unwrap().clone().try_into().unwrap());
460                let mut cond_b = builder
461                    .conditional_builder((variants, int_or_err), [], vec![int64].into())
462                    .unwrap();
463                let win_case = cond_b.case_builder(1).unwrap();
464                let [win_in] = win_case.input_wires_arr();
465                win_case.finish_with_outputs([win_in]).unwrap();
466                let mut lose_case = cond_b.case_builder(0).unwrap();
467                let const_999 = lose_case.add_load_value(ConstUsize::new(999));
468                let [const_999] = lose_case
469                    .add_dataflow_op(ConvertOpDef::ifromusize.without_log_width(), [const_999])
470                    .unwrap()
471                    .outputs_arr();
472                lose_case.finish_with_outputs([const_999]).unwrap();
473
474                let cond = cond_b.finish_sub_container().unwrap();
475
476                let [cond_result] = cond.outputs_arr();
477
478                let [usize_] = builder
479                    .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [cond_result])
480                    .unwrap()
481                    .outputs_arr();
482                builder.finish_with_outputs([usize_]).unwrap()
483            })
484    }
485
486    fn add_extensions(ctx: &mut TestContext) {
487        ctx.add_extensions(|builder| {
488            builder
489                .add_conversion_extensions()
490                .add_default_prelude_extensions()
491                .add_float_extensions()
492                .add_int_extensions()
493        });
494    }
495
496    #[rstest]
497    // Exact roundtrip conversion is defined on values up to 2**53 for f64.
498    #[case(0)]
499    #[case(3)]
500    #[case(255)]
501    #[case(4294967295)]
502    #[case(42)]
503    #[case(18_000_000_000_000_000_000)]
504    fn roundtrip_unsigned(mut exec_ctx: TestContext, #[case] val: u64) {
505        add_extensions(&mut exec_ctx);
506        let hugr = roundtrip_hugr(val, false);
507        assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main"));
508    }
509
510    #[rstest]
511    // Exact roundtrip conversion is defined on values up to 2**53 for f64.
512    #[case(0)]
513    #[case(3)]
514    #[case(255)]
515    #[case(4294967295)]
516    #[case(42)]
517    #[case(-9_000_000_000_000_000_000)]
518    fn roundtrip_signed(mut exec_ctx: TestContext, #[case] val: i64) {
519        add_extensions(&mut exec_ctx);
520        let hugr = roundtrip_hugr(val as u64, true);
521        assert_eq!(val, exec_ctx.exec_hugr_u64(hugr, "main") as i64);
522    }
523
524    // For unisgined ints larger than (1 << 54) - 1, f64s do not have enough
525    // precision to exactly roundtrip the int.
526    // The exact behaviour of the round-trip is is platform-dependent.
527    #[rstest]
528    #[case(u64::MAX)]
529    #[case(u64::MAX - 1)]
530    #[case(u64::MAX - (1 << 1))]
531    #[case(u64::MAX - (1 << 2))]
532    #[case(u64::MAX - (1 << 3))]
533    #[case(u64::MAX - (1 << 4))]
534    #[case(u64::MAX - (1 << 5))]
535    #[case(u64::MAX - (1 << 6))]
536    #[case(u64::MAX - (1 << 7))]
537    #[case(u64::MAX - (1 << 8))]
538    #[case(u64::MAX - (1 << 9))]
539    #[case(u64::MAX - (1 << 10))]
540    #[case(u64::MAX - (1 << 11))]
541    fn approx_roundtrip_unsigned(mut exec_ctx: TestContext, #[case] val: u64) {
542        add_extensions(&mut exec_ctx);
543
544        let hugr = roundtrip_hugr(val, false);
545        let result = exec_ctx.exec_hugr_u64(hugr, "main");
546        let (v_r_max, v_r_min) = (val.max(result), val.min(result));
547        // If val is too large the `trunc_u` op in `hugr` will return None.
548        // In this case the hugr returns the magic number `999`.
549        assert!(result == 999 || (v_r_max - v_r_min) < 1 << 10);
550    }
551
552    #[rstest]
553    #[case(i64::MAX)]
554    #[case(i64::MAX - 1)]
555    #[case(i64::MAX - (1 << 1))]
556    #[case(i64::MAX - (1 << 2))]
557    #[case(i64::MAX - (1 << 3))]
558    #[case(i64::MAX - (1 << 4))]
559    #[case(i64::MAX - (1 << 5))]
560    #[case(i64::MAX - (1 << 6))]
561    #[case(i64::MAX - (1 << 7))]
562    #[case(i64::MAX - (1 << 8))]
563    #[case(i64::MAX - (1 << 9))]
564    #[case(i64::MAX - (1 << 10))]
565    #[case(i64::MAX - (1 << 11))]
566    #[case(i64::MIN)]
567    #[case(i64::MIN + 1)]
568    #[case(i64::MIN + (1 << 1))]
569    #[case(i64::MIN + (1 << 2))]
570    #[case(i64::MIN + (1 << 3))]
571    #[case(i64::MIN + (1 << 4))]
572    #[case(i64::MIN + (1 << 5))]
573    #[case(i64::MIN + (1 << 6))]
574    #[case(i64::MIN + (1 << 7))]
575    #[case(i64::MIN + (1 << 8))]
576    #[case(i64::MIN + (1 << 9))]
577    #[case(i64::MIN + (1 << 10))]
578    #[case(i64::MIN + (1 << 11))]
579    fn approx_roundtrip_signed(mut exec_ctx: TestContext, #[case] val: i64) {
580        add_extensions(&mut exec_ctx);
581
582        let hugr = roundtrip_hugr(val as u64, true);
583        let result = exec_ctx.exec_hugr_u64(hugr, "main") as i64;
584        // If val.abs() is too large the `trunc_s` op in `hugr` will return None.
585        // In this case the hugr returns the magic number `999`.
586        assert!(result == 999 || (val - result).abs() < 1 << 10);
587    }
588
589    #[rstest]
590    fn itobool_cond(mut exec_ctx: TestContext, #[values(0, 1)] i: u64) {
591        use hugr_core::type_row;
592
593        let hugr = SimpleHugrConfig::new()
594            .with_outs(vec![usize_t()])
595            .with_extensions(STD_REG.to_owned())
596            .finish(|mut builder| {
597                let i = builder.add_load_value(ConstInt::new_u(0, i).unwrap());
598                let ext_op = EXTENSION.instantiate_extension_op("itobool", []).unwrap();
599                let [b] = builder.add_dataflow_op(ext_op, [i]).unwrap().outputs_arr();
600                let mut cond = builder
601                    .conditional_builder(
602                        ([type_row![], type_row![]], b),
603                        [],
604                        vec![usize_t()].into(),
605                    )
606                    .unwrap();
607                let mut case_false = cond.case_builder(0).unwrap();
608                let false_result = case_false.add_load_value(ConstUsize::new(1));
609                case_false.finish_with_outputs([false_result]).unwrap();
610                let mut case_true = cond.case_builder(1).unwrap();
611                let true_result = case_true.add_load_value(ConstUsize::new(6));
612                case_true.finish_with_outputs([true_result]).unwrap();
613                let res = cond.finish_sub_container().unwrap();
614                builder.finish_with_outputs(res.outputs()).unwrap()
615            });
616        exec_ctx.add_extensions(|builder| {
617            builder
618                .add_conversion_extensions()
619                .add_default_prelude_extensions()
620                .add_int_extensions()
621        });
622        assert_eq!(i * 5 + 1, exec_ctx.exec_hugr_u64(hugr, "main"));
623    }
624
625    #[rstest]
626    fn itobool_roundtrip(mut exec_ctx: TestContext, #[values(0, 1)] i: u64) {
627        let hugr = SimpleHugrConfig::new()
628            .with_outs(vec![INT_TYPES[0].clone()])
629            .with_extensions(STD_REG.to_owned())
630            .finish(|mut builder| {
631                let i = builder.add_load_value(ConstInt::new_u(0, i).unwrap());
632                let i2b = EXTENSION.instantiate_extension_op("itobool", []).unwrap();
633                let [b] = builder.add_dataflow_op(i2b, [i]).unwrap().outputs_arr();
634                let b2i = EXTENSION.instantiate_extension_op("ifrombool", []).unwrap();
635                let [i] = builder.add_dataflow_op(b2i, [b]).unwrap().outputs_arr();
636                builder.finish_with_outputs([i]).unwrap()
637            });
638        exec_ctx.add_extensions(|builder| {
639            builder
640                .add_conversion_extensions()
641                .add_default_prelude_extensions()
642                .add_int_extensions()
643        });
644        assert_eq!(i, exec_ctx.exec_hugr_u64(hugr, "main"));
645    }
646
647    #[rstest]
648    #[case(42.0)]
649    #[case(f64::INFINITY)]
650    #[case(f64::NEG_INFINITY)]
651    #[case(f64::NAN)]
652    #[case(-0.0f64)]
653    #[case(0.0f64)]
654    fn bytecast_int64_to_float64(mut exec_ctx: TestContext, #[case] f: f64) {
655        let hugr = SimpleHugrConfig::new()
656            .with_outs(float64_type())
657            .with_extensions(STD_REG.to_owned())
658            .finish(|mut builder| {
659                let i = builder.add_load_value(ConstInt::new_u(6, f.to_bits()).unwrap());
660                let i2f = EXTENSION
661                    .instantiate_extension_op("bytecast_int64_to_float64", [])
662                    .unwrap();
663                let [f] = builder.add_dataflow_op(i2f, [i]).unwrap().outputs_arr();
664                builder.finish_with_outputs([f]).unwrap()
665            });
666        exec_ctx.add_extensions(|builder| {
667            builder
668                .add_conversion_extensions()
669                .add_default_prelude_extensions()
670                .add_int_extensions()
671                .add_float_extensions()
672        });
673        let hugr_f = exec_ctx.exec_hugr_f64(hugr, "main");
674        assert!((f.is_nan() && hugr_f.is_nan()) || f == hugr_f);
675    }
676
677    #[rstest]
678    #[case(42.0)]
679    #[case(-0.0f64)]
680    #[case(0.0f64)]
681    fn bytecast_float64_to_int64(mut exec_ctx: TestContext, #[case] f: f64) {
682        let hugr = SimpleHugrConfig::new()
683            .with_outs(INT_TYPES[6].clone())
684            .with_extensions(STD_REG.to_owned())
685            .finish(|mut builder| {
686                let f = builder.add_load_value(ConstF64::new(f));
687                let f2i = EXTENSION
688                    .instantiate_extension_op("bytecast_float64_to_int64", [])
689                    .unwrap();
690                let [i] = builder.add_dataflow_op(f2i, [f]).unwrap().outputs_arr();
691                builder.finish_with_outputs([i]).unwrap()
692            });
693        exec_ctx.add_extensions(|builder| {
694            builder
695                .add_conversion_extensions()
696                .add_default_prelude_extensions()
697                .add_int_extensions()
698                .add_float_extensions()
699        });
700        assert_eq!(f.to_bits(), exec_ctx.exec_hugr_u64(hugr, "main"));
701    }
702}