1use hugr_core::{
2 HugrView, Node,
3 extension::{
4 prelude::{ConstError, sum_with_error},
5 simple_op::MakeExtensionOp,
6 },
7 ops::{ExtensionOp, Value, constant::CustomConst},
8 std_extensions::arithmetic::{
9 int_ops::IntOpDef,
10 int_types::{self, ConstInt},
11 },
12 types::{CustomType, Type, TypeArg},
13};
14use inkwell::{
15 IntPredicate,
16 types::{BasicType, BasicTypeEnum, IntType},
17 values::{BasicValue, BasicValueEnum, IntValue},
18};
19use lazy_static::lazy_static;
20
21use crate::{
22 CodegenExtension,
23 custom::CodegenExtsBuilder,
24 emit::{
25 EmitOpArgs, emit_value,
26 func::EmitFuncContext,
27 get_intrinsic,
28 ops::{emit_custom_binary_op, emit_custom_unary_op},
29 },
30 sum::{LLVMSumType, LLVMSumValue},
31 types::{HugrSumType, TypingSession},
32};
33
34use anyhow::{Result, anyhow, bail};
35
36use super::{DefaultPreludeCodegen, PreludeCodegen, conversions::int_type_bounds};
37
38#[derive(Clone, Debug, Default)]
39pub struct IntCodegenExtension<PCG>(PCG);
40
41impl<PCG: PreludeCodegen> IntCodegenExtension<PCG> {
42 pub fn new(ccg: PCG) -> Self {
43 Self(ccg)
44 }
45}
46
47impl<CCG: PreludeCodegen> From<CCG> for IntCodegenExtension<CCG> {
48 fn from(ccg: CCG) -> Self {
49 Self::new(ccg)
50 }
51}
52
53impl<CCG: PreludeCodegen> CodegenExtension for IntCodegenExtension<CCG> {
54 fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
55 self,
56 builder: CodegenExtsBuilder<'a, H>,
57 ) -> CodegenExtsBuilder<'a, H>
58 where
59 Self: 'a,
60 {
61 builder
62 .custom_const(emit_const_int)
63 .custom_type((int_types::EXTENSION_ID, "int".into()), llvm_type)
64 .simple_extension_op::<IntOpDef>(move |context, args, op| {
65 emit_int_op(context, &self.0, args, op)
66 })
67 }
68}
69
70lazy_static! {
71 static ref ERR_NARROW: ConstError = ConstError {
72 signal: 2,
73 message: "Can't narrow into bounds".to_string(),
74 };
75 static ref ERR_IU_TO_S: ConstError = ConstError {
76 signal: 2,
77 message: "iu_to_s argument out of bounds".to_string(),
78 };
79 static ref ERR_IS_TO_U: ConstError = ConstError {
80 signal: 2,
81 message: "is_to_u called on negative value".to_string(),
82 };
83 static ref ERR_DIV_0: ConstError = ConstError {
84 signal: 2,
85 message: "Attempted division by 0".to_string(),
86 };
87}
88
89#[derive(Debug, Eq, PartialEq)]
90enum DivOrMod {
91 Div,
92 Mod,
93 DivMod,
94}
95
96struct DivModOp {
97 op: DivOrMod,
98 signed: bool,
99 panic: bool,
100}
101
102impl DivModOp {
103 fn emit<'c, H: HugrView<Node = Node>>(
104 self,
105 ctx: &mut EmitFuncContext<'c, '_, H>,
106 pcg: &impl PreludeCodegen,
107 log_width: u64,
108 numerator: IntValue<'c>,
109 denominator: IntValue<'c>,
110 ) -> Result<Vec<BasicValueEnum<'c>>> {
111 let quotrem = make_divmod(
114 ctx,
115 pcg,
116 log_width,
117 numerator,
118 denominator,
119 self.panic,
120 self.signed,
121 )?;
122
123 if self.op == DivOrMod::DivMod {
124 if self.panic {
125 Ok(quotrem.build_untag(ctx.builder(), 0).unwrap())
127 } else {
128 Ok(vec![quotrem.as_basic_value_enum()])
129 }
130 } else {
131 let index = match self.op {
133 DivOrMod::Div => 0,
134 DivOrMod::Mod => 1,
135 _ => unreachable!(),
136 };
137 if self.panic {
139 Ok(vec![
140 quotrem
141 .build_untag(ctx.builder(), 0)?
142 .into_iter()
143 .nth(index)
144 .unwrap(),
145 ])
146 }
147 else {
150 let int_ty = numerator.get_type().as_basic_type_enum();
152 let tuple_ty =
153 LLVMSumType::try_new(ctx.iw_context(), vec![vec![int_ty, int_ty]]).unwrap();
154 let tuple = quotrem
155 .build_untag(ctx.builder(), 1)?
156 .into_iter()
157 .next()
158 .unwrap();
159 let tuple_val = LLVMSumValue::try_new(tuple, tuple_ty)?;
160 let data_val = tuple_val
161 .build_untag(ctx.builder(), 0)?
162 .into_iter()
163 .nth(index)
164 .unwrap();
165 let err_val = quotrem
166 .build_untag(ctx.builder(), 0)?
167 .into_iter()
168 .next()
169 .unwrap();
170
171 let tag_val = quotrem.build_get_tag(ctx.builder())?;
172 tag_val.set_name("tag");
173
174 let int_ty = int_types::INT_TYPES[log_width as usize].clone();
176 let out_ty = LLVMSumType::try_from_hugr_type(
177 &ctx.typing_session(),
178 sum_with_error(vec![int_ty.clone()]),
179 )
180 .unwrap();
181
182 let data_variant = out_ty.build_tag(ctx.builder(), 1, vec![data_val])?;
183 data_variant.set_name("data_variant");
184 let err_variant = out_ty.build_tag(ctx.builder(), 0, vec![err_val])?;
185 err_variant.set_name("err_variant");
186
187 let result = ctx
188 .builder()
189 .build_select(tag_val, data_variant, err_variant, "")?;
190 Ok(vec![result])
191 }
192 }
193 }
194}
195
196fn emit_icmp<'c, H: HugrView<Node = Node>>(
198 context: &mut EmitFuncContext<'c, '_, H>,
199 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
200 pred: inkwell::IntPredicate,
201) -> Result<()> {
202 let true_val = emit_value(context, &Value::true_val())?;
203 let false_val = emit_value(context, &Value::false_val())?;
204
205 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
206 let r = ctx.builder().build_int_compare(
208 pred,
209 lhs.into_int_value(),
210 rhs.into_int_value(),
211 "",
212 )?;
213 Ok(vec![
215 ctx.builder().build_select(r, true_val, false_val, "")?,
216 ])
217 })
218}
219
220fn emit_ipow<'c, H: HugrView<Node = Node>>(
224 context: &mut EmitFuncContext<'c, '_, H>,
225 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
226) -> Result<()> {
227 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
228 let done_bb = ctx.new_basic_block("done", None);
229 let pow_body_bb = ctx.new_basic_block("pow_body", Some(done_bb));
230 let return_one_bb = ctx.new_basic_block("power_of_zero", Some(pow_body_bb));
231 let pow_bb = ctx.new_basic_block("pow", Some(return_one_bb));
232
233 let acc_p = ctx.builder().build_alloca(lhs.get_type(), "acc_ptr")?;
234 let exp_p = ctx.builder().build_alloca(rhs.get_type(), "exp_ptr")?;
235 ctx.builder().build_store(acc_p, lhs)?;
236 ctx.builder().build_store(exp_p, rhs)?;
237 ctx.builder().build_unconditional_branch(pow_bb)?;
238
239 let zero = rhs.get_type().into_int_type().const_int(0, false);
240 let one = rhs.get_type().into_int_type().const_int(1, false);
242
243 ctx.builder().position_at_end(return_one_bb);
245 ctx.builder().build_store(acc_p, one)?;
246 ctx.builder().build_unconditional_branch(done_bb)?;
247
248 ctx.builder().position_at_end(pow_bb);
249 let acc = ctx.builder().build_load(acc_p, "acc")?;
250 let exp = ctx.builder().build_load(exp_p, "exp")?;
251
252 ctx.builder().build_switch(
254 exp.into_int_value(),
255 pow_body_bb,
256 &[(one, done_bb), (zero, return_one_bb)],
257 )?;
258
259 ctx.builder().position_at_end(pow_body_bb);
261 let new_acc =
262 ctx.builder()
263 .build_int_mul(acc.into_int_value(), lhs.into_int_value(), "new_acc")?;
264 let new_exp = ctx
265 .builder()
266 .build_int_sub(exp.into_int_value(), one, "new_exp")?;
267 ctx.builder().build_store(acc_p, new_acc)?;
268 ctx.builder().build_store(exp_p, new_exp)?;
269 ctx.builder().build_unconditional_branch(pow_bb)?;
270
271 ctx.builder().position_at_end(done_bb);
272 let result = ctx.builder().build_load(acc_p, "result")?;
273 Ok(vec![result.as_basic_value_enum()])
274 })
275}
276
277fn emit_int_op<'c, H: HugrView<Node = Node>>(
278 context: &mut EmitFuncContext<'c, '_, H>,
279 pcg: &impl PreludeCodegen,
280 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
281 op: IntOpDef,
282) -> Result<()> {
283 match op {
284 IntOpDef::iadd => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
285 Ok(vec![
286 ctx.builder()
287 .build_int_add(lhs.into_int_value(), rhs.into_int_value(), "")?
288 .as_basic_value_enum(),
289 ])
290 }),
291 IntOpDef::imul => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
292 Ok(vec![
293 ctx.builder()
294 .build_int_mul(lhs.into_int_value(), rhs.into_int_value(), "")?
295 .as_basic_value_enum(),
296 ])
297 }),
298 IntOpDef::isub => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
299 Ok(vec![
300 ctx.builder()
301 .build_int_sub(lhs.into_int_value(), rhs.into_int_value(), "")?
302 .as_basic_value_enum(),
303 ])
304 }),
305 IntOpDef::idiv_s => {
306 let log_width = get_width_arg(&args, &op)?;
307 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
308 let op = DivModOp {
309 op: DivOrMod::Div,
310 signed: true,
311 panic: true,
312 };
313 op.emit(
314 ctx,
315 pcg,
316 log_width,
317 lhs.into_int_value(),
318 rhs.into_int_value(),
319 )
320 })
321 }
322 IntOpDef::idiv_u => {
323 let log_width = get_width_arg(&args, &op)?;
324 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
325 let op = DivModOp {
326 op: DivOrMod::Div,
327 signed: false,
328 panic: true,
329 };
330 op.emit(
331 ctx,
332 pcg,
333 log_width,
334 lhs.into_int_value(),
335 rhs.into_int_value(),
336 )
337 })
338 }
339 IntOpDef::imod_s => {
340 let log_width = get_width_arg(&args, &op)?;
341 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
342 let op = DivModOp {
343 op: DivOrMod::Mod,
344 signed: true,
345 panic: true,
346 };
347 op.emit(
348 ctx,
349 pcg,
350 log_width,
351 lhs.into_int_value(),
352 rhs.into_int_value(),
353 )
354 })
355 }
356 IntOpDef::imod_u => {
357 let log_width = get_width_arg(&args, &op)?;
358 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
359 let op = DivModOp {
360 op: DivOrMod::Mod,
361 signed: false,
362 panic: true,
363 };
364 op.emit(
365 ctx,
366 pcg,
367 log_width,
368 lhs.into_int_value(),
369 rhs.into_int_value(),
370 )
371 })
372 }
373 IntOpDef::idivmod_u => {
374 let log_width = get_width_arg(&args, &op)?;
375 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
376 let op = DivModOp {
377 op: DivOrMod::DivMod,
378 signed: false,
379 panic: true,
380 };
381 op.emit(
382 ctx,
383 pcg,
384 log_width,
385 lhs.into_int_value(),
386 rhs.into_int_value(),
387 )
388 })
389 }
390 IntOpDef::idivmod_s => {
391 let log_width = get_width_arg(&args, &op)?;
392 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
393 let op = DivModOp {
394 op: DivOrMod::DivMod,
395 signed: true,
396 panic: true,
397 };
398 op.emit(
399 ctx,
400 pcg,
401 log_width,
402 lhs.into_int_value(),
403 rhs.into_int_value(),
404 )
405 })
406 }
407 IntOpDef::idiv_checked_s => {
408 let log_width = get_width_arg(&args, &op)?;
409 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
410 let op = DivModOp {
411 op: DivOrMod::Div,
412 signed: true,
413 panic: false,
414 };
415 op.emit(
416 ctx,
417 pcg,
418 log_width,
419 lhs.into_int_value(),
420 rhs.into_int_value(),
421 )
422 })
423 }
424 IntOpDef::idiv_checked_u => {
425 let log_width = get_width_arg(&args, &op)?;
426
427 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
428 let op = DivModOp {
429 op: DivOrMod::Div,
430 signed: false,
431 panic: false,
432 };
433 op.emit(
434 ctx,
435 pcg,
436 log_width,
437 lhs.into_int_value(),
438 rhs.into_int_value(),
439 )
440 })
441 }
442 IntOpDef::imod_checked_s => {
443 let log_width = get_width_arg(&args, &op)?;
444 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
445 let op = DivModOp {
446 op: DivOrMod::Mod,
447 signed: true,
448 panic: false,
449 };
450 op.emit(
451 ctx,
452 pcg,
453 log_width,
454 lhs.into_int_value(),
455 rhs.into_int_value(),
456 )
457 })
458 }
459 IntOpDef::imod_checked_u => {
460 let log_width = get_width_arg(&args, &op)?;
461 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
462 let op = DivModOp {
463 op: DivOrMod::Mod,
464 signed: false,
465 panic: false,
466 };
467 op.emit(
468 ctx,
469 pcg,
470 log_width,
471 lhs.into_int_value(),
472 rhs.into_int_value(),
473 )
474 })
475 }
476 IntOpDef::idivmod_checked_u => {
477 let log_width = get_width_arg(&args, &op)?;
478 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
479 let op = DivModOp {
480 op: DivOrMod::DivMod,
481 signed: false,
482 panic: false,
483 };
484 op.emit(
485 ctx,
486 pcg,
487 log_width,
488 lhs.into_int_value(),
489 rhs.into_int_value(),
490 )
491 })
492 }
493 IntOpDef::idivmod_checked_s => {
494 let log_width = get_width_arg(&args, &op)?;
495 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
496 let op = DivModOp {
497 op: DivOrMod::DivMod,
498 signed: true,
499 panic: false,
500 };
501 op.emit(
502 ctx,
503 pcg,
504 log_width,
505 lhs.into_int_value(),
506 rhs.into_int_value(),
507 )
508 })
509 }
510 IntOpDef::ineg => emit_custom_unary_op(context, args, |ctx, arg, _| {
511 Ok(vec![
512 ctx.builder()
513 .build_int_neg(arg.into_int_value(), "")?
514 .as_basic_value_enum(),
515 ])
516 }),
517 IntOpDef::iabs => emit_custom_unary_op(context, args, |ctx, arg, _| {
518 let intr = get_intrinsic(
519 ctx.get_current_module(),
520 "llvm.abs.i64",
521 [ctx.iw_context().i64_type().as_basic_type_enum()],
522 )?;
523 let true_ = ctx.iw_context().bool_type().const_all_ones();
524 let r = ctx
525 .builder()
526 .build_call(intr, &[arg.into_int_value().into(), true_.into()], "")?
527 .try_as_basic_value()
528 .unwrap_left();
529 Ok(vec![r])
530 }),
531 IntOpDef::imax_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
532 let intr = get_intrinsic(
533 ctx.get_current_module(),
534 "llvm.smax.i64",
535 [ctx.iw_context().i64_type().as_basic_type_enum()],
536 )?;
537 let r = ctx
538 .builder()
539 .build_call(
540 intr,
541 &[lhs.into_int_value().into(), rhs.into_int_value().into()],
542 "",
543 )?
544 .try_as_basic_value()
545 .unwrap_left();
546 Ok(vec![r])
547 }),
548 IntOpDef::imax_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
549 let intr = get_intrinsic(
550 ctx.get_current_module(),
551 "llvm.umax.i64",
552 [ctx.iw_context().i64_type().as_basic_type_enum()],
553 )?;
554 let r = ctx
555 .builder()
556 .build_call(
557 intr,
558 &[lhs.into_int_value().into(), rhs.into_int_value().into()],
559 "",
560 )?
561 .try_as_basic_value()
562 .unwrap_left();
563 Ok(vec![r])
564 }),
565 IntOpDef::imin_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
566 let intr = get_intrinsic(
567 ctx.get_current_module(),
568 "llvm.smin.i64",
569 [ctx.iw_context().i64_type().as_basic_type_enum()],
570 )?;
571 let r = ctx
572 .builder()
573 .build_call(
574 intr,
575 &[lhs.into_int_value().into(), rhs.into_int_value().into()],
576 "",
577 )?
578 .try_as_basic_value()
579 .unwrap_left();
580 Ok(vec![r])
581 }),
582 IntOpDef::imin_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
583 let intr = get_intrinsic(
584 ctx.get_current_module(),
585 "llvm.umin.i64",
586 [ctx.iw_context().i64_type().as_basic_type_enum()],
587 )?;
588 let r = ctx
589 .builder()
590 .build_call(
591 intr,
592 &[lhs.into_int_value().into(), rhs.into_int_value().into()],
593 "",
594 )?
595 .try_as_basic_value()
596 .unwrap_left();
597 Ok(vec![r])
598 }),
599 IntOpDef::ishl => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
600 Ok(vec![
601 ctx.builder()
602 .build_left_shift(lhs.into_int_value(), rhs.into_int_value(), "")?
603 .as_basic_value_enum(),
604 ])
605 }),
606 IntOpDef::ishr => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
607 Ok(vec![
608 ctx.builder()
609 .build_right_shift(lhs.into_int_value(), rhs.into_int_value(), false, "")?
610 .as_basic_value_enum(),
611 ])
612 }),
613 IntOpDef::ieq => emit_icmp(context, args, inkwell::IntPredicate::EQ),
614 IntOpDef::ine => emit_icmp(context, args, inkwell::IntPredicate::NE),
615 IntOpDef::ilt_s => emit_icmp(context, args, inkwell::IntPredicate::SLT),
616 IntOpDef::igt_s => emit_icmp(context, args, inkwell::IntPredicate::SGT),
617 IntOpDef::ile_s => emit_icmp(context, args, inkwell::IntPredicate::SLE),
618 IntOpDef::ige_s => emit_icmp(context, args, inkwell::IntPredicate::SGE),
619 IntOpDef::ilt_u => emit_icmp(context, args, inkwell::IntPredicate::ULT),
620 IntOpDef::igt_u => emit_icmp(context, args, inkwell::IntPredicate::UGT),
621 IntOpDef::ile_u => emit_icmp(context, args, inkwell::IntPredicate::ULE),
622 IntOpDef::ige_u => emit_icmp(context, args, inkwell::IntPredicate::UGE),
623 IntOpDef::ixor => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
624 Ok(vec![
625 ctx.builder()
626 .build_xor(lhs.into_int_value(), rhs.into_int_value(), "")?
627 .as_basic_value_enum(),
628 ])
629 }),
630 IntOpDef::ior => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
631 Ok(vec![
632 ctx.builder()
633 .build_or(lhs.into_int_value(), rhs.into_int_value(), "")?
634 .as_basic_value_enum(),
635 ])
636 }),
637 IntOpDef::inot => emit_custom_unary_op(context, args, |ctx, arg, _| {
638 Ok(vec![
639 ctx.builder()
640 .build_not(arg.into_int_value(), "")?
641 .as_basic_value_enum(),
642 ])
643 }),
644 IntOpDef::iand => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
645 Ok(vec![
646 ctx.builder()
647 .build_and(lhs.into_int_value(), rhs.into_int_value(), "")?
648 .as_basic_value_enum(),
649 ])
650 }),
651 IntOpDef::ipow => emit_ipow(context, args),
652 IntOpDef::iwiden_u => emit_custom_unary_op(context, args, |ctx, arg, outs| {
654 let [out] = outs.try_into()?;
655 Ok(vec![
656 ctx.builder()
657 .build_int_cast_sign_flag(arg.into_int_value(), out.into_int_type(), false, "")?
658 .as_basic_value_enum(),
659 ])
660 }),
661 IntOpDef::iwiden_s => emit_custom_unary_op(context, args, |ctx, arg, outs| {
662 let [out] = outs.try_into()?;
663
664 Ok(vec![
665 ctx.builder()
666 .build_int_cast_sign_flag(arg.into_int_value(), out.into_int_type(), true, "")?
667 .as_basic_value_enum(),
668 ])
669 }),
670 IntOpDef::inarrow_s => {
671 let Some(TypeArg::BoundedNat(out_log_width)) = args.node().args().last().cloned()
672 else {
673 bail!("Type arg to inarrow_s wasn't a Nat");
674 };
675 let (_, out_ty) = args.node.out_value_types().next().unwrap();
676 emit_custom_unary_op(context, args, |ctx, arg, outs| {
677 let result = make_narrow(
678 ctx,
679 arg,
680 outs,
681 out_log_width,
682 true,
683 out_ty.as_sum().unwrap().clone(),
684 )?;
685 Ok(vec![result])
686 })
687 }
688 IntOpDef::inarrow_u => {
689 let Some(TypeArg::BoundedNat(out_log_width)) = args.node().args().last().cloned()
690 else {
691 bail!("Type arg to inarrow_u wasn't a Nat");
692 };
693 let (_, out_ty) = args.node.out_value_types().next().unwrap();
694 emit_custom_unary_op(context, args, |ctx, arg, outs| {
695 let result = make_narrow(
696 ctx,
697 arg,
698 outs,
699 out_log_width,
700 false,
701 out_ty.as_sum().unwrap().clone(),
702 )?;
703 Ok(vec![result])
704 })
705 }
706 IntOpDef::iu_to_s => {
707 let log_width = get_width_arg(&args, &op)?;
708 emit_custom_unary_op(context, args, |ctx, arg, _| {
709 let (_, max_val, _) = int_type_bounds(u32::pow(2, log_width as u32));
710 let max = arg
711 .get_type()
712 .into_int_type()
713 .const_int(max_val as u64, false);
714
715 let within_bounds = ctx.builder().build_int_compare(
716 IntPredicate::ULE,
717 arg.into_int_value(),
718 max,
719 "bounds_check",
720 )?;
721
722 Ok(vec![val_or_panic(
723 ctx,
724 pcg,
725 within_bounds,
726 &ERR_IU_TO_S,
727 |_| Ok(arg),
728 )?])
729 })
730 }
731 IntOpDef::is_to_u => emit_custom_unary_op(context, args, |ctx, arg, _| {
732 let zero = arg.get_type().into_int_type().const_zero();
733
734 let within_bounds = ctx.builder().build_int_compare(
735 IntPredicate::SGE,
736 arg.into_int_value(),
737 zero,
738 "bounds_check",
739 )?;
740
741 Ok(vec![val_or_panic(
742 ctx,
743 pcg,
744 within_bounds,
745 &ERR_IS_TO_U,
746 |_| Ok(arg),
747 )?])
748 }),
749 _ => Err(anyhow!("IntOpEmitter: unimplemented op: {}", op.op_id())),
750 }
751}
752
753pub(crate) fn get_width_arg<H: HugrView<Node = Node>>(
756 args: &EmitOpArgs<'_, '_, ExtensionOp, H>,
757 op: &impl MakeExtensionOp,
758) -> Result<u64> {
759 let [TypeArg::BoundedNat(log_width)] = args.node.args() else {
760 bail!(
761 "Expected exactly one BoundedNat parameter to {}",
762 op.op_id()
763 )
764 };
765 Ok(*log_width)
766}
767
768fn make_divmod<'c, H: HugrView<Node = Node>>(
778 ctx: &mut EmitFuncContext<'c, '_, H>,
779 pcg: &impl PreludeCodegen,
780 log_width: u64,
781 numerator: IntValue<'c>,
782 denominator: IntValue<'c>,
783 panic: bool,
784 signed: bool,
785) -> Result<LLVMSumValue<'c>> {
786 let int_arg_ty = int_types::INT_TYPES[log_width as usize].clone();
787 let tuple_sum_ty = HugrSumType::new_tuple(vec![int_arg_ty.clone(), int_arg_ty.clone()]);
788
789 let pair_ty = LLVMSumType::try_from_hugr_type(&ctx.typing_session(), tuple_sum_ty.clone())?;
790
791 let build_divmod = |ctx: &mut EmitFuncContext<'c, '_, H>| -> Result<BasicValueEnum<'c>> {
792 if signed {
793 let max_signed_value = u64::pow(2, u32::pow(2, log_width as u32) - 1) - 1;
794 let max_signed = numerator.get_type().const_int(max_signed_value, false);
795 let large_divisor_bool = ctx.builder().build_int_compare(
799 IntPredicate::UGT,
800 denominator,
801 max_signed,
802 "is_divisor_large",
803 )?;
804 let large_divisor =
805 ctx.builder()
806 .build_int_z_extend(large_divisor_bool, denominator.get_type(), "")?;
807 let negative_numerator_bool = ctx.builder().build_int_compare(
808 IntPredicate::SLT,
809 numerator,
810 numerator.get_type().const_zero(),
811 "is_dividend_negative",
812 )?;
813 let negative_numerator = ctx.builder().build_int_z_extend(
814 negative_numerator_bool,
815 denominator.get_type(),
816 "",
817 )?;
818 let tag = ctx.builder().build_left_shift(
819 large_divisor,
820 denominator.get_type().const_int(1, false),
821 "",
822 )?;
823
824 let tag = ctx.builder().build_or(tag, negative_numerator, "tag")?;
825
826 let quot = ctx
827 .builder()
828 .build_int_signed_div(numerator, denominator, "quotient")?;
829 let rem = ctx
830 .builder()
831 .build_int_signed_rem(numerator, denominator, "remainder")?;
832
833 let result_ptr = ctx.builder().build_alloca(pair_ty.clone(), "result")?;
834
835 let finish = ctx.new_basic_block("finish", None);
836 let negative_bigdiv = ctx.new_basic_block("negative_bigdiv", Some(finish));
837 let negative_smoldiv = ctx.new_basic_block("negative_smoldiv", Some(finish));
838 let non_negative_bigdiv = ctx.new_basic_block("non_negative_bigdiv", Some(finish));
839 let non_negative_smoldiv = ctx.new_basic_block("non_negative_smoldiv", Some(finish));
840
841 ctx.builder().build_switch(
842 tag,
843 non_negative_smoldiv,
844 &[
845 (denominator.get_type().const_int(1, false), negative_smoldiv),
846 (
847 denominator.get_type().const_int(2, false),
848 non_negative_bigdiv,
849 ),
850 (denominator.get_type().const_int(3, false), negative_bigdiv),
851 ],
852 )?;
853
854 let build_and_store_result =
855 |ctx: &mut EmitFuncContext<'c, '_, H>, vs: Vec<BasicValueEnum<'c>>| -> Result<()> {
856 let result = pair_ty
857 .build_tag(ctx.builder(), 0, vs)?
858 .as_basic_value_enum();
860 ctx.builder().build_store(result_ptr, result)?;
861 ctx.builder().build_unconditional_branch(finish)?;
862 Ok(())
863 };
864
865 ctx.builder().position_at_end(non_negative_smoldiv);
869 build_and_store_result(
870 ctx,
871 vec![quot.as_basic_value_enum(), rem.as_basic_value_enum()],
872 )?;
873
874 ctx.builder().position_at_end(negative_smoldiv);
877 {
878 let if_rem_zero = pair_ty
880 .build_tag(
881 ctx.builder(),
882 0,
883 vec![
884 quot.as_basic_value_enum(),
885 rem.get_type().const_zero().as_basic_value_enum(),
886 ],
887 )?
888 .as_basic_value_enum();
889
890 let if_rem_nonzero = pair_ty
892 .build_tag(
893 ctx.builder(),
894 0,
895 vec![
896 ctx.builder()
897 .build_int_sub(quot, quot.get_type().const_int(1, true), "")?
898 .as_basic_value_enum(),
899 ctx.builder()
900 .build_int_add(denominator, rem, "")?
901 .as_basic_value_enum(),
902 ],
903 )?
904 .as_basic_value_enum();
905
906 let is_rem_zero = ctx.builder().build_int_compare(
907 IntPredicate::EQ,
908 rem,
909 rem.get_type().const_zero(),
910 "is_rem_0",
911 )?;
912 let result =
913 ctx.builder()
914 .build_select(is_rem_zero, if_rem_zero, if_rem_nonzero, "")?;
915 ctx.builder().build_store(result_ptr, result)?;
916 ctx.builder().build_unconditional_branch(finish)?;
917 }
918
919 ctx.builder().position_at_end(non_negative_bigdiv);
922 build_and_store_result(
923 ctx,
924 vec![
925 numerator.get_type().const_zero().as_basic_value_enum(),
926 numerator.as_basic_value_enum(),
927 ],
928 )?;
929
930 ctx.builder().position_at_end(negative_bigdiv);
934 build_and_store_result(
935 ctx,
936 vec![
937 numerator.get_type().const_all_ones().as_basic_value_enum(),
938 ctx.builder()
939 .build_int_add(numerator, denominator, "")?
940 .as_basic_value_enum(),
941 ],
942 )?;
943
944 ctx.builder().position_at_end(finish);
945 let result = ctx.builder().build_load(result_ptr, "result")?;
946 Ok(result)
947 } else {
948 let quot = ctx
949 .builder()
950 .build_int_unsigned_div(numerator, denominator, "quotient")?;
951 let rem = ctx
952 .builder()
953 .build_int_unsigned_rem(numerator, denominator, "remainder")?;
954 Ok(pair_ty
955 .build_tag(
956 ctx.builder(),
957 0,
958 vec![quot.as_basic_value_enum(), rem.as_basic_value_enum()],
959 )?
960 .as_basic_value_enum())
961 }
962 };
963
964 let int_ty = numerator.get_type();
965 let zero = int_ty.const_zero();
966 let lower_bounds_check =
967 ctx.builder()
968 .build_int_compare(IntPredicate::NE, denominator, zero, "valid_div")?;
969
970 let sum_ty = LLVMSumType::try_from_hugr_type(
971 &ctx.typing_session(),
972 sum_with_error(vec![Type::from(tuple_sum_ty)]),
973 )?;
974
975 if panic {
976 LLVMSumValue::try_new(
977 val_or_panic(ctx, pcg, lower_bounds_check, &ERR_DIV_0, |ctx| {
978 build_divmod(ctx)
979 })?,
980 pair_ty,
981 )
982 } else {
983 let result = build_divmod(ctx)?;
984 LLVMSumValue::try_new(
985 val_or_error(ctx, lower_bounds_check, result, &ERR_DIV_0, sum_ty.clone())?,
986 sum_ty,
987 )
988 }
989}
990
991fn make_narrow<'c, H: HugrView<Node = Node>>(
992 ctx: &mut EmitFuncContext<'c, '_, H>,
993 arg: BasicValueEnum<'c>,
994 outs: &[BasicTypeEnum<'c>],
995 out_log_width: u64,
996 signed: bool,
997 sum_type: HugrSumType,
998) -> Result<BasicValueEnum<'c>> {
999 let [out] = TryInto::<[BasicTypeEnum; 1]>::try_into(outs)?;
1000 let width = 1 << out_log_width;
1001 let arg_int_ty: IntType = arg.get_type().into_int_type();
1002 let (int_min_value_s, int_max_value_s, int_max_value_u) = int_type_bounds(width);
1003 let out_int_ty = out
1004 .into_struct_type()
1005 .get_field_type_at_index(2)
1006 .unwrap()
1007 .into_int_type();
1008 let outside_range = if signed {
1009 let too_big = ctx.builder().build_int_compare(
1010 IntPredicate::SGT,
1011 arg.into_int_value(),
1012 arg_int_ty.const_int(int_max_value_s as u64, true),
1013 "upper_bounds_check",
1014 )?;
1015 let too_small = ctx.builder().build_int_compare(
1016 IntPredicate::SLT,
1017 arg.into_int_value(),
1018 arg_int_ty.const_int(int_min_value_s as u64, true),
1019 "lower_bounds_check",
1020 )?;
1021 ctx.builder()
1022 .build_or(too_big, too_small, "outside_range")?
1023 } else {
1024 ctx.builder().build_int_compare(
1025 IntPredicate::UGT,
1026 arg.into_int_value(),
1027 arg_int_ty.const_int(int_max_value_u, false),
1028 "upper_bounds_check",
1029 )?
1030 };
1031
1032 let inbounds = ctx.builder().build_not(outside_range, "inbounds")?;
1033 let narrowed_val = ctx
1034 .builder()
1035 .build_int_cast_sign_flag(arg.into_int_value(), out_int_ty, signed, "")?
1036 .as_basic_value_enum();
1037 val_or_error(
1038 ctx,
1039 inbounds,
1040 narrowed_val,
1041 &ERR_NARROW,
1042 LLVMSumType::try_from_hugr_type(&ctx.typing_session(), sum_type).unwrap(),
1043 )
1044}
1045
1046fn val_or_panic<'c, H: HugrView<Node = Node>>(
1047 ctx: &mut EmitFuncContext<'c, '_, H>,
1048 pcg: &impl PreludeCodegen,
1049 dont_panic: IntValue<'c>,
1050 err: &ConstError,
1051 go: impl Fn(&mut EmitFuncContext<'c, '_, H>) -> Result<BasicValueEnum<'c>>,
1053) -> Result<BasicValueEnum<'c>> {
1054 let exit_bb = ctx.new_basic_block("exit", None);
1055 let go_bb = ctx.new_basic_block("panic_if_0", Some(exit_bb));
1056 let panic_bb = ctx.new_basic_block("panic", Some(exit_bb));
1057 ctx.builder().build_unconditional_branch(go_bb)?;
1058
1059 ctx.builder().position_at_end(panic_bb);
1060 let err = ctx.emit_custom_const(err)?;
1061 pcg.emit_panic(ctx, err)?;
1062 ctx.builder().build_unconditional_branch(exit_bb)?;
1063
1064 ctx.builder().position_at_end(go_bb);
1065 ctx.builder().build_switch(
1066 dont_panic,
1067 panic_bb,
1068 &[(dont_panic.get_type().const_int(1, false), exit_bb)],
1069 )?;
1070
1071 ctx.builder().position_at_end(exit_bb);
1072
1073 go(ctx)
1074}
1075
1076fn val_or_error<'c, H: HugrView<Node = Node>>(
1077 ctx: &mut EmitFuncContext<'c, '_, H>,
1078 should_succeed: IntValue<'c>,
1079 val: BasicValueEnum<'c>,
1080 err: &ConstError,
1081 ty: LLVMSumType<'c>,
1082) -> Result<BasicValueEnum<'c>> {
1083 let err_val = ctx.emit_custom_const(err)?;
1084
1085 let err_variant = ty.build_tag(ctx.builder(), 0, vec![err_val])?;
1086 let ok_variant = ty.build_tag(ctx.builder(), 1, vec![val])?;
1087
1088 Ok(ctx
1089 .builder()
1090 .build_select(should_succeed, ok_variant, err_variant, "")?)
1091}
1092
1093fn llvm_type<'c>(
1094 context: TypingSession<'c, '_>,
1095 hugr_type: &CustomType,
1096) -> Result<BasicTypeEnum<'c>> {
1097 if let [TypeArg::BoundedNat(n)] = hugr_type.args() {
1098 let m = *n as usize;
1099 if m < int_types::INT_TYPES.len() && int_types::INT_TYPES[m] == hugr_type.clone().into() {
1100 return Ok(match m {
1101 0..=3 => context.iw_context().i8_type(),
1102 4 => context.iw_context().i16_type(),
1103 5 => context.iw_context().i32_type(),
1104 6 => context.iw_context().i64_type(),
1105 _ => Err(anyhow!(
1106 "IntTypesCodegenExtension: unsupported log_width: {}",
1107 m
1108 ))?,
1109 }
1110 .into());
1111 }
1112 }
1113 Err(anyhow!(
1114 "IntTypesCodegenExtension: unsupported type: {}",
1115 hugr_type
1116 ))
1117}
1118
1119fn emit_const_int<'c, H: HugrView<Node = Node>>(
1120 context: &mut EmitFuncContext<'c, '_, H>,
1121 k: &ConstInt,
1122) -> Result<BasicValueEnum<'c>> {
1123 let ty: IntType = context.llvm_type(&k.get_type())?.try_into().unwrap();
1124 Ok(ty.const_int(k.value_u(), false).as_basic_value_enum())
1128}
1129
1130impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
1131 #[must_use]
1136 pub fn add_default_int_extensions(self) -> Self {
1137 self.add_extension(IntCodegenExtension::new(DefaultPreludeCodegen))
1138 }
1139}
1140
1141#[cfg(test)]
1142mod test {
1143 use anyhow::Result;
1144 use hugr_core::builder::DataflowHugr;
1145 use hugr_core::extension::prelude::{ConstError, UnwrapBuilder, error_type};
1146 use hugr_core::std_extensions::STD_REG;
1147 use hugr_core::{
1148 Hugr,
1149 builder::{Dataflow, DataflowSubContainer, SubContainer, handle::Outputs},
1150 extension::prelude::bool_t,
1151 ops::{DataflowOpTrait, ExtensionOp},
1152 std_extensions::arithmetic::{
1153 int_ops::{self, IntOpDef},
1154 int_types::{ConstInt, INT_TYPES},
1155 },
1156 types::{SumType, Type, TypeRow},
1157 };
1158 use rstest::rstest;
1159
1160 use crate::{
1161 check_emission,
1162 emit::test::{DFGW, SimpleHugrConfig},
1163 test::{TestContext, exec_ctx, llvm_ctx, single_op_hugr},
1164 };
1165
1166 #[rstest::fixture]
1167 fn int_exec_ctx(mut exec_ctx: TestContext) -> TestContext {
1168 exec_ctx.add_extensions(|cem| {
1169 cem.add_default_int_extensions()
1170 .add_default_prelude_extensions()
1171 });
1172 exec_ctx
1173 }
1174
1175 #[rstest::fixture]
1176 fn int_llvm_ctx(mut llvm_ctx: TestContext) -> TestContext {
1177 llvm_ctx.add_extensions(|cem| {
1178 cem.add_default_int_extensions()
1179 .add_default_prelude_extensions()
1180 });
1181 llvm_ctx
1182 }
1183
1184 fn make_int_op(name: impl AsRef<str>, log_width: u8) -> ExtensionOp {
1186 int_ops::EXTENSION
1187 .instantiate_extension_op(name.as_ref(), [u64::from(log_width).into()])
1188 .unwrap()
1189 }
1190
1191 fn test_binary_int_op(ext_op: ExtensionOp, log_width: u8) -> Hugr {
1192 let ty = &INT_TYPES[log_width as usize];
1193 test_int_op_with_results::<2>(ext_op, log_width, None, ty.clone())
1194 }
1195
1196 fn test_binary_icmp_op(ext_op: ExtensionOp, log_width: u8) -> Hugr {
1197 test_int_op_with_results::<2>(ext_op, log_width, None, bool_t())
1198 }
1199
1200 fn test_int_op_with_results<const N: usize>(
1201 ext_op: ExtensionOp,
1202 log_width: u8,
1203 inputs: Option<[ConstInt; N]>,
1204 output_type: Type,
1205 ) -> Hugr {
1206 test_int_op_with_results_processing(ext_op, log_width, inputs, output_type, |_, a| Ok(a))
1207 }
1208
1209 fn test_int_op_with_results_processing<const N: usize>(
1210 ext_op: ExtensionOp,
1212 log_width: u8,
1213 inputs: Option<[ConstInt; N]>, output_type: Type,
1215 process: impl Fn(&mut DFGW, Outputs) -> Result<Outputs>,
1216 ) -> Hugr {
1217 let ty = &INT_TYPES[log_width as usize];
1218 let input_tys = if inputs.is_some() {
1219 vec![]
1220 } else {
1221 let input_tys = itertools::repeat_n(ty.clone(), N).collect();
1222 assert_eq!(input_tys, ext_op.signature().input.to_vec());
1223 input_tys
1224 };
1225 SimpleHugrConfig::new()
1226 .with_ins(input_tys)
1227 .with_outs(vec![output_type])
1228 .with_extensions(STD_REG.clone())
1229 .finish(|mut hugr_builder| {
1230 let input_wires = match inputs {
1231 None => hugr_builder.input_wires_arr::<N>().to_vec(),
1232 Some(inputs) => {
1233 let mut input_wires = Vec::new();
1234 for i in inputs.into_iter() {
1235 let w = hugr_builder.add_load_value(i);
1236 input_wires.push(w);
1237 }
1238 input_wires
1239 }
1240 };
1241 let outputs = hugr_builder
1242 .add_dataflow_op(ext_op, input_wires)
1243 .unwrap()
1244 .outputs();
1245 let processed_outputs = process(&mut hugr_builder, outputs).unwrap();
1246 hugr_builder
1247 .finish_hugr_with_outputs(processed_outputs)
1248 .unwrap()
1249 })
1250 }
1251
1252 #[rstest]
1253 #[case(IntOpDef::iu_to_s, &[3])]
1254 #[case(IntOpDef::is_to_u, &[3])]
1255 #[case(IntOpDef::ineg, &[2])]
1256 #[case::idiv_checked_u("idiv_checked_u", &[3])]
1257 #[case::idiv_checked_s("idiv_checked_s", &[3])]
1258 #[case::imod_checked_u("imod_checked_u", &[6])]
1259 #[case::imod_checked_s("imod_checked_s", &[6])]
1260 #[case::idivmod_u("idivmod_u", &[3])]
1261 #[case::idivmod_s("idivmod_s", &[3])]
1262 #[case::idivmod_checked_u("idivmod_checked_u", &[6])]
1263 #[case::idivmod_checked_s("idivmod_checked_s", &[6])]
1264 fn test_emission(int_llvm_ctx: TestContext, #[case] op: IntOpDef, #[case] args: &[u8]) {
1265 use hugr_core::extension::simple_op::MakeExtensionOp as _;
1266
1267 let mut insta = insta::Settings::clone_current();
1268 insta.set_snapshot_suffix(format!(
1269 "{}_{}_{:?}",
1270 insta.snapshot_suffix().unwrap_or(""),
1271 op.op_id(),
1272 args,
1273 ));
1274 let concrete = match *args {
1275 [] => op.without_log_width(),
1276 [log_width] => op.with_log_width(log_width),
1277 [lw1, lw2] => op.with_two_log_widths(lw1, lw2),
1278 _ => panic!("unexpected number of args to the op!"),
1279 };
1280 insta.bind(|| {
1281 let hugr = single_op_hugr(concrete.into());
1282 check_emission!(hugr, int_llvm_ctx);
1283 });
1284 }
1285
1286 #[rstest]
1287 #[case::iadd("iadd", 3)]
1288 #[case::isub("isub", 6)]
1289 #[case::ipow("ipow", 3)]
1290 #[case::idiv_u("idiv_u", 3)]
1291 #[case::idiv_s("idiv_s", 3)]
1292 #[case::imod_u("imod_u", 3)]
1293 #[case::imod_s("imod_s", 3)]
1294 fn test_binop_emission(int_llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
1295 let ext_op = make_int_op(op.clone(), width);
1296 let hugr = test_binary_int_op(ext_op, width);
1297 check_emission!(op.clone(), hugr, int_llvm_ctx);
1298 }
1299
1300 #[rstest]
1301 #[case::signed_2_3("iwiden_s", 2, 3)]
1302 #[case::signed_1_6("iwiden_s", 1, 6)]
1303 #[case::unsigned_2_3("iwiden_u", 2, 3)]
1304 #[case::unsigned_1_6("iwiden_u", 1, 6)]
1305 fn test_widen_emission(
1306 int_llvm_ctx: TestContext,
1307 #[case] op: String,
1308 #[case] from: u8,
1309 #[case] to: u8,
1310 ) {
1311 let out_ty = INT_TYPES[to as usize].clone();
1312 let ext_op = int_ops::EXTENSION
1313 .instantiate_extension_op(&op, [u64::from(from).into(), u64::from(to).into()])
1314 .unwrap();
1315 let hugr = test_int_op_with_results::<1>(ext_op, from, None, out_ty);
1316
1317 check_emission!(
1318 format!("{}_{}_{}", op.clone(), from, to),
1319 hugr,
1320 int_llvm_ctx
1321 );
1322 }
1323
1324 #[rstest]
1325 #[case::signed("inarrow_s", 3, 2)]
1326 #[case::unsigned("inarrow_u", 6, 4)]
1327 fn test_narrow_emission(
1328 int_llvm_ctx: TestContext,
1329 #[case] op: String,
1330 #[case] from: u8,
1331 #[case] to: u8,
1332 ) {
1333 let out_ty = SumType::new([vec![error_type()], vec![INT_TYPES[to as usize].clone()]]);
1334 let ext_op = int_ops::EXTENSION
1335 .instantiate_extension_op(&op, [u64::from(from).into(), u64::from(to).into()])
1336 .unwrap();
1337 let hugr = test_int_op_with_results::<1>(ext_op, from, None, out_ty.into());
1338
1339 check_emission!(
1340 format!("{}_{}_{}", op.clone(), from, to),
1341 hugr,
1342 int_llvm_ctx
1343 );
1344 }
1345
1346 #[rstest]
1347 #[case::ieq("ieq", 1)]
1348 #[case::ilt_s("ilt_s", 0)]
1349 fn test_cmp_emission(int_llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
1350 let ext_op = make_int_op(op.clone(), width);
1351 let hugr = test_binary_icmp_op(ext_op, width);
1352 check_emission!(op.clone(), hugr, int_llvm_ctx);
1353 }
1354
1355 #[rstest]
1356 #[case::imax("imax_u", 1, 2, 2)]
1357 #[case::imax("imax_u", 2, 1, 2)]
1358 #[case::imax("imax_u", 2, 2, 2)]
1359 #[case::imin("imin_u", 1, 2, 1)]
1360 #[case::imin("imin_u", 2, 1, 1)]
1361 #[case::imin("imin_u", 2, 2, 2)]
1362 #[case::ishl("ishl", 73, 1, 146)]
1363 #[case::ishl("ishl", 18446744073709551615, 1, 18446744073709551614)]
1365 #[case::ishr("ishr", 73, 1, 36)]
1366 #[case::ior("ior", 6, 9, 15)]
1367 #[case::ior("ior", 6, 15, 15)]
1368 #[case::ixor("ixor", 6, 9, 15)]
1369 #[case::ixor("ixor", 6, 15, 9)]
1370 #[case::ixor("ixor", 15, 6, 9)]
1371 #[case::iand("iand", 6, 15, 6)]
1372 #[case::iand("iand", 15, 6, 6)]
1373 #[case::iand("iand", 15, 15, 15)]
1374 #[case::ipow("ipow", 2, 3, 8)]
1375 #[case::ipow("ipow", 42, 1, 42)]
1376 #[case::ipow("ipow", 42, 0, 1)]
1377 #[case::idiv("idiv_u", 42, 2, 21)]
1378 #[case::idiv("idiv_u", 42, 5, 8)]
1379 #[case::imod("imod_u", 42, 2, 0)]
1380 #[case::imod("imod_u", 42, 5, 2)]
1381 fn test_exec_unsigned_bin_op(
1382 int_exec_ctx: TestContext,
1383 #[case] op: String,
1384 #[case] lhs: u64,
1385 #[case] rhs: u64,
1386 #[case] result: u64,
1387 ) {
1388 let ty = &INT_TYPES[6].clone();
1389 let inputs = [
1390 ConstInt::new_u(6, lhs).unwrap(),
1391 ConstInt::new_u(6, rhs).unwrap(),
1392 ];
1393 let ext_op = make_int_op(&op, 6);
1394
1395 let hugr = test_int_op_with_results::<2>(ext_op, 6, Some(inputs), ty.clone());
1396 assert_eq!(int_exec_ctx.exec_hugr_u64(hugr, "main"), result);
1397 }
1398
1399 #[rstest]
1400 #[case::imax("imax_s", 1, 2, 2)]
1401 #[case::imax("imax_s", 2, 1, 2)]
1402 #[case::imax("imax_s", 2, 2, 2)]
1403 #[case::imax("imax_s", -1, -2, -1)]
1404 #[case::imax("imax_s", -2, -1, -1)]
1405 #[case::imax("imax_s", -2, -2, -2)]
1406 #[case::imin("imin_s", 1, 2, 1)]
1407 #[case::imin("imin_s", 2, 1, 1)]
1408 #[case::imin("imin_s", 2, 2, 2)]
1409 #[case::imin("imin_s", -1, -2, -2)]
1410 #[case::imin("imin_s", -2, -1, -2)]
1411 #[case::imin("imin_s", -2, -2, -2)]
1412 #[case::ipow("ipow", -2, 1, -2)]
1413 #[case::ipow("ipow", -2, 2, 4)]
1414 #[case::ipow("ipow", -2, 3, -8)]
1415 fn test_exec_signed_bin_op(
1416 int_exec_ctx: TestContext,
1417 #[case] op: String,
1418 #[case] lhs: i64,
1419 #[case] rhs: i64,
1420 #[case] result: i64,
1421 ) {
1422 let ty = &INT_TYPES[6].clone();
1423 let inputs = [
1424 ConstInt::new_s(6, lhs).unwrap(),
1425 ConstInt::new_s(6, rhs).unwrap(),
1426 ];
1427 let ext_op = make_int_op(&op, 6);
1428
1429 let hugr = test_int_op_with_results::<2>(ext_op, 6, Some(inputs), ty.clone());
1430 assert_eq!(int_exec_ctx.exec_hugr_i64(hugr, "main"), result);
1431 }
1432
1433 #[rstest]
1434 #[case::iabs("iabs", 42, 42)]
1435 #[case::iabs("iabs", -42, 42)]
1436 fn test_exec_signed_unary_op(
1437 int_exec_ctx: TestContext,
1438 #[case] op: String,
1439 #[case] arg: i64,
1440 #[case] result: i64,
1441 ) {
1442 let input = ConstInt::new_s(6, arg).unwrap();
1443 let ty = INT_TYPES[6].clone();
1444 let ext_op = make_int_op(&op, 6);
1445
1446 let hugr = test_int_op_with_results::<1>(ext_op, 6, Some([input]), ty.clone());
1447 assert_eq!(int_exec_ctx.exec_hugr_i64(hugr, "main"), result);
1448 }
1449
1450 #[rstest]
1451 #[case::inot("inot", 9223372036854775808, !9223372036854775808u64)]
1452 #[case::inot("inot", 42, !42u64)]
1453 #[case::inot("inot", !0u64, 0)]
1454 fn test_exec_unsigned_unary_op(
1455 int_exec_ctx: TestContext,
1456 #[case] op: String,
1457 #[case] arg: u64,
1458 #[case] result: u64,
1459 ) {
1460 let input = ConstInt::new_u(6, arg).unwrap();
1461 let ty = INT_TYPES[6].clone();
1462 let ext_op = make_int_op(&op, 6);
1463
1464 let hugr = test_int_op_with_results::<1>(ext_op, 6, Some([input]), ty.clone());
1465 assert_eq!(int_exec_ctx.exec_hugr_u64(hugr, "main"), result);
1466 }
1467
1468 #[rstest]
1469 #[case(-127)]
1470 #[case(-1)]
1471 #[case(0)]
1472 #[case(1)]
1473 #[case(127)]
1474 fn test_exec_widen(int_exec_ctx: TestContext, #[case] num: i16) {
1475 let from: u8 = 3;
1476 let to: u8 = 6;
1477 let ty = INT_TYPES[to as usize].clone();
1478
1479 if num >= 0 {
1480 let input = ConstInt::new_u(from, num as u64).unwrap();
1481
1482 let ext_op = int_ops::EXTENSION
1483 .instantiate_extension_op(
1484 "iwiden_u".as_ref(),
1485 [(from as u64).into(), (to as u64).into()],
1486 )
1487 .unwrap();
1488
1489 let hugr = test_int_op_with_results::<1>(ext_op, to, Some([input]), ty.clone());
1490
1491 assert_eq!(int_exec_ctx.exec_hugr_u64(hugr, "main"), num as u64);
1492 }
1493
1494 let input = ConstInt::new_s(from, num as i64).unwrap();
1495
1496 let ext_op = int_ops::EXTENSION
1497 .instantiate_extension_op(
1498 "iwiden_s".as_ref(),
1499 [(from as u64).into(), (to as u64).into()],
1500 )
1501 .unwrap();
1502
1503 let hugr = test_int_op_with_results::<1>(ext_op, to, Some([input]), ty.clone());
1504
1505 assert_eq!(int_exec_ctx.exec_hugr_i64(hugr, "main"), num as i64);
1506 }
1507
1508 #[rstest]
1509 #[case("inarrow_s", 6, 2, 4)]
1510 #[case("inarrow_s", 6, 5, (1 << 5) - 1)]
1511 #[case("inarrow_s", 6, 4, -1)]
1512 #[case("inarrow_s", 6, 4, -(1 << 4) - 1)]
1513 #[case("inarrow_s", 6, 4, -(1 <<15))]
1514 #[case("inarrow_s", 6, 5, (1 << 31) - 1)]
1515 fn test_narrow_s(
1516 int_exec_ctx: TestContext,
1517 #[case] op: String,
1518 #[case] from: u8,
1519 #[case] to: u8,
1520 #[case] arg: i64,
1521 ) {
1522 let input = ConstInt::new_s(from, arg).unwrap();
1523 let to_ty = INT_TYPES[to as usize].clone();
1524 let ext_op = int_ops::EXTENSION
1525 .instantiate_extension_op(op.as_ref(), [u64::from(from).into(), u64::from(to).into()])
1526 .unwrap();
1527
1528 let hugr = test_int_op_with_results_processing::<1>(
1529 ext_op,
1530 to,
1531 Some([input]),
1532 to_ty.clone(),
1533 |builder, outs| {
1534 let [out] = outs.to_array();
1535
1536 let err_row = TypeRow::from(vec![error_type()]);
1537 let ty_row = TypeRow::from(vec![to_ty.clone()]);
1538 let mut cond_b = builder.conditional_builder(
1546 ([err_row, ty_row], out),
1547 [],
1548 vec![to_ty.clone()].into(),
1549 )?;
1550 let mut sad_b = cond_b.case_builder(0)?;
1551 let err = ConstError::new(2, "This shouldn't happen");
1552 let w = sad_b.add_load_value(ConstInt::new_s(to, 0)?);
1553 sad_b.add_panic(err, vec![to_ty.clone()], [(w, to_ty.clone())])?;
1554 sad_b.finish_with_outputs([w])?;
1555
1556 let happy_b = cond_b.case_builder(1)?;
1557 let [w] = happy_b.input_wires_arr();
1558 happy_b.finish_with_outputs([w])?;
1559
1560 let handle = cond_b.finish_sub_container()?;
1561 Ok(handle.outputs())
1562 },
1563 );
1564 assert_eq!(int_exec_ctx.exec_hugr_i64(hugr, "main"), arg);
1565 }
1566
1567 #[rstest]
1568 #[case(6, 42)]
1569 #[case(4, 7)]
1570 fn test_u_to_s(int_exec_ctx: TestContext, #[case] log_width: u8, #[case] val: u64) {
1572 let ty = &INT_TYPES[log_width as usize].clone();
1573 let hugr = SimpleHugrConfig::new()
1574 .with_outs(vec![ty.clone()])
1575 .with_extensions(STD_REG.clone())
1576 .finish(|mut hugr_builder| {
1577 let unsigned =
1578 hugr_builder.add_load_value(ConstInt::new_u(log_width, val).unwrap());
1579 let iu_to_s = make_int_op("iu_to_s", log_width);
1580 let [signed] = hugr_builder
1581 .add_dataflow_op(iu_to_s, [unsigned])
1582 .unwrap()
1583 .outputs_arr();
1584 hugr_builder.finish_hugr_with_outputs([signed]).unwrap()
1585 });
1586 let act = int_exec_ctx.exec_hugr_i64(hugr, "main");
1587 assert_eq!(act, val as i64);
1588 }
1589
1590 #[rstest]
1591 #[case(3, 0)]
1592 #[case(4, 255)]
1593 fn test_s_to_u(int_exec_ctx: TestContext, #[case] log_width: u8, #[case] val: i64) {
1595 let ty = &INT_TYPES[log_width as usize].clone();
1596 let hugr = SimpleHugrConfig::new()
1597 .with_outs(vec![ty.clone()])
1598 .with_extensions(STD_REG.clone())
1599 .finish(|mut hugr_builder| {
1600 let signed = hugr_builder.add_load_value(ConstInt::new_s(log_width, val).unwrap());
1601 let is_to_u = make_int_op("is_to_u", log_width);
1602 let [unsigned] = hugr_builder
1603 .add_dataflow_op(is_to_u, [signed])
1604 .unwrap()
1605 .outputs_arr();
1606 let num = hugr_builder.add_load_value(ConstInt::new_u(log_width, 42).unwrap());
1607 let [res] = hugr_builder
1608 .add_dataflow_op(make_int_op("iadd", log_width), [unsigned, num])
1609 .unwrap()
1610 .outputs_arr();
1611 hugr_builder.finish_hugr_with_outputs([res]).unwrap()
1612 });
1613 let act = int_exec_ctx.exec_hugr_u64(hugr, "main");
1614 assert_eq!(act, (val as u64) + 42);
1615 }
1616
1617 #[rstest]
1619 #[case::bigdiv_non_negative(127, 255, (0, 127))] #[case::bigdiv_negative(-42, 255, (-1, 213))] #[case::smoldiv_non_negative(42, 10, (4, 2))] #[case::smoldiv_negative_rem0(-42, 21, (-2, 0))] #[case::smoldiv_negative_rem_nonzero(-42, 10, (-5, 8))] fn test_divmod_s(
1625 int_exec_ctx: TestContext,
1626 #[case] dividend: i64,
1627 #[case] divisor: u64,
1628 #[case] expected_result: (i64, u64),
1629 ) {
1630 let int_ty = INT_TYPES[3].clone();
1631 let k_dividend = ConstInt::new_s(3, dividend).unwrap();
1632 let k_divisor = ConstInt::new_u(3, divisor).unwrap();
1633 let quot_hugr = test_int_op_with_results(
1634 make_int_op("idiv_s", 3),
1635 3,
1636 Some([k_dividend.clone(), k_divisor.clone()]),
1637 int_ty.clone(),
1638 );
1639 let rem_hugr = test_int_op_with_results(
1640 make_int_op("imod_s", 3),
1641 3,
1642 Some([k_dividend, k_divisor]),
1643 int_ty,
1644 );
1645 let quot = int_exec_ctx.exec_hugr_i64(quot_hugr, "main");
1646 let rem = int_exec_ctx.exec_hugr_u64(rem_hugr, "main");
1647 assert_eq!((quot, rem), expected_result);
1648 }
1649}