1use hugr_core::{
2 HugrView, Node,
3 extension::{
4 prelude::{ConstError, sum_with_error},
5 simple_op::MakeExtensionOp,
6 },
7 ops::{ExtensionOp, Value, constant::CustomConst},
8 std_extensions::arithmetic::{
9 int_ops::IntOpDef,
10 int_types::{self, ConstInt},
11 },
12 types::{CustomType, Type, TypeArg},
13};
14use inkwell::{
15 IntPredicate,
16 types::{BasicType, BasicTypeEnum, IntType},
17 values::{BasicValue, BasicValueEnum, IntValue},
18};
19use std::sync::LazyLock;
20
21use crate::{
22 CodegenExtension,
23 custom::CodegenExtsBuilder,
24 emit::{
25 EmitOpArgs, emit_value,
26 func::EmitFuncContext,
27 get_intrinsic,
28 ops::{emit_custom_binary_op, emit_custom_unary_op},
29 },
30 sum::{LLVMSumType, LLVMSumValue},
31 types::{HugrSumType, TypingSession},
32};
33
34use anyhow::{Result, anyhow, bail};
35
36use super::{DefaultPreludeCodegen, PreludeCodegen, conversions::int_type_bounds};
37
38#[derive(Clone, Debug, Default)]
39pub struct IntCodegenExtension<PCG>(PCG);
40
41impl<PCG: PreludeCodegen> IntCodegenExtension<PCG> {
42 pub fn new(ccg: PCG) -> Self {
43 Self(ccg)
44 }
45}
46
47impl<CCG: PreludeCodegen> From<CCG> for IntCodegenExtension<CCG> {
48 fn from(ccg: CCG) -> Self {
49 Self::new(ccg)
50 }
51}
52
53impl<CCG: PreludeCodegen> CodegenExtension for IntCodegenExtension<CCG> {
54 fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
55 self,
56 builder: CodegenExtsBuilder<'a, H>,
57 ) -> CodegenExtsBuilder<'a, H>
58 where
59 Self: 'a,
60 {
61 builder
62 .custom_const(emit_const_int)
63 .custom_type((int_types::EXTENSION_ID, "int".into()), llvm_type)
64 .simple_extension_op::<IntOpDef>(move |context, args, op| {
65 emit_int_op(context, &self.0, args, op)
66 })
67 }
68}
69
70static ERR_NARROW: LazyLock<ConstError> = LazyLock::new(|| ConstError {
71 signal: 2,
72 message: "Can't narrow into bounds".to_string(),
73});
74static ERR_IU_TO_S: LazyLock<ConstError> = LazyLock::new(|| ConstError {
75 signal: 2,
76 message: "iu_to_s argument out of bounds".to_string(),
77});
78static ERR_IS_TO_U: LazyLock<ConstError> = LazyLock::new(|| ConstError {
79 signal: 2,
80 message: "is_to_u called on negative value".to_string(),
81});
82static ERR_DIV_0: LazyLock<ConstError> = LazyLock::new(|| ConstError {
83 signal: 2,
84 message: "Attempted division by 0".to_string(),
85});
86
87#[derive(Debug, Eq, PartialEq)]
88enum DivOrMod {
89 Div,
90 Mod,
91 DivMod,
92}
93
94struct DivModOp {
95 op: DivOrMod,
96 signed: bool,
97 panic: bool,
98}
99
100impl DivModOp {
101 fn emit<'c, H: HugrView<Node = Node>>(
102 self,
103 ctx: &mut EmitFuncContext<'c, '_, H>,
104 pcg: &impl PreludeCodegen,
105 log_width: u64,
106 numerator: IntValue<'c>,
107 denominator: IntValue<'c>,
108 ) -> Result<Vec<BasicValueEnum<'c>>> {
109 let quotrem = make_divmod(
112 ctx,
113 pcg,
114 log_width,
115 numerator,
116 denominator,
117 self.panic,
118 self.signed,
119 )?;
120
121 if self.op == DivOrMod::DivMod {
122 if self.panic {
123 Ok(quotrem.build_untag(ctx.builder(), 0).unwrap())
125 } else {
126 Ok(vec![quotrem.as_basic_value_enum()])
127 }
128 } else {
129 let index = match self.op {
131 DivOrMod::Div => 0,
132 DivOrMod::Mod => 1,
133 _ => unreachable!(),
134 };
135 if self.panic {
137 Ok(vec![
138 quotrem
139 .build_untag(ctx.builder(), 0)?
140 .into_iter()
141 .nth(index)
142 .unwrap(),
143 ])
144 }
145 else {
148 let int_ty = numerator.get_type().as_basic_type_enum();
150 let tuple_ty =
151 LLVMSumType::try_new(ctx.iw_context(), vec![vec![int_ty, int_ty]]).unwrap();
152 let tuple = quotrem
153 .build_untag(ctx.builder(), 1)?
154 .into_iter()
155 .next()
156 .unwrap();
157 let tuple_val = LLVMSumValue::try_new(tuple, tuple_ty)?;
158 let data_val = tuple_val
159 .build_untag(ctx.builder(), 0)?
160 .into_iter()
161 .nth(index)
162 .unwrap();
163 let err_val = quotrem
164 .build_untag(ctx.builder(), 0)?
165 .into_iter()
166 .next()
167 .unwrap();
168
169 let tag_val = quotrem.build_get_tag(ctx.builder())?;
170 tag_val.set_name("tag");
171
172 let int_ty = int_types::INT_TYPES[log_width as usize].clone();
174 let out_ty = LLVMSumType::try_from_hugr_type(
175 &ctx.typing_session(),
176 sum_with_error(vec![int_ty.clone()]),
177 )
178 .unwrap();
179
180 let data_variant = out_ty.build_tag(ctx.builder(), 1, vec![data_val])?;
181 data_variant.set_name("data_variant");
182 let err_variant = out_ty.build_tag(ctx.builder(), 0, vec![err_val])?;
183 err_variant.set_name("err_variant");
184
185 let result = ctx
186 .builder()
187 .build_select(tag_val, data_variant, err_variant, "")?;
188 Ok(vec![result])
189 }
190 }
191 }
192}
193
194fn emit_icmp<'c, H: HugrView<Node = Node>>(
196 context: &mut EmitFuncContext<'c, '_, H>,
197 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
198 pred: inkwell::IntPredicate,
199) -> Result<()> {
200 let true_val = emit_value(context, &Value::true_val())?;
201 let false_val = emit_value(context, &Value::false_val())?;
202
203 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
204 let r = ctx.builder().build_int_compare(
206 pred,
207 lhs.into_int_value(),
208 rhs.into_int_value(),
209 "",
210 )?;
211 Ok(vec![
213 ctx.builder().build_select(r, true_val, false_val, "")?,
214 ])
215 })
216}
217
218fn emit_ipow<'c, H: HugrView<Node = Node>>(
222 context: &mut EmitFuncContext<'c, '_, H>,
223 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
224) -> Result<()> {
225 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
226 let done_bb = ctx.new_basic_block("done", None);
227 let pow_body_bb = ctx.new_basic_block("pow_body", Some(done_bb));
228 let return_one_bb = ctx.new_basic_block("power_of_zero", Some(pow_body_bb));
229 let pow_bb = ctx.new_basic_block("pow", Some(return_one_bb));
230
231 let acc_p = ctx.builder().build_alloca(lhs.get_type(), "acc_ptr")?;
232 let exp_p = ctx.builder().build_alloca(rhs.get_type(), "exp_ptr")?;
233 ctx.builder().build_store(acc_p, lhs)?;
234 ctx.builder().build_store(exp_p, rhs)?;
235 ctx.builder().build_unconditional_branch(pow_bb)?;
236
237 let zero = rhs.get_type().into_int_type().const_int(0, false);
238 let one = rhs.get_type().into_int_type().const_int(1, false);
240
241 ctx.builder().position_at_end(return_one_bb);
243 ctx.builder().build_store(acc_p, one)?;
244 ctx.builder().build_unconditional_branch(done_bb)?;
245
246 ctx.builder().position_at_end(pow_bb);
247 let acc = ctx.builder().build_load(acc_p, "acc")?;
248 let exp = ctx.builder().build_load(exp_p, "exp")?;
249
250 ctx.builder().build_switch(
252 exp.into_int_value(),
253 pow_body_bb,
254 &[(one, done_bb), (zero, return_one_bb)],
255 )?;
256
257 ctx.builder().position_at_end(pow_body_bb);
259 let new_acc =
260 ctx.builder()
261 .build_int_mul(acc.into_int_value(), lhs.into_int_value(), "new_acc")?;
262 let new_exp = ctx
263 .builder()
264 .build_int_sub(exp.into_int_value(), one, "new_exp")?;
265 ctx.builder().build_store(acc_p, new_acc)?;
266 ctx.builder().build_store(exp_p, new_exp)?;
267 ctx.builder().build_unconditional_branch(pow_bb)?;
268
269 ctx.builder().position_at_end(done_bb);
270 let result = ctx.builder().build_load(acc_p, "result")?;
271 Ok(vec![result.as_basic_value_enum()])
272 })
273}
274
275fn emit_int_op<'c, H: HugrView<Node = Node>>(
276 context: &mut EmitFuncContext<'c, '_, H>,
277 pcg: &impl PreludeCodegen,
278 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
279 op: IntOpDef,
280) -> Result<()> {
281 match op {
282 IntOpDef::iadd => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
283 Ok(vec![
284 ctx.builder()
285 .build_int_add(lhs.into_int_value(), rhs.into_int_value(), "")?
286 .as_basic_value_enum(),
287 ])
288 }),
289 IntOpDef::imul => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
290 Ok(vec![
291 ctx.builder()
292 .build_int_mul(lhs.into_int_value(), rhs.into_int_value(), "")?
293 .as_basic_value_enum(),
294 ])
295 }),
296 IntOpDef::isub => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
297 Ok(vec![
298 ctx.builder()
299 .build_int_sub(lhs.into_int_value(), rhs.into_int_value(), "")?
300 .as_basic_value_enum(),
301 ])
302 }),
303 IntOpDef::idiv_s => {
304 let log_width = get_width_arg(&args, &op)?;
305 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
306 let op = DivModOp {
307 op: DivOrMod::Div,
308 signed: true,
309 panic: true,
310 };
311 op.emit(
312 ctx,
313 pcg,
314 log_width,
315 lhs.into_int_value(),
316 rhs.into_int_value(),
317 )
318 })
319 }
320 IntOpDef::idiv_u => {
321 let log_width = get_width_arg(&args, &op)?;
322 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
323 let op = DivModOp {
324 op: DivOrMod::Div,
325 signed: false,
326 panic: true,
327 };
328 op.emit(
329 ctx,
330 pcg,
331 log_width,
332 lhs.into_int_value(),
333 rhs.into_int_value(),
334 )
335 })
336 }
337 IntOpDef::imod_s => {
338 let log_width = get_width_arg(&args, &op)?;
339 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
340 let op = DivModOp {
341 op: DivOrMod::Mod,
342 signed: true,
343 panic: true,
344 };
345 op.emit(
346 ctx,
347 pcg,
348 log_width,
349 lhs.into_int_value(),
350 rhs.into_int_value(),
351 )
352 })
353 }
354 IntOpDef::imod_u => {
355 let log_width = get_width_arg(&args, &op)?;
356 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
357 let op = DivModOp {
358 op: DivOrMod::Mod,
359 signed: false,
360 panic: true,
361 };
362 op.emit(
363 ctx,
364 pcg,
365 log_width,
366 lhs.into_int_value(),
367 rhs.into_int_value(),
368 )
369 })
370 }
371 IntOpDef::idivmod_u => {
372 let log_width = get_width_arg(&args, &op)?;
373 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
374 let op = DivModOp {
375 op: DivOrMod::DivMod,
376 signed: false,
377 panic: true,
378 };
379 op.emit(
380 ctx,
381 pcg,
382 log_width,
383 lhs.into_int_value(),
384 rhs.into_int_value(),
385 )
386 })
387 }
388 IntOpDef::idivmod_s => {
389 let log_width = get_width_arg(&args, &op)?;
390 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
391 let op = DivModOp {
392 op: DivOrMod::DivMod,
393 signed: true,
394 panic: true,
395 };
396 op.emit(
397 ctx,
398 pcg,
399 log_width,
400 lhs.into_int_value(),
401 rhs.into_int_value(),
402 )
403 })
404 }
405 IntOpDef::idiv_checked_s => {
406 let log_width = get_width_arg(&args, &op)?;
407 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
408 let op = DivModOp {
409 op: DivOrMod::Div,
410 signed: true,
411 panic: false,
412 };
413 op.emit(
414 ctx,
415 pcg,
416 log_width,
417 lhs.into_int_value(),
418 rhs.into_int_value(),
419 )
420 })
421 }
422 IntOpDef::idiv_checked_u => {
423 let log_width = get_width_arg(&args, &op)?;
424
425 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
426 let op = DivModOp {
427 op: DivOrMod::Div,
428 signed: false,
429 panic: false,
430 };
431 op.emit(
432 ctx,
433 pcg,
434 log_width,
435 lhs.into_int_value(),
436 rhs.into_int_value(),
437 )
438 })
439 }
440 IntOpDef::imod_checked_s => {
441 let log_width = get_width_arg(&args, &op)?;
442 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
443 let op = DivModOp {
444 op: DivOrMod::Mod,
445 signed: true,
446 panic: false,
447 };
448 op.emit(
449 ctx,
450 pcg,
451 log_width,
452 lhs.into_int_value(),
453 rhs.into_int_value(),
454 )
455 })
456 }
457 IntOpDef::imod_checked_u => {
458 let log_width = get_width_arg(&args, &op)?;
459 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
460 let op = DivModOp {
461 op: DivOrMod::Mod,
462 signed: false,
463 panic: false,
464 };
465 op.emit(
466 ctx,
467 pcg,
468 log_width,
469 lhs.into_int_value(),
470 rhs.into_int_value(),
471 )
472 })
473 }
474 IntOpDef::idivmod_checked_u => {
475 let log_width = get_width_arg(&args, &op)?;
476 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
477 let op = DivModOp {
478 op: DivOrMod::DivMod,
479 signed: false,
480 panic: false,
481 };
482 op.emit(
483 ctx,
484 pcg,
485 log_width,
486 lhs.into_int_value(),
487 rhs.into_int_value(),
488 )
489 })
490 }
491 IntOpDef::idivmod_checked_s => {
492 let log_width = get_width_arg(&args, &op)?;
493 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
494 let op = DivModOp {
495 op: DivOrMod::DivMod,
496 signed: true,
497 panic: false,
498 };
499 op.emit(
500 ctx,
501 pcg,
502 log_width,
503 lhs.into_int_value(),
504 rhs.into_int_value(),
505 )
506 })
507 }
508 IntOpDef::ineg => emit_custom_unary_op(context, args, |ctx, arg, _| {
509 Ok(vec![
510 ctx.builder()
511 .build_int_neg(arg.into_int_value(), "")?
512 .as_basic_value_enum(),
513 ])
514 }),
515 IntOpDef::iabs => emit_custom_unary_op(context, args, |ctx, arg, _| {
516 let intr = get_intrinsic(
517 ctx.get_current_module(),
518 "llvm.abs.i64",
519 [ctx.iw_context().i64_type().as_basic_type_enum()],
520 )?;
521 let true_ = ctx.iw_context().bool_type().const_all_ones();
522 let r = ctx
523 .builder()
524 .build_call(intr, &[arg.into_int_value().into(), true_.into()], "")?
525 .try_as_basic_value()
526 .unwrap_left();
527 Ok(vec![r])
528 }),
529 IntOpDef::imax_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
530 let intr = get_intrinsic(
531 ctx.get_current_module(),
532 "llvm.smax.i64",
533 [ctx.iw_context().i64_type().as_basic_type_enum()],
534 )?;
535 let r = ctx
536 .builder()
537 .build_call(
538 intr,
539 &[lhs.into_int_value().into(), rhs.into_int_value().into()],
540 "",
541 )?
542 .try_as_basic_value()
543 .unwrap_left();
544 Ok(vec![r])
545 }),
546 IntOpDef::imax_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
547 let intr = get_intrinsic(
548 ctx.get_current_module(),
549 "llvm.umax.i64",
550 [ctx.iw_context().i64_type().as_basic_type_enum()],
551 )?;
552 let r = ctx
553 .builder()
554 .build_call(
555 intr,
556 &[lhs.into_int_value().into(), rhs.into_int_value().into()],
557 "",
558 )?
559 .try_as_basic_value()
560 .unwrap_left();
561 Ok(vec![r])
562 }),
563 IntOpDef::imin_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
564 let intr = get_intrinsic(
565 ctx.get_current_module(),
566 "llvm.smin.i64",
567 [ctx.iw_context().i64_type().as_basic_type_enum()],
568 )?;
569 let r = ctx
570 .builder()
571 .build_call(
572 intr,
573 &[lhs.into_int_value().into(), rhs.into_int_value().into()],
574 "",
575 )?
576 .try_as_basic_value()
577 .unwrap_left();
578 Ok(vec![r])
579 }),
580 IntOpDef::imin_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
581 let intr = get_intrinsic(
582 ctx.get_current_module(),
583 "llvm.umin.i64",
584 [ctx.iw_context().i64_type().as_basic_type_enum()],
585 )?;
586 let r = ctx
587 .builder()
588 .build_call(
589 intr,
590 &[lhs.into_int_value().into(), rhs.into_int_value().into()],
591 "",
592 )?
593 .try_as_basic_value()
594 .unwrap_left();
595 Ok(vec![r])
596 }),
597 IntOpDef::ishl => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
598 Ok(vec![
599 ctx.builder()
600 .build_left_shift(lhs.into_int_value(), rhs.into_int_value(), "")?
601 .as_basic_value_enum(),
602 ])
603 }),
604 IntOpDef::ishr => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
605 Ok(vec![
606 ctx.builder()
607 .build_right_shift(lhs.into_int_value(), rhs.into_int_value(), false, "")?
608 .as_basic_value_enum(),
609 ])
610 }),
611 IntOpDef::ieq => emit_icmp(context, args, inkwell::IntPredicate::EQ),
612 IntOpDef::ine => emit_icmp(context, args, inkwell::IntPredicate::NE),
613 IntOpDef::ilt_s => emit_icmp(context, args, inkwell::IntPredicate::SLT),
614 IntOpDef::igt_s => emit_icmp(context, args, inkwell::IntPredicate::SGT),
615 IntOpDef::ile_s => emit_icmp(context, args, inkwell::IntPredicate::SLE),
616 IntOpDef::ige_s => emit_icmp(context, args, inkwell::IntPredicate::SGE),
617 IntOpDef::ilt_u => emit_icmp(context, args, inkwell::IntPredicate::ULT),
618 IntOpDef::igt_u => emit_icmp(context, args, inkwell::IntPredicate::UGT),
619 IntOpDef::ile_u => emit_icmp(context, args, inkwell::IntPredicate::ULE),
620 IntOpDef::ige_u => emit_icmp(context, args, inkwell::IntPredicate::UGE),
621 IntOpDef::ixor => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
622 Ok(vec![
623 ctx.builder()
624 .build_xor(lhs.into_int_value(), rhs.into_int_value(), "")?
625 .as_basic_value_enum(),
626 ])
627 }),
628 IntOpDef::ior => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
629 Ok(vec![
630 ctx.builder()
631 .build_or(lhs.into_int_value(), rhs.into_int_value(), "")?
632 .as_basic_value_enum(),
633 ])
634 }),
635 IntOpDef::inot => emit_custom_unary_op(context, args, |ctx, arg, _| {
636 Ok(vec![
637 ctx.builder()
638 .build_not(arg.into_int_value(), "")?
639 .as_basic_value_enum(),
640 ])
641 }),
642 IntOpDef::iand => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
643 Ok(vec![
644 ctx.builder()
645 .build_and(lhs.into_int_value(), rhs.into_int_value(), "")?
646 .as_basic_value_enum(),
647 ])
648 }),
649 IntOpDef::ipow => emit_ipow(context, args),
650 IntOpDef::iwiden_u => emit_custom_unary_op(context, args, |ctx, arg, outs| {
652 let [out] = outs.try_into()?;
653 Ok(vec![
654 ctx.builder()
655 .build_int_cast_sign_flag(arg.into_int_value(), out.into_int_type(), false, "")?
656 .as_basic_value_enum(),
657 ])
658 }),
659 IntOpDef::iwiden_s => emit_custom_unary_op(context, args, |ctx, arg, outs| {
660 let [out] = outs.try_into()?;
661
662 Ok(vec![
663 ctx.builder()
664 .build_int_cast_sign_flag(arg.into_int_value(), out.into_int_type(), true, "")?
665 .as_basic_value_enum(),
666 ])
667 }),
668 IntOpDef::inarrow_s => {
669 let Some(TypeArg::BoundedNat(out_log_width)) = args.node().args().last().cloned()
670 else {
671 bail!("Type arg to inarrow_s wasn't a Nat");
672 };
673 let (_, out_ty) = args.node.out_value_types().next().unwrap();
674 emit_custom_unary_op(context, args, |ctx, arg, outs| {
675 let result = make_narrow(
676 ctx,
677 arg,
678 outs,
679 out_log_width,
680 true,
681 out_ty.as_sum().unwrap().clone(),
682 )?;
683 Ok(vec![result])
684 })
685 }
686 IntOpDef::inarrow_u => {
687 let Some(TypeArg::BoundedNat(out_log_width)) = args.node().args().last().cloned()
688 else {
689 bail!("Type arg to inarrow_u wasn't a Nat");
690 };
691 let (_, out_ty) = args.node.out_value_types().next().unwrap();
692 emit_custom_unary_op(context, args, |ctx, arg, outs| {
693 let result = make_narrow(
694 ctx,
695 arg,
696 outs,
697 out_log_width,
698 false,
699 out_ty.as_sum().unwrap().clone(),
700 )?;
701 Ok(vec![result])
702 })
703 }
704 IntOpDef::iu_to_s => {
705 let log_width = get_width_arg(&args, &op)?;
706 emit_custom_unary_op(context, args, |ctx, arg, _| {
707 let (_, max_val, _) = int_type_bounds(u32::pow(2, log_width as u32));
708 let max = arg
709 .get_type()
710 .into_int_type()
711 .const_int(max_val as u64, false);
712
713 let within_bounds = ctx.builder().build_int_compare(
714 IntPredicate::ULE,
715 arg.into_int_value(),
716 max,
717 "bounds_check",
718 )?;
719
720 Ok(vec![val_or_panic(
721 ctx,
722 pcg,
723 within_bounds,
724 &ERR_IU_TO_S,
725 |_| Ok(arg),
726 )?])
727 })
728 }
729 IntOpDef::is_to_u => emit_custom_unary_op(context, args, |ctx, arg, _| {
730 let zero = arg.get_type().into_int_type().const_zero();
731
732 let within_bounds = ctx.builder().build_int_compare(
733 IntPredicate::SGE,
734 arg.into_int_value(),
735 zero,
736 "bounds_check",
737 )?;
738
739 Ok(vec![val_or_panic(
740 ctx,
741 pcg,
742 within_bounds,
743 &ERR_IS_TO_U,
744 |_| Ok(arg),
745 )?])
746 }),
747 _ => Err(anyhow!("IntOpEmitter: unimplemented op: {}", op.op_id())),
748 }
749}
750
751pub(crate) fn get_width_arg<H: HugrView<Node = Node>>(
754 args: &EmitOpArgs<'_, '_, ExtensionOp, H>,
755 op: &impl MakeExtensionOp,
756) -> Result<u64> {
757 let [TypeArg::BoundedNat(log_width)] = args.node.args() else {
758 bail!(
759 "Expected exactly one BoundedNat parameter to {}",
760 op.op_id()
761 )
762 };
763 Ok(*log_width)
764}
765
766fn make_divmod<'c, H: HugrView<Node = Node>>(
776 ctx: &mut EmitFuncContext<'c, '_, H>,
777 pcg: &impl PreludeCodegen,
778 log_width: u64,
779 numerator: IntValue<'c>,
780 denominator: IntValue<'c>,
781 panic: bool,
782 signed: bool,
783) -> Result<LLVMSumValue<'c>> {
784 let int_arg_ty = int_types::INT_TYPES[log_width as usize].clone();
785 let tuple_sum_ty = HugrSumType::new_tuple(vec![int_arg_ty.clone(), int_arg_ty.clone()]);
786
787 let pair_ty = LLVMSumType::try_from_hugr_type(&ctx.typing_session(), tuple_sum_ty.clone())?;
788
789 let build_divmod = |ctx: &mut EmitFuncContext<'c, '_, H>| -> Result<BasicValueEnum<'c>> {
790 if signed {
791 let max_signed_value = u64::pow(2, u32::pow(2, log_width as u32) - 1) - 1;
792 let max_signed = numerator.get_type().const_int(max_signed_value, false);
793 let large_divisor_bool = ctx.builder().build_int_compare(
797 IntPredicate::UGT,
798 denominator,
799 max_signed,
800 "is_divisor_large",
801 )?;
802 let large_divisor =
803 ctx.builder()
804 .build_int_z_extend(large_divisor_bool, denominator.get_type(), "")?;
805 let negative_numerator_bool = ctx.builder().build_int_compare(
806 IntPredicate::SLT,
807 numerator,
808 numerator.get_type().const_zero(),
809 "is_dividend_negative",
810 )?;
811 let negative_numerator = ctx.builder().build_int_z_extend(
812 negative_numerator_bool,
813 denominator.get_type(),
814 "",
815 )?;
816 let tag = ctx.builder().build_left_shift(
817 large_divisor,
818 denominator.get_type().const_int(1, false),
819 "",
820 )?;
821
822 let tag = ctx.builder().build_or(tag, negative_numerator, "tag")?;
823
824 let quot = ctx
825 .builder()
826 .build_int_signed_div(numerator, denominator, "quotient")?;
827 let rem = ctx
828 .builder()
829 .build_int_signed_rem(numerator, denominator, "remainder")?;
830
831 let result_ptr = ctx.builder().build_alloca(pair_ty.clone(), "result")?;
832
833 let finish = ctx.new_basic_block("finish", None);
834 let negative_bigdiv = ctx.new_basic_block("negative_bigdiv", Some(finish));
835 let negative_smoldiv = ctx.new_basic_block("negative_smoldiv", Some(finish));
836 let non_negative_bigdiv = ctx.new_basic_block("non_negative_bigdiv", Some(finish));
837 let non_negative_smoldiv = ctx.new_basic_block("non_negative_smoldiv", Some(finish));
838
839 ctx.builder().build_switch(
840 tag,
841 non_negative_smoldiv,
842 &[
843 (denominator.get_type().const_int(1, false), negative_smoldiv),
844 (
845 denominator.get_type().const_int(2, false),
846 non_negative_bigdiv,
847 ),
848 (denominator.get_type().const_int(3, false), negative_bigdiv),
849 ],
850 )?;
851
852 let build_and_store_result =
853 |ctx: &mut EmitFuncContext<'c, '_, H>, vs: Vec<BasicValueEnum<'c>>| -> Result<()> {
854 let result = pair_ty
855 .build_tag(ctx.builder(), 0, vs)?
856 .as_basic_value_enum();
858 ctx.builder().build_store(result_ptr, result)?;
859 ctx.builder().build_unconditional_branch(finish)?;
860 Ok(())
861 };
862
863 ctx.builder().position_at_end(non_negative_smoldiv);
867 build_and_store_result(
868 ctx,
869 vec![quot.as_basic_value_enum(), rem.as_basic_value_enum()],
870 )?;
871
872 ctx.builder().position_at_end(negative_smoldiv);
875 {
876 let if_rem_zero = pair_ty
878 .build_tag(
879 ctx.builder(),
880 0,
881 vec![
882 quot.as_basic_value_enum(),
883 rem.get_type().const_zero().as_basic_value_enum(),
884 ],
885 )?
886 .as_basic_value_enum();
887
888 let if_rem_nonzero = pair_ty
890 .build_tag(
891 ctx.builder(),
892 0,
893 vec![
894 ctx.builder()
895 .build_int_sub(quot, quot.get_type().const_int(1, true), "")?
896 .as_basic_value_enum(),
897 ctx.builder()
898 .build_int_add(denominator, rem, "")?
899 .as_basic_value_enum(),
900 ],
901 )?
902 .as_basic_value_enum();
903
904 let is_rem_zero = ctx.builder().build_int_compare(
905 IntPredicate::EQ,
906 rem,
907 rem.get_type().const_zero(),
908 "is_rem_0",
909 )?;
910 let result =
911 ctx.builder()
912 .build_select(is_rem_zero, if_rem_zero, if_rem_nonzero, "")?;
913 ctx.builder().build_store(result_ptr, result)?;
914 ctx.builder().build_unconditional_branch(finish)?;
915 }
916
917 ctx.builder().position_at_end(non_negative_bigdiv);
920 build_and_store_result(
921 ctx,
922 vec![
923 numerator.get_type().const_zero().as_basic_value_enum(),
924 numerator.as_basic_value_enum(),
925 ],
926 )?;
927
928 ctx.builder().position_at_end(negative_bigdiv);
932 build_and_store_result(
933 ctx,
934 vec![
935 numerator.get_type().const_all_ones().as_basic_value_enum(),
936 ctx.builder()
937 .build_int_add(numerator, denominator, "")?
938 .as_basic_value_enum(),
939 ],
940 )?;
941
942 ctx.builder().position_at_end(finish);
943 let result = ctx.builder().build_load(result_ptr, "result")?;
944 Ok(result)
945 } else {
946 let quot = ctx
947 .builder()
948 .build_int_unsigned_div(numerator, denominator, "quotient")?;
949 let rem = ctx
950 .builder()
951 .build_int_unsigned_rem(numerator, denominator, "remainder")?;
952 Ok(pair_ty
953 .build_tag(
954 ctx.builder(),
955 0,
956 vec![quot.as_basic_value_enum(), rem.as_basic_value_enum()],
957 )?
958 .as_basic_value_enum())
959 }
960 };
961
962 let int_ty = numerator.get_type();
963 let zero = int_ty.const_zero();
964 let lower_bounds_check =
965 ctx.builder()
966 .build_int_compare(IntPredicate::NE, denominator, zero, "valid_div")?;
967
968 let sum_ty = LLVMSumType::try_from_hugr_type(
969 &ctx.typing_session(),
970 sum_with_error(vec![Type::from(tuple_sum_ty)]),
971 )?;
972
973 if panic {
974 LLVMSumValue::try_new(
975 val_or_panic(ctx, pcg, lower_bounds_check, &ERR_DIV_0, |ctx| {
976 build_divmod(ctx)
977 })?,
978 pair_ty,
979 )
980 } else {
981 let result = build_divmod(ctx)?;
982 LLVMSumValue::try_new(
983 val_or_error(ctx, lower_bounds_check, result, &ERR_DIV_0, sum_ty.clone())?,
984 sum_ty,
985 )
986 }
987}
988
989fn make_narrow<'c, H: HugrView<Node = Node>>(
990 ctx: &mut EmitFuncContext<'c, '_, H>,
991 arg: BasicValueEnum<'c>,
992 outs: &[BasicTypeEnum<'c>],
993 out_log_width: u64,
994 signed: bool,
995 sum_type: HugrSumType,
996) -> Result<BasicValueEnum<'c>> {
997 let [out] = TryInto::<[BasicTypeEnum; 1]>::try_into(outs)?;
998 let width = 1 << out_log_width;
999 let arg_int_ty: IntType = arg.get_type().into_int_type();
1000 let (int_min_value_s, int_max_value_s, int_max_value_u) = int_type_bounds(width);
1001 let out_int_ty = out
1002 .into_struct_type()
1003 .get_field_type_at_index(2)
1004 .unwrap()
1005 .into_int_type();
1006 let outside_range = if signed {
1007 let too_big = ctx.builder().build_int_compare(
1008 IntPredicate::SGT,
1009 arg.into_int_value(),
1010 arg_int_ty.const_int(int_max_value_s as u64, true),
1011 "upper_bounds_check",
1012 )?;
1013 let too_small = ctx.builder().build_int_compare(
1014 IntPredicate::SLT,
1015 arg.into_int_value(),
1016 arg_int_ty.const_int(int_min_value_s as u64, true),
1017 "lower_bounds_check",
1018 )?;
1019 ctx.builder()
1020 .build_or(too_big, too_small, "outside_range")?
1021 } else {
1022 ctx.builder().build_int_compare(
1023 IntPredicate::UGT,
1024 arg.into_int_value(),
1025 arg_int_ty.const_int(int_max_value_u, false),
1026 "upper_bounds_check",
1027 )?
1028 };
1029
1030 let inbounds = ctx.builder().build_not(outside_range, "inbounds")?;
1031 let narrowed_val = ctx
1032 .builder()
1033 .build_int_cast_sign_flag(arg.into_int_value(), out_int_ty, signed, "")?
1034 .as_basic_value_enum();
1035 val_or_error(
1036 ctx,
1037 inbounds,
1038 narrowed_val,
1039 &ERR_NARROW,
1040 LLVMSumType::try_from_hugr_type(&ctx.typing_session(), sum_type).unwrap(),
1041 )
1042}
1043
1044fn val_or_panic<'c, H: HugrView<Node = Node>>(
1045 ctx: &mut EmitFuncContext<'c, '_, H>,
1046 pcg: &impl PreludeCodegen,
1047 dont_panic: IntValue<'c>,
1048 err: &ConstError,
1049 go: impl Fn(&mut EmitFuncContext<'c, '_, H>) -> Result<BasicValueEnum<'c>>,
1051) -> Result<BasicValueEnum<'c>> {
1052 let exit_bb = ctx.new_basic_block("exit", None);
1053 let go_bb = ctx.new_basic_block("panic_if_0", Some(exit_bb));
1054 let panic_bb = ctx.new_basic_block("panic", Some(exit_bb));
1055 ctx.builder().build_unconditional_branch(go_bb)?;
1056
1057 ctx.builder().position_at_end(panic_bb);
1058 let err = ctx.emit_custom_const(err)?;
1059 pcg.emit_panic(ctx, err)?;
1060 ctx.builder().build_unconditional_branch(exit_bb)?;
1061
1062 ctx.builder().position_at_end(go_bb);
1063 ctx.builder().build_switch(
1064 dont_panic,
1065 panic_bb,
1066 &[(dont_panic.get_type().const_int(1, false), exit_bb)],
1067 )?;
1068
1069 ctx.builder().position_at_end(exit_bb);
1070
1071 go(ctx)
1072}
1073
1074fn val_or_error<'c, H: HugrView<Node = Node>>(
1075 ctx: &mut EmitFuncContext<'c, '_, H>,
1076 should_succeed: IntValue<'c>,
1077 val: BasicValueEnum<'c>,
1078 err: &ConstError,
1079 ty: LLVMSumType<'c>,
1080) -> Result<BasicValueEnum<'c>> {
1081 let err_val = ctx.emit_custom_const(err)?;
1082
1083 let err_variant = ty.build_tag(ctx.builder(), 0, vec![err_val])?;
1084 let ok_variant = ty.build_tag(ctx.builder(), 1, vec![val])?;
1085
1086 Ok(ctx
1087 .builder()
1088 .build_select(should_succeed, ok_variant, err_variant, "")?)
1089}
1090
1091fn llvm_type<'c>(
1092 context: TypingSession<'c, '_>,
1093 hugr_type: &CustomType,
1094) -> Result<BasicTypeEnum<'c>> {
1095 if let [TypeArg::BoundedNat(n)] = hugr_type.args() {
1096 let m = *n as usize;
1097 if m < int_types::INT_TYPES.len() && int_types::INT_TYPES[m] == hugr_type.clone().into() {
1098 return Ok(match m {
1099 0..=3 => context.iw_context().i8_type(),
1100 4 => context.iw_context().i16_type(),
1101 5 => context.iw_context().i32_type(),
1102 6 => context.iw_context().i64_type(),
1103 _ => Err(anyhow!(
1104 "IntTypesCodegenExtension: unsupported log_width: {}",
1105 m
1106 ))?,
1107 }
1108 .into());
1109 }
1110 }
1111 Err(anyhow!(
1112 "IntTypesCodegenExtension: unsupported type: {}",
1113 hugr_type
1114 ))
1115}
1116
1117fn emit_const_int<'c, H: HugrView<Node = Node>>(
1118 context: &mut EmitFuncContext<'c, '_, H>,
1119 k: &ConstInt,
1120) -> Result<BasicValueEnum<'c>> {
1121 let ty: IntType = context.llvm_type(&k.get_type())?.try_into().unwrap();
1122 Ok(ty.const_int(k.value_u(), false).as_basic_value_enum())
1126}
1127
1128impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
1129 #[must_use]
1134 pub fn add_default_int_extensions(self) -> Self {
1135 self.add_extension(IntCodegenExtension::new(DefaultPreludeCodegen))
1136 }
1137}
1138
1139#[cfg(test)]
1140mod test {
1141 use anyhow::Result;
1142 use hugr_core::builder::DataflowHugr;
1143 use hugr_core::extension::prelude::{ConstError, UnwrapBuilder, error_type};
1144 use hugr_core::std_extensions::STD_REG;
1145 use hugr_core::{
1146 Hugr,
1147 builder::{Dataflow, DataflowSubContainer, SubContainer, handle::Outputs},
1148 extension::prelude::bool_t,
1149 ops::{DataflowOpTrait, ExtensionOp},
1150 std_extensions::arithmetic::{
1151 int_ops::{self, IntOpDef},
1152 int_types::{ConstInt, INT_TYPES},
1153 },
1154 types::{SumType, Type, TypeRow},
1155 };
1156 use rstest::rstest;
1157
1158 use crate::{
1159 check_emission,
1160 emit::test::{DFGW, SimpleHugrConfig},
1161 test::{TestContext, exec_ctx, llvm_ctx, single_op_hugr},
1162 };
1163
1164 #[rstest::fixture]
1165 fn int_exec_ctx(mut exec_ctx: TestContext) -> TestContext {
1166 exec_ctx.add_extensions(|cem| {
1167 cem.add_default_int_extensions()
1168 .add_default_prelude_extensions()
1169 });
1170 exec_ctx
1171 }
1172
1173 #[rstest::fixture]
1174 fn int_llvm_ctx(mut llvm_ctx: TestContext) -> TestContext {
1175 llvm_ctx.add_extensions(|cem| {
1176 cem.add_default_int_extensions()
1177 .add_default_prelude_extensions()
1178 });
1179 llvm_ctx
1180 }
1181
1182 fn make_int_op(name: impl AsRef<str>, log_width: u8) -> ExtensionOp {
1184 int_ops::EXTENSION
1185 .instantiate_extension_op(name.as_ref(), [u64::from(log_width).into()])
1186 .unwrap()
1187 }
1188
1189 fn test_binary_int_op(ext_op: ExtensionOp, log_width: u8) -> Hugr {
1190 let ty = &INT_TYPES[log_width as usize];
1191 test_int_op_with_results::<2>(ext_op, log_width, None, ty.clone())
1192 }
1193
1194 fn test_binary_icmp_op(ext_op: ExtensionOp, log_width: u8) -> Hugr {
1195 test_int_op_with_results::<2>(ext_op, log_width, None, bool_t())
1196 }
1197
1198 fn test_int_op_with_results<const N: usize>(
1199 ext_op: ExtensionOp,
1200 log_width: u8,
1201 inputs: Option<[ConstInt; N]>,
1202 output_type: Type,
1203 ) -> Hugr {
1204 test_int_op_with_results_processing(ext_op, log_width, inputs, output_type, |_, a| Ok(a))
1205 }
1206
1207 fn test_int_op_with_results_processing<const N: usize>(
1208 ext_op: ExtensionOp,
1210 log_width: u8,
1211 inputs: Option<[ConstInt; N]>, output_type: Type,
1213 process: impl Fn(&mut DFGW, Outputs) -> Result<Outputs>,
1214 ) -> Hugr {
1215 let ty = &INT_TYPES[log_width as usize];
1216 let input_tys = if inputs.is_some() {
1217 vec![]
1218 } else {
1219 let input_tys = itertools::repeat_n(ty.clone(), N).collect();
1220 assert_eq!(input_tys, ext_op.signature().input.to_vec());
1221 input_tys
1222 };
1223 SimpleHugrConfig::new()
1224 .with_ins(input_tys)
1225 .with_outs(vec![output_type])
1226 .with_extensions(STD_REG.clone())
1227 .finish(|mut hugr_builder| {
1228 let input_wires = match inputs {
1229 None => hugr_builder.input_wires_arr::<N>().to_vec(),
1230 Some(inputs) => {
1231 let mut input_wires = Vec::new();
1232 for i in inputs.into_iter() {
1233 let w = hugr_builder.add_load_value(i);
1234 input_wires.push(w);
1235 }
1236 input_wires
1237 }
1238 };
1239 let outputs = hugr_builder
1240 .add_dataflow_op(ext_op, input_wires)
1241 .unwrap()
1242 .outputs();
1243 let processed_outputs = process(&mut hugr_builder, outputs).unwrap();
1244 hugr_builder
1245 .finish_hugr_with_outputs(processed_outputs)
1246 .unwrap()
1247 })
1248 }
1249
1250 #[rstest]
1251 #[case(IntOpDef::iu_to_s, &[3])]
1252 #[case(IntOpDef::is_to_u, &[3])]
1253 #[case(IntOpDef::ineg, &[2])]
1254 #[case::idiv_checked_u("idiv_checked_u", &[3])]
1255 #[case::idiv_checked_s("idiv_checked_s", &[3])]
1256 #[case::imod_checked_u("imod_checked_u", &[6])]
1257 #[case::imod_checked_s("imod_checked_s", &[6])]
1258 #[case::idivmod_u("idivmod_u", &[3])]
1259 #[case::idivmod_s("idivmod_s", &[3])]
1260 #[case::idivmod_checked_u("idivmod_checked_u", &[6])]
1261 #[case::idivmod_checked_s("idivmod_checked_s", &[6])]
1262 fn test_emission(int_llvm_ctx: TestContext, #[case] op: IntOpDef, #[case] args: &[u8]) {
1263 use hugr_core::extension::simple_op::MakeExtensionOp as _;
1264
1265 let mut insta = insta::Settings::clone_current();
1266 insta.set_snapshot_suffix(format!(
1267 "{}_{}_{:?}",
1268 insta.snapshot_suffix().unwrap_or(""),
1269 op.op_id(),
1270 args,
1271 ));
1272 let concrete = match *args {
1273 [] => op.without_log_width(),
1274 [log_width] => op.with_log_width(log_width),
1275 [lw1, lw2] => op.with_two_log_widths(lw1, lw2),
1276 _ => panic!("unexpected number of args to the op!"),
1277 };
1278 insta.bind(|| {
1279 let hugr = single_op_hugr(concrete.into());
1280 check_emission!(hugr, int_llvm_ctx);
1281 });
1282 }
1283
1284 #[rstest]
1285 #[case::iadd("iadd", 3)]
1286 #[case::isub("isub", 6)]
1287 #[case::ipow("ipow", 3)]
1288 #[case::idiv_u("idiv_u", 3)]
1289 #[case::idiv_s("idiv_s", 3)]
1290 #[case::imod_u("imod_u", 3)]
1291 #[case::imod_s("imod_s", 3)]
1292 fn test_binop_emission(int_llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
1293 let ext_op = make_int_op(op.clone(), width);
1294 let hugr = test_binary_int_op(ext_op, width);
1295 check_emission!(op.clone(), hugr, int_llvm_ctx);
1296 }
1297
1298 #[rstest]
1299 #[case::signed_2_3("iwiden_s", 2, 3)]
1300 #[case::signed_1_6("iwiden_s", 1, 6)]
1301 #[case::unsigned_2_3("iwiden_u", 2, 3)]
1302 #[case::unsigned_1_6("iwiden_u", 1, 6)]
1303 fn test_widen_emission(
1304 int_llvm_ctx: TestContext,
1305 #[case] op: String,
1306 #[case] from: u8,
1307 #[case] to: u8,
1308 ) {
1309 let out_ty = INT_TYPES[to as usize].clone();
1310 let ext_op = int_ops::EXTENSION
1311 .instantiate_extension_op(&op, [u64::from(from).into(), u64::from(to).into()])
1312 .unwrap();
1313 let hugr = test_int_op_with_results::<1>(ext_op, from, None, out_ty);
1314
1315 check_emission!(
1316 format!("{}_{}_{}", op.clone(), from, to),
1317 hugr,
1318 int_llvm_ctx
1319 );
1320 }
1321
1322 #[rstest]
1323 #[case::signed("inarrow_s", 3, 2)]
1324 #[case::unsigned("inarrow_u", 6, 4)]
1325 fn test_narrow_emission(
1326 int_llvm_ctx: TestContext,
1327 #[case] op: String,
1328 #[case] from: u8,
1329 #[case] to: u8,
1330 ) {
1331 let out_ty = SumType::new([vec![error_type()], vec![INT_TYPES[to as usize].clone()]]);
1332 let ext_op = int_ops::EXTENSION
1333 .instantiate_extension_op(&op, [u64::from(from).into(), u64::from(to).into()])
1334 .unwrap();
1335 let hugr = test_int_op_with_results::<1>(ext_op, from, None, out_ty.into());
1336
1337 check_emission!(
1338 format!("{}_{}_{}", op.clone(), from, to),
1339 hugr,
1340 int_llvm_ctx
1341 );
1342 }
1343
1344 #[rstest]
1345 #[case::ieq("ieq", 1)]
1346 #[case::ilt_s("ilt_s", 0)]
1347 fn test_cmp_emission(int_llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
1348 let ext_op = make_int_op(op.clone(), width);
1349 let hugr = test_binary_icmp_op(ext_op, width);
1350 check_emission!(op.clone(), hugr, int_llvm_ctx);
1351 }
1352
1353 #[rstest]
1354 #[case::imax("imax_u", 1, 2, 2)]
1355 #[case::imax("imax_u", 2, 1, 2)]
1356 #[case::imax("imax_u", 2, 2, 2)]
1357 #[case::imin("imin_u", 1, 2, 1)]
1358 #[case::imin("imin_u", 2, 1, 1)]
1359 #[case::imin("imin_u", 2, 2, 2)]
1360 #[case::ishl("ishl", 73, 1, 146)]
1361 #[case::ishl("ishl", 18446744073709551615, 1, 18446744073709551614)]
1363 #[case::ishr("ishr", 73, 1, 36)]
1364 #[case::ior("ior", 6, 9, 15)]
1365 #[case::ior("ior", 6, 15, 15)]
1366 #[case::ixor("ixor", 6, 9, 15)]
1367 #[case::ixor("ixor", 6, 15, 9)]
1368 #[case::ixor("ixor", 15, 6, 9)]
1369 #[case::iand("iand", 6, 15, 6)]
1370 #[case::iand("iand", 15, 6, 6)]
1371 #[case::iand("iand", 15, 15, 15)]
1372 #[case::ipow("ipow", 2, 3, 8)]
1373 #[case::ipow("ipow", 42, 1, 42)]
1374 #[case::ipow("ipow", 42, 0, 1)]
1375 #[case::idiv("idiv_u", 42, 2, 21)]
1376 #[case::idiv("idiv_u", 42, 5, 8)]
1377 #[case::imod("imod_u", 42, 2, 0)]
1378 #[case::imod("imod_u", 42, 5, 2)]
1379 fn test_exec_unsigned_bin_op(
1380 int_exec_ctx: TestContext,
1381 #[case] op: String,
1382 #[case] lhs: u64,
1383 #[case] rhs: u64,
1384 #[case] result: u64,
1385 ) {
1386 let ty = &INT_TYPES[6].clone();
1387 let inputs = [
1388 ConstInt::new_u(6, lhs).unwrap(),
1389 ConstInt::new_u(6, rhs).unwrap(),
1390 ];
1391 let ext_op = make_int_op(&op, 6);
1392
1393 let hugr = test_int_op_with_results::<2>(ext_op, 6, Some(inputs), ty.clone());
1394 assert_eq!(int_exec_ctx.exec_hugr_u64(hugr, "main"), result);
1395 }
1396
1397 #[rstest]
1398 #[case::imax("imax_s", 1, 2, 2)]
1399 #[case::imax("imax_s", 2, 1, 2)]
1400 #[case::imax("imax_s", 2, 2, 2)]
1401 #[case::imax("imax_s", -1, -2, -1)]
1402 #[case::imax("imax_s", -2, -1, -1)]
1403 #[case::imax("imax_s", -2, -2, -2)]
1404 #[case::imin("imin_s", 1, 2, 1)]
1405 #[case::imin("imin_s", 2, 1, 1)]
1406 #[case::imin("imin_s", 2, 2, 2)]
1407 #[case::imin("imin_s", -1, -2, -2)]
1408 #[case::imin("imin_s", -2, -1, -2)]
1409 #[case::imin("imin_s", -2, -2, -2)]
1410 #[case::ipow("ipow", -2, 1, -2)]
1411 #[case::ipow("ipow", -2, 2, 4)]
1412 #[case::ipow("ipow", -2, 3, -8)]
1413 fn test_exec_signed_bin_op(
1414 int_exec_ctx: TestContext,
1415 #[case] op: String,
1416 #[case] lhs: i64,
1417 #[case] rhs: i64,
1418 #[case] result: i64,
1419 ) {
1420 let ty = &INT_TYPES[6].clone();
1421 let inputs = [
1422 ConstInt::new_s(6, lhs).unwrap(),
1423 ConstInt::new_s(6, rhs).unwrap(),
1424 ];
1425 let ext_op = make_int_op(&op, 6);
1426
1427 let hugr = test_int_op_with_results::<2>(ext_op, 6, Some(inputs), ty.clone());
1428 assert_eq!(int_exec_ctx.exec_hugr_i64(hugr, "main"), result);
1429 }
1430
1431 #[rstest]
1432 #[case::iabs("iabs", 42, 42)]
1433 #[case::iabs("iabs", -42, 42)]
1434 fn test_exec_signed_unary_op(
1435 int_exec_ctx: TestContext,
1436 #[case] op: String,
1437 #[case] arg: i64,
1438 #[case] result: i64,
1439 ) {
1440 let input = ConstInt::new_s(6, arg).unwrap();
1441 let ty = INT_TYPES[6].clone();
1442 let ext_op = make_int_op(&op, 6);
1443
1444 let hugr = test_int_op_with_results::<1>(ext_op, 6, Some([input]), ty.clone());
1445 assert_eq!(int_exec_ctx.exec_hugr_i64(hugr, "main"), result);
1446 }
1447
1448 #[rstest]
1449 #[case::inot("inot", 9223372036854775808, !9223372036854775808u64)]
1450 #[case::inot("inot", 42, !42u64)]
1451 #[case::inot("inot", !0u64, 0)]
1452 fn test_exec_unsigned_unary_op(
1453 int_exec_ctx: TestContext,
1454 #[case] op: String,
1455 #[case] arg: u64,
1456 #[case] result: u64,
1457 ) {
1458 let input = ConstInt::new_u(6, arg).unwrap();
1459 let ty = INT_TYPES[6].clone();
1460 let ext_op = make_int_op(&op, 6);
1461
1462 let hugr = test_int_op_with_results::<1>(ext_op, 6, Some([input]), ty.clone());
1463 assert_eq!(int_exec_ctx.exec_hugr_u64(hugr, "main"), result);
1464 }
1465
1466 #[rstest]
1467 #[case(-127)]
1468 #[case(-1)]
1469 #[case(0)]
1470 #[case(1)]
1471 #[case(127)]
1472 fn test_exec_widen(int_exec_ctx: TestContext, #[case] num: i16) {
1473 let from: u8 = 3;
1474 let to: u8 = 6;
1475 let ty = INT_TYPES[to as usize].clone();
1476
1477 if num >= 0 {
1478 let input = ConstInt::new_u(from, num as u64).unwrap();
1479
1480 let ext_op = int_ops::EXTENSION
1481 .instantiate_extension_op(
1482 "iwiden_u".as_ref(),
1483 [(from as u64).into(), (to as u64).into()],
1484 )
1485 .unwrap();
1486
1487 let hugr = test_int_op_with_results::<1>(ext_op, to, Some([input]), ty.clone());
1488
1489 assert_eq!(int_exec_ctx.exec_hugr_u64(hugr, "main"), num as u64);
1490 }
1491
1492 let input = ConstInt::new_s(from, num as i64).unwrap();
1493
1494 let ext_op = int_ops::EXTENSION
1495 .instantiate_extension_op(
1496 "iwiden_s".as_ref(),
1497 [(from as u64).into(), (to as u64).into()],
1498 )
1499 .unwrap();
1500
1501 let hugr = test_int_op_with_results::<1>(ext_op, to, Some([input]), ty.clone());
1502
1503 assert_eq!(int_exec_ctx.exec_hugr_i64(hugr, "main"), num as i64);
1504 }
1505
1506 #[rstest]
1507 #[case("inarrow_s", 6, 2, 4)]
1508 #[case("inarrow_s", 6, 5, (1 << 5) - 1)]
1509 #[case("inarrow_s", 6, 4, -1)]
1510 #[case("inarrow_s", 6, 4, -(1 << 4) - 1)]
1511 #[case("inarrow_s", 6, 4, -(1 <<15))]
1512 #[case("inarrow_s", 6, 5, (1 << 31) - 1)]
1513 fn test_narrow_s(
1514 int_exec_ctx: TestContext,
1515 #[case] op: String,
1516 #[case] from: u8,
1517 #[case] to: u8,
1518 #[case] arg: i64,
1519 ) {
1520 let input = ConstInt::new_s(from, arg).unwrap();
1521 let to_ty = INT_TYPES[to as usize].clone();
1522 let ext_op = int_ops::EXTENSION
1523 .instantiate_extension_op(op.as_ref(), [u64::from(from).into(), u64::from(to).into()])
1524 .unwrap();
1525
1526 let hugr = test_int_op_with_results_processing::<1>(
1527 ext_op,
1528 to,
1529 Some([input]),
1530 to_ty.clone(),
1531 |builder, outs| {
1532 let [out] = outs.to_array();
1533
1534 let err_row = TypeRow::from(vec![error_type()]);
1535 let ty_row = TypeRow::from(vec![to_ty.clone()]);
1536 let mut cond_b = builder.conditional_builder(
1544 ([err_row, ty_row], out),
1545 [],
1546 vec![to_ty.clone()].into(),
1547 )?;
1548 let mut sad_b = cond_b.case_builder(0)?;
1549 let err = ConstError::new(2, "This shouldn't happen");
1550 let w = sad_b.add_load_value(ConstInt::new_s(to, 0)?);
1551 sad_b.add_panic(err, vec![to_ty.clone()], [(w, to_ty.clone())])?;
1552 sad_b.finish_with_outputs([w])?;
1553
1554 let happy_b = cond_b.case_builder(1)?;
1555 let [w] = happy_b.input_wires_arr();
1556 happy_b.finish_with_outputs([w])?;
1557
1558 let handle = cond_b.finish_sub_container()?;
1559 Ok(handle.outputs())
1560 },
1561 );
1562 assert_eq!(int_exec_ctx.exec_hugr_i64(hugr, "main"), arg);
1563 }
1564
1565 #[rstest]
1566 #[case(6, 42)]
1567 #[case(4, 7)]
1568 fn test_u_to_s(int_exec_ctx: TestContext, #[case] log_width: u8, #[case] val: u64) {
1570 let ty = &INT_TYPES[log_width as usize].clone();
1571 let hugr = SimpleHugrConfig::new()
1572 .with_outs(vec![ty.clone()])
1573 .with_extensions(STD_REG.clone())
1574 .finish(|mut hugr_builder| {
1575 let unsigned =
1576 hugr_builder.add_load_value(ConstInt::new_u(log_width, val).unwrap());
1577 let iu_to_s = make_int_op("iu_to_s", log_width);
1578 let [signed] = hugr_builder
1579 .add_dataflow_op(iu_to_s, [unsigned])
1580 .unwrap()
1581 .outputs_arr();
1582 hugr_builder.finish_hugr_with_outputs([signed]).unwrap()
1583 });
1584 let act = int_exec_ctx.exec_hugr_i64(hugr, "main");
1585 assert_eq!(act, val as i64);
1586 }
1587
1588 #[rstest]
1589 #[case(3, 0)]
1590 #[case(4, 255)]
1591 fn test_s_to_u(int_exec_ctx: TestContext, #[case] log_width: u8, #[case] val: i64) {
1593 let ty = &INT_TYPES[log_width as usize].clone();
1594 let hugr = SimpleHugrConfig::new()
1595 .with_outs(vec![ty.clone()])
1596 .with_extensions(STD_REG.clone())
1597 .finish(|mut hugr_builder| {
1598 let signed = hugr_builder.add_load_value(ConstInt::new_s(log_width, val).unwrap());
1599 let is_to_u = make_int_op("is_to_u", log_width);
1600 let [unsigned] = hugr_builder
1601 .add_dataflow_op(is_to_u, [signed])
1602 .unwrap()
1603 .outputs_arr();
1604 let num = hugr_builder.add_load_value(ConstInt::new_u(log_width, 42).unwrap());
1605 let [res] = hugr_builder
1606 .add_dataflow_op(make_int_op("iadd", log_width), [unsigned, num])
1607 .unwrap()
1608 .outputs_arr();
1609 hugr_builder.finish_hugr_with_outputs([res]).unwrap()
1610 });
1611 let act = int_exec_ctx.exec_hugr_u64(hugr, "main");
1612 assert_eq!(act, (val as u64) + 42);
1613 }
1614
1615 #[rstest]
1617 #[case::bigdiv_non_negative(127, 255, (0, 127))] #[case::bigdiv_negative(-42, 255, (-1, 213))] #[case::smoldiv_non_negative(42, 10, (4, 2))] #[case::smoldiv_negative_rem0(-42, 21, (-2, 0))] #[case::smoldiv_negative_rem_nonzero(-42, 10, (-5, 8))] fn test_divmod_s(
1623 int_exec_ctx: TestContext,
1624 #[case] dividend: i64,
1625 #[case] divisor: u64,
1626 #[case] expected_result: (i64, u64),
1627 ) {
1628 let int_ty = INT_TYPES[3].clone();
1629 let k_dividend = ConstInt::new_s(3, dividend).unwrap();
1630 let k_divisor = ConstInt::new_u(3, divisor).unwrap();
1631 let quot_hugr = test_int_op_with_results(
1632 make_int_op("idiv_s", 3),
1633 3,
1634 Some([k_dividend.clone(), k_divisor.clone()]),
1635 int_ty.clone(),
1636 );
1637 let rem_hugr = test_int_op_with_results(
1638 make_int_op("imod_s", 3),
1639 3,
1640 Some([k_dividend, k_divisor]),
1641 int_ty,
1642 );
1643 let quot = int_exec_ctx.exec_hugr_i64(quot_hugr, "main");
1644 let rem = int_exec_ctx.exec_hugr_u64(rem_hugr, "main");
1645 assert_eq!((quot, rem), expected_result);
1646 }
1647}