1use hugr_core::{
2 extension::prelude::ConstError,
3 ops::{constant::CustomConst, ExtensionOp, NamedOp, Value},
4 std_extensions::arithmetic::{
5 int_ops::IntOpDef,
6 int_types::{self, ConstInt},
7 },
8 types::{CustomType, TypeArg},
9 HugrView, Node,
10};
11use inkwell::{
12 types::{BasicType, BasicTypeEnum, IntType},
13 values::{BasicValue, BasicValueEnum, IntValue},
14 IntPredicate,
15};
16
17use crate::{
18 custom::CodegenExtsBuilder,
19 emit::{
20 emit_value,
21 func::EmitFuncContext,
22 get_intrinsic,
23 libc::{emit_libc_abort, emit_libc_printf},
24 ops::{emit_custom_binary_op, emit_custom_unary_op},
25 EmitOpArgs,
26 },
27 sum::LLVMSumType,
28 types::{HugrSumType, TypingSession},
29};
30
31use anyhow::{anyhow, bail, Result};
32
33use super::conversions::int_type_bounds;
34
35enum RuntimeError {
36 Narrow,
37}
38
39impl RuntimeError {
40 fn show(&self) -> &str {
41 match self {
42 RuntimeError::Narrow => "Can't narrow into bounds",
43 }
44 }
45}
46
47fn emit_icmp<'c, H: HugrView<Node = Node>>(
49 context: &mut EmitFuncContext<'c, '_, H>,
50 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
51 pred: inkwell::IntPredicate,
52) -> Result<()> {
53 let true_val = emit_value(context, &Value::true_val())?;
54 let false_val = emit_value(context, &Value::false_val())?;
55
56 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
57 let r = ctx.builder().build_int_compare(
59 pred,
60 lhs.into_int_value(),
61 rhs.into_int_value(),
62 "",
63 )?;
64 Ok(vec![ctx
66 .builder()
67 .build_select(r, true_val, false_val, "")?])
68 })
69}
70
71fn emit_ipow<'c, H: HugrView<Node = Node>>(
75 context: &mut EmitFuncContext<'c, '_, H>,
76 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
77) -> Result<()> {
78 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
79 let done_bb = ctx.new_basic_block("done", None);
80 let pow_body_bb = ctx.new_basic_block("pow_body", Some(done_bb));
81 let return_one_bb = ctx.new_basic_block("power_of_zero", Some(pow_body_bb));
82 let pow_bb = ctx.new_basic_block("pow", Some(return_one_bb));
83
84 let acc_p = ctx.builder().build_alloca(lhs.get_type(), "acc_ptr")?;
85 let exp_p = ctx.builder().build_alloca(rhs.get_type(), "exp_ptr")?;
86 ctx.builder().build_store(acc_p, lhs)?;
87 ctx.builder().build_store(exp_p, rhs)?;
88 ctx.builder().build_unconditional_branch(pow_bb)?;
89
90 let zero = rhs.get_type().into_int_type().const_int(0, false);
91 let one = rhs.get_type().into_int_type().const_int(1, false);
93
94 ctx.builder().position_at_end(return_one_bb);
96 ctx.builder().build_store(acc_p, one)?;
97 ctx.builder().build_unconditional_branch(done_bb)?;
98
99 ctx.builder().position_at_end(pow_bb);
100 let acc = ctx.builder().build_load(acc_p, "acc")?;
101 let exp = ctx.builder().build_load(exp_p, "exp")?;
102
103 ctx.builder().build_switch(
105 exp.into_int_value(),
106 pow_body_bb,
107 &[(one, done_bb), (zero, return_one_bb)],
108 )?;
109
110 ctx.builder().position_at_end(pow_body_bb);
112 let new_acc =
113 ctx.builder()
114 .build_int_mul(acc.into_int_value(), lhs.into_int_value(), "new_acc")?;
115 let new_exp = ctx
116 .builder()
117 .build_int_sub(exp.into_int_value(), one, "new_exp")?;
118 ctx.builder().build_store(acc_p, new_acc)?;
119 ctx.builder().build_store(exp_p, new_exp)?;
120 ctx.builder().build_unconditional_branch(pow_bb)?;
121
122 ctx.builder().position_at_end(done_bb);
123 let result = ctx.builder().build_load(acc_p, "result")?;
124 Ok(vec![result.as_basic_value_enum()])
125 })
126}
127
128fn emit_int_op<'c, H: HugrView<Node = Node>>(
129 context: &mut EmitFuncContext<'c, '_, H>,
130 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
131 op: IntOpDef,
132) -> Result<()> {
133 match op {
134 IntOpDef::iadd => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
135 Ok(vec![ctx
136 .builder()
137 .build_int_add(lhs.into_int_value(), rhs.into_int_value(), "")?
138 .as_basic_value_enum()])
139 }),
140 IntOpDef::imul => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
141 Ok(vec![ctx
142 .builder()
143 .build_int_mul(lhs.into_int_value(), rhs.into_int_value(), "")?
144 .as_basic_value_enum()])
145 }),
146 IntOpDef::isub => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
147 Ok(vec![ctx
148 .builder()
149 .build_int_sub(lhs.into_int_value(), rhs.into_int_value(), "")?
150 .as_basic_value_enum()])
151 }),
152 IntOpDef::idiv_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
153 Ok(vec![ctx
154 .builder()
155 .build_int_signed_div(lhs.into_int_value(), rhs.into_int_value(), "")?
156 .as_basic_value_enum()])
157 }),
158 IntOpDef::idiv_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
159 Ok(vec![ctx
160 .builder()
161 .build_int_unsigned_div(lhs.into_int_value(), rhs.into_int_value(), "")?
162 .as_basic_value_enum()])
163 }),
164 IntOpDef::imod_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
165 Ok(vec![ctx
166 .builder()
167 .build_int_signed_rem(lhs.into_int_value(), rhs.into_int_value(), "")?
168 .as_basic_value_enum()])
169 }),
170 IntOpDef::ineg => emit_custom_unary_op(context, args, |ctx, arg, _| {
171 Ok(vec![ctx
172 .builder()
173 .build_int_neg(arg.into_int_value(), "")?
174 .as_basic_value_enum()])
175 }),
176 IntOpDef::iabs => emit_custom_unary_op(context, args, |ctx, arg, _| {
177 let intr = get_intrinsic(
178 ctx.get_current_module(),
179 "llvm.abs.i64",
180 [ctx.iw_context().i64_type().as_basic_type_enum()],
181 )?;
182 let true_ = ctx.iw_context().bool_type().const_all_ones();
183 let r = ctx
184 .builder()
185 .build_call(intr, &[arg.into_int_value().into(), true_.into()], "")?
186 .try_as_basic_value()
187 .unwrap_left();
188 Ok(vec![r])
189 }),
190 IntOpDef::imax_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
191 let intr = get_intrinsic(
192 ctx.get_current_module(),
193 "llvm.smax.i64",
194 [ctx.iw_context().i64_type().as_basic_type_enum()],
195 )?;
196 let r = ctx
197 .builder()
198 .build_call(
199 intr,
200 &[lhs.into_int_value().into(), rhs.into_int_value().into()],
201 "",
202 )?
203 .try_as_basic_value()
204 .unwrap_left();
205 Ok(vec![r])
206 }),
207 IntOpDef::imax_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
208 let intr = get_intrinsic(
209 ctx.get_current_module(),
210 "llvm.umax.i64",
211 [ctx.iw_context().i64_type().as_basic_type_enum()],
212 )?;
213 let r = ctx
214 .builder()
215 .build_call(
216 intr,
217 &[lhs.into_int_value().into(), rhs.into_int_value().into()],
218 "",
219 )?
220 .try_as_basic_value()
221 .unwrap_left();
222 Ok(vec![r])
223 }),
224 IntOpDef::imin_s => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
225 let intr = get_intrinsic(
226 ctx.get_current_module(),
227 "llvm.smin.i64",
228 [ctx.iw_context().i64_type().as_basic_type_enum()],
229 )?;
230 let r = ctx
231 .builder()
232 .build_call(
233 intr,
234 &[lhs.into_int_value().into(), rhs.into_int_value().into()],
235 "",
236 )?
237 .try_as_basic_value()
238 .unwrap_left();
239 Ok(vec![r])
240 }),
241 IntOpDef::imin_u => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
242 let intr = get_intrinsic(
243 ctx.get_current_module(),
244 "llvm.umin.i64",
245 [ctx.iw_context().i64_type().as_basic_type_enum()],
246 )?;
247 let r = ctx
248 .builder()
249 .build_call(
250 intr,
251 &[lhs.into_int_value().into(), rhs.into_int_value().into()],
252 "",
253 )?
254 .try_as_basic_value()
255 .unwrap_left();
256 Ok(vec![r])
257 }),
258 IntOpDef::ishl => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
259 Ok(vec![ctx
260 .builder()
261 .build_left_shift(lhs.into_int_value(), rhs.into_int_value(), "")?
262 .as_basic_value_enum()])
263 }),
264 IntOpDef::ishr => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
265 Ok(vec![ctx
266 .builder()
267 .build_right_shift(lhs.into_int_value(), rhs.into_int_value(), false, "")?
268 .as_basic_value_enum()])
269 }),
270 IntOpDef::ieq => emit_icmp(context, args, inkwell::IntPredicate::EQ),
271 IntOpDef::ine => emit_icmp(context, args, inkwell::IntPredicate::NE),
272 IntOpDef::ilt_s => emit_icmp(context, args, inkwell::IntPredicate::SLT),
273 IntOpDef::igt_s => emit_icmp(context, args, inkwell::IntPredicate::SGT),
274 IntOpDef::ile_s => emit_icmp(context, args, inkwell::IntPredicate::SLE),
275 IntOpDef::ige_s => emit_icmp(context, args, inkwell::IntPredicate::SGE),
276 IntOpDef::ilt_u => emit_icmp(context, args, inkwell::IntPredicate::ULT),
277 IntOpDef::igt_u => emit_icmp(context, args, inkwell::IntPredicate::UGT),
278 IntOpDef::ile_u => emit_icmp(context, args, inkwell::IntPredicate::ULE),
279 IntOpDef::ige_u => emit_icmp(context, args, inkwell::IntPredicate::UGE),
280 IntOpDef::ixor => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
281 Ok(vec![ctx
282 .builder()
283 .build_xor(lhs.into_int_value(), rhs.into_int_value(), "")?
284 .as_basic_value_enum()])
285 }),
286 IntOpDef::ior => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
287 Ok(vec![ctx
288 .builder()
289 .build_or(lhs.into_int_value(), rhs.into_int_value(), "")?
290 .as_basic_value_enum()])
291 }),
292 IntOpDef::inot => emit_custom_unary_op(context, args, |ctx, arg, _| {
293 Ok(vec![ctx
294 .builder()
295 .build_not(arg.into_int_value(), "")?
296 .as_basic_value_enum()])
297 }),
298 IntOpDef::iand => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
299 Ok(vec![ctx
300 .builder()
301 .build_and(lhs.into_int_value(), rhs.into_int_value(), "")?
302 .as_basic_value_enum()])
303 }),
304 IntOpDef::ipow => emit_ipow(context, args),
305 IntOpDef::iwiden_u => emit_custom_unary_op(context, args, |ctx, arg, outs| {
307 let [out] = outs.try_into()?;
308 Ok(vec![ctx
309 .builder()
310 .build_int_cast_sign_flag(arg.into_int_value(), out.into_int_type(), false, "")?
311 .as_basic_value_enum()])
312 }),
313 IntOpDef::iwiden_s => emit_custom_unary_op(context, args, |ctx, arg, outs| {
314 let [out] = outs.try_into()?;
315
316 Ok(vec![ctx
317 .builder()
318 .build_int_cast_sign_flag(arg.into_int_value(), out.into_int_type(), true, "")?
319 .as_basic_value_enum()])
320 }),
321 IntOpDef::inarrow_s => {
322 let Some(TypeArg::BoundedNat { n: out_log_width }) = args.node().args().last().cloned()
323 else {
324 bail!("Type arg to inarrow_s wasn't a Nat");
325 };
326 let (_, out_ty) = args.node.out_value_types().next().unwrap();
327 emit_custom_unary_op(context, args, |ctx, arg, outs| {
328 let result = make_narrow(
329 ctx,
330 arg,
331 outs,
332 out_log_width,
333 true,
334 out_ty.as_sum().unwrap().clone(),
335 )?;
336 Ok(vec![result])
337 })
338 }
339 IntOpDef::inarrow_u => {
340 let Some(TypeArg::BoundedNat { n: out_log_width }) = args.node().args().last().cloned()
341 else {
342 bail!("Type arg to inarrow_u wasn't a Nat");
343 };
344 let (_, out_ty) = args.node.out_value_types().next().unwrap();
345 emit_custom_unary_op(context, args, |ctx, arg, outs| {
346 let result = make_narrow(
347 ctx,
348 arg,
349 outs,
350 out_log_width,
351 false,
352 out_ty.as_sum().unwrap().clone(),
353 )?;
354 Ok(vec![result])
355 })
356 }
357 IntOpDef::iu_to_s => {
358 let [TypeArg::BoundedNat { n: log_width }] =
359 TryInto::<[TypeArg; 1]>::try_into(args.node.args().to_vec()).unwrap()
360 else {
361 bail!("Type argument to iu_to_s wasn't a number");
362 };
363 emit_custom_unary_op(context, args, |ctx, arg, _| {
364 let (_, max_val, _) = int_type_bounds(u32::pow(2, log_width as u32));
365 let max = arg
366 .get_type()
367 .into_int_type()
368 .const_int(max_val as u64, false);
369
370 let within_bounds = ctx.builder().build_int_compare(
371 IntPredicate::ULE,
372 arg.into_int_value(),
373 max,
374 "bounds_check",
375 )?;
376
377 Ok(vec![val_or_panic(
378 ctx,
379 within_bounds,
380 "iu_to_s argument out of bounds",
381 arg,
382 )?])
383 })
384 }
385 IntOpDef::is_to_u => emit_custom_unary_op(context, args, |ctx, arg, _| {
386 let zero = arg.get_type().into_int_type().const_zero();
387
388 let within_bounds = ctx.builder().build_int_compare(
389 IntPredicate::SGE,
390 arg.into_int_value(),
391 zero,
392 "bounds_check",
393 )?;
394
395 Ok(vec![val_or_panic(
396 ctx,
397 within_bounds,
398 "is_to_u called on negative value",
399 arg,
400 )?])
401 }),
402 _ => Err(anyhow!("IntOpEmitter: unimplemented op: {}", op.name())),
403 }
404}
405
406fn make_narrow<'c, H: HugrView<Node = Node>>(
407 ctx: &mut EmitFuncContext<'c, '_, H>,
408 arg: BasicValueEnum<'c>,
409 outs: &[BasicTypeEnum<'c>],
410 out_log_width: u64,
411 signed: bool,
412 sum_type: HugrSumType,
413) -> Result<BasicValueEnum<'c>> {
414 let [out] = TryInto::<[BasicTypeEnum; 1]>::try_into(outs)?;
415 let width = 1 << out_log_width;
416 let arg_int_ty: IntType = arg.get_type().into_int_type();
417 let (int_min_value_s, int_max_value_s, int_max_value_u) = int_type_bounds(width);
418 let out_int_ty = out
419 .into_struct_type()
420 .get_field_type_at_index(2)
421 .unwrap()
422 .into_int_type();
423 let outside_range = if signed {
424 let too_big = ctx.builder().build_int_compare(
425 IntPredicate::SGT,
426 arg.into_int_value(),
427 arg_int_ty.const_int(int_max_value_s as u64, true),
428 "upper_bounds_check",
429 )?;
430 let too_small = ctx.builder().build_int_compare(
431 IntPredicate::SLT,
432 arg.into_int_value(),
433 arg_int_ty.const_int(int_min_value_s as u64, true),
434 "lower_bounds_check",
435 )?;
436 ctx.builder()
437 .build_or(too_big, too_small, "outside_range")?
438 } else {
439 ctx.builder().build_int_compare(
440 IntPredicate::UGT,
441 arg.into_int_value(),
442 arg_int_ty.const_int(int_max_value_u, false),
443 "upper_bounds_check",
444 )?
445 };
446
447 let narrowed_val = ctx
448 .builder()
449 .build_int_cast_sign_flag(arg.into_int_value(), out_int_ty, signed, "")?
450 .as_basic_value_enum();
451 val_or_error(
452 ctx,
453 outside_range,
454 narrowed_val,
455 RuntimeError::Narrow,
456 LLVMSumType::try_from_hugr_type(&ctx.typing_session(), sum_type).unwrap(),
457 )
458}
459
460fn val_or_panic<'c, H: HugrView<Node = Node>>(
461 ctx: &mut EmitFuncContext<'c, '_, H>,
462 dont_panic: IntValue<'c>,
463 err_msg_str: &str,
464 val: BasicValueEnum<'c>, ) -> Result<BasicValueEnum<'c>> {
466 let done_bb = ctx.new_basic_block("done", None);
467 let exit_bb = ctx.new_basic_block("exit", Some(done_bb));
468 let go_bb = ctx.new_basic_block("panic_if_0", Some(exit_bb));
469 let panic_bb = ctx.new_basic_block("panic", Some(exit_bb));
470 ctx.builder().build_unconditional_branch(go_bb)?;
471
472 ctx.builder().position_at_end(exit_bb);
473 ctx.builder().build_return(Some(&val))?;
474
475 ctx.builder().position_at_end(panic_bb);
476 let err_msg = ctx
477 .builder()
478 .build_global_string_ptr(err_msg_str, "err_msg")?
479 .as_basic_value_enum();
480 emit_libc_printf(ctx, &[err_msg.into()])?;
481 emit_libc_abort(ctx)?;
482 ctx.builder().build_unconditional_branch(exit_bb)?;
483
484 ctx.builder().position_at_end(go_bb);
485 ctx.builder().build_switch(
486 dont_panic,
487 panic_bb,
488 &[(dont_panic.get_type().const_int(1, false), exit_bb)],
489 )?;
490
491 ctx.builder().position_at_end(done_bb);
492
493 Ok(val) }
495
496fn val_or_error<'c, H: HugrView<Node = Node>>(
497 ctx: &mut EmitFuncContext<'c, '_, H>,
498 should_fail: IntValue<'c>,
499 val: BasicValueEnum<'c>,
500 msg: RuntimeError,
501 ty: LLVMSumType<'c>,
502) -> Result<BasicValueEnum<'c>> {
503 let err_msg = Value::extension(ConstError::new(2, msg.show()));
504 let err_val = emit_value(ctx, &err_msg)?;
505
506 let err_variant = ty.build_tag(ctx.builder(), 0, vec![err_val])?;
507 let ok_variant = ty.build_tag(ctx.builder(), 1, vec![val])?;
508
509 Ok(ctx
510 .builder()
511 .build_select(should_fail, err_variant, ok_variant, "")?)
512}
513
514fn llvm_type<'c>(
515 context: TypingSession<'c, '_>,
516 hugr_type: &CustomType,
517) -> Result<BasicTypeEnum<'c>> {
518 if let [TypeArg::BoundedNat { n }] = hugr_type.args() {
519 let m = *n as usize;
520 if m < int_types::INT_TYPES.len() && int_types::INT_TYPES[m] == hugr_type.clone().into() {
521 return Ok(match m {
522 0..=3 => context.iw_context().i8_type(),
523 4 => context.iw_context().i16_type(),
524 5 => context.iw_context().i32_type(),
525 6 => context.iw_context().i64_type(),
526 _ => Err(anyhow!(
527 "IntTypesCodegenExtension: unsupported log_width: {}",
528 m
529 ))?,
530 }
531 .into());
532 }
533 }
534 Err(anyhow!(
535 "IntTypesCodegenExtension: unsupported type: {}",
536 hugr_type
537 ))
538}
539
540fn emit_const_int<'c, H: HugrView<Node = Node>>(
541 context: &mut EmitFuncContext<'c, '_, H>,
542 k: &ConstInt,
543) -> Result<BasicValueEnum<'c>> {
544 let ty: IntType = context.llvm_type(&k.get_type())?.try_into().unwrap();
545 Ok(ty.const_int(k.value_u(), false).as_basic_value_enum())
549}
550
551pub fn add_int_extensions<'a, H: HugrView<Node = Node> + 'a>(
554 cem: CodegenExtsBuilder<'a, H>,
555) -> CodegenExtsBuilder<'a, H> {
556 cem.custom_const(emit_const_int)
557 .custom_type((int_types::EXTENSION_ID, "int".into()), llvm_type)
558 .simple_extension_op::<IntOpDef>(emit_int_op)
559}
560
561impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
562 pub fn add_int_extensions(self) -> Self {
565 add_int_extensions(self)
566 }
567}
568
569#[cfg(test)]
570mod test {
571 use anyhow::Result;
572 use hugr_core::extension::prelude::{error_type, ConstError, UnwrapBuilder};
573 use hugr_core::std_extensions::STD_REG;
574 use hugr_core::{
575 builder::{handle::Outputs, Dataflow, DataflowSubContainer, SubContainer},
576 extension::prelude::bool_t,
577 ops::{DataflowOpTrait, ExtensionOp, NamedOp},
578 std_extensions::arithmetic::{
579 int_ops::{self, IntOpDef},
580 int_types::{ConstInt, INT_TYPES},
581 },
582 types::{SumType, Type, TypeRow},
583 Hugr,
584 };
585 use rstest::rstest;
586
587 use crate::extension::DefaultPreludeCodegen;
588 use crate::{
589 check_emission,
590 emit::test::{SimpleHugrConfig, DFGW},
591 extension::{int::add_int_extensions, prelude::add_prelude_extensions},
592 test::{exec_ctx, llvm_ctx, single_op_hugr, TestContext},
593 };
594
595 fn make_int_op(name: impl AsRef<str>, log_width: u8) -> ExtensionOp {
597 int_ops::EXTENSION
598 .instantiate_extension_op(name.as_ref(), [(log_width as u64).into()])
599 .unwrap()
600 }
601
602 fn test_binary_int_op(ext_op: ExtensionOp, log_width: u8) -> Hugr {
603 let ty = &INT_TYPES[log_width as usize];
604 test_int_op_with_results::<2>(ext_op, log_width, None, ty.clone())
605 }
606
607 fn test_binary_icmp_op(ext_op: ExtensionOp, log_width: u8) -> Hugr {
608 test_int_op_with_results::<2>(ext_op, log_width, None, bool_t())
609 }
610
611 fn test_int_op_with_results<const N: usize>(
612 ext_op: ExtensionOp,
613 log_width: u8,
614 inputs: Option<[ConstInt; N]>,
615 output_type: Type,
616 ) -> Hugr {
617 test_int_op_with_results_processing(ext_op, log_width, inputs, output_type, |_, a| Ok(a))
618 }
619
620 fn test_int_op_with_results_processing<const N: usize>(
621 ext_op: ExtensionOp,
623 log_width: u8,
624 inputs: Option<[ConstInt; N]>, output_type: Type,
626 process: impl Fn(&mut DFGW, Outputs) -> Result<Outputs>,
627 ) -> Hugr {
628 let ty = &INT_TYPES[log_width as usize];
629 let input_tys = if inputs.is_some() {
630 vec![]
631 } else {
632 let input_tys = itertools::repeat_n(ty.clone(), N).collect();
633 assert_eq!(input_tys, ext_op.signature().input.to_vec());
634 input_tys
635 };
636 SimpleHugrConfig::new()
637 .with_ins(input_tys)
638 .with_outs(vec![output_type])
639 .with_extensions(STD_REG.clone())
640 .finish(|mut hugr_builder| {
641 let input_wires = match inputs {
642 None => hugr_builder.input_wires_arr::<N>().to_vec(),
643 Some(inputs) => {
644 let mut input_wires = Vec::new();
645 inputs.into_iter().for_each(|i| {
646 let w = hugr_builder.add_load_value(i);
647 input_wires.push(w);
648 });
649 input_wires
650 }
651 };
652 let outputs = hugr_builder
653 .add_dataflow_op(ext_op, input_wires)
654 .unwrap()
655 .outputs();
656 let processed_outputs = process(&mut hugr_builder, outputs).unwrap();
657 hugr_builder.finish_with_outputs(processed_outputs).unwrap()
658 })
659 }
660
661 #[rstest]
662 #[case(IntOpDef::iu_to_s, &[3])]
663 #[case(IntOpDef::is_to_u, &[3])]
664 #[case(IntOpDef::ineg, &[2])]
665 fn test_emission(mut llvm_ctx: TestContext, #[case] op: IntOpDef, #[case] args: &[u8]) {
666 llvm_ctx.add_extensions(add_int_extensions);
667 let mut insta = insta::Settings::clone_current();
668 insta.set_snapshot_suffix(format!(
669 "{}_{}_{:?}",
670 insta.snapshot_suffix().unwrap_or(""),
671 op.name(),
672 args,
673 ));
674 let concrete = match *args {
675 [] => op.without_log_width(),
676 [log_width] => op.with_log_width(log_width),
677 [lw1, lw2] => op.with_two_log_widths(lw1, lw2),
678 _ => panic!("unexpected number of args to the op!"),
679 };
680 insta.bind(|| {
681 let hugr = single_op_hugr(concrete.into());
682 check_emission!(hugr, llvm_ctx);
683 })
684 }
685
686 #[rstest]
687 #[case::iadd("iadd", 3)]
688 #[case::isub("isub", 6)]
689 #[case::ipow("ipow", 3)]
690 fn test_binop_emission(mut llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
691 llvm_ctx.add_extensions(add_int_extensions);
692 let ext_op = make_int_op(op.clone(), width);
693 let hugr = test_binary_int_op(ext_op, width);
694 check_emission!(op.clone(), hugr, llvm_ctx);
695 }
696
697 #[rstest]
698 #[case::signed_2_3("iwiden_s", 2, 3)]
699 #[case::signed_1_6("iwiden_s", 1, 6)]
700 #[case::unsigned_2_3("iwiden_u", 2, 3)]
701 #[case::unsigned_1_6("iwiden_u", 1, 6)]
702 fn test_widen_emission(
703 mut llvm_ctx: TestContext,
704 #[case] op: String,
705 #[case] from: u8,
706 #[case] to: u8,
707 ) {
708 llvm_ctx.add_extensions(add_int_extensions);
709 let out_ty = INT_TYPES[to as usize].clone();
710 let ext_op = int_ops::EXTENSION
711 .instantiate_extension_op(&op, [(from as u64).into(), (to as u64).into()])
712 .unwrap();
713 let hugr = test_int_op_with_results::<1>(ext_op, from, None, out_ty);
714
715 check_emission!(format!("{}_{}_{}", op.clone(), from, to), hugr, llvm_ctx);
716 }
717
718 #[rstest]
719 #[case::signed("inarrow_s", 3, 2)]
720 #[case::unsigned("inarrow_u", 6, 4)]
721 fn test_narrow_emission(
722 mut llvm_ctx: TestContext,
723 #[case] op: String,
724 #[case] from: u8,
725 #[case] to: u8,
726 ) {
727 llvm_ctx.add_extensions(add_int_extensions);
728 llvm_ctx.add_extensions(|cem| add_prelude_extensions(cem, DefaultPreludeCodegen));
729 let out_ty = SumType::new([vec![error_type()], vec![INT_TYPES[to as usize].clone()]]);
730 let ext_op = int_ops::EXTENSION
731 .instantiate_extension_op(&op, [(from as u64).into(), (to as u64).into()])
732 .unwrap();
733 let hugr = test_int_op_with_results::<1>(ext_op, from, None, out_ty.into());
734
735 check_emission!(format!("{}_{}_{}", op.clone(), from, to), hugr, llvm_ctx);
736 }
737
738 #[rstest]
739 #[case::ieq("ieq", 1)]
740 #[case::ilt_s("ilt_s", 0)]
741 fn test_cmp_emission(mut llvm_ctx: TestContext, #[case] op: String, #[case] width: u8) {
742 llvm_ctx.add_extensions(add_int_extensions);
743 let ext_op = make_int_op(op.clone(), width);
744 let hugr = test_binary_icmp_op(ext_op, width);
745 check_emission!(op.clone(), hugr, llvm_ctx);
746 }
747
748 #[rstest]
749 #[case::imax("imax_u", 1, 2, 2)]
750 #[case::imax("imax_u", 2, 1, 2)]
751 #[case::imax("imax_u", 2, 2, 2)]
752 #[case::imin("imin_u", 1, 2, 1)]
753 #[case::imin("imin_u", 2, 1, 1)]
754 #[case::imin("imin_u", 2, 2, 2)]
755 #[case::ishl("ishl", 73, 1, 146)]
756 #[case::ishl("ishl", 18446744073709551615, 1, 18446744073709551614)]
758 #[case::ishr("ishr", 73, 1, 36)]
759 #[case::ior("ior", 6, 9, 15)]
760 #[case::ior("ior", 6, 15, 15)]
761 #[case::ixor("ixor", 6, 9, 15)]
762 #[case::ixor("ixor", 6, 15, 9)]
763 #[case::ixor("ixor", 15, 6, 9)]
764 #[case::iand("iand", 6, 15, 6)]
765 #[case::iand("iand", 15, 6, 6)]
766 #[case::iand("iand", 15, 15, 15)]
767 #[case::ipow("ipow", 2, 3, 8)]
768 #[case::ipow("ipow", 42, 1, 42)]
769 #[case::ipow("ipow", 42, 0, 1)]
770 fn test_exec_unsigned_bin_op(
771 mut exec_ctx: TestContext,
772 #[case] op: String,
773 #[case] lhs: u64,
774 #[case] rhs: u64,
775 #[case] result: u64,
776 ) {
777 exec_ctx.add_extensions(add_int_extensions);
778 let ty = &INT_TYPES[6].clone();
779 let inputs = [
780 ConstInt::new_u(6, lhs).unwrap(),
781 ConstInt::new_u(6, rhs).unwrap(),
782 ];
783 let ext_op = make_int_op(&op, 6);
784
785 let hugr = test_int_op_with_results::<2>(ext_op, 6, Some(inputs), ty.clone());
786 assert_eq!(exec_ctx.exec_hugr_u64(hugr, "main"), result);
787 }
788
789 #[rstest]
790 #[case::imax("imax_s", 1, 2, 2)]
791 #[case::imax("imax_s", 2, 1, 2)]
792 #[case::imax("imax_s", 2, 2, 2)]
793 #[case::imax("imax_s", -1, -2, -1)]
794 #[case::imax("imax_s", -2, -1, -1)]
795 #[case::imax("imax_s", -2, -2, -2)]
796 #[case::imin("imin_s", 1, 2, 1)]
797 #[case::imin("imin_s", 2, 1, 1)]
798 #[case::imin("imin_s", 2, 2, 2)]
799 #[case::imin("imin_s", -1, -2, -2)]
800 #[case::imin("imin_s", -2, -1, -2)]
801 #[case::imin("imin_s", -2, -2, -2)]
802 #[case::ipow("ipow", -2, 1, -2)]
803 #[case::ipow("ipow", -2, 2, 4)]
804 #[case::ipow("ipow", -2, 3, -8)]
805 fn test_exec_signed_bin_op(
806 mut exec_ctx: TestContext,
807 #[case] op: String,
808 #[case] lhs: i64,
809 #[case] rhs: i64,
810 #[case] result: i64,
811 ) {
812 exec_ctx.add_extensions(add_int_extensions);
813 let ty = &INT_TYPES[6].clone();
814 let inputs = [
815 ConstInt::new_s(6, lhs).unwrap(),
816 ConstInt::new_s(6, rhs).unwrap(),
817 ];
818 let ext_op = make_int_op(&op, 6);
819
820 let hugr = test_int_op_with_results::<2>(ext_op, 6, Some(inputs), ty.clone());
821 assert_eq!(exec_ctx.exec_hugr_i64(hugr, "main"), result);
822 }
823
824 #[rstest]
825 #[case::iabs("iabs", 42, 42)]
826 #[case::iabs("iabs", -42, 42)]
827 fn test_exec_signed_unary_op(
828 mut exec_ctx: TestContext,
829 #[case] op: String,
830 #[case] arg: i64,
831 #[case] result: i64,
832 ) {
833 exec_ctx.add_extensions(add_int_extensions);
834 let input = ConstInt::new_s(6, arg).unwrap();
835 let ty = INT_TYPES[6].clone();
836 let ext_op = make_int_op(&op, 6);
837
838 let hugr = test_int_op_with_results::<1>(ext_op, 6, Some([input]), ty.clone());
839 assert_eq!(exec_ctx.exec_hugr_i64(hugr, "main"), result);
840 }
841
842 #[rstest]
843 #[case::inot("inot", 9223372036854775808, !9223372036854775808u64)]
844 #[case::inot("inot", 42, !42u64)]
845 #[case::inot("inot", !0u64, 0)]
846 fn test_exec_unsigned_unary_op(
847 mut exec_ctx: TestContext,
848 #[case] op: String,
849 #[case] arg: u64,
850 #[case] result: u64,
851 ) {
852 exec_ctx.add_extensions(add_int_extensions);
853 let input = ConstInt::new_u(6, arg).unwrap();
854 let ty = INT_TYPES[6].clone();
855 let ext_op = make_int_op(&op, 6);
856
857 let hugr = test_int_op_with_results::<1>(ext_op, 6, Some([input]), ty.clone());
858 assert_eq!(exec_ctx.exec_hugr_u64(hugr, "main"), result);
859 }
860
861 #[rstest]
862 #[case("inarrow_s", 6, 2, 4)]
863 #[case("inarrow_s", 6, 5, (1 << 5) - 1)]
864 #[case("inarrow_s", 6, 4, -1)]
865 #[case("inarrow_s", 6, 4, -(1 << 4) - 1)]
866 #[case("inarrow_s", 6, 4, -(1 <<15))]
867 #[case("inarrow_s", 6, 5, (1 << 31) - 1)]
868 fn test_narrow_s(
869 mut exec_ctx: TestContext,
870 #[case] op: String,
871 #[case] from: u8,
872 #[case] to: u8,
873 #[case] arg: i64,
874 ) {
875 exec_ctx.add_extensions(add_int_extensions);
876 exec_ctx.add_extensions(|cem| add_prelude_extensions(cem, DefaultPreludeCodegen));
877 let input = ConstInt::new_s(from, arg).unwrap();
878 let to_ty = INT_TYPES[to as usize].clone();
879 let ext_op = int_ops::EXTENSION
880 .instantiate_extension_op(op.as_ref(), [(from as u64).into(), (to as u64).into()])
881 .unwrap();
882
883 let hugr = test_int_op_with_results_processing::<1>(
884 ext_op,
885 to,
886 Some([input]),
887 to_ty.clone(),
888 |builder, outs| {
889 let [out] = outs.to_array();
890
891 let err_row = TypeRow::from(vec![error_type()]);
892 let ty_row = TypeRow::from(vec![to_ty.clone()]);
893 let mut cond_b = builder.conditional_builder(
901 ([err_row, ty_row], out),
902 [],
903 vec![to_ty.clone()].into(),
904 )?;
905 let mut sad_b = cond_b.case_builder(0)?;
906 let err = ConstError::new(2, "This shouldn't happen");
907 let w = sad_b.add_load_value(ConstInt::new_s(to, 0)?);
908 sad_b.add_panic(err, vec![to_ty.clone()], [(w, to_ty.clone())])?;
909 sad_b.finish_with_outputs([w])?;
910
911 let happy_b = cond_b.case_builder(1)?;
912 let [w] = happy_b.input_wires_arr();
913 happy_b.finish_with_outputs([w])?;
914
915 let handle = cond_b.finish_sub_container()?;
916 Ok(handle.outputs())
917 },
918 );
919 assert_eq!(exec_ctx.exec_hugr_i64(hugr, "main"), arg);
920 }
921
922 #[rstest]
923 #[case(6, 42)]
924 #[case(4, 7)]
925 fn test_u_to_s(mut exec_ctx: TestContext, #[case] log_width: u8, #[case] val: u64) {
927 exec_ctx.add_extensions(add_int_extensions);
928 let ty = &INT_TYPES[log_width as usize].clone();
929 let hugr = SimpleHugrConfig::new()
930 .with_outs(vec![ty.clone()])
931 .with_extensions(STD_REG.clone())
932 .finish(|mut hugr_builder| {
933 let unsigned =
934 hugr_builder.add_load_value(ConstInt::new_u(log_width, val).unwrap());
935 let iu_to_s = make_int_op("iu_to_s", log_width);
936 let [signed] = hugr_builder
937 .add_dataflow_op(iu_to_s, [unsigned])
938 .unwrap()
939 .outputs_arr();
940 hugr_builder.finish_with_outputs([signed]).unwrap()
941 });
942 let act = exec_ctx.exec_hugr_i64(hugr, "main");
943 assert_eq!(act, val as i64);
944 }
945
946 #[rstest]
947 #[case(3, 0)]
948 #[case(4, 255)]
949 fn test_s_to_u(mut exec_ctx: TestContext, #[case] log_width: u8, #[case] val: i64) {
951 exec_ctx.add_extensions(add_int_extensions);
952 let ty = &INT_TYPES[log_width as usize].clone();
953 let hugr = SimpleHugrConfig::new()
954 .with_outs(vec![ty.clone()])
955 .with_extensions(STD_REG.clone())
956 .finish(|mut hugr_builder| {
957 let signed = hugr_builder.add_load_value(ConstInt::new_s(log_width, val).unwrap());
958 let is_to_u = make_int_op("is_to_u", log_width);
959 let [unsigned] = hugr_builder
960 .add_dataflow_op(is_to_u, [signed])
961 .unwrap()
962 .outputs_arr();
963 hugr_builder.finish_with_outputs([unsigned]).unwrap()
964 });
965 let act = exec_ctx.exec_hugr_u64(hugr, "main");
966 assert_eq!(act, val as u64);
967 }
968}