hugr_llvm/extension/
int.rs

1use hugr_core::{
2    HugrView, Node,
3    extension::{
4        prelude::{ConstError, sum_with_error},
5        simple_op::MakeExtensionOp,
6    },
7    ops::{ExtensionOp, Value, constant::CustomConst},
8    std_extensions::arithmetic::{
9        int_ops::IntOpDef,
10        int_types::{self, ConstInt},
11    },
12    types::{CustomType, Type, TypeArg},
13};
14use inkwell::{
15    IntPredicate,
16    types::{BasicType, BasicTypeEnum, IntType},
17    values::{BasicValue, BasicValueEnum, IntValue},
18};
19use lazy_static::lazy_static;
20
21use crate::{
22    CodegenExtension,
23    custom::CodegenExtsBuilder,
24    emit::{
25        EmitOpArgs, emit_value,
26        func::EmitFuncContext,
27        get_intrinsic,
28        ops::{emit_custom_binary_op, emit_custom_unary_op},
29    },
30    sum::{LLVMSumType, LLVMSumValue},
31    types::{HugrSumType, TypingSession},
32};
33
34use anyhow::{Result, anyhow, bail};
35
36use super::{DefaultPreludeCodegen, PreludeCodegen, conversions::int_type_bounds};
37
38#[derive(Clone, Debug, Default)]
39pub struct IntCodegenExtension<PCG>(PCG);
40
41impl<PCG: PreludeCodegen> IntCodegenExtension<PCG> {
42    pub fn new(ccg: PCG) -> Self {
43        Self(ccg)
44    }
45}
46
47impl<CCG: PreludeCodegen> From<CCG> for IntCodegenExtension<CCG> {
48    fn from(ccg: CCG) -> Self {
49        Self::new(ccg)
50    }
51}
52
53impl<CCG: PreludeCodegen> CodegenExtension for IntCodegenExtension<CCG> {
54    fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
55        self,
56        builder: CodegenExtsBuilder<'a, H>,
57    ) -> CodegenExtsBuilder<'a, H>
58    where
59        Self: 'a,
60    {
61        builder
62            .custom_const(emit_const_int)
63            .custom_type((int_types::EXTENSION_ID, "int".into()), llvm_type)
64            .simple_extension_op::<IntOpDef>(move |context, args, op| {
65                emit_int_op(context, &self.0, args, op)
66            })
67    }
68}
69
70lazy_static! {
71    static ref ERR_NARROW: ConstError = ConstError {
72        signal: 2,
73        message: "Can't narrow into bounds".to_string(),
74    };
75    static ref ERR_IU_TO_S: ConstError = ConstError {
76        signal: 2,
77        message: "iu_to_s argument out of bounds".to_string(),
78    };
79    static ref ERR_IS_TO_U: ConstError = ConstError {
80        signal: 2,
81        message: "is_to_u called on negative value".to_string(),
82    };
83    static ref ERR_DIV_0: ConstError = ConstError {
84        signal: 2,
85        message: "Attempted division by 0".to_string(),
86    };
87}
88
89#[derive(Debug, Eq, PartialEq)]
90enum DivOrMod {
91    Div,
92    Mod,
93    DivMod,
94}
95
96struct DivModOp {
97    op: DivOrMod,
98    signed: bool,
99    panic: bool,
100}
101
102impl DivModOp {
103    fn emit<'c, H: HugrView<Node = Node>>(
104        self,
105        ctx: &mut EmitFuncContext<'c, '_, H>,
106        pcg: &impl PreludeCodegen,
107        log_width: u64,
108        numerator: IntValue<'c>,
109        denominator: IntValue<'c>,
110    ) -> Result<Vec<BasicValueEnum<'c>>> {
111        // Hugr semantics say that div and mod are equivalent to doing a divmod,
112        // then projecting out an element from the pair, so that's what we do.
113        let quotrem = make_divmod(
114            ctx,
115            pcg,
116            log_width,
117            numerator,
118            denominator,
119            self.panic,
120            self.signed,
121        )?;
122
123        if self.op == DivOrMod::DivMod {
124            if self.panic {
125                // Unpack the tuple into two values.
126                Ok(quotrem.build_untag(ctx.builder(), 0).unwrap())
127            } else {
128                Ok(vec![quotrem.as_basic_value_enum()])
129            }
130        } else {
131            // Which field we should project out from the result of divmod.
132            let index = match self.op {
133                DivOrMod::Div => 0,
134                DivOrMod::Mod => 1,
135                _ => unreachable!(),
136            };
137            // If we emitted a panicking divmod, the result is just a tuple type.
138            if self.panic {
139                Ok(vec![
140                    quotrem
141                        .build_untag(ctx.builder(), 0)?
142                        .into_iter()
143                        .nth(index)
144                        .unwrap(),
145                ])
146            }
147            // Otherwise, we have a sum type `err + [int,int]`, which we need to
148            // turn into a `err + int`.
149            else {
150                // Get the data out the the divmod result.
151                let int_ty = numerator.get_type().as_basic_type_enum();
152                let tuple_ty =
153                    LLVMSumType::try_new(ctx.iw_context(), vec![vec![int_ty, int_ty]]).unwrap();
154                let tuple = quotrem
155                    .build_untag(ctx.builder(), 1)?
156                    .into_iter()
157                    .next()
158                    .unwrap();
159                let tuple_val = LLVMSumValue::try_new(tuple, tuple_ty)?;
160                let data_val = tuple_val
161                    .build_untag(ctx.builder(), 0)?
162                    .into_iter()
163                    .nth(index)
164                    .unwrap();
165                let err_val = quotrem
166                    .build_untag(ctx.builder(), 0)?
167                    .into_iter()
168                    .next()
169                    .unwrap();
170
171                let tag_val = quotrem.build_get_tag(ctx.builder())?;
172                tag_val.set_name("tag");
173
174                // Build a new struct with the old tag and error data.
175                let int_ty = int_types::INT_TYPES[log_width as usize].clone();
176                let out_ty = LLVMSumType::try_from_hugr_type(
177                    &ctx.typing_session(),
178                    sum_with_error(vec![int_ty.clone()]),
179                )
180                .unwrap();
181
182                let data_variant = out_ty.build_tag(ctx.builder(), 1, vec![data_val])?;
183                data_variant.set_name("data_variant");
184                let err_variant = out_ty.build_tag(ctx.builder(), 0, vec![err_val])?;
185                err_variant.set_name("err_variant");
186
187                let result = ctx
188                    .builder()
189                    .build_select(tag_val, data_variant, err_variant, "")?;
190                Ok(vec![result])
191            }
192        }
193    }
194}
195
196/// `ConstError` an integer comparison operation.
197fn emit_icmp<'c, H: HugrView<Node = Node>>(
198    context: &mut EmitFuncContext<'c, '_, H>,
199    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
200    pred: inkwell::IntPredicate,
201) -> Result<()> {
202    let true_val = emit_value(context, &Value::true_val())?;
203    let false_val = emit_value(context, &Value::false_val())?;
204
205    emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
206        // get result as an i1
207        let r = ctx.builder().build_int_compare(
208            pred,
209            lhs.into_int_value(),
210            rhs.into_int_value(),
211            "",
212        )?;
213        // convert to whatever bool_t is
214        Ok(vec![
215            ctx.builder().build_select(r, true_val, false_val, "")?,
216        ])
217    })
218}
219
220/// Emit an ipow operation. This isn't directly supported in llvm, so we do a
221/// loop over the exponent, performing `imul`s instead.
222/// The insertion pointer is expected to be pointing to the end of `launch_bb`.
223fn emit_ipow<'c, H: HugrView<Node = Node>>(
224    context: &mut EmitFuncContext<'c, '_, H>,
225    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
226) -> Result<()> {
227    emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
228        let done_bb = ctx.new_basic_block("done", None);
229        let pow_body_bb = ctx.new_basic_block("pow_body", Some(done_bb));
230        let return_one_bb = ctx.new_basic_block("power_of_zero", Some(pow_body_bb));
231        let pow_bb = ctx.new_basic_block("pow", Some(return_one_bb));
232
233        let acc_p = ctx.builder().build_alloca(lhs.get_type(), "acc_ptr")?;
234        let exp_p = ctx.builder().build_alloca(rhs.get_type(), "exp_ptr")?;
235        ctx.builder().build_store(acc_p, lhs)?;
236        ctx.builder().build_store(exp_p, rhs)?;
237        ctx.builder().build_unconditional_branch(pow_bb)?;
238
239        let zero = rhs.get_type().into_int_type().const_int(0, false);
240        // Assumes RHS type is the same as output type (which it should be)
241        let one = rhs.get_type().into_int_type().const_int(1, false);
242
243        // Block for just returning one
244        ctx.builder().position_at_end(return_one_bb);
245        ctx.builder().build_store(acc_p, one)?;
246        ctx.builder().build_unconditional_branch(done_bb)?;
247
248        ctx.builder().position_at_end(pow_bb);
249        let acc = ctx.builder().build_load(acc_p, "acc")?;
250        let exp = ctx.builder().build_load(exp_p, "exp")?;
251
252        // Special case if the exponent is 0 or 1
253        ctx.builder().build_switch(
254            exp.into_int_value(),
255            pow_body_bb,
256            &[(one, done_bb), (zero, return_one_bb)],
257        )?;
258
259        // Block that performs one `imul` and modifies the values in the store
260        ctx.builder().position_at_end(pow_body_bb);
261        let new_acc =
262            ctx.builder()
263                .build_int_mul(acc.into_int_value(), lhs.into_int_value(), "new_acc")?;
264        let new_exp = ctx
265            .builder()
266            .build_int_sub(exp.into_int_value(), one, "new_exp")?;
267        ctx.builder().build_store(acc_p, new_acc)?;
268        ctx.builder().build_store(exp_p, new_exp)?;
269        ctx.builder().build_unconditional_branch(pow_bb)?;
270
271        ctx.builder().position_at_end(done_bb);
272        let result = ctx.builder().build_load(acc_p, "result")?;
273        Ok(vec![result.as_basic_value_enum()])
274    })
275}
276
277fn emit_int_op<'c, H: HugrView<Node = Node>>(
278    context: &mut EmitFuncContext<'c, '_, H>,
279    pcg: &impl PreludeCodegen,
280    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
281    op: IntOpDef,
282) -> Result<()> {
283    match op {
284        IntOpDef::iadd => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
285            Ok(vec![
286                ctx.builder()
287                    .build_int_add(lhs.into_int_value(), rhs.into_int_value(), "")?
288                    .as_basic_value_enum(),
289            ])
290        }),
291        IntOpDef::imul => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
292            Ok(vec![
293                ctx.builder()
294                    .build_int_mul(lhs.into_int_value(), rhs.into_int_value(), "")?
295                    .as_basic_value_enum(),
296            ])
297        }),
298        IntOpDef::isub => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
299            Ok(vec![
300                ctx.builder()
301                    .build_int_sub(lhs.into_int_value(), rhs.into_int_value(), "")?
302                    .as_basic_value_enum(),
303            ])
304        }),
305        IntOpDef::idiv_s => {
306            let log_width = get_width_arg(&args, &op)?;
307            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
308                let op = DivModOp {
309                    op: DivOrMod::Div,
310                    signed: true,
311                    panic: true,
312                };
313                op.emit(
314                    ctx,
315                    pcg,
316                    log_width,
317                    lhs.into_int_value(),
318                    rhs.into_int_value(),
319                )
320            })
321        }
322        IntOpDef::idiv_u => {
323            let log_width = get_width_arg(&args, &op)?;
324            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
325                let op = DivModOp {
326                    op: DivOrMod::Div,
327                    signed: false,
328                    panic: true,
329                };
330                op.emit(
331                    ctx,
332                    pcg,
333                    log_width,
334                    lhs.into_int_value(),
335                    rhs.into_int_value(),
336                )
337            })
338        }
339        IntOpDef::imod_s => {
340            let log_width = get_width_arg(&args, &op)?;
341            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
342                let op = DivModOp {
343                    op: DivOrMod::Mod,
344                    signed: true,
345                    panic: true,
346                };
347                op.emit(
348                    ctx,
349                    pcg,
350                    log_width,
351                    lhs.into_int_value(),
352                    rhs.into_int_value(),
353                )
354            })
355        }
356        IntOpDef::imod_u => {
357            let log_width = get_width_arg(&args, &op)?;
358            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
359                let op = DivModOp {
360                    op: DivOrMod::Mod,
361                    signed: false,
362                    panic: true,
363                };
364                op.emit(
365                    ctx,
366                    pcg,
367                    log_width,
368                    lhs.into_int_value(),
369                    rhs.into_int_value(),
370                )
371            })
372        }
373        IntOpDef::idivmod_u => {
374            let log_width = get_width_arg(&args, &op)?;
375            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
376                let op = DivModOp {
377                    op: DivOrMod::DivMod,
378                    signed: false,
379                    panic: true,
380                };
381                op.emit(
382                    ctx,
383                    pcg,
384                    log_width,
385                    lhs.into_int_value(),
386                    rhs.into_int_value(),
387                )
388            })
389        }
390        IntOpDef::idivmod_s => {
391            let log_width = get_width_arg(&args, &op)?;
392            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
393                let op = DivModOp {
394                    op: DivOrMod::DivMod,
395                    signed: true,
396                    panic: true,
397                };
398                op.emit(
399                    ctx,
400                    pcg,
401                    log_width,
402                    lhs.into_int_value(),
403                    rhs.into_int_value(),
404                )
405            })
406        }
407        IntOpDef::idiv_checked_s => {
408            let log_width = get_width_arg(&args, &op)?;
409            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
410                let op = DivModOp {
411                    op: DivOrMod::Div,
412                    signed: true,
413                    panic: false,
414                };
415                op.emit(
416                    ctx,
417                    pcg,
418                    log_width,
419                    lhs.into_int_value(),
420                    rhs.into_int_value(),
421                )
422            })
423        }
424        IntOpDef::idiv_checked_u => {
425            let log_width = get_width_arg(&args, &op)?;
426
427            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
428                let op = DivModOp {
429                    op: DivOrMod::Div,
430                    signed: false,
431                    panic: false,
432                };
433                op.emit(
434                    ctx,
435                    pcg,
436                    log_width,
437                    lhs.into_int_value(),
438                    rhs.into_int_value(),
439                )
440            })
441        }
442        IntOpDef::imod_checked_s => {
443            let log_width = get_width_arg(&args, &op)?;
444            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
445                let op = DivModOp {
446                    op: DivOrMod::Mod,
447                    signed: true,
448                    panic: false,
449                };
450                op.emit(
451                    ctx,
452                    pcg,
453                    log_width,
454                    lhs.into_int_value(),
455                    rhs.into_int_value(),
456                )
457            })
458        }
459        IntOpDef::imod_checked_u => {
460            let log_width = get_width_arg(&args, &op)?;
461            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
462                let op = DivModOp {
463                    op: DivOrMod::Mod,
464                    signed: false,
465                    panic: false,
466                };
467                op.emit(
468                    ctx,
469                    pcg,
470                    log_width,
471                    lhs.into_int_value(),
472                    rhs.into_int_value(),
473                )
474            })
475        }
476        IntOpDef::idivmod_checked_u => {
477            let log_width = get_width_arg(&args, &op)?;
478            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
479                let op = DivModOp {
480                    op: DivOrMod::DivMod,
481                    signed: false,
482                    panic: false,
483                };
484                op.emit(
485                    ctx,
486                    pcg,
487                    log_width,
488                    lhs.into_int_value(),
489                    rhs.into_int_value(),
490                )
491            })
492        }
493        IntOpDef::idivmod_checked_s => {
494            let log_width = get_width_arg(&args, &op)?;
495            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
496                let op = DivModOp {
497                    op: DivOrMod::DivMod,
498                    signed: true,
499                    panic: false,
500                };
501                op.emit(
502                    ctx,
503                    pcg,
504                    log_width,
505                    lhs.into_int_value(),
506                    rhs.into_int_value(),
507                )
508            })
509        }
510        IntOpDef::ineg => emit_custom_unary_op(context, args, |ctx, arg, _| {
511            Ok(vec![
512                ctx.builder()
513                    .build_int_neg(arg.into_int_value(), "")?
514                    .as_basic_value_enum(),
515            ])
516        }),
517        IntOpDef::iabs => emit_custom_unary_op(context, args, |ctx, arg, _| {
518            let intr = get_intrinsic(
519                ctx.get_current_module(),
520                "llvm.abs.i64",
521                [ctx.iw_context().i64_type().as_basic_type_enum()],
522            )?;
523            let true_ = ctx.iw_context().bool_type().const_all_ones();
524            let r = ctx
525                .builder()
526                .build_call(intr, &[arg.into_int_value().into(), true_.into()], "")?
527                .try_as_basic_value()
528                .unwrap_left();
529            Ok(vec![r])
530        }),
531        IntOpDef::imax_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
532            let intr = get_intrinsic(
533                ctx.get_current_module(),
534                "llvm.smax.i64",
535                [ctx.iw_context().i64_type().as_basic_type_enum()],
536            )?;
537            let r = ctx
538                .builder()
539                .build_call(
540                    intr,
541                    &[lhs.into_int_value().into(), rhs.into_int_value().into()],
542                    "",
543                )?
544                .try_as_basic_value()
545                .unwrap_left();
546            Ok(vec![r])
547        }),
548        IntOpDef::imax_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
549            let intr = get_intrinsic(
550                ctx.get_current_module(),
551                "llvm.umax.i64",
552                [ctx.iw_context().i64_type().as_basic_type_enum()],
553            )?;
554            let r = ctx
555                .builder()
556                .build_call(
557                    intr,
558                    &[lhs.into_int_value().into(), rhs.into_int_value().into()],
559                    "",
560                )?
561                .try_as_basic_value()
562                .unwrap_left();
563            Ok(vec![r])
564        }),
565        IntOpDef::imin_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
566            let intr = get_intrinsic(
567                ctx.get_current_module(),
568                "llvm.smin.i64",
569                [ctx.iw_context().i64_type().as_basic_type_enum()],
570            )?;
571            let r = ctx
572                .builder()
573                .build_call(
574                    intr,
575                    &[lhs.into_int_value().into(), rhs.into_int_value().into()],
576                    "",
577                )?
578                .try_as_basic_value()
579                .unwrap_left();
580            Ok(vec![r])
581        }),
582        IntOpDef::imin_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
583            let intr = get_intrinsic(
584                ctx.get_current_module(),
585                "llvm.umin.i64",
586                [ctx.iw_context().i64_type().as_basic_type_enum()],
587            )?;
588            let r = ctx
589                .builder()
590                .build_call(
591                    intr,
592                    &[lhs.into_int_value().into(), rhs.into_int_value().into()],
593                    "",
594                )?
595                .try_as_basic_value()
596                .unwrap_left();
597            Ok(vec![r])
598        }),
599        IntOpDef::ishl => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
600            Ok(vec![
601                ctx.builder()
602                    .build_left_shift(lhs.into_int_value(), rhs.into_int_value(), "")?
603                    .as_basic_value_enum(),
604            ])
605        }),
606        IntOpDef::ishr => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
607            Ok(vec![
608                ctx.builder()
609                    .build_right_shift(lhs.into_int_value(), rhs.into_int_value(), false, "")?
610                    .as_basic_value_enum(),
611            ])
612        }),
613        IntOpDef::ieq => emit_icmp(context, args, inkwell::IntPredicate::EQ),
614        IntOpDef::ine => emit_icmp(context, args, inkwell::IntPredicate::NE),
615        IntOpDef::ilt_s => emit_icmp(context, args, inkwell::IntPredicate::SLT),
616        IntOpDef::igt_s => emit_icmp(context, args, inkwell::IntPredicate::SGT),
617        IntOpDef::ile_s => emit_icmp(context, args, inkwell::IntPredicate::SLE),
618        IntOpDef::ige_s => emit_icmp(context, args, inkwell::IntPredicate::SGE),
619        IntOpDef::ilt_u => emit_icmp(context, args, inkwell::IntPredicate::ULT),
620        IntOpDef::igt_u => emit_icmp(context, args, inkwell::IntPredicate::UGT),
621        IntOpDef::ile_u => emit_icmp(context, args, inkwell::IntPredicate::ULE),
622        IntOpDef::ige_u => emit_icmp(context, args, inkwell::IntPredicate::UGE),
623        IntOpDef::ixor => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
624            Ok(vec![
625                ctx.builder()
626                    .build_xor(lhs.into_int_value(), rhs.into_int_value(), "")?
627                    .as_basic_value_enum(),
628            ])
629        }),
630        IntOpDef::ior => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
631            Ok(vec![
632                ctx.builder()
633                    .build_or(lhs.into_int_value(), rhs.into_int_value(), "")?
634                    .as_basic_value_enum(),
635            ])
636        }),
637        IntOpDef::inot => emit_custom_unary_op(context, args, |ctx, arg, _| {
638            Ok(vec![
639                ctx.builder()
640                    .build_not(arg.into_int_value(), "")?
641                    .as_basic_value_enum(),
642            ])
643        }),
644        IntOpDef::iand => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
645            Ok(vec![
646                ctx.builder()
647                    .build_and(lhs.into_int_value(), rhs.into_int_value(), "")?
648                    .as_basic_value_enum(),
649            ])
650        }),
651        IntOpDef::ipow => emit_ipow(context, args),
652        // Type args are width of input, width of output
653        IntOpDef::iwiden_u => emit_custom_unary_op(context, args, |ctx, arg, outs| {
654            let [out] = outs.try_into()?;
655            Ok(vec![
656                ctx.builder()
657                    .build_int_cast_sign_flag(arg.into_int_value(), out.into_int_type(), false, "")?
658                    .as_basic_value_enum(),
659            ])
660        }),
661        IntOpDef::iwiden_s => emit_custom_unary_op(context, args, |ctx, arg, outs| {
662            let [out] = outs.try_into()?;
663
664            Ok(vec![
665                ctx.builder()
666                    .build_int_cast_sign_flag(arg.into_int_value(), out.into_int_type(), true, "")?
667                    .as_basic_value_enum(),
668            ])
669        }),
670        IntOpDef::inarrow_s => {
671            let Some(TypeArg::BoundedNat(out_log_width)) = args.node().args().last().cloned()
672            else {
673                bail!("Type arg to inarrow_s wasn't a Nat");
674            };
675            let (_, out_ty) = args.node.out_value_types().next().unwrap();
676            emit_custom_unary_op(context, args, |ctx, arg, outs| {
677                let result = make_narrow(
678                    ctx,
679                    arg,
680                    outs,
681                    out_log_width,
682                    true,
683                    out_ty.as_sum().unwrap().clone(),
684                )?;
685                Ok(vec![result])
686            })
687        }
688        IntOpDef::inarrow_u => {
689            let Some(TypeArg::BoundedNat(out_log_width)) = args.node().args().last().cloned()
690            else {
691                bail!("Type arg to inarrow_u wasn't a Nat");
692            };
693            let (_, out_ty) = args.node.out_value_types().next().unwrap();
694            emit_custom_unary_op(context, args, |ctx, arg, outs| {
695                let result = make_narrow(
696                    ctx,
697                    arg,
698                    outs,
699                    out_log_width,
700                    false,
701                    out_ty.as_sum().unwrap().clone(),
702                )?;
703                Ok(vec![result])
704            })
705        }
706        IntOpDef::iu_to_s => {
707            let log_width = get_width_arg(&args, &op)?;
708            emit_custom_unary_op(context, args, |ctx, arg, _| {
709                let (_, max_val, _) = int_type_bounds(u32::pow(2, log_width as u32));
710                let max = arg
711                    .get_type()
712                    .into_int_type()
713                    .const_int(max_val as u64, false);
714
715                let within_bounds = ctx.builder().build_int_compare(
716                    IntPredicate::ULE,
717                    arg.into_int_value(),
718                    max,
719                    "bounds_check",
720                )?;
721
722                Ok(vec![val_or_panic(
723                    ctx,
724                    pcg,
725                    within_bounds,
726                    &ERR_IU_TO_S,
727                    |_| Ok(arg),
728                )?])
729            })
730        }
731        IntOpDef::is_to_u => emit_custom_unary_op(context, args, |ctx, arg, _| {
732            let zero = arg.get_type().into_int_type().const_zero();
733
734            let within_bounds = ctx.builder().build_int_compare(
735                IntPredicate::SGE,
736                arg.into_int_value(),
737                zero,
738                "bounds_check",
739            )?;
740
741            Ok(vec![val_or_panic(
742                ctx,
743                pcg,
744                within_bounds,
745                &ERR_IS_TO_U,
746                |_| Ok(arg),
747            )?])
748        }),
749        _ => Err(anyhow!("IntOpEmitter: unimplemented op: {}", op.op_id())),
750    }
751}
752
753// Helper to get the log width arg to an int op when it's the only argument
754// panic if there's not exactly one nat arg
755pub(crate) fn get_width_arg<H: HugrView<Node = Node>>(
756    args: &EmitOpArgs<'_, '_, ExtensionOp, H>,
757    op: &impl MakeExtensionOp,
758) -> Result<u64> {
759    let [TypeArg::BoundedNat(log_width)] = args.node.args() else {
760        bail!(
761            "Expected exactly one BoundedNat parameter to {}",
762            op.op_id()
763        )
764    };
765    Ok(*log_width)
766}
767
768// The semantics of the hugr operation specify that the divisor argument is
769// always unsigned, and the signed/unsigned variants affect the types of the
770// dividend and quotient only.
771//
772// LLVM's semantics for `srem`, however, have both operands being the same type.
773// Moreover, llvm's `srem` does not implement the modulo operation: the
774// remainder will have the same sign as the dividend instead of the divisor.
775//
776// See discussion at: https://github.com/CQCL/hugr/pull/2025#discussion_r2012537992
777fn make_divmod<'c, H: HugrView<Node = Node>>(
778    ctx: &mut EmitFuncContext<'c, '_, H>,
779    pcg: &impl PreludeCodegen,
780    log_width: u64,
781    numerator: IntValue<'c>,
782    denominator: IntValue<'c>,
783    panic: bool,
784    signed: bool,
785) -> Result<LLVMSumValue<'c>> {
786    let int_arg_ty = int_types::INT_TYPES[log_width as usize].clone();
787    let tuple_sum_ty = HugrSumType::new_tuple(vec![int_arg_ty.clone(), int_arg_ty.clone()]);
788
789    let pair_ty = LLVMSumType::try_from_hugr_type(&ctx.typing_session(), tuple_sum_ty.clone())?;
790
791    let build_divmod = |ctx: &mut EmitFuncContext<'c, '_, H>| -> Result<BasicValueEnum<'c>> {
792        if signed {
793            let max_signed_value = u64::pow(2, u32::pow(2, log_width as u32) - 1) - 1;
794            let max_signed = numerator.get_type().const_int(max_signed_value, false);
795            // Determine whether the divisor is "big" or "smol" for special casing.
796            // Here, "big" means the divisor is larger than the biggest value
797            // that could be represented by the type of the dividend.
798            let large_divisor_bool = ctx.builder().build_int_compare(
799                IntPredicate::UGT,
800                denominator,
801                max_signed,
802                "is_divisor_large",
803            )?;
804            let large_divisor =
805                ctx.builder()
806                    .build_int_z_extend(large_divisor_bool, denominator.get_type(), "")?;
807            let negative_numerator_bool = ctx.builder().build_int_compare(
808                IntPredicate::SLT,
809                numerator,
810                numerator.get_type().const_zero(),
811                "is_dividend_negative",
812            )?;
813            let negative_numerator = ctx.builder().build_int_z_extend(
814                negative_numerator_bool,
815                denominator.get_type(),
816                "",
817            )?;
818            let tag = ctx.builder().build_left_shift(
819                large_divisor,
820                denominator.get_type().const_int(1, false),
821                "",
822            )?;
823
824            let tag = ctx.builder().build_or(tag, negative_numerator, "tag")?;
825
826            let quot = ctx
827                .builder()
828                .build_int_signed_div(numerator, denominator, "quotient")?;
829            let rem = ctx
830                .builder()
831                .build_int_signed_rem(numerator, denominator, "remainder")?;
832
833            let result_ptr = ctx.builder().build_alloca(pair_ty.clone(), "result")?;
834
835            let finish = ctx.new_basic_block("finish", None);
836            let negative_bigdiv = ctx.new_basic_block("negative_bigdiv", Some(finish));
837            let negative_smoldiv = ctx.new_basic_block("negative_smoldiv", Some(finish));
838            let non_negative_bigdiv = ctx.new_basic_block("non_negative_bigdiv", Some(finish));
839            let non_negative_smoldiv = ctx.new_basic_block("non_negative_smoldiv", Some(finish));
840
841            ctx.builder().build_switch(
842                tag,
843                non_negative_smoldiv,
844                &[
845                    (denominator.get_type().const_int(1, false), negative_smoldiv),
846                    (
847                        denominator.get_type().const_int(2, false),
848                        non_negative_bigdiv,
849                    ),
850                    (denominator.get_type().const_int(3, false), negative_bigdiv),
851                ],
852            )?;
853
854            let build_and_store_result =
855                |ctx: &mut EmitFuncContext<'c, '_, H>, vs: Vec<BasicValueEnum<'c>>| -> Result<()> {
856                    let result = pair_ty
857                        .build_tag(ctx.builder(), 0, vs)?
858                        //.build_tag(ctx.builder(), 0, vec![tag.as_basic_value_enum(), tag.as_basic_value_enum()])?
859                        .as_basic_value_enum();
860                    ctx.builder().build_store(result_ptr, result)?;
861                    ctx.builder().build_unconditional_branch(finish)?;
862                    Ok(())
863                };
864
865            // Default case (although it should only be reached by one branch).
866            // When the divisor is smol and the dividend is positive, we can
867            // rely on LLVM intrinsics.
868            ctx.builder().position_at_end(non_negative_smoldiv);
869            build_and_store_result(
870                ctx,
871                vec![quot.as_basic_value_enum(), rem.as_basic_value_enum()],
872            )?;
873
874            // When the divisor is smol and the dividend is negative,
875            // we have two cases:
876            ctx.builder().position_at_end(negative_smoldiv);
877            {
878                // If the remainder is 0, we can use the results of LLVM's `srem`
879                let if_rem_zero = pair_ty
880                    .build_tag(
881                        ctx.builder(),
882                        0,
883                        vec![
884                            quot.as_basic_value_enum(),
885                            rem.get_type().const_zero().as_basic_value_enum(),
886                        ],
887                    )?
888                    .as_basic_value_enum();
889
890                // Otherwise, we return `(quotient - 1, divisor + remainder)`
891                let if_rem_nonzero = pair_ty
892                    .build_tag(
893                        ctx.builder(),
894                        0,
895                        vec![
896                            ctx.builder()
897                                .build_int_sub(quot, quot.get_type().const_int(1, true), "")?
898                                .as_basic_value_enum(),
899                            ctx.builder()
900                                .build_int_add(denominator, rem, "")?
901                                .as_basic_value_enum(),
902                        ],
903                    )?
904                    .as_basic_value_enum();
905
906                let is_rem_zero = ctx.builder().build_int_compare(
907                    IntPredicate::EQ,
908                    rem,
909                    rem.get_type().const_zero(),
910                    "is_rem_0",
911                )?;
912                let result =
913                    ctx.builder()
914                        .build_select(is_rem_zero, if_rem_zero, if_rem_nonzero, "")?;
915                ctx.builder().build_store(result_ptr, result)?;
916                ctx.builder().build_unconditional_branch(finish)?;
917            }
918
919            // The (unsigned) divisor is bigger than the (signed) dividend could
920            // possibly be, so it's safe to return quotient 0 and remainder = dividend
921            ctx.builder().position_at_end(non_negative_bigdiv);
922            build_and_store_result(
923                ctx,
924                vec![
925                    numerator.get_type().const_zero().as_basic_value_enum(),
926                    numerator.as_basic_value_enum(),
927                ],
928            )?;
929
930            // The divisor is larger than the dividend can possibly be, and the
931            // dividend is negative. This means we have to return `quotient - 1`
932            // and the remainder is `dividend + divisor`.
933            ctx.builder().position_at_end(negative_bigdiv);
934            build_and_store_result(
935                ctx,
936                vec![
937                    numerator.get_type().const_all_ones().as_basic_value_enum(),
938                    ctx.builder()
939                        .build_int_add(numerator, denominator, "")?
940                        .as_basic_value_enum(),
941                ],
942            )?;
943
944            ctx.builder().position_at_end(finish);
945            let result = ctx.builder().build_load(result_ptr, "result")?;
946            Ok(result)
947        } else {
948            let quot = ctx
949                .builder()
950                .build_int_unsigned_div(numerator, denominator, "quotient")?;
951            let rem = ctx
952                .builder()
953                .build_int_unsigned_rem(numerator, denominator, "remainder")?;
954            Ok(pair_ty
955                .build_tag(
956                    ctx.builder(),
957                    0,
958                    vec![quot.as_basic_value_enum(), rem.as_basic_value_enum()],
959                )?
960                .as_basic_value_enum())
961        }
962    };
963
964    let int_ty = numerator.get_type();
965    let zero = int_ty.const_zero();
966    let lower_bounds_check =
967        ctx.builder()
968            .build_int_compare(IntPredicate::NE, denominator, zero, "valid_div")?;
969
970    let sum_ty = LLVMSumType::try_from_hugr_type(
971        &ctx.typing_session(),
972        sum_with_error(vec![Type::from(tuple_sum_ty)]),
973    )?;
974
975    if panic {
976        LLVMSumValue::try_new(
977            val_or_panic(ctx, pcg, lower_bounds_check, &ERR_DIV_0, |ctx| {
978                build_divmod(ctx)
979            })?,
980            pair_ty,
981        )
982    } else {
983        let result = build_divmod(ctx)?;
984        LLVMSumValue::try_new(
985            val_or_error(ctx, lower_bounds_check, result, &ERR_DIV_0, sum_ty.clone())?,
986            sum_ty,
987        )
988    }
989}
990
991fn make_narrow<'c, H: HugrView<Node = Node>>(
992    ctx: &mut EmitFuncContext<'c, '_, H>,
993    arg: BasicValueEnum<'c>,
994    outs: &[BasicTypeEnum<'c>],
995    out_log_width: u64,
996    signed: bool,
997    sum_type: HugrSumType,
998) -> Result<BasicValueEnum<'c>> {
999    let [out] = TryInto::<[BasicTypeEnum; 1]>::try_into(outs)?;
1000    let width = 1 << out_log_width;
1001    let arg_int_ty: IntType = arg.get_type().into_int_type();
1002    let (int_min_value_s, int_max_value_s, int_max_value_u) = int_type_bounds(width);
1003    let out_int_ty = out
1004        .into_struct_type()
1005        .get_field_type_at_index(2)
1006        .unwrap()
1007        .into_int_type();
1008    let outside_range = if signed {
1009        let too_big = ctx.builder().build_int_compare(
1010            IntPredicate::SGT,
1011            arg.into_int_value(),
1012            arg_int_ty.const_int(int_max_value_s as u64, true),
1013            "upper_bounds_check",
1014        )?;
1015        let too_small = ctx.builder().build_int_compare(
1016            IntPredicate::SLT,
1017            arg.into_int_value(),
1018            arg_int_ty.const_int(int_min_value_s as u64, true),
1019            "lower_bounds_check",
1020        )?;
1021        ctx.builder()
1022            .build_or(too_big, too_small, "outside_range")?
1023    } else {
1024        ctx.builder().build_int_compare(
1025            IntPredicate::UGT,
1026            arg.into_int_value(),
1027            arg_int_ty.const_int(int_max_value_u, false),
1028            "upper_bounds_check",
1029        )?
1030    };
1031
1032    let inbounds = ctx.builder().build_not(outside_range, "inbounds")?;
1033    let narrowed_val = ctx
1034        .builder()
1035        .build_int_cast_sign_flag(arg.into_int_value(), out_int_ty, signed, "")?
1036        .as_basic_value_enum();
1037    val_or_error(
1038        ctx,
1039        inbounds,
1040        narrowed_val,
1041        &ERR_NARROW,
1042        LLVMSumType::try_from_hugr_type(&ctx.typing_session(), sum_type).unwrap(),
1043    )
1044}
1045
1046fn val_or_panic<'c, H: HugrView<Node = Node>>(
1047    ctx: &mut EmitFuncContext<'c, '_, H>,
1048    pcg: &impl PreludeCodegen,
1049    dont_panic: IntValue<'c>,
1050    err: &ConstError,
1051    // Returned value must be same int type as `dont_panic`.
1052    go: impl Fn(&mut EmitFuncContext<'c, '_, H>) -> Result<BasicValueEnum<'c>>,
1053) -> Result<BasicValueEnum<'c>> {
1054    let exit_bb = ctx.new_basic_block("exit", None);
1055    let go_bb = ctx.new_basic_block("panic_if_0", Some(exit_bb));
1056    let panic_bb = ctx.new_basic_block("panic", Some(exit_bb));
1057    ctx.builder().build_unconditional_branch(go_bb)?;
1058
1059    ctx.builder().position_at_end(panic_bb);
1060    let err = ctx.emit_custom_const(err)?;
1061    pcg.emit_panic(ctx, err)?;
1062    ctx.builder().build_unconditional_branch(exit_bb)?;
1063
1064    ctx.builder().position_at_end(go_bb);
1065    ctx.builder().build_switch(
1066        dont_panic,
1067        panic_bb,
1068        &[(dont_panic.get_type().const_int(1, false), exit_bb)],
1069    )?;
1070
1071    ctx.builder().position_at_end(exit_bb);
1072
1073    go(ctx)
1074}
1075
1076fn val_or_error<'c, H: HugrView<Node = Node>>(
1077    ctx: &mut EmitFuncContext<'c, '_, H>,
1078    should_succeed: IntValue<'c>,
1079    val: BasicValueEnum<'c>,
1080    err: &ConstError,
1081    ty: LLVMSumType<'c>,
1082) -> Result<BasicValueEnum<'c>> {
1083    let err_val = ctx.emit_custom_const(err)?;
1084
1085    let err_variant = ty.build_tag(ctx.builder(), 0, vec![err_val])?;
1086    let ok_variant = ty.build_tag(ctx.builder(), 1, vec![val])?;
1087
1088    Ok(ctx
1089        .builder()
1090        .build_select(should_succeed, ok_variant, err_variant, "")?)
1091}
1092
1093fn llvm_type<'c>(
1094    context: TypingSession<'c, '_>,
1095    hugr_type: &CustomType,
1096) -> Result<BasicTypeEnum<'c>> {
1097    if let [TypeArg::BoundedNat(n)] = hugr_type.args() {
1098        let m = *n as usize;
1099        if m < int_types::INT_TYPES.len() && int_types::INT_TYPES[m] == hugr_type.clone().into() {
1100            return Ok(match m {
1101                0..=3 => context.iw_context().i8_type(),
1102                4 => context.iw_context().i16_type(),
1103                5 => context.iw_context().i32_type(),
1104                6 => context.iw_context().i64_type(),
1105                _ => Err(anyhow!(
1106                    "IntTypesCodegenExtension: unsupported log_width: {}",
1107                    m
1108                ))?,
1109            }
1110            .into());
1111        }
1112    }
1113    Err(anyhow!(
1114        "IntTypesCodegenExtension: unsupported type: {}",
1115        hugr_type
1116    ))
1117}
1118
1119fn emit_const_int<'c, H: HugrView<Node = Node>>(
1120    context: &mut EmitFuncContext<'c, '_, H>,
1121    k: &ConstInt,
1122) -> Result<BasicValueEnum<'c>> {
1123    let ty: IntType = context.llvm_type(&k.get_type())?.try_into().unwrap();
1124    // k.value_u() is in two's complement representation of the exactly
1125    // correct bit width, so we are safe to unconditionally retrieve the
1126    // unsigned value and do no sign extension.
1127    Ok(ty.const_int(k.value_u(), false).as_basic_value_enum())
1128}
1129
1130impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
1131    /// Populates a [`CodegenExtsBuilder`] with all extensions needed to lower int
1132    /// ops, types, and constants.
1133    ///
1134    /// Any ops that panic will do so using [`DefaultPreludeCodegen`].
1135    #[must_use]
1136    pub fn add_default_int_extensions(self) -> Self {
1137        self.add_extension(IntCodegenExtension::new(DefaultPreludeCodegen))
1138    }
1139}
1140
1141#[cfg(test)]
1142mod test {
1143    use anyhow::Result;
1144    use hugr_core::builder::DataflowHugr;
1145    use hugr_core::extension::prelude::{ConstError, UnwrapBuilder, error_type};
1146    use hugr_core::std_extensions::STD_REG;
1147    use hugr_core::{
1148        Hugr,
1149        builder::{Dataflow, DataflowSubContainer, SubContainer, handle::Outputs},
1150        extension::prelude::bool_t,
1151        ops::{DataflowOpTrait, ExtensionOp},
1152        std_extensions::arithmetic::{
1153            int_ops::{self, IntOpDef},
1154            int_types::{ConstInt, INT_TYPES},
1155        },
1156        types::{SumType, Type, TypeRow},
1157    };
1158    use rstest::rstest;
1159
1160    use crate::{
1161        check_emission,
1162        emit::test::{DFGW, SimpleHugrConfig},
1163        test::{TestContext, exec_ctx, llvm_ctx, single_op_hugr},
1164    };
1165
1166    #[rstest::fixture]
1167    fn int_exec_ctx(mut exec_ctx: TestContext) -> TestContext {
1168        exec_ctx.add_extensions(|cem| {
1169            cem.add_default_int_extensions()
1170                .add_default_prelude_extensions()
1171        });
1172        exec_ctx
1173    }
1174
1175    #[rstest::fixture]
1176    fn int_llvm_ctx(mut llvm_ctx: TestContext) -> TestContext {
1177        llvm_ctx.add_extensions(|cem| {
1178            cem.add_default_int_extensions()
1179                .add_default_prelude_extensions()
1180        });
1181        llvm_ctx
1182    }
1183
1184    // Instantiate an extension op which takes one width argument
1185    fn make_int_op(name: impl AsRef<str>, log_width: u8) -> ExtensionOp {
1186        int_ops::EXTENSION
1187            .instantiate_extension_op(name.as_ref(), [u64::from(log_width).into()])
1188            .unwrap()
1189    }
1190
1191    fn test_binary_int_op(ext_op: ExtensionOp, log_width: u8) -> Hugr {
1192        let ty = &INT_TYPES[log_width as usize];
1193        test_int_op_with_results::<2>(ext_op, log_width, None, ty.clone())
1194    }
1195
1196    fn test_binary_icmp_op(ext_op: ExtensionOp, log_width: u8) -> Hugr {
1197        test_int_op_with_results::<2>(ext_op, log_width, None, bool_t())
1198    }
1199
1200    fn test_int_op_with_results<const N: usize>(
1201        ext_op: ExtensionOp,
1202        log_width: u8,
1203        inputs: Option<[ConstInt; N]>,
1204        output_type: Type,
1205    ) -> Hugr {
1206        test_int_op_with_results_processing(ext_op, log_width, inputs, output_type, |_, a| Ok(a))
1207    }
1208
1209    fn test_int_op_with_results_processing<const N: usize>(
1210        // N is the number of inputs to the hugr
1211        ext_op: ExtensionOp,
1212        log_width: u8,
1213        inputs: Option<[ConstInt; N]>, // If inputs are provided, they'll be wired into the op, otherwise the inputs to the hugr will be wired into the op
1214        output_type: Type,
1215        process: impl Fn(&mut DFGW, Outputs) -> Result<Outputs>,
1216    ) -> Hugr {
1217        let ty = &INT_TYPES[log_width as usize];
1218        let input_tys = if inputs.is_some() {
1219            vec![]
1220        } else {
1221            let input_tys = itertools::repeat_n(ty.clone(), N).collect();
1222            assert_eq!(input_tys, ext_op.signature().input.to_vec());
1223            input_tys
1224        };
1225        SimpleHugrConfig::new()
1226            .with_ins(input_tys)
1227            .with_outs(vec![output_type])
1228            .with_extensions(STD_REG.clone())
1229            .finish(|mut hugr_builder| {
1230                let input_wires = match inputs {
1231                    None => hugr_builder.input_wires_arr::<N>().to_vec(),
1232                    Some(inputs) => {
1233                        let mut input_wires = Vec::new();
1234                        for i in inputs.into_iter() {
1235                            let w = hugr_builder.add_load_value(i);
1236                            input_wires.push(w);
1237                        }
1238                        input_wires
1239                    }
1240                };
1241                let outputs = hugr_builder
1242                    .add_dataflow_op(ext_op, input_wires)
1243                    .unwrap()
1244                    .outputs();
1245                let processed_outputs = process(&mut hugr_builder, outputs).unwrap();
1246                hugr_builder
1247                    .finish_hugr_with_outputs(processed_outputs)
1248                    .unwrap()
1249            })
1250    }
1251
1252    #[rstest]
1253    #[case(IntOpDef::iu_to_s, &[3])]
1254    #[case(IntOpDef::is_to_u, &[3])]
1255    #[case(IntOpDef::ineg, &[2])]
1256    #[case::idiv_checked_u("idiv_checked_u", &[3])]
1257    #[case::idiv_checked_s("idiv_checked_s", &[3])]
1258    #[case::imod_checked_u("imod_checked_u", &[6])]
1259    #[case::imod_checked_s("imod_checked_s", &[6])]
1260    #[case::idivmod_u("idivmod_u", &[3])]
1261    #[case::idivmod_s("idivmod_s", &[3])]
1262    #[case::idivmod_checked_u("idivmod_checked_u", &[6])]
1263    #[case::idivmod_checked_s("idivmod_checked_s", &[6])]
1264    fn test_emission(int_llvm_ctx: TestContext, #[case] op: IntOpDef, #[case] args: &[u8]) {
1265        use hugr_core::extension::simple_op::MakeExtensionOp as _;
1266
1267        let mut insta = insta::Settings::clone_current();
1268        insta.set_snapshot_suffix(format!(
1269            "{}_{}_{:?}",
1270            insta.snapshot_suffix().unwrap_or(""),
1271            op.op_id(),
1272            args,
1273        ));
1274        let concrete = match *args {
1275            [] => op.without_log_width(),
1276            [log_width] => op.with_log_width(log_width),
1277            [lw1, lw2] => op.with_two_log_widths(lw1, lw2),
1278            _ => panic!("unexpected number of args to the op!"),
1279        };
1280        insta.bind(|| {
1281            let hugr = single_op_hugr(concrete.into());
1282            check_emission!(hugr, int_llvm_ctx);
1283        });
1284    }
1285
1286    #[rstest]
1287    #[case::iadd("iadd", 3)]
1288    #[case::isub("isub", 6)]
1289    #[case::ipow("ipow", 3)]
1290    #[case::idiv_u("idiv_u", 3)]
1291    #[case::idiv_s("idiv_s", 3)]
1292    #[case::imod_u("imod_u", 3)]
1293    #[case::imod_s("imod_s", 3)]
1294    fn test_binop_emission(int_llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
1295        let ext_op = make_int_op(op.clone(), width);
1296        let hugr = test_binary_int_op(ext_op, width);
1297        check_emission!(op.clone(), hugr, int_llvm_ctx);
1298    }
1299
1300    #[rstest]
1301    #[case::signed_2_3("iwiden_s", 2, 3)]
1302    #[case::signed_1_6("iwiden_s", 1, 6)]
1303    #[case::unsigned_2_3("iwiden_u", 2, 3)]
1304    #[case::unsigned_1_6("iwiden_u", 1, 6)]
1305    fn test_widen_emission(
1306        int_llvm_ctx: TestContext,
1307        #[case] op: String,
1308        #[case] from: u8,
1309        #[case] to: u8,
1310    ) {
1311        let out_ty = INT_TYPES[to as usize].clone();
1312        let ext_op = int_ops::EXTENSION
1313            .instantiate_extension_op(&op, [u64::from(from).into(), u64::from(to).into()])
1314            .unwrap();
1315        let hugr = test_int_op_with_results::<1>(ext_op, from, None, out_ty);
1316
1317        check_emission!(
1318            format!("{}_{}_{}", op.clone(), from, to),
1319            hugr,
1320            int_llvm_ctx
1321        );
1322    }
1323
1324    #[rstest]
1325    #[case::signed("inarrow_s", 3, 2)]
1326    #[case::unsigned("inarrow_u", 6, 4)]
1327    fn test_narrow_emission(
1328        int_llvm_ctx: TestContext,
1329        #[case] op: String,
1330        #[case] from: u8,
1331        #[case] to: u8,
1332    ) {
1333        let out_ty = SumType::new([vec![error_type()], vec![INT_TYPES[to as usize].clone()]]);
1334        let ext_op = int_ops::EXTENSION
1335            .instantiate_extension_op(&op, [u64::from(from).into(), u64::from(to).into()])
1336            .unwrap();
1337        let hugr = test_int_op_with_results::<1>(ext_op, from, None, out_ty.into());
1338
1339        check_emission!(
1340            format!("{}_{}_{}", op.clone(), from, to),
1341            hugr,
1342            int_llvm_ctx
1343        );
1344    }
1345
1346    #[rstest]
1347    #[case::ieq("ieq", 1)]
1348    #[case::ilt_s("ilt_s", 0)]
1349    fn test_cmp_emission(int_llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
1350        let ext_op = make_int_op(op.clone(), width);
1351        let hugr = test_binary_icmp_op(ext_op, width);
1352        check_emission!(op.clone(), hugr, int_llvm_ctx);
1353    }
1354
1355    #[rstest]
1356    #[case::imax("imax_u", 1, 2, 2)]
1357    #[case::imax("imax_u", 2, 1, 2)]
1358    #[case::imax("imax_u", 2, 2, 2)]
1359    #[case::imin("imin_u", 1, 2, 1)]
1360    #[case::imin("imin_u", 2, 1, 1)]
1361    #[case::imin("imin_u", 2, 2, 2)]
1362    #[case::ishl("ishl", 73, 1, 146)]
1363    // (2^64 - 1) << 1 = (2^64 - 2)
1364    #[case::ishl("ishl", 18446744073709551615, 1, 18446744073709551614)]
1365    #[case::ishr("ishr", 73, 1, 36)]
1366    #[case::ior("ior", 6, 9, 15)]
1367    #[case::ior("ior", 6, 15, 15)]
1368    #[case::ixor("ixor", 6, 9, 15)]
1369    #[case::ixor("ixor", 6, 15, 9)]
1370    #[case::ixor("ixor", 15, 6, 9)]
1371    #[case::iand("iand", 6, 15, 6)]
1372    #[case::iand("iand", 15, 6, 6)]
1373    #[case::iand("iand", 15, 15, 15)]
1374    #[case::ipow("ipow", 2, 3, 8)]
1375    #[case::ipow("ipow", 42, 1, 42)]
1376    #[case::ipow("ipow", 42, 0, 1)]
1377    #[case::idiv("idiv_u", 42, 2, 21)]
1378    #[case::idiv("idiv_u", 42, 5, 8)]
1379    #[case::imod("imod_u", 42, 2, 0)]
1380    #[case::imod("imod_u", 42, 5, 2)]
1381    fn test_exec_unsigned_bin_op(
1382        int_exec_ctx: TestContext,
1383        #[case] op: String,
1384        #[case] lhs: u64,
1385        #[case] rhs: u64,
1386        #[case] result: u64,
1387    ) {
1388        let ty = &INT_TYPES[6].clone();
1389        let inputs = [
1390            ConstInt::new_u(6, lhs).unwrap(),
1391            ConstInt::new_u(6, rhs).unwrap(),
1392        ];
1393        let ext_op = make_int_op(&op, 6);
1394
1395        let hugr = test_int_op_with_results::<2>(ext_op, 6, Some(inputs), ty.clone());
1396        assert_eq!(int_exec_ctx.exec_hugr_u64(hugr, "main"), result);
1397    }
1398
1399    #[rstest]
1400    #[case::imax("imax_s", 1, 2, 2)]
1401    #[case::imax("imax_s", 2, 1, 2)]
1402    #[case::imax("imax_s", 2, 2, 2)]
1403    #[case::imax("imax_s", -1, -2, -1)]
1404    #[case::imax("imax_s", -2, -1, -1)]
1405    #[case::imax("imax_s", -2, -2, -2)]
1406    #[case::imin("imin_s", 1, 2, 1)]
1407    #[case::imin("imin_s", 2, 1, 1)]
1408    #[case::imin("imin_s", 2, 2, 2)]
1409    #[case::imin("imin_s", -1, -2, -2)]
1410    #[case::imin("imin_s", -2, -1, -2)]
1411    #[case::imin("imin_s", -2, -2, -2)]
1412    #[case::ipow("ipow", -2, 1, -2)]
1413    #[case::ipow("ipow", -2, 2, 4)]
1414    #[case::ipow("ipow", -2, 3, -8)]
1415    fn test_exec_signed_bin_op(
1416        int_exec_ctx: TestContext,
1417        #[case] op: String,
1418        #[case] lhs: i64,
1419        #[case] rhs: i64,
1420        #[case] result: i64,
1421    ) {
1422        let ty = &INT_TYPES[6].clone();
1423        let inputs = [
1424            ConstInt::new_s(6, lhs).unwrap(),
1425            ConstInt::new_s(6, rhs).unwrap(),
1426        ];
1427        let ext_op = make_int_op(&op, 6);
1428
1429        let hugr = test_int_op_with_results::<2>(ext_op, 6, Some(inputs), ty.clone());
1430        assert_eq!(int_exec_ctx.exec_hugr_i64(hugr, "main"), result);
1431    }
1432
1433    #[rstest]
1434    #[case::iabs("iabs", 42, 42)]
1435    #[case::iabs("iabs", -42, 42)]
1436    fn test_exec_signed_unary_op(
1437        int_exec_ctx: TestContext,
1438        #[case] op: String,
1439        #[case] arg: i64,
1440        #[case] result: i64,
1441    ) {
1442        let input = ConstInt::new_s(6, arg).unwrap();
1443        let ty = INT_TYPES[6].clone();
1444        let ext_op = make_int_op(&op, 6);
1445
1446        let hugr = test_int_op_with_results::<1>(ext_op, 6, Some([input]), ty.clone());
1447        assert_eq!(int_exec_ctx.exec_hugr_i64(hugr, "main"), result);
1448    }
1449
1450    #[rstest]
1451    #[case::inot("inot", 9223372036854775808, !9223372036854775808u64)]
1452    #[case::inot("inot", 42, !42u64)]
1453    #[case::inot("inot", !0u64, 0)]
1454    fn test_exec_unsigned_unary_op(
1455        int_exec_ctx: TestContext,
1456        #[case] op: String,
1457        #[case] arg: u64,
1458        #[case] result: u64,
1459    ) {
1460        let input = ConstInt::new_u(6, arg).unwrap();
1461        let ty = INT_TYPES[6].clone();
1462        let ext_op = make_int_op(&op, 6);
1463
1464        let hugr = test_int_op_with_results::<1>(ext_op, 6, Some([input]), ty.clone());
1465        assert_eq!(int_exec_ctx.exec_hugr_u64(hugr, "main"), result);
1466    }
1467
1468    #[rstest]
1469    #[case(-127)]
1470    #[case(-1)]
1471    #[case(0)]
1472    #[case(1)]
1473    #[case(127)]
1474    fn test_exec_widen(int_exec_ctx: TestContext, #[case] num: i16) {
1475        let from: u8 = 3;
1476        let to: u8 = 6;
1477        let ty = INT_TYPES[to as usize].clone();
1478
1479        if num >= 0 {
1480            let input = ConstInt::new_u(from, num as u64).unwrap();
1481
1482            let ext_op = int_ops::EXTENSION
1483                .instantiate_extension_op(
1484                    "iwiden_u".as_ref(),
1485                    [(from as u64).into(), (to as u64).into()],
1486                )
1487                .unwrap();
1488
1489            let hugr = test_int_op_with_results::<1>(ext_op, to, Some([input]), ty.clone());
1490
1491            assert_eq!(int_exec_ctx.exec_hugr_u64(hugr, "main"), num as u64);
1492        }
1493
1494        let input = ConstInt::new_s(from, num as i64).unwrap();
1495
1496        let ext_op = int_ops::EXTENSION
1497            .instantiate_extension_op(
1498                "iwiden_s".as_ref(),
1499                [(from as u64).into(), (to as u64).into()],
1500            )
1501            .unwrap();
1502
1503        let hugr = test_int_op_with_results::<1>(ext_op, to, Some([input]), ty.clone());
1504
1505        assert_eq!(int_exec_ctx.exec_hugr_i64(hugr, "main"), num as i64);
1506    }
1507
1508    #[rstest]
1509    #[case("inarrow_s", 6, 2, 4)]
1510    #[case("inarrow_s", 6, 5, (1 << 5) - 1)]
1511    #[case("inarrow_s", 6, 4, -1)]
1512    #[case("inarrow_s", 6, 4, -(1 << 4) - 1)]
1513    #[case("inarrow_s", 6, 4, -(1 <<15))]
1514    #[case("inarrow_s", 6, 5, (1 << 31) - 1)]
1515    fn test_narrow_s(
1516        int_exec_ctx: TestContext,
1517        #[case] op: String,
1518        #[case] from: u8,
1519        #[case] to: u8,
1520        #[case] arg: i64,
1521    ) {
1522        let input = ConstInt::new_s(from, arg).unwrap();
1523        let to_ty = INT_TYPES[to as usize].clone();
1524        let ext_op = int_ops::EXTENSION
1525            .instantiate_extension_op(op.as_ref(), [u64::from(from).into(), u64::from(to).into()])
1526            .unwrap();
1527
1528        let hugr = test_int_op_with_results_processing::<1>(
1529            ext_op,
1530            to,
1531            Some([input]),
1532            to_ty.clone(),
1533            |builder, outs| {
1534                let [out] = outs.to_array();
1535
1536                let err_row = TypeRow::from(vec![error_type()]);
1537                let ty_row = TypeRow::from(vec![to_ty.clone()]);
1538                // Handle the sum type returned by narrow by building a conditional.
1539                // We're only testing the happy path here, so insert a panic in the
1540                // "error" branch, knowing that it wont come up.
1541                //
1542                // Negative results can be tested manually, but we lack the testing
1543                // infrastructure to detect execution crashes without crashing the
1544                // test process.
1545                let mut cond_b = builder.conditional_builder(
1546                    ([err_row, ty_row], out),
1547                    [],
1548                    vec![to_ty.clone()].into(),
1549                )?;
1550                let mut sad_b = cond_b.case_builder(0)?;
1551                let err = ConstError::new(2, "This shouldn't happen");
1552                let w = sad_b.add_load_value(ConstInt::new_s(to, 0)?);
1553                sad_b.add_panic(err, vec![to_ty.clone()], [(w, to_ty.clone())])?;
1554                sad_b.finish_with_outputs([w])?;
1555
1556                let happy_b = cond_b.case_builder(1)?;
1557                let [w] = happy_b.input_wires_arr();
1558                happy_b.finish_with_outputs([w])?;
1559
1560                let handle = cond_b.finish_sub_container()?;
1561                Ok(handle.outputs())
1562            },
1563        );
1564        assert_eq!(int_exec_ctx.exec_hugr_i64(hugr, "main"), arg);
1565    }
1566
1567    #[rstest]
1568    #[case(6, 42)]
1569    #[case(4, 7)]
1570    //#[case(4, 256)] -- crashes because a panic is emitted (good)
1571    fn test_u_to_s(int_exec_ctx: TestContext, #[case] log_width: u8, #[case] val: u64) {
1572        let ty = &INT_TYPES[log_width as usize].clone();
1573        let hugr = SimpleHugrConfig::new()
1574            .with_outs(vec![ty.clone()])
1575            .with_extensions(STD_REG.clone())
1576            .finish(|mut hugr_builder| {
1577                let unsigned =
1578                    hugr_builder.add_load_value(ConstInt::new_u(log_width, val).unwrap());
1579                let iu_to_s = make_int_op("iu_to_s", log_width);
1580                let [signed] = hugr_builder
1581                    .add_dataflow_op(iu_to_s, [unsigned])
1582                    .unwrap()
1583                    .outputs_arr();
1584                hugr_builder.finish_hugr_with_outputs([signed]).unwrap()
1585            });
1586        let act = int_exec_ctx.exec_hugr_i64(hugr, "main");
1587        assert_eq!(act, val as i64);
1588    }
1589
1590    #[rstest]
1591    #[case(3, 0)]
1592    #[case(4, 255)]
1593    // #[case(3, -1)] -- crashes because a panic is emitted (good)
1594    fn test_s_to_u(int_exec_ctx: TestContext, #[case] log_width: u8, #[case] val: i64) {
1595        let ty = &INT_TYPES[log_width as usize].clone();
1596        let hugr = SimpleHugrConfig::new()
1597            .with_outs(vec![ty.clone()])
1598            .with_extensions(STD_REG.clone())
1599            .finish(|mut hugr_builder| {
1600                let signed = hugr_builder.add_load_value(ConstInt::new_s(log_width, val).unwrap());
1601                let is_to_u = make_int_op("is_to_u", log_width);
1602                let [unsigned] = hugr_builder
1603                    .add_dataflow_op(is_to_u, [signed])
1604                    .unwrap()
1605                    .outputs_arr();
1606                let num = hugr_builder.add_load_value(ConstInt::new_u(log_width, 42).unwrap());
1607                let [res] = hugr_builder
1608                    .add_dataflow_op(make_int_op("iadd", log_width), [unsigned, num])
1609                    .unwrap()
1610                    .outputs_arr();
1611                hugr_builder.finish_hugr_with_outputs([res]).unwrap()
1612            });
1613        let act = int_exec_ctx.exec_hugr_u64(hugr, "main");
1614        assert_eq!(act, (val as u64) + 42);
1615    }
1616
1617    // Log width fixed at 3 (i.e. divmod : Fn(i8, u8) -> (i8, u8)
1618    #[rstest]
1619    #[case::bigdiv_non_negative(127, 255, (0, 127))] // Big divisor, positive dividend
1620    #[case::bigdiv_negative(-42, 255, (-1, 213))] // Big divisor, negative dividend
1621    #[case::smoldiv_non_negative(42, 10, (4, 2))] // Normal divisor, positive dividend
1622    #[case::smoldiv_negative_rem0(-42, 21, (-2, 0))] // Normal divisor, negative dividend, remainder 0
1623    #[case::smoldiv_negative_rem_nonzero(-42, 10, (-5, 8))] // Normal divisor, negative dividend, remainder >0
1624    fn test_divmod_s(
1625        int_exec_ctx: TestContext,
1626        #[case] dividend: i64,
1627        #[case] divisor: u64,
1628        #[case] expected_result: (i64, u64),
1629    ) {
1630        let int_ty = INT_TYPES[3].clone();
1631        let k_dividend = ConstInt::new_s(3, dividend).unwrap();
1632        let k_divisor = ConstInt::new_u(3, divisor).unwrap();
1633        let quot_hugr = test_int_op_with_results(
1634            make_int_op("idiv_s", 3),
1635            3,
1636            Some([k_dividend.clone(), k_divisor.clone()]),
1637            int_ty.clone(),
1638        );
1639        let rem_hugr = test_int_op_with_results(
1640            make_int_op("imod_s", 3),
1641            3,
1642            Some([k_dividend, k_divisor]),
1643            int_ty,
1644        );
1645        let quot = int_exec_ctx.exec_hugr_i64(quot_hugr, "main");
1646        let rem = int_exec_ctx.exec_hugr_u64(rem_hugr, "main");
1647        assert_eq!((quot, rem), expected_result);
1648    }
1649}