hugr_llvm/extension/
int.rs

1use hugr_core::{
2    HugrView, Node,
3    extension::{
4        prelude::{ConstError, sum_with_error},
5        simple_op::MakeExtensionOp,
6    },
7    ops::{ExtensionOp, Value, constant::CustomConst},
8    std_extensions::arithmetic::{
9        int_ops::IntOpDef,
10        int_types::{self, ConstInt},
11    },
12    types::{CustomType, Type, TypeArg},
13};
14use inkwell::{
15    IntPredicate,
16    types::{BasicType, BasicTypeEnum, IntType},
17    values::{BasicValue, BasicValueEnum, IntValue},
18};
19use std::sync::LazyLock;
20
21use crate::{
22    CodegenExtension,
23    custom::CodegenExtsBuilder,
24    emit::{
25        EmitOpArgs, emit_value,
26        func::EmitFuncContext,
27        get_intrinsic,
28        ops::{emit_custom_binary_op, emit_custom_unary_op},
29    },
30    sum::{LLVMSumType, LLVMSumValue},
31    types::{HugrSumType, TypingSession},
32};
33
34use anyhow::{Result, anyhow, bail};
35
36use super::{DefaultPreludeCodegen, PreludeCodegen, conversions::int_type_bounds};
37
38#[derive(Clone, Debug, Default)]
39pub struct IntCodegenExtension<PCG>(PCG);
40
41impl<PCG: PreludeCodegen> IntCodegenExtension<PCG> {
42    pub fn new(ccg: PCG) -> Self {
43        Self(ccg)
44    }
45}
46
47impl<CCG: PreludeCodegen> From<CCG> for IntCodegenExtension<CCG> {
48    fn from(ccg: CCG) -> Self {
49        Self::new(ccg)
50    }
51}
52
53impl<CCG: PreludeCodegen> CodegenExtension for IntCodegenExtension<CCG> {
54    fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
55        self,
56        builder: CodegenExtsBuilder<'a, H>,
57    ) -> CodegenExtsBuilder<'a, H>
58    where
59        Self: 'a,
60    {
61        builder
62            .custom_const(emit_const_int)
63            .custom_type((int_types::EXTENSION_ID, "int".into()), llvm_type)
64            .simple_extension_op::<IntOpDef>(move |context, args, op| {
65                emit_int_op(context, &self.0, args, op)
66            })
67    }
68}
69
70static ERR_NARROW: LazyLock<ConstError> = LazyLock::new(|| ConstError {
71    signal: 2,
72    message: "Can't narrow into bounds".to_string(),
73});
74static ERR_IU_TO_S: LazyLock<ConstError> = LazyLock::new(|| ConstError {
75    signal: 2,
76    message: "iu_to_s argument out of bounds".to_string(),
77});
78static ERR_IS_TO_U: LazyLock<ConstError> = LazyLock::new(|| ConstError {
79    signal: 2,
80    message: "is_to_u called on negative value".to_string(),
81});
82static ERR_DIV_0: LazyLock<ConstError> = LazyLock::new(|| ConstError {
83    signal: 2,
84    message: "Attempted division by 0".to_string(),
85});
86
87#[derive(Debug, Eq, PartialEq)]
88enum DivOrMod {
89    Div,
90    Mod,
91    DivMod,
92}
93
94struct DivModOp {
95    op: DivOrMod,
96    signed: bool,
97    panic: bool,
98}
99
100impl DivModOp {
101    fn emit<'c, H: HugrView<Node = Node>>(
102        self,
103        ctx: &mut EmitFuncContext<'c, '_, H>,
104        pcg: &impl PreludeCodegen,
105        log_width: u64,
106        numerator: IntValue<'c>,
107        denominator: IntValue<'c>,
108    ) -> Result<Vec<BasicValueEnum<'c>>> {
109        // Hugr semantics say that div and mod are equivalent to doing a divmod,
110        // then projecting out an element from the pair, so that's what we do.
111        let quotrem = make_divmod(
112            ctx,
113            pcg,
114            log_width,
115            numerator,
116            denominator,
117            self.panic,
118            self.signed,
119        )?;
120
121        if self.op == DivOrMod::DivMod {
122            if self.panic {
123                // Unpack the tuple into two values.
124                Ok(quotrem.build_untag(ctx.builder(), 0).unwrap())
125            } else {
126                Ok(vec![quotrem.as_basic_value_enum()])
127            }
128        } else {
129            // Which field we should project out from the result of divmod.
130            let index = match self.op {
131                DivOrMod::Div => 0,
132                DivOrMod::Mod => 1,
133                _ => unreachable!(),
134            };
135            // If we emitted a panicking divmod, the result is just a tuple type.
136            if self.panic {
137                Ok(vec![
138                    quotrem
139                        .build_untag(ctx.builder(), 0)?
140                        .into_iter()
141                        .nth(index)
142                        .unwrap(),
143                ])
144            }
145            // Otherwise, we have a sum type `err + [int,int]`, which we need to
146            // turn into a `err + int`.
147            else {
148                // Get the data out the the divmod result.
149                let int_ty = numerator.get_type().as_basic_type_enum();
150                let tuple_ty =
151                    LLVMSumType::try_new(ctx.iw_context(), vec![vec![int_ty, int_ty]]).unwrap();
152                let tuple = quotrem
153                    .build_untag(ctx.builder(), 1)?
154                    .into_iter()
155                    .next()
156                    .unwrap();
157                let tuple_val = LLVMSumValue::try_new(tuple, tuple_ty)?;
158                let data_val = tuple_val
159                    .build_untag(ctx.builder(), 0)?
160                    .into_iter()
161                    .nth(index)
162                    .unwrap();
163                let err_val = quotrem
164                    .build_untag(ctx.builder(), 0)?
165                    .into_iter()
166                    .next()
167                    .unwrap();
168
169                let tag_val = quotrem.build_get_tag(ctx.builder())?;
170                tag_val.set_name("tag");
171
172                // Build a new struct with the old tag and error data.
173                let int_ty = int_types::INT_TYPES[log_width as usize].clone();
174                let out_ty = LLVMSumType::try_from_hugr_type(
175                    &ctx.typing_session(),
176                    sum_with_error(vec![int_ty.clone()]),
177                )
178                .unwrap();
179
180                let data_variant = out_ty.build_tag(ctx.builder(), 1, vec![data_val])?;
181                data_variant.set_name("data_variant");
182                let err_variant = out_ty.build_tag(ctx.builder(), 0, vec![err_val])?;
183                err_variant.set_name("err_variant");
184
185                let result = ctx
186                    .builder()
187                    .build_select(tag_val, data_variant, err_variant, "")?;
188                Ok(vec![result])
189            }
190        }
191    }
192}
193
194/// `ConstError` an integer comparison operation.
195fn emit_icmp<'c, H: HugrView<Node = Node>>(
196    context: &mut EmitFuncContext<'c, '_, H>,
197    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
198    pred: inkwell::IntPredicate,
199) -> Result<()> {
200    let true_val = emit_value(context, &Value::true_val())?;
201    let false_val = emit_value(context, &Value::false_val())?;
202
203    emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
204        // get result as an i1
205        let r = ctx.builder().build_int_compare(
206            pred,
207            lhs.into_int_value(),
208            rhs.into_int_value(),
209            "",
210        )?;
211        // convert to whatever bool_t is
212        Ok(vec![
213            ctx.builder().build_select(r, true_val, false_val, "")?,
214        ])
215    })
216}
217
218/// Emit an ipow operation. This isn't directly supported in llvm, so we do a
219/// loop over the exponent, performing `imul`s instead.
220/// The insertion pointer is expected to be pointing to the end of `launch_bb`.
221fn emit_ipow<'c, H: HugrView<Node = Node>>(
222    context: &mut EmitFuncContext<'c, '_, H>,
223    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
224) -> Result<()> {
225    emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
226        let done_bb = ctx.new_basic_block("done", None);
227        let pow_body_bb = ctx.new_basic_block("pow_body", Some(done_bb));
228        let return_one_bb = ctx.new_basic_block("power_of_zero", Some(pow_body_bb));
229        let pow_bb = ctx.new_basic_block("pow", Some(return_one_bb));
230
231        let acc_p = ctx.builder().build_alloca(lhs.get_type(), "acc_ptr")?;
232        let exp_p = ctx.builder().build_alloca(rhs.get_type(), "exp_ptr")?;
233        ctx.builder().build_store(acc_p, lhs)?;
234        ctx.builder().build_store(exp_p, rhs)?;
235        ctx.builder().build_unconditional_branch(pow_bb)?;
236
237        let zero = rhs.get_type().into_int_type().const_int(0, false);
238        // Assumes RHS type is the same as output type (which it should be)
239        let one = rhs.get_type().into_int_type().const_int(1, false);
240
241        // Block for just returning one
242        ctx.builder().position_at_end(return_one_bb);
243        ctx.builder().build_store(acc_p, one)?;
244        ctx.builder().build_unconditional_branch(done_bb)?;
245
246        ctx.builder().position_at_end(pow_bb);
247        let acc = ctx.builder().build_load(acc_p, "acc")?;
248        let exp = ctx.builder().build_load(exp_p, "exp")?;
249
250        // Special case if the exponent is 0 or 1
251        ctx.builder().build_switch(
252            exp.into_int_value(),
253            pow_body_bb,
254            &[(one, done_bb), (zero, return_one_bb)],
255        )?;
256
257        // Block that performs one `imul` and modifies the values in the store
258        ctx.builder().position_at_end(pow_body_bb);
259        let new_acc =
260            ctx.builder()
261                .build_int_mul(acc.into_int_value(), lhs.into_int_value(), "new_acc")?;
262        let new_exp = ctx
263            .builder()
264            .build_int_sub(exp.into_int_value(), one, "new_exp")?;
265        ctx.builder().build_store(acc_p, new_acc)?;
266        ctx.builder().build_store(exp_p, new_exp)?;
267        ctx.builder().build_unconditional_branch(pow_bb)?;
268
269        ctx.builder().position_at_end(done_bb);
270        let result = ctx.builder().build_load(acc_p, "result")?;
271        Ok(vec![result.as_basic_value_enum()])
272    })
273}
274
275fn emit_int_op<'c, H: HugrView<Node = Node>>(
276    context: &mut EmitFuncContext<'c, '_, H>,
277    pcg: &impl PreludeCodegen,
278    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
279    op: IntOpDef,
280) -> Result<()> {
281    match op {
282        IntOpDef::iadd => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
283            Ok(vec![
284                ctx.builder()
285                    .build_int_add(lhs.into_int_value(), rhs.into_int_value(), "")?
286                    .as_basic_value_enum(),
287            ])
288        }),
289        IntOpDef::imul => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
290            Ok(vec![
291                ctx.builder()
292                    .build_int_mul(lhs.into_int_value(), rhs.into_int_value(), "")?
293                    .as_basic_value_enum(),
294            ])
295        }),
296        IntOpDef::isub => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
297            Ok(vec![
298                ctx.builder()
299                    .build_int_sub(lhs.into_int_value(), rhs.into_int_value(), "")?
300                    .as_basic_value_enum(),
301            ])
302        }),
303        IntOpDef::idiv_s => {
304            let log_width = get_width_arg(&args, &op)?;
305            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
306                let op = DivModOp {
307                    op: DivOrMod::Div,
308                    signed: true,
309                    panic: true,
310                };
311                op.emit(
312                    ctx,
313                    pcg,
314                    log_width,
315                    lhs.into_int_value(),
316                    rhs.into_int_value(),
317                )
318            })
319        }
320        IntOpDef::idiv_u => {
321            let log_width = get_width_arg(&args, &op)?;
322            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
323                let op = DivModOp {
324                    op: DivOrMod::Div,
325                    signed: false,
326                    panic: true,
327                };
328                op.emit(
329                    ctx,
330                    pcg,
331                    log_width,
332                    lhs.into_int_value(),
333                    rhs.into_int_value(),
334                )
335            })
336        }
337        IntOpDef::imod_s => {
338            let log_width = get_width_arg(&args, &op)?;
339            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
340                let op = DivModOp {
341                    op: DivOrMod::Mod,
342                    signed: true,
343                    panic: true,
344                };
345                op.emit(
346                    ctx,
347                    pcg,
348                    log_width,
349                    lhs.into_int_value(),
350                    rhs.into_int_value(),
351                )
352            })
353        }
354        IntOpDef::imod_u => {
355            let log_width = get_width_arg(&args, &op)?;
356            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
357                let op = DivModOp {
358                    op: DivOrMod::Mod,
359                    signed: false,
360                    panic: true,
361                };
362                op.emit(
363                    ctx,
364                    pcg,
365                    log_width,
366                    lhs.into_int_value(),
367                    rhs.into_int_value(),
368                )
369            })
370        }
371        IntOpDef::idivmod_u => {
372            let log_width = get_width_arg(&args, &op)?;
373            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
374                let op = DivModOp {
375                    op: DivOrMod::DivMod,
376                    signed: false,
377                    panic: true,
378                };
379                op.emit(
380                    ctx,
381                    pcg,
382                    log_width,
383                    lhs.into_int_value(),
384                    rhs.into_int_value(),
385                )
386            })
387        }
388        IntOpDef::idivmod_s => {
389            let log_width = get_width_arg(&args, &op)?;
390            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
391                let op = DivModOp {
392                    op: DivOrMod::DivMod,
393                    signed: true,
394                    panic: true,
395                };
396                op.emit(
397                    ctx,
398                    pcg,
399                    log_width,
400                    lhs.into_int_value(),
401                    rhs.into_int_value(),
402                )
403            })
404        }
405        IntOpDef::idiv_checked_s => {
406            let log_width = get_width_arg(&args, &op)?;
407            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
408                let op = DivModOp {
409                    op: DivOrMod::Div,
410                    signed: true,
411                    panic: false,
412                };
413                op.emit(
414                    ctx,
415                    pcg,
416                    log_width,
417                    lhs.into_int_value(),
418                    rhs.into_int_value(),
419                )
420            })
421        }
422        IntOpDef::idiv_checked_u => {
423            let log_width = get_width_arg(&args, &op)?;
424
425            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
426                let op = DivModOp {
427                    op: DivOrMod::Div,
428                    signed: false,
429                    panic: false,
430                };
431                op.emit(
432                    ctx,
433                    pcg,
434                    log_width,
435                    lhs.into_int_value(),
436                    rhs.into_int_value(),
437                )
438            })
439        }
440        IntOpDef::imod_checked_s => {
441            let log_width = get_width_arg(&args, &op)?;
442            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
443                let op = DivModOp {
444                    op: DivOrMod::Mod,
445                    signed: true,
446                    panic: false,
447                };
448                op.emit(
449                    ctx,
450                    pcg,
451                    log_width,
452                    lhs.into_int_value(),
453                    rhs.into_int_value(),
454                )
455            })
456        }
457        IntOpDef::imod_checked_u => {
458            let log_width = get_width_arg(&args, &op)?;
459            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
460                let op = DivModOp {
461                    op: DivOrMod::Mod,
462                    signed: false,
463                    panic: false,
464                };
465                op.emit(
466                    ctx,
467                    pcg,
468                    log_width,
469                    lhs.into_int_value(),
470                    rhs.into_int_value(),
471                )
472            })
473        }
474        IntOpDef::idivmod_checked_u => {
475            let log_width = get_width_arg(&args, &op)?;
476            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
477                let op = DivModOp {
478                    op: DivOrMod::DivMod,
479                    signed: false,
480                    panic: false,
481                };
482                op.emit(
483                    ctx,
484                    pcg,
485                    log_width,
486                    lhs.into_int_value(),
487                    rhs.into_int_value(),
488                )
489            })
490        }
491        IntOpDef::idivmod_checked_s => {
492            let log_width = get_width_arg(&args, &op)?;
493            emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
494                let op = DivModOp {
495                    op: DivOrMod::DivMod,
496                    signed: true,
497                    panic: false,
498                };
499                op.emit(
500                    ctx,
501                    pcg,
502                    log_width,
503                    lhs.into_int_value(),
504                    rhs.into_int_value(),
505                )
506            })
507        }
508        IntOpDef::ineg => emit_custom_unary_op(context, args, |ctx, arg, _| {
509            Ok(vec![
510                ctx.builder()
511                    .build_int_neg(arg.into_int_value(), "")?
512                    .as_basic_value_enum(),
513            ])
514        }),
515        IntOpDef::iabs => emit_custom_unary_op(context, args, |ctx, arg, _| {
516            let intr = get_intrinsic(
517                ctx.get_current_module(),
518                "llvm.abs.i64",
519                [ctx.iw_context().i64_type().as_basic_type_enum()],
520            )?;
521            let true_ = ctx.iw_context().bool_type().const_all_ones();
522            let r = ctx
523                .builder()
524                .build_call(intr, &[arg.into_int_value().into(), true_.into()], "")?
525                .try_as_basic_value()
526                .unwrap_left();
527            Ok(vec![r])
528        }),
529        IntOpDef::imax_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
530            let intr = get_intrinsic(
531                ctx.get_current_module(),
532                "llvm.smax.i64",
533                [ctx.iw_context().i64_type().as_basic_type_enum()],
534            )?;
535            let r = ctx
536                .builder()
537                .build_call(
538                    intr,
539                    &[lhs.into_int_value().into(), rhs.into_int_value().into()],
540                    "",
541                )?
542                .try_as_basic_value()
543                .unwrap_left();
544            Ok(vec![r])
545        }),
546        IntOpDef::imax_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
547            let intr = get_intrinsic(
548                ctx.get_current_module(),
549                "llvm.umax.i64",
550                [ctx.iw_context().i64_type().as_basic_type_enum()],
551            )?;
552            let r = ctx
553                .builder()
554                .build_call(
555                    intr,
556                    &[lhs.into_int_value().into(), rhs.into_int_value().into()],
557                    "",
558                )?
559                .try_as_basic_value()
560                .unwrap_left();
561            Ok(vec![r])
562        }),
563        IntOpDef::imin_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
564            let intr = get_intrinsic(
565                ctx.get_current_module(),
566                "llvm.smin.i64",
567                [ctx.iw_context().i64_type().as_basic_type_enum()],
568            )?;
569            let r = ctx
570                .builder()
571                .build_call(
572                    intr,
573                    &[lhs.into_int_value().into(), rhs.into_int_value().into()],
574                    "",
575                )?
576                .try_as_basic_value()
577                .unwrap_left();
578            Ok(vec![r])
579        }),
580        IntOpDef::imin_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
581            let intr = get_intrinsic(
582                ctx.get_current_module(),
583                "llvm.umin.i64",
584                [ctx.iw_context().i64_type().as_basic_type_enum()],
585            )?;
586            let r = ctx
587                .builder()
588                .build_call(
589                    intr,
590                    &[lhs.into_int_value().into(), rhs.into_int_value().into()],
591                    "",
592                )?
593                .try_as_basic_value()
594                .unwrap_left();
595            Ok(vec![r])
596        }),
597        IntOpDef::ishl => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
598            Ok(vec![
599                ctx.builder()
600                    .build_left_shift(lhs.into_int_value(), rhs.into_int_value(), "")?
601                    .as_basic_value_enum(),
602            ])
603        }),
604        IntOpDef::ishr => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
605            Ok(vec![
606                ctx.builder()
607                    .build_right_shift(lhs.into_int_value(), rhs.into_int_value(), false, "")?
608                    .as_basic_value_enum(),
609            ])
610        }),
611        IntOpDef::ieq => emit_icmp(context, args, inkwell::IntPredicate::EQ),
612        IntOpDef::ine => emit_icmp(context, args, inkwell::IntPredicate::NE),
613        IntOpDef::ilt_s => emit_icmp(context, args, inkwell::IntPredicate::SLT),
614        IntOpDef::igt_s => emit_icmp(context, args, inkwell::IntPredicate::SGT),
615        IntOpDef::ile_s => emit_icmp(context, args, inkwell::IntPredicate::SLE),
616        IntOpDef::ige_s => emit_icmp(context, args, inkwell::IntPredicate::SGE),
617        IntOpDef::ilt_u => emit_icmp(context, args, inkwell::IntPredicate::ULT),
618        IntOpDef::igt_u => emit_icmp(context, args, inkwell::IntPredicate::UGT),
619        IntOpDef::ile_u => emit_icmp(context, args, inkwell::IntPredicate::ULE),
620        IntOpDef::ige_u => emit_icmp(context, args, inkwell::IntPredicate::UGE),
621        IntOpDef::ixor => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
622            Ok(vec![
623                ctx.builder()
624                    .build_xor(lhs.into_int_value(), rhs.into_int_value(), "")?
625                    .as_basic_value_enum(),
626            ])
627        }),
628        IntOpDef::ior => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
629            Ok(vec![
630                ctx.builder()
631                    .build_or(lhs.into_int_value(), rhs.into_int_value(), "")?
632                    .as_basic_value_enum(),
633            ])
634        }),
635        IntOpDef::inot => emit_custom_unary_op(context, args, |ctx, arg, _| {
636            Ok(vec![
637                ctx.builder()
638                    .build_not(arg.into_int_value(), "")?
639                    .as_basic_value_enum(),
640            ])
641        }),
642        IntOpDef::iand => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
643            Ok(vec![
644                ctx.builder()
645                    .build_and(lhs.into_int_value(), rhs.into_int_value(), "")?
646                    .as_basic_value_enum(),
647            ])
648        }),
649        IntOpDef::ipow => emit_ipow(context, args),
650        // Type args are width of input, width of output
651        IntOpDef::iwiden_u => emit_custom_unary_op(context, args, |ctx, arg, outs| {
652            let [out] = outs.try_into()?;
653            Ok(vec![
654                ctx.builder()
655                    .build_int_cast_sign_flag(arg.into_int_value(), out.into_int_type(), false, "")?
656                    .as_basic_value_enum(),
657            ])
658        }),
659        IntOpDef::iwiden_s => emit_custom_unary_op(context, args, |ctx, arg, outs| {
660            let [out] = outs.try_into()?;
661
662            Ok(vec![
663                ctx.builder()
664                    .build_int_cast_sign_flag(arg.into_int_value(), out.into_int_type(), true, "")?
665                    .as_basic_value_enum(),
666            ])
667        }),
668        IntOpDef::inarrow_s => {
669            let Some(TypeArg::BoundedNat(out_log_width)) = args.node().args().last().cloned()
670            else {
671                bail!("Type arg to inarrow_s wasn't a Nat");
672            };
673            let (_, out_ty) = args.node.out_value_types().next().unwrap();
674            emit_custom_unary_op(context, args, |ctx, arg, outs| {
675                let result = make_narrow(
676                    ctx,
677                    arg,
678                    outs,
679                    out_log_width,
680                    true,
681                    out_ty.as_sum().unwrap().clone(),
682                )?;
683                Ok(vec![result])
684            })
685        }
686        IntOpDef::inarrow_u => {
687            let Some(TypeArg::BoundedNat(out_log_width)) = args.node().args().last().cloned()
688            else {
689                bail!("Type arg to inarrow_u wasn't a Nat");
690            };
691            let (_, out_ty) = args.node.out_value_types().next().unwrap();
692            emit_custom_unary_op(context, args, |ctx, arg, outs| {
693                let result = make_narrow(
694                    ctx,
695                    arg,
696                    outs,
697                    out_log_width,
698                    false,
699                    out_ty.as_sum().unwrap().clone(),
700                )?;
701                Ok(vec![result])
702            })
703        }
704        IntOpDef::iu_to_s => {
705            let log_width = get_width_arg(&args, &op)?;
706            emit_custom_unary_op(context, args, |ctx, arg, _| {
707                let (_, max_val, _) = int_type_bounds(u32::pow(2, log_width as u32));
708                let max = arg
709                    .get_type()
710                    .into_int_type()
711                    .const_int(max_val as u64, false);
712
713                let within_bounds = ctx.builder().build_int_compare(
714                    IntPredicate::ULE,
715                    arg.into_int_value(),
716                    max,
717                    "bounds_check",
718                )?;
719
720                Ok(vec![val_or_panic(
721                    ctx,
722                    pcg,
723                    within_bounds,
724                    &ERR_IU_TO_S,
725                    |_| Ok(arg),
726                )?])
727            })
728        }
729        IntOpDef::is_to_u => emit_custom_unary_op(context, args, |ctx, arg, _| {
730            let zero = arg.get_type().into_int_type().const_zero();
731
732            let within_bounds = ctx.builder().build_int_compare(
733                IntPredicate::SGE,
734                arg.into_int_value(),
735                zero,
736                "bounds_check",
737            )?;
738
739            Ok(vec![val_or_panic(
740                ctx,
741                pcg,
742                within_bounds,
743                &ERR_IS_TO_U,
744                |_| Ok(arg),
745            )?])
746        }),
747        _ => Err(anyhow!("IntOpEmitter: unimplemented op: {}", op.op_id())),
748    }
749}
750
751// Helper to get the log width arg to an int op when it's the only argument
752// panic if there's not exactly one nat arg
753pub(crate) fn get_width_arg<H: HugrView<Node = Node>>(
754    args: &EmitOpArgs<'_, '_, ExtensionOp, H>,
755    op: &impl MakeExtensionOp,
756) -> Result<u64> {
757    let [TypeArg::BoundedNat(log_width)] = args.node.args() else {
758        bail!(
759            "Expected exactly one BoundedNat parameter to {}",
760            op.op_id()
761        )
762    };
763    Ok(*log_width)
764}
765
766// The semantics of the hugr operation specify that the divisor argument is
767// always unsigned, and the signed/unsigned variants affect the types of the
768// dividend and quotient only.
769//
770// LLVM's semantics for `srem`, however, have both operands being the same type.
771// Moreover, llvm's `srem` does not implement the modulo operation: the
772// remainder will have the same sign as the dividend instead of the divisor.
773//
774// See discussion at: https://github.com/CQCL/hugr/pull/2025#discussion_r2012537992
775fn make_divmod<'c, H: HugrView<Node = Node>>(
776    ctx: &mut EmitFuncContext<'c, '_, H>,
777    pcg: &impl PreludeCodegen,
778    log_width: u64,
779    numerator: IntValue<'c>,
780    denominator: IntValue<'c>,
781    panic: bool,
782    signed: bool,
783) -> Result<LLVMSumValue<'c>> {
784    let int_arg_ty = int_types::INT_TYPES[log_width as usize].clone();
785    let tuple_sum_ty = HugrSumType::new_tuple(vec![int_arg_ty.clone(), int_arg_ty.clone()]);
786
787    let pair_ty = LLVMSumType::try_from_hugr_type(&ctx.typing_session(), tuple_sum_ty.clone())?;
788
789    let build_divmod = |ctx: &mut EmitFuncContext<'c, '_, H>| -> Result<BasicValueEnum<'c>> {
790        if signed {
791            let max_signed_value = u64::pow(2, u32::pow(2, log_width as u32) - 1) - 1;
792            let max_signed = numerator.get_type().const_int(max_signed_value, false);
793            // Determine whether the divisor is "big" or "smol" for special casing.
794            // Here, "big" means the divisor is larger than the biggest value
795            // that could be represented by the type of the dividend.
796            let large_divisor_bool = ctx.builder().build_int_compare(
797                IntPredicate::UGT,
798                denominator,
799                max_signed,
800                "is_divisor_large",
801            )?;
802            let large_divisor =
803                ctx.builder()
804                    .build_int_z_extend(large_divisor_bool, denominator.get_type(), "")?;
805            let negative_numerator_bool = ctx.builder().build_int_compare(
806                IntPredicate::SLT,
807                numerator,
808                numerator.get_type().const_zero(),
809                "is_dividend_negative",
810            )?;
811            let negative_numerator = ctx.builder().build_int_z_extend(
812                negative_numerator_bool,
813                denominator.get_type(),
814                "",
815            )?;
816            let tag = ctx.builder().build_left_shift(
817                large_divisor,
818                denominator.get_type().const_int(1, false),
819                "",
820            )?;
821
822            let tag = ctx.builder().build_or(tag, negative_numerator, "tag")?;
823
824            let quot = ctx
825                .builder()
826                .build_int_signed_div(numerator, denominator, "quotient")?;
827            let rem = ctx
828                .builder()
829                .build_int_signed_rem(numerator, denominator, "remainder")?;
830
831            let result_ptr = ctx.builder().build_alloca(pair_ty.clone(), "result")?;
832
833            let finish = ctx.new_basic_block("finish", None);
834            let negative_bigdiv = ctx.new_basic_block("negative_bigdiv", Some(finish));
835            let negative_smoldiv = ctx.new_basic_block("negative_smoldiv", Some(finish));
836            let non_negative_bigdiv = ctx.new_basic_block("non_negative_bigdiv", Some(finish));
837            let non_negative_smoldiv = ctx.new_basic_block("non_negative_smoldiv", Some(finish));
838
839            ctx.builder().build_switch(
840                tag,
841                non_negative_smoldiv,
842                &[
843                    (denominator.get_type().const_int(1, false), negative_smoldiv),
844                    (
845                        denominator.get_type().const_int(2, false),
846                        non_negative_bigdiv,
847                    ),
848                    (denominator.get_type().const_int(3, false), negative_bigdiv),
849                ],
850            )?;
851
852            let build_and_store_result =
853                |ctx: &mut EmitFuncContext<'c, '_, H>, vs: Vec<BasicValueEnum<'c>>| -> Result<()> {
854                    let result = pair_ty
855                        .build_tag(ctx.builder(), 0, vs)?
856                        //.build_tag(ctx.builder(), 0, vec![tag.as_basic_value_enum(), tag.as_basic_value_enum()])?
857                        .as_basic_value_enum();
858                    ctx.builder().build_store(result_ptr, result)?;
859                    ctx.builder().build_unconditional_branch(finish)?;
860                    Ok(())
861                };
862
863            // Default case (although it should only be reached by one branch).
864            // When the divisor is smol and the dividend is positive, we can
865            // rely on LLVM intrinsics.
866            ctx.builder().position_at_end(non_negative_smoldiv);
867            build_and_store_result(
868                ctx,
869                vec![quot.as_basic_value_enum(), rem.as_basic_value_enum()],
870            )?;
871
872            // When the divisor is smol and the dividend is negative,
873            // we have two cases:
874            ctx.builder().position_at_end(negative_smoldiv);
875            {
876                // If the remainder is 0, we can use the results of LLVM's `srem`
877                let if_rem_zero = pair_ty
878                    .build_tag(
879                        ctx.builder(),
880                        0,
881                        vec![
882                            quot.as_basic_value_enum(),
883                            rem.get_type().const_zero().as_basic_value_enum(),
884                        ],
885                    )?
886                    .as_basic_value_enum();
887
888                // Otherwise, we return `(quotient - 1, divisor + remainder)`
889                let if_rem_nonzero = pair_ty
890                    .build_tag(
891                        ctx.builder(),
892                        0,
893                        vec![
894                            ctx.builder()
895                                .build_int_sub(quot, quot.get_type().const_int(1, true), "")?
896                                .as_basic_value_enum(),
897                            ctx.builder()
898                                .build_int_add(denominator, rem, "")?
899                                .as_basic_value_enum(),
900                        ],
901                    )?
902                    .as_basic_value_enum();
903
904                let is_rem_zero = ctx.builder().build_int_compare(
905                    IntPredicate::EQ,
906                    rem,
907                    rem.get_type().const_zero(),
908                    "is_rem_0",
909                )?;
910                let result =
911                    ctx.builder()
912                        .build_select(is_rem_zero, if_rem_zero, if_rem_nonzero, "")?;
913                ctx.builder().build_store(result_ptr, result)?;
914                ctx.builder().build_unconditional_branch(finish)?;
915            }
916
917            // The (unsigned) divisor is bigger than the (signed) dividend could
918            // possibly be, so it's safe to return quotient 0 and remainder = dividend
919            ctx.builder().position_at_end(non_negative_bigdiv);
920            build_and_store_result(
921                ctx,
922                vec![
923                    numerator.get_type().const_zero().as_basic_value_enum(),
924                    numerator.as_basic_value_enum(),
925                ],
926            )?;
927
928            // The divisor is larger than the dividend can possibly be, and the
929            // dividend is negative. This means we have to return `quotient - 1`
930            // and the remainder is `dividend + divisor`.
931            ctx.builder().position_at_end(negative_bigdiv);
932            build_and_store_result(
933                ctx,
934                vec![
935                    numerator.get_type().const_all_ones().as_basic_value_enum(),
936                    ctx.builder()
937                        .build_int_add(numerator, denominator, "")?
938                        .as_basic_value_enum(),
939                ],
940            )?;
941
942            ctx.builder().position_at_end(finish);
943            let result = ctx.builder().build_load(result_ptr, "result")?;
944            Ok(result)
945        } else {
946            let quot = ctx
947                .builder()
948                .build_int_unsigned_div(numerator, denominator, "quotient")?;
949            let rem = ctx
950                .builder()
951                .build_int_unsigned_rem(numerator, denominator, "remainder")?;
952            Ok(pair_ty
953                .build_tag(
954                    ctx.builder(),
955                    0,
956                    vec![quot.as_basic_value_enum(), rem.as_basic_value_enum()],
957                )?
958                .as_basic_value_enum())
959        }
960    };
961
962    let int_ty = numerator.get_type();
963    let zero = int_ty.const_zero();
964    let lower_bounds_check =
965        ctx.builder()
966            .build_int_compare(IntPredicate::NE, denominator, zero, "valid_div")?;
967
968    let sum_ty = LLVMSumType::try_from_hugr_type(
969        &ctx.typing_session(),
970        sum_with_error(vec![Type::from(tuple_sum_ty)]),
971    )?;
972
973    if panic {
974        LLVMSumValue::try_new(
975            val_or_panic(ctx, pcg, lower_bounds_check, &ERR_DIV_0, |ctx| {
976                build_divmod(ctx)
977            })?,
978            pair_ty,
979        )
980    } else {
981        let result = build_divmod(ctx)?;
982        LLVMSumValue::try_new(
983            val_or_error(ctx, lower_bounds_check, result, &ERR_DIV_0, sum_ty.clone())?,
984            sum_ty,
985        )
986    }
987}
988
989fn make_narrow<'c, H: HugrView<Node = Node>>(
990    ctx: &mut EmitFuncContext<'c, '_, H>,
991    arg: BasicValueEnum<'c>,
992    outs: &[BasicTypeEnum<'c>],
993    out_log_width: u64,
994    signed: bool,
995    sum_type: HugrSumType,
996) -> Result<BasicValueEnum<'c>> {
997    let [out] = TryInto::<[BasicTypeEnum; 1]>::try_into(outs)?;
998    let width = 1 << out_log_width;
999    let arg_int_ty: IntType = arg.get_type().into_int_type();
1000    let (int_min_value_s, int_max_value_s, int_max_value_u) = int_type_bounds(width);
1001    let out_int_ty = out
1002        .into_struct_type()
1003        .get_field_type_at_index(2)
1004        .unwrap()
1005        .into_int_type();
1006    let outside_range = if signed {
1007        let too_big = ctx.builder().build_int_compare(
1008            IntPredicate::SGT,
1009            arg.into_int_value(),
1010            arg_int_ty.const_int(int_max_value_s as u64, true),
1011            "upper_bounds_check",
1012        )?;
1013        let too_small = ctx.builder().build_int_compare(
1014            IntPredicate::SLT,
1015            arg.into_int_value(),
1016            arg_int_ty.const_int(int_min_value_s as u64, true),
1017            "lower_bounds_check",
1018        )?;
1019        ctx.builder()
1020            .build_or(too_big, too_small, "outside_range")?
1021    } else {
1022        ctx.builder().build_int_compare(
1023            IntPredicate::UGT,
1024            arg.into_int_value(),
1025            arg_int_ty.const_int(int_max_value_u, false),
1026            "upper_bounds_check",
1027        )?
1028    };
1029
1030    let inbounds = ctx.builder().build_not(outside_range, "inbounds")?;
1031    let narrowed_val = ctx
1032        .builder()
1033        .build_int_cast_sign_flag(arg.into_int_value(), out_int_ty, signed, "")?
1034        .as_basic_value_enum();
1035    val_or_error(
1036        ctx,
1037        inbounds,
1038        narrowed_val,
1039        &ERR_NARROW,
1040        LLVMSumType::try_from_hugr_type(&ctx.typing_session(), sum_type).unwrap(),
1041    )
1042}
1043
1044fn val_or_panic<'c, H: HugrView<Node = Node>>(
1045    ctx: &mut EmitFuncContext<'c, '_, H>,
1046    pcg: &impl PreludeCodegen,
1047    dont_panic: IntValue<'c>,
1048    err: &ConstError,
1049    // Returned value must be same int type as `dont_panic`.
1050    go: impl Fn(&mut EmitFuncContext<'c, '_, H>) -> Result<BasicValueEnum<'c>>,
1051) -> Result<BasicValueEnum<'c>> {
1052    let exit_bb = ctx.new_basic_block("exit", None);
1053    let go_bb = ctx.new_basic_block("panic_if_0", Some(exit_bb));
1054    let panic_bb = ctx.new_basic_block("panic", Some(exit_bb));
1055    ctx.builder().build_unconditional_branch(go_bb)?;
1056
1057    ctx.builder().position_at_end(panic_bb);
1058    let err = ctx.emit_custom_const(err)?;
1059    pcg.emit_panic(ctx, err)?;
1060    ctx.builder().build_unconditional_branch(exit_bb)?;
1061
1062    ctx.builder().position_at_end(go_bb);
1063    ctx.builder().build_switch(
1064        dont_panic,
1065        panic_bb,
1066        &[(dont_panic.get_type().const_int(1, false), exit_bb)],
1067    )?;
1068
1069    ctx.builder().position_at_end(exit_bb);
1070
1071    go(ctx)
1072}
1073
1074fn val_or_error<'c, H: HugrView<Node = Node>>(
1075    ctx: &mut EmitFuncContext<'c, '_, H>,
1076    should_succeed: IntValue<'c>,
1077    val: BasicValueEnum<'c>,
1078    err: &ConstError,
1079    ty: LLVMSumType<'c>,
1080) -> Result<BasicValueEnum<'c>> {
1081    let err_val = ctx.emit_custom_const(err)?;
1082
1083    let err_variant = ty.build_tag(ctx.builder(), 0, vec![err_val])?;
1084    let ok_variant = ty.build_tag(ctx.builder(), 1, vec![val])?;
1085
1086    Ok(ctx
1087        .builder()
1088        .build_select(should_succeed, ok_variant, err_variant, "")?)
1089}
1090
1091fn llvm_type<'c>(
1092    context: TypingSession<'c, '_>,
1093    hugr_type: &CustomType,
1094) -> Result<BasicTypeEnum<'c>> {
1095    if let [TypeArg::BoundedNat(n)] = hugr_type.args() {
1096        let m = *n as usize;
1097        if m < int_types::INT_TYPES.len() && int_types::INT_TYPES[m] == hugr_type.clone().into() {
1098            return Ok(match m {
1099                0..=3 => context.iw_context().i8_type(),
1100                4 => context.iw_context().i16_type(),
1101                5 => context.iw_context().i32_type(),
1102                6 => context.iw_context().i64_type(),
1103                _ => Err(anyhow!(
1104                    "IntTypesCodegenExtension: unsupported log_width: {}",
1105                    m
1106                ))?,
1107            }
1108            .into());
1109        }
1110    }
1111    Err(anyhow!(
1112        "IntTypesCodegenExtension: unsupported type: {}",
1113        hugr_type
1114    ))
1115}
1116
1117fn emit_const_int<'c, H: HugrView<Node = Node>>(
1118    context: &mut EmitFuncContext<'c, '_, H>,
1119    k: &ConstInt,
1120) -> Result<BasicValueEnum<'c>> {
1121    let ty: IntType = context.llvm_type(&k.get_type())?.try_into().unwrap();
1122    // k.value_u() is in two's complement representation of the exactly
1123    // correct bit width, so we are safe to unconditionally retrieve the
1124    // unsigned value and do no sign extension.
1125    Ok(ty.const_int(k.value_u(), false).as_basic_value_enum())
1126}
1127
1128impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
1129    /// Populates a [`CodegenExtsBuilder`] with all extensions needed to lower int
1130    /// ops, types, and constants.
1131    ///
1132    /// Any ops that panic will do so using [`DefaultPreludeCodegen`].
1133    #[must_use]
1134    pub fn add_default_int_extensions(self) -> Self {
1135        self.add_extension(IntCodegenExtension::new(DefaultPreludeCodegen))
1136    }
1137}
1138
1139#[cfg(test)]
1140mod test {
1141    use anyhow::Result;
1142    use hugr_core::builder::DataflowHugr;
1143    use hugr_core::extension::prelude::{ConstError, UnwrapBuilder, error_type};
1144    use hugr_core::std_extensions::STD_REG;
1145    use hugr_core::{
1146        Hugr,
1147        builder::{Dataflow, DataflowSubContainer, SubContainer, handle::Outputs},
1148        extension::prelude::bool_t,
1149        ops::{DataflowOpTrait, ExtensionOp},
1150        std_extensions::arithmetic::{
1151            int_ops::{self, IntOpDef},
1152            int_types::{ConstInt, INT_TYPES},
1153        },
1154        types::{SumType, Type, TypeRow},
1155    };
1156    use rstest::rstest;
1157
1158    use crate::{
1159        check_emission,
1160        emit::test::{DFGW, SimpleHugrConfig},
1161        test::{TestContext, exec_ctx, llvm_ctx, single_op_hugr},
1162    };
1163
1164    #[rstest::fixture]
1165    fn int_exec_ctx(mut exec_ctx: TestContext) -> TestContext {
1166        exec_ctx.add_extensions(|cem| {
1167            cem.add_default_int_extensions()
1168                .add_default_prelude_extensions()
1169        });
1170        exec_ctx
1171    }
1172
1173    #[rstest::fixture]
1174    fn int_llvm_ctx(mut llvm_ctx: TestContext) -> TestContext {
1175        llvm_ctx.add_extensions(|cem| {
1176            cem.add_default_int_extensions()
1177                .add_default_prelude_extensions()
1178        });
1179        llvm_ctx
1180    }
1181
1182    // Instantiate an extension op which takes one width argument
1183    fn make_int_op(name: impl AsRef<str>, log_width: u8) -> ExtensionOp {
1184        int_ops::EXTENSION
1185            .instantiate_extension_op(name.as_ref(), [u64::from(log_width).into()])
1186            .unwrap()
1187    }
1188
1189    fn test_binary_int_op(ext_op: ExtensionOp, log_width: u8) -> Hugr {
1190        let ty = &INT_TYPES[log_width as usize];
1191        test_int_op_with_results::<2>(ext_op, log_width, None, ty.clone())
1192    }
1193
1194    fn test_binary_icmp_op(ext_op: ExtensionOp, log_width: u8) -> Hugr {
1195        test_int_op_with_results::<2>(ext_op, log_width, None, bool_t())
1196    }
1197
1198    fn test_int_op_with_results<const N: usize>(
1199        ext_op: ExtensionOp,
1200        log_width: u8,
1201        inputs: Option<[ConstInt; N]>,
1202        output_type: Type,
1203    ) -> Hugr {
1204        test_int_op_with_results_processing(ext_op, log_width, inputs, output_type, |_, a| Ok(a))
1205    }
1206
1207    fn test_int_op_with_results_processing<const N: usize>(
1208        // N is the number of inputs to the hugr
1209        ext_op: ExtensionOp,
1210        log_width: u8,
1211        inputs: Option<[ConstInt; N]>, // If inputs are provided, they'll be wired into the op, otherwise the inputs to the hugr will be wired into the op
1212        output_type: Type,
1213        process: impl Fn(&mut DFGW, Outputs) -> Result<Outputs>,
1214    ) -> Hugr {
1215        let ty = &INT_TYPES[log_width as usize];
1216        let input_tys = if inputs.is_some() {
1217            vec![]
1218        } else {
1219            let input_tys = itertools::repeat_n(ty.clone(), N).collect();
1220            assert_eq!(input_tys, ext_op.signature().input.to_vec());
1221            input_tys
1222        };
1223        SimpleHugrConfig::new()
1224            .with_ins(input_tys)
1225            .with_outs(vec![output_type])
1226            .with_extensions(STD_REG.clone())
1227            .finish(|mut hugr_builder| {
1228                let input_wires = match inputs {
1229                    None => hugr_builder.input_wires_arr::<N>().to_vec(),
1230                    Some(inputs) => {
1231                        let mut input_wires = Vec::new();
1232                        for i in inputs.into_iter() {
1233                            let w = hugr_builder.add_load_value(i);
1234                            input_wires.push(w);
1235                        }
1236                        input_wires
1237                    }
1238                };
1239                let outputs = hugr_builder
1240                    .add_dataflow_op(ext_op, input_wires)
1241                    .unwrap()
1242                    .outputs();
1243                let processed_outputs = process(&mut hugr_builder, outputs).unwrap();
1244                hugr_builder
1245                    .finish_hugr_with_outputs(processed_outputs)
1246                    .unwrap()
1247            })
1248    }
1249
1250    #[rstest]
1251    #[case(IntOpDef::iu_to_s, &[3])]
1252    #[case(IntOpDef::is_to_u, &[3])]
1253    #[case(IntOpDef::ineg, &[2])]
1254    #[case::idiv_checked_u("idiv_checked_u", &[3])]
1255    #[case::idiv_checked_s("idiv_checked_s", &[3])]
1256    #[case::imod_checked_u("imod_checked_u", &[6])]
1257    #[case::imod_checked_s("imod_checked_s", &[6])]
1258    #[case::idivmod_u("idivmod_u", &[3])]
1259    #[case::idivmod_s("idivmod_s", &[3])]
1260    #[case::idivmod_checked_u("idivmod_checked_u", &[6])]
1261    #[case::idivmod_checked_s("idivmod_checked_s", &[6])]
1262    fn test_emission(int_llvm_ctx: TestContext, #[case] op: IntOpDef, #[case] args: &[u8]) {
1263        use hugr_core::extension::simple_op::MakeExtensionOp as _;
1264
1265        let mut insta = insta::Settings::clone_current();
1266        insta.set_snapshot_suffix(format!(
1267            "{}_{}_{:?}",
1268            insta.snapshot_suffix().unwrap_or(""),
1269            op.op_id(),
1270            args,
1271        ));
1272        let concrete = match *args {
1273            [] => op.without_log_width(),
1274            [log_width] => op.with_log_width(log_width),
1275            [lw1, lw2] => op.with_two_log_widths(lw1, lw2),
1276            _ => panic!("unexpected number of args to the op!"),
1277        };
1278        insta.bind(|| {
1279            let hugr = single_op_hugr(concrete.into());
1280            check_emission!(hugr, int_llvm_ctx);
1281        });
1282    }
1283
1284    #[rstest]
1285    #[case::iadd("iadd", 3)]
1286    #[case::isub("isub", 6)]
1287    #[case::ipow("ipow", 3)]
1288    #[case::idiv_u("idiv_u", 3)]
1289    #[case::idiv_s("idiv_s", 3)]
1290    #[case::imod_u("imod_u", 3)]
1291    #[case::imod_s("imod_s", 3)]
1292    fn test_binop_emission(int_llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
1293        let ext_op = make_int_op(op.clone(), width);
1294        let hugr = test_binary_int_op(ext_op, width);
1295        check_emission!(op.clone(), hugr, int_llvm_ctx);
1296    }
1297
1298    #[rstest]
1299    #[case::signed_2_3("iwiden_s", 2, 3)]
1300    #[case::signed_1_6("iwiden_s", 1, 6)]
1301    #[case::unsigned_2_3("iwiden_u", 2, 3)]
1302    #[case::unsigned_1_6("iwiden_u", 1, 6)]
1303    fn test_widen_emission(
1304        int_llvm_ctx: TestContext,
1305        #[case] op: String,
1306        #[case] from: u8,
1307        #[case] to: u8,
1308    ) {
1309        let out_ty = INT_TYPES[to as usize].clone();
1310        let ext_op = int_ops::EXTENSION
1311            .instantiate_extension_op(&op, [u64::from(from).into(), u64::from(to).into()])
1312            .unwrap();
1313        let hugr = test_int_op_with_results::<1>(ext_op, from, None, out_ty);
1314
1315        check_emission!(
1316            format!("{}_{}_{}", op.clone(), from, to),
1317            hugr,
1318            int_llvm_ctx
1319        );
1320    }
1321
1322    #[rstest]
1323    #[case::signed("inarrow_s", 3, 2)]
1324    #[case::unsigned("inarrow_u", 6, 4)]
1325    fn test_narrow_emission(
1326        int_llvm_ctx: TestContext,
1327        #[case] op: String,
1328        #[case] from: u8,
1329        #[case] to: u8,
1330    ) {
1331        let out_ty = SumType::new([vec![error_type()], vec![INT_TYPES[to as usize].clone()]]);
1332        let ext_op = int_ops::EXTENSION
1333            .instantiate_extension_op(&op, [u64::from(from).into(), u64::from(to).into()])
1334            .unwrap();
1335        let hugr = test_int_op_with_results::<1>(ext_op, from, None, out_ty.into());
1336
1337        check_emission!(
1338            format!("{}_{}_{}", op.clone(), from, to),
1339            hugr,
1340            int_llvm_ctx
1341        );
1342    }
1343
1344    #[rstest]
1345    #[case::ieq("ieq", 1)]
1346    #[case::ilt_s("ilt_s", 0)]
1347    fn test_cmp_emission(int_llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
1348        let ext_op = make_int_op(op.clone(), width);
1349        let hugr = test_binary_icmp_op(ext_op, width);
1350        check_emission!(op.clone(), hugr, int_llvm_ctx);
1351    }
1352
1353    #[rstest]
1354    #[case::imax("imax_u", 1, 2, 2)]
1355    #[case::imax("imax_u", 2, 1, 2)]
1356    #[case::imax("imax_u", 2, 2, 2)]
1357    #[case::imin("imin_u", 1, 2, 1)]
1358    #[case::imin("imin_u", 2, 1, 1)]
1359    #[case::imin("imin_u", 2, 2, 2)]
1360    #[case::ishl("ishl", 73, 1, 146)]
1361    // (2^64 - 1) << 1 = (2^64 - 2)
1362    #[case::ishl("ishl", 18446744073709551615, 1, 18446744073709551614)]
1363    #[case::ishr("ishr", 73, 1, 36)]
1364    #[case::ior("ior", 6, 9, 15)]
1365    #[case::ior("ior", 6, 15, 15)]
1366    #[case::ixor("ixor", 6, 9, 15)]
1367    #[case::ixor("ixor", 6, 15, 9)]
1368    #[case::ixor("ixor", 15, 6, 9)]
1369    #[case::iand("iand", 6, 15, 6)]
1370    #[case::iand("iand", 15, 6, 6)]
1371    #[case::iand("iand", 15, 15, 15)]
1372    #[case::ipow("ipow", 2, 3, 8)]
1373    #[case::ipow("ipow", 42, 1, 42)]
1374    #[case::ipow("ipow", 42, 0, 1)]
1375    #[case::idiv("idiv_u", 42, 2, 21)]
1376    #[case::idiv("idiv_u", 42, 5, 8)]
1377    #[case::imod("imod_u", 42, 2, 0)]
1378    #[case::imod("imod_u", 42, 5, 2)]
1379    fn test_exec_unsigned_bin_op(
1380        int_exec_ctx: TestContext,
1381        #[case] op: String,
1382        #[case] lhs: u64,
1383        #[case] rhs: u64,
1384        #[case] result: u64,
1385    ) {
1386        let ty = &INT_TYPES[6].clone();
1387        let inputs = [
1388            ConstInt::new_u(6, lhs).unwrap(),
1389            ConstInt::new_u(6, rhs).unwrap(),
1390        ];
1391        let ext_op = make_int_op(&op, 6);
1392
1393        let hugr = test_int_op_with_results::<2>(ext_op, 6, Some(inputs), ty.clone());
1394        assert_eq!(int_exec_ctx.exec_hugr_u64(hugr, "main"), result);
1395    }
1396
1397    #[rstest]
1398    #[case::imax("imax_s", 1, 2, 2)]
1399    #[case::imax("imax_s", 2, 1, 2)]
1400    #[case::imax("imax_s", 2, 2, 2)]
1401    #[case::imax("imax_s", -1, -2, -1)]
1402    #[case::imax("imax_s", -2, -1, -1)]
1403    #[case::imax("imax_s", -2, -2, -2)]
1404    #[case::imin("imin_s", 1, 2, 1)]
1405    #[case::imin("imin_s", 2, 1, 1)]
1406    #[case::imin("imin_s", 2, 2, 2)]
1407    #[case::imin("imin_s", -1, -2, -2)]
1408    #[case::imin("imin_s", -2, -1, -2)]
1409    #[case::imin("imin_s", -2, -2, -2)]
1410    #[case::ipow("ipow", -2, 1, -2)]
1411    #[case::ipow("ipow", -2, 2, 4)]
1412    #[case::ipow("ipow", -2, 3, -8)]
1413    fn test_exec_signed_bin_op(
1414        int_exec_ctx: TestContext,
1415        #[case] op: String,
1416        #[case] lhs: i64,
1417        #[case] rhs: i64,
1418        #[case] result: i64,
1419    ) {
1420        let ty = &INT_TYPES[6].clone();
1421        let inputs = [
1422            ConstInt::new_s(6, lhs).unwrap(),
1423            ConstInt::new_s(6, rhs).unwrap(),
1424        ];
1425        let ext_op = make_int_op(&op, 6);
1426
1427        let hugr = test_int_op_with_results::<2>(ext_op, 6, Some(inputs), ty.clone());
1428        assert_eq!(int_exec_ctx.exec_hugr_i64(hugr, "main"), result);
1429    }
1430
1431    #[rstest]
1432    #[case::iabs("iabs", 42, 42)]
1433    #[case::iabs("iabs", -42, 42)]
1434    fn test_exec_signed_unary_op(
1435        int_exec_ctx: TestContext,
1436        #[case] op: String,
1437        #[case] arg: i64,
1438        #[case] result: i64,
1439    ) {
1440        let input = ConstInt::new_s(6, arg).unwrap();
1441        let ty = INT_TYPES[6].clone();
1442        let ext_op = make_int_op(&op, 6);
1443
1444        let hugr = test_int_op_with_results::<1>(ext_op, 6, Some([input]), ty.clone());
1445        assert_eq!(int_exec_ctx.exec_hugr_i64(hugr, "main"), result);
1446    }
1447
1448    #[rstest]
1449    #[case::inot("inot", 9223372036854775808, !9223372036854775808u64)]
1450    #[case::inot("inot", 42, !42u64)]
1451    #[case::inot("inot", !0u64, 0)]
1452    fn test_exec_unsigned_unary_op(
1453        int_exec_ctx: TestContext,
1454        #[case] op: String,
1455        #[case] arg: u64,
1456        #[case] result: u64,
1457    ) {
1458        let input = ConstInt::new_u(6, arg).unwrap();
1459        let ty = INT_TYPES[6].clone();
1460        let ext_op = make_int_op(&op, 6);
1461
1462        let hugr = test_int_op_with_results::<1>(ext_op, 6, Some([input]), ty.clone());
1463        assert_eq!(int_exec_ctx.exec_hugr_u64(hugr, "main"), result);
1464    }
1465
1466    #[rstest]
1467    #[case(-127)]
1468    #[case(-1)]
1469    #[case(0)]
1470    #[case(1)]
1471    #[case(127)]
1472    fn test_exec_widen(int_exec_ctx: TestContext, #[case] num: i16) {
1473        let from: u8 = 3;
1474        let to: u8 = 6;
1475        let ty = INT_TYPES[to as usize].clone();
1476
1477        if num >= 0 {
1478            let input = ConstInt::new_u(from, num as u64).unwrap();
1479
1480            let ext_op = int_ops::EXTENSION
1481                .instantiate_extension_op(
1482                    "iwiden_u".as_ref(),
1483                    [(from as u64).into(), (to as u64).into()],
1484                )
1485                .unwrap();
1486
1487            let hugr = test_int_op_with_results::<1>(ext_op, to, Some([input]), ty.clone());
1488
1489            assert_eq!(int_exec_ctx.exec_hugr_u64(hugr, "main"), num as u64);
1490        }
1491
1492        let input = ConstInt::new_s(from, num as i64).unwrap();
1493
1494        let ext_op = int_ops::EXTENSION
1495            .instantiate_extension_op(
1496                "iwiden_s".as_ref(),
1497                [(from as u64).into(), (to as u64).into()],
1498            )
1499            .unwrap();
1500
1501        let hugr = test_int_op_with_results::<1>(ext_op, to, Some([input]), ty.clone());
1502
1503        assert_eq!(int_exec_ctx.exec_hugr_i64(hugr, "main"), num as i64);
1504    }
1505
1506    #[rstest]
1507    #[case("inarrow_s", 6, 2, 4)]
1508    #[case("inarrow_s", 6, 5, (1 << 5) - 1)]
1509    #[case("inarrow_s", 6, 4, -1)]
1510    #[case("inarrow_s", 6, 4, -(1 << 4) - 1)]
1511    #[case("inarrow_s", 6, 4, -(1 <<15))]
1512    #[case("inarrow_s", 6, 5, (1 << 31) - 1)]
1513    fn test_narrow_s(
1514        int_exec_ctx: TestContext,
1515        #[case] op: String,
1516        #[case] from: u8,
1517        #[case] to: u8,
1518        #[case] arg: i64,
1519    ) {
1520        let input = ConstInt::new_s(from, arg).unwrap();
1521        let to_ty = INT_TYPES[to as usize].clone();
1522        let ext_op = int_ops::EXTENSION
1523            .instantiate_extension_op(op.as_ref(), [u64::from(from).into(), u64::from(to).into()])
1524            .unwrap();
1525
1526        let hugr = test_int_op_with_results_processing::<1>(
1527            ext_op,
1528            to,
1529            Some([input]),
1530            to_ty.clone(),
1531            |builder, outs| {
1532                let [out] = outs.to_array();
1533
1534                let err_row = TypeRow::from(vec![error_type()]);
1535                let ty_row = TypeRow::from(vec![to_ty.clone()]);
1536                // Handle the sum type returned by narrow by building a conditional.
1537                // We're only testing the happy path here, so insert a panic in the
1538                // "error" branch, knowing that it wont come up.
1539                //
1540                // Negative results can be tested manually, but we lack the testing
1541                // infrastructure to detect execution crashes without crashing the
1542                // test process.
1543                let mut cond_b = builder.conditional_builder(
1544                    ([err_row, ty_row], out),
1545                    [],
1546                    vec![to_ty.clone()].into(),
1547                )?;
1548                let mut sad_b = cond_b.case_builder(0)?;
1549                let err = ConstError::new(2, "This shouldn't happen");
1550                let w = sad_b.add_load_value(ConstInt::new_s(to, 0)?);
1551                sad_b.add_panic(err, vec![to_ty.clone()], [(w, to_ty.clone())])?;
1552                sad_b.finish_with_outputs([w])?;
1553
1554                let happy_b = cond_b.case_builder(1)?;
1555                let [w] = happy_b.input_wires_arr();
1556                happy_b.finish_with_outputs([w])?;
1557
1558                let handle = cond_b.finish_sub_container()?;
1559                Ok(handle.outputs())
1560            },
1561        );
1562        assert_eq!(int_exec_ctx.exec_hugr_i64(hugr, "main"), arg);
1563    }
1564
1565    #[rstest]
1566    #[case(6, 42)]
1567    #[case(4, 7)]
1568    //#[case(4, 256)] -- crashes because a panic is emitted (good)
1569    fn test_u_to_s(int_exec_ctx: TestContext, #[case] log_width: u8, #[case] val: u64) {
1570        let ty = &INT_TYPES[log_width as usize].clone();
1571        let hugr = SimpleHugrConfig::new()
1572            .with_outs(vec![ty.clone()])
1573            .with_extensions(STD_REG.clone())
1574            .finish(|mut hugr_builder| {
1575                let unsigned =
1576                    hugr_builder.add_load_value(ConstInt::new_u(log_width, val).unwrap());
1577                let iu_to_s = make_int_op("iu_to_s", log_width);
1578                let [signed] = hugr_builder
1579                    .add_dataflow_op(iu_to_s, [unsigned])
1580                    .unwrap()
1581                    .outputs_arr();
1582                hugr_builder.finish_hugr_with_outputs([signed]).unwrap()
1583            });
1584        let act = int_exec_ctx.exec_hugr_i64(hugr, "main");
1585        assert_eq!(act, val as i64);
1586    }
1587
1588    #[rstest]
1589    #[case(3, 0)]
1590    #[case(4, 255)]
1591    // #[case(3, -1)] -- crashes because a panic is emitted (good)
1592    fn test_s_to_u(int_exec_ctx: TestContext, #[case] log_width: u8, #[case] val: i64) {
1593        let ty = &INT_TYPES[log_width as usize].clone();
1594        let hugr = SimpleHugrConfig::new()
1595            .with_outs(vec![ty.clone()])
1596            .with_extensions(STD_REG.clone())
1597            .finish(|mut hugr_builder| {
1598                let signed = hugr_builder.add_load_value(ConstInt::new_s(log_width, val).unwrap());
1599                let is_to_u = make_int_op("is_to_u", log_width);
1600                let [unsigned] = hugr_builder
1601                    .add_dataflow_op(is_to_u, [signed])
1602                    .unwrap()
1603                    .outputs_arr();
1604                let num = hugr_builder.add_load_value(ConstInt::new_u(log_width, 42).unwrap());
1605                let [res] = hugr_builder
1606                    .add_dataflow_op(make_int_op("iadd", log_width), [unsigned, num])
1607                    .unwrap()
1608                    .outputs_arr();
1609                hugr_builder.finish_hugr_with_outputs([res]).unwrap()
1610            });
1611        let act = int_exec_ctx.exec_hugr_u64(hugr, "main");
1612        assert_eq!(act, (val as u64) + 42);
1613    }
1614
1615    // Log width fixed at 3 (i.e. divmod : Fn(i8, u8) -> (i8, u8)
1616    #[rstest]
1617    #[case::bigdiv_non_negative(127, 255, (0, 127))] // Big divisor, positive dividend
1618    #[case::bigdiv_negative(-42, 255, (-1, 213))] // Big divisor, negative dividend
1619    #[case::smoldiv_non_negative(42, 10, (4, 2))] // Normal divisor, positive dividend
1620    #[case::smoldiv_negative_rem0(-42, 21, (-2, 0))] // Normal divisor, negative dividend, remainder 0
1621    #[case::smoldiv_negative_rem_nonzero(-42, 10, (-5, 8))] // Normal divisor, negative dividend, remainder >0
1622    fn test_divmod_s(
1623        int_exec_ctx: TestContext,
1624        #[case] dividend: i64,
1625        #[case] divisor: u64,
1626        #[case] expected_result: (i64, u64),
1627    ) {
1628        let int_ty = INT_TYPES[3].clone();
1629        let k_dividend = ConstInt::new_s(3, dividend).unwrap();
1630        let k_divisor = ConstInt::new_u(3, divisor).unwrap();
1631        let quot_hugr = test_int_op_with_results(
1632            make_int_op("idiv_s", 3),
1633            3,
1634            Some([k_dividend.clone(), k_divisor.clone()]),
1635            int_ty.clone(),
1636        );
1637        let rem_hugr = test_int_op_with_results(
1638            make_int_op("imod_s", 3),
1639            3,
1640            Some([k_dividend, k_divisor]),
1641            int_ty,
1642        );
1643        let quot = int_exec_ctx.exec_hugr_i64(quot_hugr, "main");
1644        let rem = int_exec_ctx.exec_hugr_u64(rem_hugr, "main");
1645        assert_eq!((quot, rem), expected_result);
1646    }
1647}