hugr_llvm/extension/
int.rs

1use hugr_core::{
2    extension::prelude::ConstError,
3    ops::{constant::CustomConst, ExtensionOp, NamedOp, Value},
4    std_extensions::arithmetic::{
5        int_ops::IntOpDef,
6        int_types::{self, ConstInt},
7    },
8    types::{CustomType, TypeArg},
9    HugrView, Node,
10};
11use inkwell::{
12    types::{BasicType, BasicTypeEnum, IntType},
13    values::{BasicValue, BasicValueEnum, IntValue},
14    IntPredicate,
15};
16
17use crate::{
18    custom::CodegenExtsBuilder,
19    emit::{
20        emit_value,
21        func::EmitFuncContext,
22        get_intrinsic,
23        libc::{emit_libc_abort, emit_libc_printf},
24        ops::{emit_custom_binary_op, emit_custom_unary_op},
25        EmitOpArgs,
26    },
27    sum::LLVMSumType,
28    types::{HugrSumType, TypingSession},
29};
30
31use anyhow::{anyhow, bail, Result};
32
33use super::conversions::int_type_bounds;
34
35enum RuntimeError {
36    Narrow,
37}
38
39impl RuntimeError {
40    fn show(&self) -> &str {
41        match self {
42            RuntimeError::Narrow => "Can't narrow into bounds",
43        }
44    }
45}
46
47/// Emit an integer comparison operation.
48fn emit_icmp<'c, H: HugrView<Node = Node>>(
49    context: &mut EmitFuncContext<'c, '_, H>,
50    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
51    pred: inkwell::IntPredicate,
52) -> Result<()> {
53    let true_val = emit_value(context, &Value::true_val())?;
54    let false_val = emit_value(context, &Value::false_val())?;
55
56    emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
57        // get result as an i1
58        let r = ctx.builder().build_int_compare(
59            pred,
60            lhs.into_int_value(),
61            rhs.into_int_value(),
62            "",
63        )?;
64        // convert to whatever bool_t is
65        Ok(vec![ctx
66            .builder()
67            .build_select(r, true_val, false_val, "")?])
68    })
69}
70
71/// Emit an ipow operation. This isn't directly supported in llvm, so we do a
72/// loop over the exponent, performing `imul`s instead.
73/// The insertion pointer is expected to be pointing to the end of `launch_bb`.
74fn emit_ipow<'c, H: HugrView<Node = Node>>(
75    context: &mut EmitFuncContext<'c, '_, H>,
76    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
77) -> Result<()> {
78    emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
79        let done_bb = ctx.new_basic_block("done", None);
80        let pow_body_bb = ctx.new_basic_block("pow_body", Some(done_bb));
81        let return_one_bb = ctx.new_basic_block("power_of_zero", Some(pow_body_bb));
82        let pow_bb = ctx.new_basic_block("pow", Some(return_one_bb));
83
84        let acc_p = ctx.builder().build_alloca(lhs.get_type(), "acc_ptr")?;
85        let exp_p = ctx.builder().build_alloca(rhs.get_type(), "exp_ptr")?;
86        ctx.builder().build_store(acc_p, lhs)?;
87        ctx.builder().build_store(exp_p, rhs)?;
88        ctx.builder().build_unconditional_branch(pow_bb)?;
89
90        let zero = rhs.get_type().into_int_type().const_int(0, false);
91        // Assumes RHS type is the same as output type (which it should be)
92        let one = rhs.get_type().into_int_type().const_int(1, false);
93
94        // Block for just returning one
95        ctx.builder().position_at_end(return_one_bb);
96        ctx.builder().build_store(acc_p, one)?;
97        ctx.builder().build_unconditional_branch(done_bb)?;
98
99        ctx.builder().position_at_end(pow_bb);
100        let acc = ctx.builder().build_load(acc_p, "acc")?;
101        let exp = ctx.builder().build_load(exp_p, "exp")?;
102
103        // Special case if the exponent is 0 or 1
104        ctx.builder().build_switch(
105            exp.into_int_value(),
106            pow_body_bb,
107            &[(one, done_bb), (zero, return_one_bb)],
108        )?;
109
110        // Block that performs one `imul` and modifies the values in the store
111        ctx.builder().position_at_end(pow_body_bb);
112        let new_acc =
113            ctx.builder()
114                .build_int_mul(acc.into_int_value(), lhs.into_int_value(), "new_acc")?;
115        let new_exp = ctx
116            .builder()
117            .build_int_sub(exp.into_int_value(), one, "new_exp")?;
118        ctx.builder().build_store(acc_p, new_acc)?;
119        ctx.builder().build_store(exp_p, new_exp)?;
120        ctx.builder().build_unconditional_branch(pow_bb)?;
121
122        ctx.builder().position_at_end(done_bb);
123        let result = ctx.builder().build_load(acc_p, "result")?;
124        Ok(vec![result.as_basic_value_enum()])
125    })
126}
127
128fn emit_int_op<'c, H: HugrView<Node = Node>>(
129    context: &mut EmitFuncContext<'c, '_, H>,
130    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
131    op: IntOpDef,
132) -> Result<()> {
133    match op {
134        IntOpDef::iadd => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
135            Ok(vec![ctx
136                .builder()
137                .build_int_add(lhs.into_int_value(), rhs.into_int_value(), "")?
138                .as_basic_value_enum()])
139        }),
140        IntOpDef::imul => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
141            Ok(vec![ctx
142                .builder()
143                .build_int_mul(lhs.into_int_value(), rhs.into_int_value(), "")?
144                .as_basic_value_enum()])
145        }),
146        IntOpDef::isub => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
147            Ok(vec![ctx
148                .builder()
149                .build_int_sub(lhs.into_int_value(), rhs.into_int_value(), "")?
150                .as_basic_value_enum()])
151        }),
152        IntOpDef::idiv_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
153            Ok(vec![ctx
154                .builder()
155                .build_int_signed_div(lhs.into_int_value(), rhs.into_int_value(), "")?
156                .as_basic_value_enum()])
157        }),
158        IntOpDef::idiv_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
159            Ok(vec![ctx
160                .builder()
161                .build_int_unsigned_div(lhs.into_int_value(), rhs.into_int_value(), "")?
162                .as_basic_value_enum()])
163        }),
164        IntOpDef::imod_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
165            Ok(vec![ctx
166                .builder()
167                .build_int_signed_rem(lhs.into_int_value(), rhs.into_int_value(), "")?
168                .as_basic_value_enum()])
169        }),
170        IntOpDef::ineg => emit_custom_unary_op(context, args, |ctx, arg, _| {
171            Ok(vec![ctx
172                .builder()
173                .build_int_neg(arg.into_int_value(), "")?
174                .as_basic_value_enum()])
175        }),
176        IntOpDef::iabs => emit_custom_unary_op(context, args, |ctx, arg, _| {
177            let intr = get_intrinsic(
178                ctx.get_current_module(),
179                "llvm.abs.i64",
180                [ctx.iw_context().i64_type().as_basic_type_enum()],
181            )?;
182            let true_ = ctx.iw_context().bool_type().const_all_ones();
183            let r = ctx
184                .builder()
185                .build_call(intr, &[arg.into_int_value().into(), true_.into()], "")?
186                .try_as_basic_value()
187                .unwrap_left();
188            Ok(vec![r])
189        }),
190        IntOpDef::imax_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
191            let intr = get_intrinsic(
192                ctx.get_current_module(),
193                "llvm.smax.i64",
194                [ctx.iw_context().i64_type().as_basic_type_enum()],
195            )?;
196            let r = ctx
197                .builder()
198                .build_call(
199                    intr,
200                    &[lhs.into_int_value().into(), rhs.into_int_value().into()],
201                    "",
202                )?
203                .try_as_basic_value()
204                .unwrap_left();
205            Ok(vec![r])
206        }),
207        IntOpDef::imax_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
208            let intr = get_intrinsic(
209                ctx.get_current_module(),
210                "llvm.umax.i64",
211                [ctx.iw_context().i64_type().as_basic_type_enum()],
212            )?;
213            let r = ctx
214                .builder()
215                .build_call(
216                    intr,
217                    &[lhs.into_int_value().into(), rhs.into_int_value().into()],
218                    "",
219                )?
220                .try_as_basic_value()
221                .unwrap_left();
222            Ok(vec![r])
223        }),
224        IntOpDef::imin_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
225            let intr = get_intrinsic(
226                ctx.get_current_module(),
227                "llvm.smin.i64",
228                [ctx.iw_context().i64_type().as_basic_type_enum()],
229            )?;
230            let r = ctx
231                .builder()
232                .build_call(
233                    intr,
234                    &[lhs.into_int_value().into(), rhs.into_int_value().into()],
235                    "",
236                )?
237                .try_as_basic_value()
238                .unwrap_left();
239            Ok(vec![r])
240        }),
241        IntOpDef::imin_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
242            let intr = get_intrinsic(
243                ctx.get_current_module(),
244                "llvm.umin.i64",
245                [ctx.iw_context().i64_type().as_basic_type_enum()],
246            )?;
247            let r = ctx
248                .builder()
249                .build_call(
250                    intr,
251                    &[lhs.into_int_value().into(), rhs.into_int_value().into()],
252                    "",
253                )?
254                .try_as_basic_value()
255                .unwrap_left();
256            Ok(vec![r])
257        }),
258        IntOpDef::ishl => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
259            Ok(vec![ctx
260                .builder()
261                .build_left_shift(lhs.into_int_value(), rhs.into_int_value(), "")?
262                .as_basic_value_enum()])
263        }),
264        IntOpDef::ishr => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
265            Ok(vec![ctx
266                .builder()
267                .build_right_shift(lhs.into_int_value(), rhs.into_int_value(), false, "")?
268                .as_basic_value_enum()])
269        }),
270        IntOpDef::ieq => emit_icmp(context, args, inkwell::IntPredicate::EQ),
271        IntOpDef::ine => emit_icmp(context, args, inkwell::IntPredicate::NE),
272        IntOpDef::ilt_s => emit_icmp(context, args, inkwell::IntPredicate::SLT),
273        IntOpDef::igt_s => emit_icmp(context, args, inkwell::IntPredicate::SGT),
274        IntOpDef::ile_s => emit_icmp(context, args, inkwell::IntPredicate::SLE),
275        IntOpDef::ige_s => emit_icmp(context, args, inkwell::IntPredicate::SGE),
276        IntOpDef::ilt_u => emit_icmp(context, args, inkwell::IntPredicate::ULT),
277        IntOpDef::igt_u => emit_icmp(context, args, inkwell::IntPredicate::UGT),
278        IntOpDef::ile_u => emit_icmp(context, args, inkwell::IntPredicate::ULE),
279        IntOpDef::ige_u => emit_icmp(context, args, inkwell::IntPredicate::UGE),
280        IntOpDef::ixor => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
281            Ok(vec![ctx
282                .builder()
283                .build_xor(lhs.into_int_value(), rhs.into_int_value(), "")?
284                .as_basic_value_enum()])
285        }),
286        IntOpDef::ior => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
287            Ok(vec![ctx
288                .builder()
289                .build_or(lhs.into_int_value(), rhs.into_int_value(), "")?
290                .as_basic_value_enum()])
291        }),
292        IntOpDef::inot => emit_custom_unary_op(context, args, |ctx, arg, _| {
293            Ok(vec![ctx
294                .builder()
295                .build_not(arg.into_int_value(), "")?
296                .as_basic_value_enum()])
297        }),
298        IntOpDef::iand => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
299            Ok(vec![ctx
300                .builder()
301                .build_and(lhs.into_int_value(), rhs.into_int_value(), "")?
302                .as_basic_value_enum()])
303        }),
304        IntOpDef::ipow => emit_ipow(context, args),
305        // Type args are width of input, width of output
306        IntOpDef::iwiden_u => emit_custom_unary_op(context, args, |ctx, arg, outs| {
307            let [out] = outs.try_into()?;
308            Ok(vec![ctx
309                .builder()
310                .build_int_cast_sign_flag(arg.into_int_value(), out.into_int_type(), false, "")?
311                .as_basic_value_enum()])
312        }),
313        IntOpDef::iwiden_s => emit_custom_unary_op(context, args, |ctx, arg, outs| {
314            let [out] = outs.try_into()?;
315
316            Ok(vec![ctx
317                .builder()
318                .build_int_cast_sign_flag(arg.into_int_value(), out.into_int_type(), true, "")?
319                .as_basic_value_enum()])
320        }),
321        IntOpDef::inarrow_s => {
322            let Some(TypeArg::BoundedNat { n: out_log_width }) = args.node().args().last().cloned()
323            else {
324                bail!("Type arg to inarrow_s wasn't a Nat");
325            };
326            let (_, out_ty) = args.node.out_value_types().next().unwrap();
327            emit_custom_unary_op(context, args, |ctx, arg, outs| {
328                let result = make_narrow(
329                    ctx,
330                    arg,
331                    outs,
332                    out_log_width,
333                    true,
334                    out_ty.as_sum().unwrap().clone(),
335                )?;
336                Ok(vec![result])
337            })
338        }
339        IntOpDef::inarrow_u => {
340            let Some(TypeArg::BoundedNat { n: out_log_width }) = args.node().args().last().cloned()
341            else {
342                bail!("Type arg to inarrow_u wasn't a Nat");
343            };
344            let (_, out_ty) = args.node.out_value_types().next().unwrap();
345            emit_custom_unary_op(context, args, |ctx, arg, outs| {
346                let result = make_narrow(
347                    ctx,
348                    arg,
349                    outs,
350                    out_log_width,
351                    false,
352                    out_ty.as_sum().unwrap().clone(),
353                )?;
354                Ok(vec![result])
355            })
356        }
357        IntOpDef::iu_to_s => {
358            let [TypeArg::BoundedNat { n: log_width }] =
359                TryInto::<[TypeArg; 1]>::try_into(args.node.args().to_vec()).unwrap()
360            else {
361                bail!("Type argument to iu_to_s wasn't a number");
362            };
363            emit_custom_unary_op(context, args, |ctx, arg, _| {
364                let (_, max_val, _) = int_type_bounds(u32::pow(2, log_width as u32));
365                let max = arg
366                    .get_type()
367                    .into_int_type()
368                    .const_int(max_val as u64, false);
369
370                let within_bounds = ctx.builder().build_int_compare(
371                    IntPredicate::ULE,
372                    arg.into_int_value(),
373                    max,
374                    "bounds_check",
375                )?;
376
377                Ok(vec![val_or_panic(
378                    ctx,
379                    within_bounds,
380                    "iu_to_s argument out of bounds",
381                    arg,
382                )?])
383            })
384        }
385        IntOpDef::is_to_u => emit_custom_unary_op(context, args, |ctx, arg, _| {
386            let zero = arg.get_type().into_int_type().const_zero();
387
388            let within_bounds = ctx.builder().build_int_compare(
389                IntPredicate::SGE,
390                arg.into_int_value(),
391                zero,
392                "bounds_check",
393            )?;
394
395            Ok(vec![val_or_panic(
396                ctx,
397                within_bounds,
398                "is_to_u called on negative value",
399                arg,
400            )?])
401        }),
402        _ => Err(anyhow!("IntOpEmitter: unimplemented op: {}", op.name())),
403    }
404}
405
406fn make_narrow<'c, H: HugrView<Node = Node>>(
407    ctx: &mut EmitFuncContext<'c, '_, H>,
408    arg: BasicValueEnum<'c>,
409    outs: &[BasicTypeEnum<'c>],
410    out_log_width: u64,
411    signed: bool,
412    sum_type: HugrSumType,
413) -> Result<BasicValueEnum<'c>> {
414    let [out] = TryInto::<[BasicTypeEnum; 1]>::try_into(outs)?;
415    let width = 1 << out_log_width;
416    let arg_int_ty: IntType = arg.get_type().into_int_type();
417    let (int_min_value_s, int_max_value_s, int_max_value_u) = int_type_bounds(width);
418    let out_int_ty = out
419        .into_struct_type()
420        .get_field_type_at_index(2)
421        .unwrap()
422        .into_int_type();
423    let outside_range = if signed {
424        let too_big = ctx.builder().build_int_compare(
425            IntPredicate::SGT,
426            arg.into_int_value(),
427            arg_int_ty.const_int(int_max_value_s as u64, true),
428            "upper_bounds_check",
429        )?;
430        let too_small = ctx.builder().build_int_compare(
431            IntPredicate::SLT,
432            arg.into_int_value(),
433            arg_int_ty.const_int(int_min_value_s as u64, true),
434            "lower_bounds_check",
435        )?;
436        ctx.builder()
437            .build_or(too_big, too_small, "outside_range")?
438    } else {
439        ctx.builder().build_int_compare(
440            IntPredicate::UGT,
441            arg.into_int_value(),
442            arg_int_ty.const_int(int_max_value_u, false),
443            "upper_bounds_check",
444        )?
445    };
446
447    let narrowed_val = ctx
448        .builder()
449        .build_int_cast_sign_flag(arg.into_int_value(), out_int_ty, signed, "")?
450        .as_basic_value_enum();
451    val_or_error(
452        ctx,
453        outside_range,
454        narrowed_val,
455        RuntimeError::Narrow,
456        LLVMSumType::try_from_hugr_type(&ctx.typing_session(), sum_type).unwrap(),
457    )
458}
459
460fn val_or_panic<'c, H: HugrView<Node = Node>>(
461    ctx: &mut EmitFuncContext<'c, '_, H>,
462    dont_panic: IntValue<'c>,
463    err_msg_str: &str,
464    val: BasicValueEnum<'c>, // Must be same int type as `dont_panic`
465) -> Result<BasicValueEnum<'c>> {
466    let done_bb = ctx.new_basic_block("done", None);
467    let exit_bb = ctx.new_basic_block("exit", Some(done_bb));
468    let go_bb = ctx.new_basic_block("panic_if_0", Some(exit_bb));
469    let panic_bb = ctx.new_basic_block("panic", Some(exit_bb));
470    ctx.builder().build_unconditional_branch(go_bb)?;
471
472    ctx.builder().position_at_end(exit_bb);
473    ctx.builder().build_return(Some(&val))?;
474
475    ctx.builder().position_at_end(panic_bb);
476    let err_msg = ctx
477        .builder()
478        .build_global_string_ptr(err_msg_str, "err_msg")?
479        .as_basic_value_enum();
480    emit_libc_printf(ctx, &[err_msg.into()])?;
481    emit_libc_abort(ctx)?;
482    ctx.builder().build_unconditional_branch(exit_bb)?;
483
484    ctx.builder().position_at_end(go_bb);
485    ctx.builder().build_switch(
486        dont_panic,
487        panic_bb,
488        &[(dont_panic.get_type().const_int(1, false), exit_bb)],
489    )?;
490
491    ctx.builder().position_at_end(done_bb);
492
493    Ok(val) // Returning val should be nonsense if we panic
494}
495
496fn val_or_error<'c, H: HugrView<Node = Node>>(
497    ctx: &mut EmitFuncContext<'c, '_, H>,
498    should_fail: IntValue<'c>,
499    val: BasicValueEnum<'c>,
500    msg: RuntimeError,
501    ty: LLVMSumType<'c>,
502) -> Result<BasicValueEnum<'c>> {
503    let err_msg = Value::extension(ConstError::new(2, msg.show()));
504    let err_val = emit_value(ctx, &err_msg)?;
505
506    let err_variant = ty.build_tag(ctx.builder(), 0, vec![err_val])?;
507    let ok_variant = ty.build_tag(ctx.builder(), 1, vec![val])?;
508
509    Ok(ctx
510        .builder()
511        .build_select(should_fail, err_variant, ok_variant, "")?)
512}
513
514fn llvm_type<'c>(
515    context: TypingSession<'c, '_>,
516    hugr_type: &CustomType,
517) -> Result<BasicTypeEnum<'c>> {
518    if let [TypeArg::BoundedNat { n }] = hugr_type.args() {
519        let m = *n as usize;
520        if m < int_types::INT_TYPES.len() && int_types::INT_TYPES[m] == hugr_type.clone().into() {
521            return Ok(match m {
522                0..=3 => context.iw_context().i8_type(),
523                4 => context.iw_context().i16_type(),
524                5 => context.iw_context().i32_type(),
525                6 => context.iw_context().i64_type(),
526                _ => Err(anyhow!(
527                    "IntTypesCodegenExtension: unsupported log_width: {}",
528                    m
529                ))?,
530            }
531            .into());
532        }
533    }
534    Err(anyhow!(
535        "IntTypesCodegenExtension: unsupported type: {}",
536        hugr_type
537    ))
538}
539
540fn emit_const_int<'c, H: HugrView<Node = Node>>(
541    context: &mut EmitFuncContext<'c, '_, H>,
542    k: &ConstInt,
543) -> Result<BasicValueEnum<'c>> {
544    let ty: IntType = context.llvm_type(&k.get_type())?.try_into().unwrap();
545    // k.value_u() is in two's complement representation of the exactly
546    // correct bit width, so we are safe to unconditionally retrieve the
547    // unsigned value and do no sign extension.
548    Ok(ty.const_int(k.value_u(), false).as_basic_value_enum())
549}
550
551/// Populates a [CodegenExtsBuilder] with all extensions needed to lower int
552/// ops, types, and constants.
553pub fn add_int_extensions<'a, H: HugrView<Node = Node> + 'a>(
554    cem: CodegenExtsBuilder<'a, H>,
555) -> CodegenExtsBuilder<'a, H> {
556    cem.custom_const(emit_const_int)
557        .custom_type((int_types::EXTENSION_ID, "int".into()), llvm_type)
558        .simple_extension_op::<IntOpDef>(emit_int_op)
559}
560
561impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
562    /// Populates a [CodegenExtsBuilder] with all extensions needed to lower int
563    /// ops, types, and constants.
564    pub fn add_int_extensions(self) -> Self {
565        add_int_extensions(self)
566    }
567}
568
569#[cfg(test)]
570mod test {
571    use anyhow::Result;
572    use hugr_core::extension::prelude::{error_type, ConstError, UnwrapBuilder};
573    use hugr_core::std_extensions::STD_REG;
574    use hugr_core::{
575        builder::{handle::Outputs, Dataflow, DataflowSubContainer, SubContainer},
576        extension::prelude::bool_t,
577        ops::{DataflowOpTrait, ExtensionOp, NamedOp},
578        std_extensions::arithmetic::{
579            int_ops::{self, IntOpDef},
580            int_types::{ConstInt, INT_TYPES},
581        },
582        types::{SumType, Type, TypeRow},
583        Hugr,
584    };
585    use rstest::rstest;
586
587    use crate::extension::DefaultPreludeCodegen;
588    use crate::{
589        check_emission,
590        emit::test::{SimpleHugrConfig, DFGW},
591        extension::{int::add_int_extensions, prelude::add_prelude_extensions},
592        test::{exec_ctx, llvm_ctx, single_op_hugr, TestContext},
593    };
594
595    // Instantiate an extension op which takes one width argument
596    fn make_int_op(name: impl AsRef<str>, log_width: u8) -> ExtensionOp {
597        int_ops::EXTENSION
598            .instantiate_extension_op(name.as_ref(), [(log_width as u64).into()])
599            .unwrap()
600    }
601
602    fn test_binary_int_op(ext_op: ExtensionOp, log_width: u8) -> Hugr {
603        let ty = &INT_TYPES[log_width as usize];
604        test_int_op_with_results::<2>(ext_op, log_width, None, ty.clone())
605    }
606
607    fn test_binary_icmp_op(ext_op: ExtensionOp, log_width: u8) -> Hugr {
608        test_int_op_with_results::<2>(ext_op, log_width, None, bool_t())
609    }
610
611    fn test_int_op_with_results<const N: usize>(
612        ext_op: ExtensionOp,
613        log_width: u8,
614        inputs: Option<[ConstInt; N]>,
615        output_type: Type,
616    ) -> Hugr {
617        test_int_op_with_results_processing(ext_op, log_width, inputs, output_type, |_, a| Ok(a))
618    }
619
620    fn test_int_op_with_results_processing<const N: usize>(
621        // N is the number of inputs to the hugr
622        ext_op: ExtensionOp,
623        log_width: u8,
624        inputs: Option<[ConstInt; N]>, // If inputs are provided, they'll be wired into the op, otherwise the inputs to the hugr will be wired into the op
625        output_type: Type,
626        process: impl Fn(&mut DFGW, Outputs) -> Result<Outputs>,
627    ) -> Hugr {
628        let ty = &INT_TYPES[log_width as usize];
629        let input_tys = if inputs.is_some() {
630            vec![]
631        } else {
632            let input_tys = itertools::repeat_n(ty.clone(), N).collect();
633            assert_eq!(input_tys, ext_op.signature().input.to_vec());
634            input_tys
635        };
636        SimpleHugrConfig::new()
637            .with_ins(input_tys)
638            .with_outs(vec![output_type])
639            .with_extensions(STD_REG.clone())
640            .finish(|mut hugr_builder| {
641                let input_wires = match inputs {
642                    None => hugr_builder.input_wires_arr::<N>().to_vec(),
643                    Some(inputs) => {
644                        let mut input_wires = Vec::new();
645                        inputs.into_iter().for_each(|i| {
646                            let w = hugr_builder.add_load_value(i);
647                            input_wires.push(w);
648                        });
649                        input_wires
650                    }
651                };
652                let outputs = hugr_builder
653                    .add_dataflow_op(ext_op, input_wires)
654                    .unwrap()
655                    .outputs();
656                let processed_outputs = process(&mut hugr_builder, outputs).unwrap();
657                hugr_builder.finish_with_outputs(processed_outputs).unwrap()
658            })
659    }
660
661    #[rstest]
662    #[case(IntOpDef::iu_to_s, &[3])]
663    #[case(IntOpDef::is_to_u, &[3])]
664    #[case(IntOpDef::ineg, &[2])]
665    fn test_emission(mut llvm_ctx: TestContext, #[case] op: IntOpDef, #[case] args: &[u8]) {
666        llvm_ctx.add_extensions(add_int_extensions);
667        let mut insta = insta::Settings::clone_current();
668        insta.set_snapshot_suffix(format!(
669            "{}_{}_{:?}",
670            insta.snapshot_suffix().unwrap_or(""),
671            op.name(),
672            args,
673        ));
674        let concrete = match *args {
675            [] => op.without_log_width(),
676            [log_width] => op.with_log_width(log_width),
677            [lw1, lw2] => op.with_two_log_widths(lw1, lw2),
678            _ => panic!("unexpected number of args to the op!"),
679        };
680        insta.bind(|| {
681            let hugr = single_op_hugr(concrete.into());
682            check_emission!(hugr, llvm_ctx);
683        })
684    }
685
686    #[rstest]
687    #[case::iadd("iadd", 3)]
688    #[case::isub("isub", 6)]
689    #[case::ipow("ipow", 3)]
690    fn test_binop_emission(mut llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
691        llvm_ctx.add_extensions(add_int_extensions);
692        let ext_op = make_int_op(op.clone(), width);
693        let hugr = test_binary_int_op(ext_op, width);
694        check_emission!(op.clone(), hugr, llvm_ctx);
695    }
696
697    #[rstest]
698    #[case::signed_2_3("iwiden_s", 2, 3)]
699    #[case::signed_1_6("iwiden_s", 1, 6)]
700    #[case::unsigned_2_3("iwiden_u", 2, 3)]
701    #[case::unsigned_1_6("iwiden_u", 1, 6)]
702    fn test_widen_emission(
703        mut llvm_ctx: TestContext,
704        #[case] op: String,
705        #[case] from: u8,
706        #[case] to: u8,
707    ) {
708        llvm_ctx.add_extensions(add_int_extensions);
709        let out_ty = INT_TYPES[to as usize].clone();
710        let ext_op = int_ops::EXTENSION
711            .instantiate_extension_op(&op, [(from as u64).into(), (to as u64).into()])
712            .unwrap();
713        let hugr = test_int_op_with_results::<1>(ext_op, from, None, out_ty);
714
715        check_emission!(format!("{}_{}_{}", op.clone(), from, to), hugr, llvm_ctx);
716    }
717
718    #[rstest]
719    #[case::signed("inarrow_s", 3, 2)]
720    #[case::unsigned("inarrow_u", 6, 4)]
721    fn test_narrow_emission(
722        mut llvm_ctx: TestContext,
723        #[case] op: String,
724        #[case] from: u8,
725        #[case] to: u8,
726    ) {
727        llvm_ctx.add_extensions(add_int_extensions);
728        llvm_ctx.add_extensions(|cem| add_prelude_extensions(cem, DefaultPreludeCodegen));
729        let out_ty = SumType::new([vec![error_type()], vec![INT_TYPES[to as usize].clone()]]);
730        let ext_op = int_ops::EXTENSION
731            .instantiate_extension_op(&op, [(from as u64).into(), (to as u64).into()])
732            .unwrap();
733        let hugr = test_int_op_with_results::<1>(ext_op, from, None, out_ty.into());
734
735        check_emission!(format!("{}_{}_{}", op.clone(), from, to), hugr, llvm_ctx);
736    }
737
738    #[rstest]
739    #[case::ieq("ieq", 1)]
740    #[case::ilt_s("ilt_s", 0)]
741    fn test_cmp_emission(mut llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
742        llvm_ctx.add_extensions(add_int_extensions);
743        let ext_op = make_int_op(op.clone(), width);
744        let hugr = test_binary_icmp_op(ext_op, width);
745        check_emission!(op.clone(), hugr, llvm_ctx);
746    }
747
748    #[rstest]
749    #[case::imax("imax_u", 1, 2, 2)]
750    #[case::imax("imax_u", 2, 1, 2)]
751    #[case::imax("imax_u", 2, 2, 2)]
752    #[case::imin("imin_u", 1, 2, 1)]
753    #[case::imin("imin_u", 2, 1, 1)]
754    #[case::imin("imin_u", 2, 2, 2)]
755    #[case::ishl("ishl", 73, 1, 146)]
756    // (2^64 - 1) << 1 = (2^64 - 2)
757    #[case::ishl("ishl", 18446744073709551615, 1, 18446744073709551614)]
758    #[case::ishr("ishr", 73, 1, 36)]
759    #[case::ior("ior", 6, 9, 15)]
760    #[case::ior("ior", 6, 15, 15)]
761    #[case::ixor("ixor", 6, 9, 15)]
762    #[case::ixor("ixor", 6, 15, 9)]
763    #[case::ixor("ixor", 15, 6, 9)]
764    #[case::iand("iand", 6, 15, 6)]
765    #[case::iand("iand", 15, 6, 6)]
766    #[case::iand("iand", 15, 15, 15)]
767    #[case::ipow("ipow", 2, 3, 8)]
768    #[case::ipow("ipow", 42, 1, 42)]
769    #[case::ipow("ipow", 42, 0, 1)]
770    fn test_exec_unsigned_bin_op(
771        mut exec_ctx: TestContext,
772        #[case] op: String,
773        #[case] lhs: u64,
774        #[case] rhs: u64,
775        #[case] result: u64,
776    ) {
777        exec_ctx.add_extensions(add_int_extensions);
778        let ty = &INT_TYPES[6].clone();
779        let inputs = [
780            ConstInt::new_u(6, lhs).unwrap(),
781            ConstInt::new_u(6, rhs).unwrap(),
782        ];
783        let ext_op = make_int_op(&op, 6);
784
785        let hugr = test_int_op_with_results::<2>(ext_op, 6, Some(inputs), ty.clone());
786        assert_eq!(exec_ctx.exec_hugr_u64(hugr, "main"), result);
787    }
788
789    #[rstest]
790    #[case::imax("imax_s", 1, 2, 2)]
791    #[case::imax("imax_s", 2, 1, 2)]
792    #[case::imax("imax_s", 2, 2, 2)]
793    #[case::imax("imax_s", -1, -2, -1)]
794    #[case::imax("imax_s", -2, -1, -1)]
795    #[case::imax("imax_s", -2, -2, -2)]
796    #[case::imin("imin_s", 1, 2, 1)]
797    #[case::imin("imin_s", 2, 1, 1)]
798    #[case::imin("imin_s", 2, 2, 2)]
799    #[case::imin("imin_s", -1, -2, -2)]
800    #[case::imin("imin_s", -2, -1, -2)]
801    #[case::imin("imin_s", -2, -2, -2)]
802    #[case::ipow("ipow", -2, 1, -2)]
803    #[case::ipow("ipow", -2, 2, 4)]
804    #[case::ipow("ipow", -2, 3, -8)]
805    fn test_exec_signed_bin_op(
806        mut exec_ctx: TestContext,
807        #[case] op: String,
808        #[case] lhs: i64,
809        #[case] rhs: i64,
810        #[case] result: i64,
811    ) {
812        exec_ctx.add_extensions(add_int_extensions);
813        let ty = &INT_TYPES[6].clone();
814        let inputs = [
815            ConstInt::new_s(6, lhs).unwrap(),
816            ConstInt::new_s(6, rhs).unwrap(),
817        ];
818        let ext_op = make_int_op(&op, 6);
819
820        let hugr = test_int_op_with_results::<2>(ext_op, 6, Some(inputs), ty.clone());
821        assert_eq!(exec_ctx.exec_hugr_i64(hugr, "main"), result);
822    }
823
824    #[rstest]
825    #[case::iabs("iabs", 42, 42)]
826    #[case::iabs("iabs", -42, 42)]
827    fn test_exec_signed_unary_op(
828        mut exec_ctx: TestContext,
829        #[case] op: String,
830        #[case] arg: i64,
831        #[case] result: i64,
832    ) {
833        exec_ctx.add_extensions(add_int_extensions);
834        let input = ConstInt::new_s(6, arg).unwrap();
835        let ty = INT_TYPES[6].clone();
836        let ext_op = make_int_op(&op, 6);
837
838        let hugr = test_int_op_with_results::<1>(ext_op, 6, Some([input]), ty.clone());
839        assert_eq!(exec_ctx.exec_hugr_i64(hugr, "main"), result);
840    }
841
842    #[rstest]
843    #[case::inot("inot", 9223372036854775808, !9223372036854775808u64)]
844    #[case::inot("inot", 42, !42u64)]
845    #[case::inot("inot", !0u64, 0)]
846    fn test_exec_unsigned_unary_op(
847        mut exec_ctx: TestContext,
848        #[case] op: String,
849        #[case] arg: u64,
850        #[case] result: u64,
851    ) {
852        exec_ctx.add_extensions(add_int_extensions);
853        let input = ConstInt::new_u(6, arg).unwrap();
854        let ty = INT_TYPES[6].clone();
855        let ext_op = make_int_op(&op, 6);
856
857        let hugr = test_int_op_with_results::<1>(ext_op, 6, Some([input]), ty.clone());
858        assert_eq!(exec_ctx.exec_hugr_u64(hugr, "main"), result);
859    }
860
861    #[rstest]
862    #[case("inarrow_s", 6, 2, 4)]
863    #[case("inarrow_s", 6, 5, (1 << 5) - 1)]
864    #[case("inarrow_s", 6, 4, -1)]
865    #[case("inarrow_s", 6, 4, -(1 << 4) - 1)]
866    #[case("inarrow_s", 6, 4, -(1 <<15))]
867    #[case("inarrow_s", 6, 5, (1 << 31) - 1)]
868    fn test_narrow_s(
869        mut exec_ctx: TestContext,
870        #[case] op: String,
871        #[case] from: u8,
872        #[case] to: u8,
873        #[case] arg: i64,
874    ) {
875        exec_ctx.add_extensions(add_int_extensions);
876        exec_ctx.add_extensions(|cem| add_prelude_extensions(cem, DefaultPreludeCodegen));
877        let input = ConstInt::new_s(from, arg).unwrap();
878        let to_ty = INT_TYPES[to as usize].clone();
879        let ext_op = int_ops::EXTENSION
880            .instantiate_extension_op(op.as_ref(), [(from as u64).into(), (to as u64).into()])
881            .unwrap();
882
883        let hugr = test_int_op_with_results_processing::<1>(
884            ext_op,
885            to,
886            Some([input]),
887            to_ty.clone(),
888            |builder, outs| {
889                let [out] = outs.to_array();
890
891                let err_row = TypeRow::from(vec![error_type()]);
892                let ty_row = TypeRow::from(vec![to_ty.clone()]);
893                // Handle the sum type returned by narrow by building a conditional.
894                // We're only testing the happy path here, so insert a panic in the
895                // "error" branch, knowing that it wont come up.
896                //
897                // Negative results can be tested manually, but we lack the testing
898                // infrastructure to detect execution crashes without crashing the
899                // test process.
900                let mut cond_b = builder.conditional_builder(
901                    ([err_row, ty_row], out),
902                    [],
903                    vec![to_ty.clone()].into(),
904                )?;
905                let mut sad_b = cond_b.case_builder(0)?;
906                let err = ConstError::new(2, "This shouldn't happen");
907                let w = sad_b.add_load_value(ConstInt::new_s(to, 0)?);
908                sad_b.add_panic(err, vec![to_ty.clone()], [(w, to_ty.clone())])?;
909                sad_b.finish_with_outputs([w])?;
910
911                let happy_b = cond_b.case_builder(1)?;
912                let [w] = happy_b.input_wires_arr();
913                happy_b.finish_with_outputs([w])?;
914
915                let handle = cond_b.finish_sub_container()?;
916                Ok(handle.outputs())
917            },
918        );
919        assert_eq!(exec_ctx.exec_hugr_i64(hugr, "main"), arg);
920    }
921
922    #[rstest]
923    #[case(6, 42)]
924    #[case(4, 7)]
925    //#[case(4, 256)] -- crashes because a panic is emitted (good)
926    fn test_u_to_s(mut exec_ctx: TestContext, #[case] log_width: u8, #[case] val: u64) {
927        exec_ctx.add_extensions(add_int_extensions);
928        let ty = &INT_TYPES[log_width as usize].clone();
929        let hugr = SimpleHugrConfig::new()
930            .with_outs(vec![ty.clone()])
931            .with_extensions(STD_REG.clone())
932            .finish(|mut hugr_builder| {
933                let unsigned =
934                    hugr_builder.add_load_value(ConstInt::new_u(log_width, val).unwrap());
935                let iu_to_s = make_int_op("iu_to_s", log_width);
936                let [signed] = hugr_builder
937                    .add_dataflow_op(iu_to_s, [unsigned])
938                    .unwrap()
939                    .outputs_arr();
940                hugr_builder.finish_with_outputs([signed]).unwrap()
941            });
942        let act = exec_ctx.exec_hugr_i64(hugr, "main");
943        assert_eq!(act, val as i64);
944    }
945
946    #[rstest]
947    #[case(3, 0)]
948    #[case(4, 255)]
949    // #[case(3, -1)] -- crashes because a panic is emitted (good)
950    fn test_s_to_u(mut exec_ctx: TestContext, #[case] log_width: u8, #[case] val: i64) {
951        exec_ctx.add_extensions(add_int_extensions);
952        let ty = &INT_TYPES[log_width as usize].clone();
953        let hugr = SimpleHugrConfig::new()
954            .with_outs(vec![ty.clone()])
955            .with_extensions(STD_REG.clone())
956            .finish(|mut hugr_builder| {
957                let signed = hugr_builder.add_load_value(ConstInt::new_s(log_width, val).unwrap());
958                let is_to_u = make_int_op("is_to_u", log_width);
959                let [unsigned] = hugr_builder
960                    .add_dataflow_op(is_to_u, [signed])
961                    .unwrap()
962                    .outputs_arr();
963                hugr_builder.finish_with_outputs([unsigned]).unwrap()
964            });
965        let act = exec_ctx.exec_hugr_u64(hugr, "main");
966        assert_eq!(act, val as u64);
967    }
968}