hugr_llvm/extension/collections/
array.rs

1//! Codegen for prelude array operations.
2//!
3//! An `array<n, T>` is now lowered to a fat pointer `{ptr, usize}` that is allocated
4//! to at least `n * sizeof(T)` bytes. The extra `usize` is an offset pointing to the
5//! first element, i.e. the first element is at address `ptr + offset * sizeof(T)`.
6//!
7//! The rational behind the additional offset is the `pop_left` operation which bumps
8//! the offset instead of mutating the pointer. This way, we can still free the original
9//! pointer when the array is discarded after a pop.
10//!
11//! We provide utility functions [`array_fat_pointer_ty`], [`build_array_fat_pointer`], and
12//! [`decompose_array_fat_pointer`] to work with array fat pointers.
13//!
14//! The [`DefaultArrayCodegen`] extension allocates all arrays on the heap using the
15//! standard libc `malloc` and `free` functions. This behaviour can be customised
16//! by providing a different implementation for [`ArrayCodegen::emit_allocate_array`]
17//! and [`ArrayCodegen::emit_free_array`].
18use std::iter;
19
20use anyhow::{Ok, Result, anyhow};
21use hugr_core::extension::prelude::{option_type, usize_t};
22use hugr_core::extension::simple_op::{MakeExtensionOp, MakeRegisteredOp};
23use hugr_core::ops::DataflowOpTrait;
24use hugr_core::std_extensions::collections::array::{
25    self, ArrayClone, ArrayDiscard, ArrayOp, ArrayOpDef, ArrayRepeat, ArrayScan, array_type,
26};
27use hugr_core::types::{TypeArg, TypeEnum};
28use hugr_core::{HugrView, Node};
29use inkwell::builder::Builder;
30use inkwell::intrinsics::Intrinsic;
31use inkwell::types::{BasicType, BasicTypeEnum, IntType, StructType};
32use inkwell::values::{
33    BasicValue as _, BasicValueEnum, CallableValue, IntValue, PointerValue, StructValue,
34};
35use inkwell::{AddressSpace, IntPredicate};
36use itertools::Itertools;
37
38use crate::emit::emit_value;
39use crate::emit::libc::{emit_libc_free, emit_libc_malloc};
40use crate::{CodegenExtension, CodegenExtsBuilder};
41use crate::{
42    emit::{EmitFuncContext, RowPromise, deaggregate_call_result},
43    types::{HugrType, TypingSession},
44};
45
46impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
47    /// Add a [`ArrayCodegenExtension`] to the given [`CodegenExtsBuilder`] using `ccg`
48    /// as the implementation.
49    #[must_use]
50    pub fn add_default_array_extensions(self) -> Self {
51        self.add_array_extensions(DefaultArrayCodegen)
52    }
53
54    /// Add a [`ArrayCodegenExtension`] to the given [`CodegenExtsBuilder`] using
55    /// [`DefaultArrayCodegen`] as the implementation.
56    pub fn add_array_extensions(self, ccg: impl ArrayCodegen + 'a) -> Self {
57        self.add_extension(ArrayCodegenExtension::from(ccg))
58    }
59}
60
61/// A helper trait for customising the lowering of [`hugr_core::std_extensions::collections::array`]
62/// types, [`hugr_core::ops::constant::CustomConst`]s, and ops.
63///
64/// An `array<n, T>` is now lowered to a fat pointer `{ptr, usize}` that is allocated
65/// to at least `n * sizeof(T)` bytes. The extra `usize` is an offset pointing to the
66/// first element, i.e. the first element is at address `ptr + offset * sizeof(T)`.
67///
68/// The rational behind the additional offset is the `pop_left` operation which bumps
69/// the offset instead of mutating the pointer. This way, we can still free the original
70/// pointer when the array is discarded after a pop.
71///
72/// By default, all arrays are allocated on the heap using the standard libc `malloc`
73/// and `free` functions. This behaviour can be customised by providing a different
74/// implementation for [`ArrayCodegen::emit_allocate_array`] and
75/// [`ArrayCodegen::emit_free_array`].
76pub trait ArrayCodegen: Clone {
77    /// Emit an allocation of `size` bytes and return the corresponding pointer.
78    ///
79    /// The default implementation allocates on the heap by emitting a call to the
80    /// standard libc `malloc` function.
81    fn emit_allocate_array<'c, H: HugrView<Node = Node>>(
82        &self,
83        ctx: &mut EmitFuncContext<'c, '_, H>,
84        size: IntValue<'c>,
85    ) -> Result<PointerValue<'c>> {
86        let ptr = emit_libc_malloc(ctx, size.into())?;
87        Ok(ptr.into_pointer_value())
88    }
89
90    /// Emit an deallocation of a pointer.
91    ///
92    /// The default implementation emits a call to the standard libc `free` function.
93    fn emit_free_array<'c, H: HugrView<Node = Node>>(
94        &self,
95        ctx: &mut EmitFuncContext<'c, '_, H>,
96        ptr: PointerValue<'c>,
97    ) -> Result<()> {
98        emit_libc_free(ctx, ptr.into())
99    }
100
101    /// Return the llvm type of [`hugr_core::std_extensions::collections::array::ARRAY_TYPENAME`].
102    fn array_type<'c>(
103        &self,
104        session: &TypingSession<'c, '_>,
105        elem_ty: BasicTypeEnum<'c>,
106        _size: u64,
107    ) -> impl BasicType<'c> {
108        array_fat_pointer_ty(session, elem_ty)
109    }
110
111    /// Emit a [`hugr_core::std_extensions::collections::array::ArrayValue`].
112    fn emit_array_value<'c, H: HugrView<Node = Node>>(
113        &self,
114        ctx: &mut EmitFuncContext<'c, '_, H>,
115        value: &array::ArrayValue,
116    ) -> Result<BasicValueEnum<'c>> {
117        emit_array_value(self, ctx, value)
118    }
119
120    /// Emit a [`hugr_core::std_extensions::collections::array::ArrayOp`].
121    fn emit_array_op<'c, H: HugrView<Node = Node>>(
122        &self,
123        ctx: &mut EmitFuncContext<'c, '_, H>,
124        op: ArrayOp,
125        inputs: Vec<BasicValueEnum<'c>>,
126        outputs: RowPromise<'c>,
127    ) -> Result<()> {
128        emit_array_op(self, ctx, op, inputs, outputs)
129    }
130
131    /// Emit a [`hugr_core::std_extensions::collections::array::ArrayClone`] operation.
132    fn emit_array_clone<'c, H: HugrView<Node = Node>>(
133        &self,
134        ctx: &mut EmitFuncContext<'c, '_, H>,
135        op: ArrayClone,
136        array_v: BasicValueEnum<'c>,
137    ) -> Result<(BasicValueEnum<'c>, BasicValueEnum<'c>)> {
138        emit_clone_op(self, ctx, op, array_v)
139    }
140
141    /// Emit a [`hugr_core::std_extensions::collections::array::ArrayDiscard`] operation.
142    fn emit_array_discard<'c, H: HugrView<Node = Node>>(
143        &self,
144        ctx: &mut EmitFuncContext<'c, '_, H>,
145        op: ArrayDiscard,
146        array_v: BasicValueEnum<'c>,
147    ) -> Result<()> {
148        emit_array_discard(self, ctx, op, array_v)
149    }
150
151    /// Emit a [`hugr_core::std_extensions::collections::array::ArrayRepeat`] op.
152    fn emit_array_repeat<'c, H: HugrView<Node = Node>>(
153        &self,
154        ctx: &mut EmitFuncContext<'c, '_, H>,
155        op: ArrayRepeat,
156        func: BasicValueEnum<'c>,
157    ) -> Result<BasicValueEnum<'c>> {
158        emit_repeat_op(self, ctx, op, func)
159    }
160
161    /// Emit a [`hugr_core::std_extensions::collections::array::ArrayScan`] op.
162    ///
163    /// Returns the resulting array and the final values of the accumulators.
164    fn emit_array_scan<'c, H: HugrView<Node = Node>>(
165        &self,
166        ctx: &mut EmitFuncContext<'c, '_, H>,
167        op: ArrayScan,
168        src_array: BasicValueEnum<'c>,
169        func: BasicValueEnum<'c>,
170        initial_accs: &[BasicValueEnum<'c>],
171    ) -> Result<(BasicValueEnum<'c>, Vec<BasicValueEnum<'c>>)> {
172        emit_scan_op(
173            self,
174            ctx,
175            op,
176            src_array.into_struct_value(),
177            func,
178            initial_accs,
179        )
180    }
181}
182
183/// A trivial implementation of [`ArrayCodegen`] which passes all methods
184/// through to their default implementations.
185#[derive(Default, Clone)]
186pub struct DefaultArrayCodegen;
187
188impl ArrayCodegen for DefaultArrayCodegen {}
189
190#[derive(Clone, Debug, Default)]
191pub struct ArrayCodegenExtension<CCG>(CCG);
192
193impl<CCG: ArrayCodegen> ArrayCodegenExtension<CCG> {
194    pub fn new(ccg: CCG) -> Self {
195        Self(ccg)
196    }
197}
198
199impl<CCG: ArrayCodegen> From<CCG> for ArrayCodegenExtension<CCG> {
200    fn from(ccg: CCG) -> Self {
201        Self::new(ccg)
202    }
203}
204
205impl<CCG: ArrayCodegen> CodegenExtension for ArrayCodegenExtension<CCG> {
206    fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
207        self,
208        builder: CodegenExtsBuilder<'a, H>,
209    ) -> CodegenExtsBuilder<'a, H>
210    where
211        Self: 'a,
212    {
213        builder
214            .custom_type((array::EXTENSION_ID, array::ARRAY_TYPENAME), {
215                let ccg = self.0.clone();
216                move |ts, hugr_type| {
217                    let [TypeArg::BoundedNat(n), TypeArg::Runtime(ty)] = hugr_type.args() else {
218                        return Err(anyhow!("Invalid type args for array type"));
219                    };
220                    let elem_ty = ts.llvm_type(ty)?;
221                    Ok(ccg.array_type(&ts, elem_ty, *n).as_basic_type_enum())
222                }
223            })
224            .custom_const::<array::ArrayValue>({
225                let ccg = self.0.clone();
226                move |context, k| ccg.emit_array_value(context, k)
227            })
228            .simple_extension_op::<ArrayOpDef>({
229                let ccg = self.0.clone();
230                move |context, args, _| {
231                    ccg.emit_array_op(
232                        context,
233                        ArrayOp::from_extension_op(args.node().as_ref())?,
234                        args.inputs,
235                        args.outputs,
236                    )
237                }
238            })
239            .extension_op(array::EXTENSION_ID, array::ARRAY_CLONE_OP_ID, {
240                let ccg = self.0.clone();
241                move |context, args| {
242                    let arr = args.inputs[0];
243                    let op = ArrayClone::from_extension_op(args.node().as_ref())?;
244                    let (arr1, arr2) = ccg.emit_array_clone(context, op, arr)?;
245                    args.outputs.finish(context.builder(), [arr1, arr2])
246                }
247            })
248            .extension_op(array::EXTENSION_ID, array::ARRAY_DISCARD_OP_ID, {
249                let ccg = self.0.clone();
250                move |context, args| {
251                    let arr = args.inputs[0];
252                    let op = ArrayDiscard::from_extension_op(args.node().as_ref())?;
253                    ccg.emit_array_discard(context, op, arr)?;
254                    args.outputs.finish(context.builder(), [])
255                }
256            })
257            .extension_op(array::EXTENSION_ID, array::ARRAY_REPEAT_OP_ID, {
258                let ccg = self.0.clone();
259                move |context, args| {
260                    let func = args.inputs[0];
261                    let op = ArrayRepeat::from_extension_op(args.node().as_ref())?;
262                    let arr = ccg.emit_array_repeat(context, op, func)?;
263                    args.outputs.finish(context.builder(), [arr])
264                }
265            })
266            .extension_op(array::EXTENSION_ID, array::ARRAY_SCAN_OP_ID, {
267                let ccg = self.0.clone();
268                move |context, args| {
269                    let src_array = args.inputs[0];
270                    let func = args.inputs[1];
271                    let initial_accs = &args.inputs[2..];
272                    let op = ArrayScan::from_extension_op(args.node().as_ref())?;
273                    let (tgt_array, final_accs) =
274                        ccg.emit_array_scan(context, op, src_array, func, initial_accs)?;
275                    args.outputs
276                        .finish(context.builder(), iter::once(tgt_array).chain(final_accs))
277                }
278            })
279    }
280}
281
282fn usize_ty<'c>(ts: &TypingSession<'c, '_>) -> IntType<'c> {
283    ts.llvm_type(&usize_t())
284        .expect("Prelude codegen is registered")
285        .into_int_type()
286}
287
288/// Returns the LLVM representation of an array value as a fat pointer.
289#[must_use]
290pub fn array_fat_pointer_ty<'c>(
291    session: &TypingSession<'c, '_>,
292    elem_ty: BasicTypeEnum<'c>,
293) -> StructType<'c> {
294    let iw_ctx = session.iw_context();
295    iw_ctx.struct_type(
296        &[
297            elem_ty.ptr_type(AddressSpace::default()).into(),
298            usize_ty(session).into(),
299        ],
300        false,
301    )
302}
303
304/// Constructs an array fat pointer value.
305pub fn build_array_fat_pointer<'c, H: HugrView<Node = Node>>(
306    ctx: &mut EmitFuncContext<'c, '_, H>,
307    ptr: PointerValue<'c>,
308    offset: IntValue<'c>,
309) -> Result<StructValue<'c>> {
310    let array_ty = array_fat_pointer_ty(
311        &ctx.typing_session(),
312        ptr.get_type().get_element_type().try_into().unwrap(),
313    );
314    let array_v = array_ty.get_poison();
315    let array_v = ctx
316        .builder()
317        .build_insert_value(array_v, ptr.as_basic_value_enum(), 0, "")?;
318    let array_v = ctx
319        .builder()
320        .build_insert_value(array_v, offset.as_basic_value_enum(), 1, "")?;
321    Ok(array_v.into_struct_value())
322}
323
324/// Returns the underlying pointer and offset stored in a fat array pointer.
325pub fn decompose_array_fat_pointer<'c>(
326    builder: &Builder<'c>,
327    array_v: BasicValueEnum<'c>,
328) -> Result<(PointerValue<'c>, IntValue<'c>)> {
329    let array_v = array_v.into_struct_value();
330    let array_ptr = builder.build_extract_value(array_v, 0, "array_ptr")?;
331    let array_offset = builder.build_extract_value(array_v, 1, "array_offset")?;
332    Ok((
333        array_ptr.into_pointer_value(),
334        array_offset.into_int_value(),
335    ))
336}
337
338/// Helper function to allocate a fat array pointer.
339///
340/// Returns a pointer and a struct: The pointer points to the first element of the array (i.e. it
341/// is of type `elem_ty.ptr_type()`). The struct is the fat pointer of the that stores an additional
342/// offset (initialised to be 0).
343pub fn build_array_alloc<'c, H: HugrView<Node = Node>>(
344    ctx: &mut EmitFuncContext<'c, '_, H>,
345    ccg: &impl ArrayCodegen,
346    elem_ty: BasicTypeEnum<'c>,
347    size: u64,
348) -> Result<(PointerValue<'c>, StructValue<'c>)> {
349    let usize_t = usize_ty(&ctx.typing_session());
350    let length = usize_t.const_int(size, false);
351    let size_value = ctx
352        .builder()
353        .build_int_mul(length, elem_ty.size_of().unwrap(), "")?;
354    let ptr = ccg.emit_allocate_array(ctx, size_value)?;
355    let elem_ptr = ctx
356        .builder()
357        .build_bit_cast(ptr, elem_ty.ptr_type(AddressSpace::default()), "")?
358        .into_pointer_value();
359    let offset = usize_t.const_zero();
360    let array_v = build_array_fat_pointer(ctx, elem_ptr, offset)?;
361    Ok((elem_ptr, array_v))
362}
363
364/// Helper function to build a loop that repeats for a given number of iterations.
365///
366/// The provided closure is called to build the loop body. Afterwards, the builder is positioned at
367/// the end of the loop exit block.
368fn build_loop<'c, T, H: HugrView<Node = Node>>(
369    ctx: &mut EmitFuncContext<'c, '_, H>,
370    iters: IntValue<'c>,
371    go: impl FnOnce(&mut EmitFuncContext<'c, '_, H>, IntValue<'c>) -> Result<T>,
372) -> Result<T> {
373    let builder = ctx.builder();
374    let idx_ty = usize_ty(&ctx.typing_session());
375    let idx_ptr = builder.build_alloca(idx_ty, "")?;
376    builder.build_store(idx_ptr, idx_ty.const_zero())?;
377
378    let exit_block = ctx.new_basic_block("", None);
379
380    let (body_block, val) = ctx.build_positioned_new_block("", Some(exit_block), |ctx, bb| {
381        let idx = ctx.builder().build_load(idx_ptr, "")?.into_int_value();
382        let val = go(ctx, idx)?;
383        let builder = ctx.builder();
384        let inc_idx = builder.build_int_add(idx, idx_ty.const_int(1, false), "")?;
385        builder.build_store(idx_ptr, inc_idx)?;
386        // Branch to the head is built later
387        Ok((bb, val))
388    })?;
389
390    let head_block = ctx.build_positioned_new_block("", Some(body_block), |ctx, bb| {
391        let builder = ctx.builder();
392        let idx = builder.build_load(idx_ptr, "")?.into_int_value();
393        let cmp = builder.build_int_compare(IntPredicate::ULT, idx, iters, "")?;
394        builder.build_conditional_branch(cmp, body_block, exit_block)?;
395        Ok(bb)
396    })?;
397
398    let builder = ctx.builder();
399    builder.build_unconditional_branch(head_block)?;
400    builder.position_at_end(body_block);
401    builder.build_unconditional_branch(head_block)?;
402    ctx.builder().position_at_end(exit_block);
403    Ok(val)
404}
405
406/// Emits an [`array::ArrayValue`].
407pub fn emit_array_value<'c, H: HugrView<Node = Node>>(
408    ccg: &impl ArrayCodegen,
409    ctx: &mut EmitFuncContext<'c, '_, H>,
410    value: &array::ArrayValue,
411) -> Result<BasicValueEnum<'c>> {
412    let ts = ctx.typing_session();
413    let elem_ty = ts.llvm_type(value.get_element_type())?;
414    let (elem_ptr, array_v) =
415        build_array_alloc(ctx, ccg, elem_ty, value.get_contents().len() as u64)?;
416    for (i, v) in value.get_contents().iter().enumerate() {
417        let llvm_v = emit_value(ctx, v)?;
418        let idx = ts.iw_context().i32_type().const_int(i as u64, true);
419        let elem_addr = unsafe { ctx.builder().build_in_bounds_gep(elem_ptr, &[idx], "")? };
420        ctx.builder().build_store(elem_addr, llvm_v)?;
421    }
422    Ok(array_v.into())
423}
424
425/// Emits an [`ArrayOp`].
426pub fn emit_array_op<'c, H: HugrView<Node = Node>>(
427    ccg: &impl ArrayCodegen,
428    ctx: &mut EmitFuncContext<'c, '_, H>,
429    op: ArrayOp,
430    inputs: Vec<BasicValueEnum<'c>>,
431    outputs: RowPromise<'c>,
432) -> Result<()> {
433    let builder = ctx.builder();
434    let ts = ctx.typing_session();
435    let sig = op
436        .clone()
437        .to_extension_op()
438        .unwrap()
439        .signature()
440        .into_owned();
441    let ArrayOp {
442        def,
443        elem_ty: ref hugr_elem_ty,
444        size,
445    } = op;
446    let elem_ty = ts.llvm_type(hugr_elem_ty)?;
447    match def {
448        ArrayOpDef::new_array => {
449            let (elem_ptr, array_v) = build_array_alloc(ctx, ccg, elem_ty, size)?;
450            let usize_t = usize_ty(&ctx.typing_session());
451            for (i, v) in inputs.into_iter().enumerate() {
452                let idx = usize_t.const_int(i as u64, true);
453                let elem_addr = unsafe { ctx.builder().build_in_bounds_gep(elem_ptr, &[idx], "")? };
454                ctx.builder().build_store(elem_addr, v)?;
455            }
456            outputs.finish(ctx.builder(), [array_v.into()])
457        }
458        ArrayOpDef::unpack => {
459            let [array_v] = inputs
460                .try_into()
461                .map_err(|_| anyhow!("ArrayOpDef::unpack expects one argument"))?;
462            let (array_ptr, array_offset) = decompose_array_fat_pointer(builder, array_v)?;
463
464            let mut result = Vec::with_capacity(size as usize);
465            let usize_t = usize_ty(&ctx.typing_session());
466
467            for i in 0..size {
468                let idx = builder.build_int_add(array_offset, usize_t.const_int(i, false), "")?;
469                let elem_addr = unsafe { builder.build_in_bounds_gep(array_ptr, &[idx], "")? };
470                let elem_v = builder.build_load(elem_addr, "")?;
471                result.push(elem_v);
472            }
473
474            outputs.finish(ctx.builder(), result)
475        }
476        ArrayOpDef::get => {
477            let [array_v, index_v] = inputs
478                .try_into()
479                .map_err(|_| anyhow!("ArrayOpDef::get expects two arguments"))?;
480            let (array_ptr, array_offset) = decompose_array_fat_pointer(builder, array_v)?;
481            let index_v = index_v.into_int_value();
482            let res_hugr_ty = sig
483                .output()
484                .get(0)
485                .ok_or(anyhow!("ArrayOp::get has no outputs"))?;
486
487            let res_sum_ty = {
488                let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else {
489                    Err(anyhow!("ArrayOp::get output is not a sum type"))?
490                };
491                ts.llvm_sum_type(st.clone())?
492            };
493
494            let exit_rmb = ctx.new_row_mail_box(sig.output.iter(), "")?;
495
496            let exit_block = ctx.build_positioned_new_block("", None, |ctx, bb| {
497                outputs.finish(ctx.builder(), exit_rmb.read_vec(ctx.builder(), [])?)?;
498                Ok(bb)
499            })?;
500
501            let success_block =
502                ctx.build_positioned_new_block("", Some(exit_block), |ctx, bb| {
503                    let builder = ctx.builder();
504                    // inside `success_block` we know `index_v` to be in bounds
505                    let index_v = builder.build_int_add(index_v, array_offset, "")?;
506                    let elem_addr =
507                        unsafe { builder.build_in_bounds_gep(array_ptr, &[index_v], "")? };
508                    let elem_v = builder.build_load(elem_addr, "")?;
509                    let success_v = res_sum_ty.build_tag(builder, 1, vec![elem_v])?;
510                    exit_rmb.write(ctx.builder(), [success_v.into(), array_v])?;
511                    builder.build_unconditional_branch(exit_block)?;
512                    Ok(bb)
513                })?;
514
515            let failure_block =
516                ctx.build_positioned_new_block("", Some(success_block), |ctx, bb| {
517                    let builder = ctx.builder();
518                    let failure_v = res_sum_ty.build_tag(builder, 0, vec![])?;
519                    exit_rmb.write(ctx.builder(), [failure_v.into(), array_v])?;
520                    builder.build_unconditional_branch(exit_block)?;
521                    Ok(bb)
522                })?;
523
524            let builder = ctx.builder();
525            let is_success = builder.build_int_compare(
526                IntPredicate::ULT,
527                index_v,
528                index_v.get_type().const_int(size, false),
529                "",
530            )?;
531
532            builder.build_conditional_branch(is_success, success_block, failure_block)?;
533            builder.position_at_end(exit_block);
534            Ok(())
535        }
536        ArrayOpDef::set => {
537            let [array_v, index_v, value_v] = inputs
538                .try_into()
539                .map_err(|_| anyhow!("ArrayOpDef::set expects three arguments"))?;
540            let (array_ptr, array_offset) = decompose_array_fat_pointer(builder, array_v)?;
541            let index_v = index_v.into_int_value();
542
543            let res_hugr_ty = sig
544                .output()
545                .get(0)
546                .ok_or(anyhow!("ArrayOp::set has no outputs"))?;
547
548            let res_sum_ty = {
549                let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else {
550                    Err(anyhow!("ArrayOp::set output is not a sum type"))?
551                };
552                ts.llvm_sum_type(st.clone())?
553            };
554
555            let exit_rmb = ctx.new_row_mail_box([res_hugr_ty], "")?;
556
557            let exit_block = ctx.build_positioned_new_block("", None, |ctx, bb| {
558                outputs.finish(ctx.builder(), exit_rmb.read_vec(ctx.builder(), [])?)?;
559                Ok(bb)
560            })?;
561
562            let success_block =
563                ctx.build_positioned_new_block("", Some(exit_block), |ctx, bb| {
564                    let builder = ctx.builder();
565                    // inside `success_block` we know `index_v` to be in bounds.
566                    let index_v = builder.build_int_add(index_v, array_offset, "")?;
567                    let elem_addr =
568                        unsafe { builder.build_in_bounds_gep(array_ptr, &[index_v], "")? };
569                    let elem_v = builder.build_load(elem_addr, "")?;
570                    builder.build_store(elem_addr, value_v)?;
571                    let success_v = res_sum_ty.build_tag(builder, 1, vec![elem_v, array_v])?;
572                    exit_rmb.write(ctx.builder(), [success_v.into()])?;
573                    builder.build_unconditional_branch(exit_block)?;
574                    Ok(bb)
575                })?;
576
577            let failure_block =
578                ctx.build_positioned_new_block("", Some(success_block), |ctx, bb| {
579                    let builder = ctx.builder();
580                    let failure_v = res_sum_ty.build_tag(builder, 0, vec![value_v, array_v])?;
581                    exit_rmb.write(ctx.builder(), [failure_v.into()])?;
582                    builder.build_unconditional_branch(exit_block)?;
583                    Ok(bb)
584                })?;
585
586            let builder = ctx.builder();
587            let is_success = builder.build_int_compare(
588                IntPredicate::ULT,
589                index_v,
590                index_v.get_type().const_int(size, false),
591                "",
592            )?;
593            builder.build_conditional_branch(is_success, success_block, failure_block)?;
594            builder.position_at_end(exit_block);
595            Ok(())
596        }
597        ArrayOpDef::swap => {
598            let [array_v, index1_v, index2_v] = inputs
599                .try_into()
600                .map_err(|_| anyhow!("ArrayOpDef::swap expects three arguments"))?;
601            let (array_ptr, array_offset) = decompose_array_fat_pointer(builder, array_v)?;
602            let index1_v = index1_v.into_int_value();
603            let index2_v = index2_v.into_int_value();
604
605            let res_hugr_ty = sig
606                .output()
607                .get(0)
608                .ok_or(anyhow!("ArrayOp::swap has no outputs"))?;
609
610            let res_sum_ty = {
611                let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else {
612                    Err(anyhow!("ArrayOp::swap output is not a sum type"))?
613                };
614                ts.llvm_sum_type(st.clone())?
615            };
616
617            let exit_rmb = ctx.new_row_mail_box([res_hugr_ty], "")?;
618
619            let exit_block = ctx.build_positioned_new_block("", None, |ctx, bb| {
620                outputs.finish(ctx.builder(), exit_rmb.read_vec(ctx.builder(), [])?)?;
621                Ok(bb)
622            })?;
623
624            let success_block =
625                ctx.build_positioned_new_block("", Some(exit_block), |ctx, bb| {
626                    // if `index1_v` == `index2_v` then the following is a no-op.
627                    // We could check for this: either with a select instruction
628                    // here, or by branching to another case in earlier.
629                    // Doing so would generate better code in cases where the
630                    // optimiser can determine that the indices are the same, at
631                    // the cost of worse code in cases where it cannot.
632                    // For now we choose the simpler option of omitting the check.
633                    let builder = ctx.builder();
634                    // inside `success_block` we know `index1_v` and `index2_v`
635                    // to be in bounds.
636                    let index1_v = builder.build_int_add(index1_v, array_offset, "")?;
637                    let index2_v = builder.build_int_add(index2_v, array_offset, "")?;
638                    let elem1_addr =
639                        unsafe { builder.build_in_bounds_gep(array_ptr, &[index1_v], "")? };
640                    let elem1_v = builder.build_load(elem1_addr, "")?;
641                    let elem2_addr =
642                        unsafe { builder.build_in_bounds_gep(array_ptr, &[index2_v], "")? };
643                    let elem2_v = builder.build_load(elem2_addr, "")?;
644                    builder.build_store(elem1_addr, elem2_v)?;
645                    builder.build_store(elem2_addr, elem1_v)?;
646                    let success_v = res_sum_ty.build_tag(builder, 1, vec![array_v])?;
647                    exit_rmb.write(ctx.builder(), [success_v.into()])?;
648                    builder.build_unconditional_branch(exit_block)?;
649                    Ok(bb)
650                })?;
651
652            let failure_block =
653                ctx.build_positioned_new_block("", Some(success_block), |ctx, bb| {
654                    let builder = ctx.builder();
655                    let failure_v = res_sum_ty.build_tag(builder, 0, vec![array_v])?;
656                    exit_rmb.write(ctx.builder(), [failure_v.into()])?;
657                    builder.build_unconditional_branch(exit_block)?;
658                    Ok(bb)
659                })?;
660
661            let builder = ctx.builder();
662            let is_success = {
663                let index1_ok = builder.build_int_compare(
664                    IntPredicate::ULT,
665                    index1_v,
666                    index1_v.get_type().const_int(size, false),
667                    "",
668                )?;
669                let index2_ok = builder.build_int_compare(
670                    IntPredicate::ULT,
671                    index2_v,
672                    index2_v.get_type().const_int(size, false),
673                    "",
674                )?;
675                builder.build_and(index1_ok, index2_ok, "")?
676            };
677            builder.build_conditional_branch(is_success, success_block, failure_block)?;
678            builder.position_at_end(exit_block);
679            Ok(())
680        }
681        ArrayOpDef::pop_left => {
682            let [array_v] = inputs
683                .try_into()
684                .map_err(|_| anyhow!("ArrayOpDef::pop_left expects one argument"))?;
685            let r = emit_pop_op(
686                ctx,
687                hugr_elem_ty.clone(),
688                size,
689                array_v.into_struct_value(),
690                true,
691            )?;
692            outputs.finish(ctx.builder(), [r])
693        }
694        ArrayOpDef::pop_right => {
695            let [array_v] = inputs
696                .try_into()
697                .map_err(|_| anyhow!("ArrayOpDef::pop_right expects one argument"))?;
698            let r = emit_pop_op(
699                ctx,
700                hugr_elem_ty.clone(),
701                size,
702                array_v.into_struct_value(),
703                false,
704            )?;
705            outputs.finish(ctx.builder(), [r])
706        }
707        ArrayOpDef::discard_empty => {
708            let [array_v] = inputs
709                .try_into()
710                .map_err(|_| anyhow!("ArrayOpDef::discard_empty expects one argument"))?;
711            let (ptr, _) = decompose_array_fat_pointer(builder, array_v)?;
712            ccg.emit_free_array(ctx, ptr)?;
713            outputs.finish(ctx.builder(), [])
714        }
715        _ => todo!(),
716    }
717}
718
719/// Emits an [`ArrayClone`] op.
720pub fn emit_clone_op<'c, H: HugrView<Node = Node>>(
721    ccg: &impl ArrayCodegen,
722    ctx: &mut EmitFuncContext<'c, '_, H>,
723    op: ArrayClone,
724    array_v: BasicValueEnum<'c>,
725) -> Result<(BasicValueEnum<'c>, BasicValueEnum<'c>)> {
726    let elem_ty = ctx.llvm_type(&op.elem_ty)?;
727    let (array_ptr, array_offset) = decompose_array_fat_pointer(ctx.builder(), array_v)?;
728    let (other_ptr, other_array_v) = build_array_alloc(ctx, ccg, elem_ty, op.size)?;
729    let src_ptr = unsafe {
730        ctx.builder()
731            .build_in_bounds_gep(array_ptr, &[array_offset], "")?
732    };
733    let length = usize_ty(&ctx.typing_session()).const_int(op.size, false);
734    let size_value = ctx
735        .builder()
736        .build_int_mul(length, elem_ty.size_of().unwrap(), "")?;
737    let is_volatile = ctx.iw_context().bool_type().const_zero();
738
739    let memcpy_intrinsic = Intrinsic::find("llvm.memcpy").unwrap();
740    let memcpy = memcpy_intrinsic
741        .get_declaration(
742            ctx.get_current_module(),
743            &[
744                other_ptr.get_type().into(),
745                src_ptr.get_type().into(),
746                size_value.get_type().into(),
747            ],
748        )
749        .unwrap();
750    ctx.builder().build_call(
751        memcpy,
752        &[
753            other_ptr.into(),
754            src_ptr.into(),
755            size_value.into(),
756            is_volatile.into(),
757        ],
758        "",
759    )?;
760    Ok((array_v, other_array_v.into()))
761}
762
763/// Emits an [`ArrayDiscard`] op.
764pub fn emit_array_discard<'c, H: HugrView<Node = Node>>(
765    ccg: &impl ArrayCodegen,
766    ctx: &mut EmitFuncContext<'c, '_, H>,
767    _op: ArrayDiscard,
768    array_v: BasicValueEnum<'c>,
769) -> Result<()> {
770    let array_ptr =
771        ctx.builder()
772            .build_extract_value(array_v.into_struct_value(), 0, "array_ptr")?;
773    ccg.emit_free_array(ctx, array_ptr.into_pointer_value())?;
774    Ok(())
775}
776
777/// Emits the [`ArrayOpDef::pop_left`] and [`ArrayOpDef::pop_right`] operations.
778fn emit_pop_op<'c, H: HugrView<Node = Node>>(
779    ctx: &mut EmitFuncContext<'c, '_, H>,
780    elem_ty: HugrType,
781    size: u64,
782    array_v: StructValue<'c>,
783    pop_left: bool,
784) -> Result<BasicValueEnum<'c>> {
785    let ts = ctx.typing_session();
786    let builder = ctx.builder();
787    let (array_ptr, array_offset) = decompose_array_fat_pointer(builder, array_v.into())?;
788    let ret_ty = ts.llvm_sum_type(option_type(vec![
789        elem_ty.clone(),
790        array_type(size.saturating_add_signed(-1), elem_ty),
791    ]))?;
792    if size == 0 {
793        return Ok(ret_ty.build_tag(builder, 0, vec![])?.into());
794    }
795    let (elem_ptr, new_array_offset) = {
796        if pop_left {
797            let new_array_offset = builder.build_int_add(
798                array_offset,
799                usize_ty(&ts).const_int(1, false),
800                "new_offset",
801            )?;
802            let elem_ptr = unsafe { builder.build_in_bounds_gep(array_ptr, &[array_offset], "") }?;
803            (elem_ptr, new_array_offset)
804        } else {
805            let idx = builder.build_int_add(
806                array_offset,
807                usize_ty(&ts).const_int(size - 1, false),
808                "",
809            )?;
810            let elem_ptr = unsafe { builder.build_in_bounds_gep(array_ptr, &[idx], "") }?;
811            (elem_ptr, array_offset)
812        }
813    };
814    let elem_v = builder.build_load(elem_ptr, "")?;
815    let new_array_v = build_array_fat_pointer(ctx, array_ptr, new_array_offset)?;
816
817    Ok(ret_ty
818        .build_tag(ctx.builder(), 1, vec![elem_v, new_array_v.into()])?
819        .into())
820}
821
822/// Emits an [`ArrayRepeat`] op.
823pub fn emit_repeat_op<'c, H: HugrView<Node = Node>>(
824    ccg: &impl ArrayCodegen,
825    ctx: &mut EmitFuncContext<'c, '_, H>,
826    op: ArrayRepeat,
827    func: BasicValueEnum<'c>,
828) -> Result<BasicValueEnum<'c>> {
829    let elem_ty = ctx.llvm_type(&op.elem_ty)?;
830    let (ptr, array_v) = build_array_alloc(ctx, ccg, elem_ty, op.size)?;
831    let array_len = usize_ty(&ctx.typing_session()).const_int(op.size, false);
832    build_loop(ctx, array_len, |ctx, idx| {
833        let builder = ctx.builder();
834        let func_ptr = CallableValue::try_from(func.into_pointer_value())
835            .map_err(|()| anyhow!("ArrayOpDef::repeat expects a function pointer"))?;
836        let v = builder
837            .build_call(func_ptr, &[], "")?
838            .try_as_basic_value()
839            .left()
840            .ok_or(anyhow!("ArrayOpDef::repeat function must return a value"))?;
841        let elem_addr = unsafe { builder.build_in_bounds_gep(ptr, &[idx], "")? };
842        builder.build_store(elem_addr, v)?;
843        Ok(())
844    })?;
845    Ok(array_v.into())
846}
847
848/// Emits an [`ArrayScan`] op.
849///
850/// Returns the resulting array and the final values of the accumulators.
851pub fn emit_scan_op<'c, H: HugrView<Node = Node>>(
852    ccg: &impl ArrayCodegen,
853    ctx: &mut EmitFuncContext<'c, '_, H>,
854    op: ArrayScan,
855    src_array_v: StructValue<'c>,
856    func: BasicValueEnum<'c>,
857    initial_accs: &[BasicValueEnum<'c>],
858) -> Result<(BasicValueEnum<'c>, Vec<BasicValueEnum<'c>>)> {
859    let (src_ptr, src_offset) = decompose_array_fat_pointer(ctx.builder(), src_array_v.into())?;
860    let tgt_elem_ty = ctx.llvm_type(&op.tgt_ty)?;
861    // TODO: If `sizeof(op.src_ty) >= sizeof(op.tgt_ty)`, we could reuse the memory
862    // from `src` instead of allocating a fresh array
863    let (tgt_ptr, tgt_array_v) = build_array_alloc(ctx, ccg, tgt_elem_ty, op.size)?;
864    let array_len = usize_ty(&ctx.typing_session()).const_int(op.size, false);
865    let acc_tys: Vec<_> = op
866        .acc_tys
867        .iter()
868        .map(|ty| ctx.llvm_type(ty))
869        .try_collect()?;
870    let builder = ctx.builder();
871    let acc_ptrs: Vec<_> = acc_tys
872        .iter()
873        .map(|ty| builder.build_alloca(*ty, ""))
874        .try_collect()?;
875    for (ptr, initial_val) in acc_ptrs.iter().zip(initial_accs) {
876        builder.build_store(*ptr, *initial_val)?;
877    }
878
879    build_loop(ctx, array_len, |ctx, idx| {
880        let builder = ctx.builder();
881        let func_ptr = CallableValue::try_from(func.into_pointer_value())
882            .map_err(|()| anyhow!("ArrayOpDef::scan expects a function pointer"))?;
883        let src_idx = builder.build_int_add(idx, src_offset, "")?;
884        let src_elem_addr = unsafe { builder.build_in_bounds_gep(src_ptr, &[src_idx], "")? };
885        let src_elem = builder.build_load(src_elem_addr, "")?;
886        let mut args = vec![src_elem.into()];
887        for ptr in &acc_ptrs {
888            args.push(builder.build_load(*ptr, "")?.into());
889        }
890        let call = builder.build_call(func_ptr, args.as_slice(), "")?;
891        let call_results = deaggregate_call_result(builder, call, 1 + acc_tys.len())?;
892        let tgt_elem_addr = unsafe { builder.build_in_bounds_gep(tgt_ptr, &[idx], "")? };
893        builder.build_store(tgt_elem_addr, call_results[0])?;
894        for (ptr, next_act) in acc_ptrs.iter().zip(call_results[1..].iter()) {
895            builder.build_store(*ptr, *next_act)?;
896        }
897        Ok(())
898    })?;
899
900    ccg.emit_free_array(ctx, src_ptr)?;
901    let builder = ctx.builder();
902    let final_accs = acc_ptrs
903        .into_iter()
904        .map(|ptr| builder.build_load(ptr, ""))
905        .try_collect()?;
906    Ok((tgt_array_v.into(), final_accs))
907}
908
909#[cfg(test)]
910mod test {
911    use hugr_core::builder::{DataflowHugr, HugrBuilder};
912    use hugr_core::extension::prelude::either_type;
913    use hugr_core::ops::Tag;
914    use hugr_core::std_extensions::STD_REG;
915    use hugr_core::std_extensions::collections::array::op_builder::build_all_array_ops;
916    use hugr_core::std_extensions::collections::array::{
917        self, ArrayOpBuilder, ArrayRepeat, ArrayScan, array_type,
918    };
919    use hugr_core::types::Type;
920    use hugr_core::{
921        builder::{Dataflow, DataflowSubContainer, SubContainer},
922        extension::{
923            ExtensionRegistry,
924            prelude::{self, ConstUsize, UnwrapBuilder as _, bool_t, option_type, usize_t},
925        },
926        ops::Value,
927        std_extensions::{
928            arithmetic::{
929                int_ops::{self},
930                int_types::{self, ConstInt, int_type},
931            },
932            logic,
933        },
934        type_row,
935        types::Signature,
936    };
937    use itertools::Itertools as _;
938    use rstest::rstest;
939
940    use crate::{
941        check_emission,
942        emit::test::SimpleHugrConfig,
943        test::{TestContext, exec_ctx, llvm_ctx},
944        utils::{IntOpBuilder, LogicOpBuilder},
945    };
946
947    #[rstest]
948    fn emit_all_ops(mut llvm_ctx: TestContext) {
949        let hugr = SimpleHugrConfig::new()
950            .with_extensions(STD_REG.to_owned())
951            .finish(|mut builder| {
952                build_all_array_ops(builder.dfg_builder_endo([]).unwrap())
953                    .finish_sub_container()
954                    .unwrap();
955                builder.finish_hugr().unwrap()
956            });
957        llvm_ctx.add_extensions(|cge| {
958            cge.add_default_prelude_extensions()
959                .add_default_array_extensions()
960        });
961        check_emission!(hugr, llvm_ctx);
962    }
963
964    #[rstest]
965    fn emit_get(mut llvm_ctx: TestContext) {
966        let hugr = SimpleHugrConfig::new()
967            .with_extensions(STD_REG.to_owned())
968            .finish(|mut builder| {
969                let us1 = builder.add_load_value(ConstUsize::new(1));
970                let us2 = builder.add_load_value(ConstUsize::new(2));
971                let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap();
972                let (_, arr) = builder.add_array_get(usize_t(), 2, arr, us1).unwrap();
973                builder.add_array_discard(usize_t(), 2, arr).unwrap();
974                builder.finish_hugr_with_outputs([]).unwrap()
975            });
976        llvm_ctx.add_extensions(|cge| {
977            cge.add_default_prelude_extensions()
978                .add_default_array_extensions()
979        });
980        check_emission!(hugr, llvm_ctx);
981    }
982
983    #[rstest]
984    fn emit_clone(mut llvm_ctx: TestContext) {
985        let hugr = SimpleHugrConfig::new()
986            .with_extensions(STD_REG.to_owned())
987            .finish(|mut builder| {
988                let us1 = builder.add_load_value(ConstUsize::new(1));
989                let us2 = builder.add_load_value(ConstUsize::new(2));
990                let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap();
991                let (arr1, arr2) = builder.add_array_clone(usize_t(), 2, arr).unwrap();
992                builder.add_array_discard(usize_t(), 2, arr1).unwrap();
993                builder.add_array_discard(usize_t(), 2, arr2).unwrap();
994                builder.finish_hugr_with_outputs([]).unwrap()
995            });
996        llvm_ctx.add_extensions(|cge| {
997            cge.add_default_prelude_extensions()
998                .add_default_array_extensions()
999        });
1000        check_emission!(hugr, llvm_ctx);
1001    }
1002
1003    #[rstest]
1004    fn emit_array_value(mut llvm_ctx: TestContext) {
1005        let hugr = SimpleHugrConfig::new()
1006            .with_extensions(STD_REG.to_owned())
1007            .with_outs(vec![array_type(2, usize_t())])
1008            .finish(|mut builder| {
1009                let vs = vec![ConstUsize::new(1).into(), ConstUsize::new(2).into()];
1010                let arr = builder.add_load_value(array::ArrayValue::new(usize_t(), vs));
1011                builder.finish_hugr_with_outputs([arr]).unwrap()
1012            });
1013        llvm_ctx.add_extensions(|cge| {
1014            cge.add_default_prelude_extensions()
1015                .add_default_array_extensions()
1016        });
1017        check_emission!(hugr, llvm_ctx);
1018    }
1019
1020    // #[rstest]
1021    // #[case(1, 2, 3)]
1022    // #[case(0, 0, 0)]
1023    // #[case(10, 20, 30)]
1024    // fn exec_unpack_and_sum(mut exec_ctx: TestContext, #[case] a: u64, #[case] b: u64, #[case] expected: u64) {
1025    //     let hugr = SimpleHugrConfig::new()
1026    //         .with_extensions(exec_registry())
1027    //         .with_outs(vec![usize_t()])
1028    //         .finish(|mut builder| {
1029    //             // Create an array with the test values
1030    //             let values = vec![ConstUsize::new(a).into(), ConstUsize::new(b).into()];
1031    //             let arr = builder.add_load_value(array::ArrayValue::new(usize_t(), values));
1032
1033    //             // Unpack the array
1034    //             let [val_a, val_b] = builder.add_array_unpack(usize_t(), 2, arr).unwrap().try_into().unwrap();
1035
1036    //             // Add the values
1037    //             let sum = {
1038    //                 let int_ty = int_type(6);
1039    //                 let a_int = builder.cast(val_a, int_ty.clone()).unwrap();
1040    //                 let b_int = builder.cast(val_b, int_ty.clone()).unwrap();
1041    //                 let sum_int = builder.add_iadd(6, a_int, b_int).unwrap();
1042    //                 builder.cast(sum_int, usize_t()).unwrap()
1043    //             };
1044
1045    //             builder.finish_hugr_with_outputs([sum]).unwrap()
1046    //         });
1047    //     exec_ctx.add_extensions(|cge| {
1048    //         cge.add_default_prelude_extensions()
1049    //             .add_default_array_extensions()
1050    //             .add_default_int_extensions()
1051    //     });
1052    //     assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
1053    // }
1054
1055    fn exec_registry() -> ExtensionRegistry {
1056        ExtensionRegistry::new([
1057            int_types::EXTENSION.to_owned(),
1058            int_ops::EXTENSION.to_owned(),
1059            logic::EXTENSION.to_owned(),
1060            prelude::PRELUDE.to_owned(),
1061            array::EXTENSION.to_owned(),
1062        ])
1063    }
1064
1065    #[rstest]
1066    #[case(0, 1)]
1067    #[case(1, 2)]
1068    #[case(3, 0)]
1069    #[case(999999, 0)]
1070    fn exec_get(mut exec_ctx: TestContext, #[case] index: u64, #[case] expected: u64) {
1071        // We build a HUGR that:
1072        // - Creates an array of [1,2]
1073        // - Gets the element at the given index
1074        // - Returns the element if the index is in bounds, otherwise 0
1075        let hugr = SimpleHugrConfig::new()
1076            .with_outs(usize_t())
1077            .with_extensions(exec_registry())
1078            .finish(|mut builder| {
1079                let us0 = builder.add_load_value(ConstUsize::new(0));
1080                let us1 = builder.add_load_value(ConstUsize::new(1));
1081                let us2 = builder.add_load_value(ConstUsize::new(2));
1082                let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap();
1083                let i = builder.add_load_value(ConstUsize::new(index));
1084                let (get_r, arr) = builder.add_array_get(usize_t(), 2, arr, i).unwrap();
1085                builder.add_array_discard(usize_t(), 2, arr).unwrap();
1086                let r = {
1087                    let ot = option_type(usize_t());
1088                    let variants = (0..ot.num_variants())
1089                        .map(|i| ot.get_variant(i).cloned().unwrap().try_into().unwrap())
1090                        .collect_vec();
1091                    let mut builder = builder
1092                        .conditional_builder((variants, get_r), [], usize_t().into())
1093                        .unwrap();
1094                    {
1095                        let failure_case = builder.case_builder(0).unwrap();
1096                        failure_case.finish_with_outputs([us0]).unwrap();
1097                    }
1098                    {
1099                        let success_case = builder.case_builder(1).unwrap();
1100                        let inputs = success_case.input_wires();
1101                        success_case.finish_with_outputs(inputs).unwrap();
1102                    }
1103                    builder.finish_sub_container().unwrap().out_wire(0)
1104                };
1105                builder.finish_hugr_with_outputs([r]).unwrap()
1106            });
1107        exec_ctx.add_extensions(|cge| {
1108            cge.add_default_prelude_extensions()
1109                .add_default_array_extensions()
1110        });
1111        assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
1112    }
1113
1114    #[rstest]
1115    #[case(0, 3, 1, [3,2])]
1116    #[case(1, 3, 2, [1,3])]
1117    #[case(2, 3, 3, [1,2])]
1118    #[case(999999, 3, 3, [1,2])]
1119    fn exec_set(
1120        mut exec_ctx: TestContext,
1121        #[case] index: u64,
1122        #[case] value: u64,
1123        #[case] expected_elem: u64,
1124        #[case] expected_arr: [u64; 2],
1125    ) {
1126        // We build a HUGR that
1127        // - Creates an array: [1,2]
1128        // - Sets the element at the given index to the given value
1129        // - Checks the following, returning 1 iff they are all true:
1130        //   - The element returned from set is `expected_elem`
1131        //   - The Oth element of the resulting array is `expected_arr_0`
1132
1133        use hugr_core::extension::prelude::either_type;
1134        let int_ty = int_type(3);
1135        let hugr = SimpleHugrConfig::new()
1136            .with_outs(usize_t())
1137            .with_extensions(exec_registry())
1138            .finish(|mut builder| {
1139                let us0 = builder.add_load_value(ConstUsize::new(0));
1140                let us1 = builder.add_load_value(ConstUsize::new(1));
1141                let i1 = builder.add_load_value(ConstInt::new_u(3, 1).unwrap());
1142                let i2 = builder.add_load_value(ConstInt::new_u(3, 2).unwrap());
1143                let arr = builder.add_new_array(int_ty.clone(), [i1, i2]).unwrap();
1144                let index = builder.add_load_value(ConstUsize::new(index));
1145                let value = builder.add_load_value(ConstInt::new_u(3, value).unwrap());
1146                let get_r = builder
1147                    .add_array_set(int_ty.clone(), 2, arr, index, value)
1148                    .unwrap();
1149                let r = {
1150                    let res_sum_ty = {
1151                        let row = vec![int_ty.clone(), array_type(2, int_ty.clone())];
1152                        either_type(row.clone(), row)
1153                    };
1154                    let variants = (0..res_sum_ty.num_variants())
1155                        .map(|i| {
1156                            res_sum_ty
1157                                .get_variant(i)
1158                                .cloned()
1159                                .unwrap()
1160                                .try_into()
1161                                .unwrap()
1162                        })
1163                        .collect_vec();
1164                    let mut builder = builder
1165                        .conditional_builder((variants, get_r), [], bool_t().into())
1166                        .unwrap();
1167                    for i in 0..2 {
1168                        let mut builder = builder.case_builder(i).unwrap();
1169                        let [elem, arr] = builder.input_wires_arr();
1170                        let expected_elem =
1171                            builder.add_load_value(ConstInt::new_u(3, expected_elem).unwrap());
1172                        let expected_arr_0 =
1173                            builder.add_load_value(ConstInt::new_u(3, expected_arr[0]).unwrap());
1174                        let expected_arr_1 =
1175                            builder.add_load_value(ConstInt::new_u(3, expected_arr[1]).unwrap());
1176                        let (r, arr) = builder.add_array_get(int_ty.clone(), 2, arr, us0).unwrap();
1177                        let [arr_0] = builder
1178                            .build_unwrap_sum(1, option_type(int_ty.clone()), r)
1179                            .unwrap();
1180                        let (r, arr) = builder.add_array_get(int_ty.clone(), 2, arr, us1).unwrap();
1181                        let [arr_1] = builder
1182                            .build_unwrap_sum(1, option_type(int_ty.clone()), r)
1183                            .unwrap();
1184                        let elem_eq = builder.add_ieq(3, elem, expected_elem).unwrap();
1185                        let arr_0_eq = builder.add_ieq(3, arr_0, expected_arr_0).unwrap();
1186                        let arr_1_eq = builder.add_ieq(3, arr_1, expected_arr_1).unwrap();
1187                        let r = builder.add_and(elem_eq, arr_0_eq).unwrap();
1188                        let r = builder.add_and(r, arr_1_eq).unwrap();
1189                        builder.add_array_discard(int_ty.clone(), 2, arr).unwrap();
1190                        builder.finish_with_outputs([r]).unwrap();
1191                    }
1192                    builder.finish_sub_container().unwrap().out_wire(0)
1193                };
1194                let r = {
1195                    let mut conditional = builder
1196                        .conditional_builder(([type_row![], type_row![]], r), [], usize_t().into())
1197                        .unwrap();
1198                    conditional
1199                        .case_builder(0)
1200                        .unwrap()
1201                        .finish_with_outputs([us0])
1202                        .unwrap();
1203                    conditional
1204                        .case_builder(1)
1205                        .unwrap()
1206                        .finish_with_outputs([us1])
1207                        .unwrap();
1208                    conditional.finish_sub_container().unwrap().out_wire(0)
1209                };
1210                builder.finish_hugr_with_outputs([r]).unwrap()
1211            });
1212        exec_ctx.add_extensions(|cge| {
1213            cge.add_default_prelude_extensions()
1214                .add_default_array_extensions()
1215                .add_default_int_extensions()
1216                .add_logic_extensions()
1217        });
1218        assert_eq!(1, exec_ctx.exec_hugr_u64(hugr, "main"));
1219    }
1220
1221    #[rstest]
1222    #[case(0, 1, [2,1], true)]
1223    #[case(0, 0, [1,2], true)]
1224    #[case(0, 2, [1,2], false)]
1225    #[case(2, 0, [1,2], false)]
1226    #[case(9999999, 0, [1,2], false)]
1227    #[case(0, 9999999, [1,2], false)]
1228    fn exec_swap(
1229        mut exec_ctx: TestContext,
1230        #[case] index1: u64,
1231        #[case] index2: u64,
1232        #[case] expected_arr: [u64; 2],
1233        #[case] expected_succeeded: bool,
1234    ) {
1235        // We build a HUGR that:
1236        // - Creates an array: [1 ,2]
1237        // - Swaps the elements at the given indices
1238        // - Checks the following, returning 1 iff the following are all true:
1239        //  - The element at index 0 is `expected_elem_0`
1240        //  - The swap operation succeeded iff `expected_succeeded`
1241
1242        let int_ty = int_type(3);
1243        let arr_ty = array_type(2, int_ty.clone());
1244        let hugr = SimpleHugrConfig::new()
1245            .with_outs(usize_t())
1246            .with_extensions(exec_registry())
1247            .finish(|mut builder| {
1248                let us0 = builder.add_load_value(ConstUsize::new(0));
1249                let us1 = builder.add_load_value(ConstUsize::new(1));
1250                let i1 = builder.add_load_value(ConstInt::new_u(3, 1).unwrap());
1251                let i2 = builder.add_load_value(ConstInt::new_u(3, 2).unwrap());
1252                let arr = builder.add_new_array(int_ty.clone(), [i1, i2]).unwrap();
1253
1254                let index1 = builder.add_load_value(ConstUsize::new(index1));
1255                let index2 = builder.add_load_value(ConstUsize::new(index2));
1256                let r = builder
1257                    .add_array_swap(int_ty.clone(), 2, arr, index1, index2)
1258                    .unwrap();
1259                let [arr, was_expected_success] = {
1260                    let mut conditional = builder
1261                        .conditional_builder(
1262                            (
1263                                [vec![arr_ty.clone()].into(), vec![arr_ty.clone()].into()],
1264                                r,
1265                            ),
1266                            [],
1267                            vec![arr_ty, bool_t()].into(),
1268                        )
1269                        .unwrap();
1270                    for i in 0..2 {
1271                        let mut case = conditional.case_builder(i).unwrap();
1272                        let [arr] = case.input_wires_arr();
1273                        let was_expected_success =
1274                            case.add_load_value(if (i == 1) == expected_succeeded {
1275                                Value::true_val()
1276                            } else {
1277                                Value::false_val()
1278                            });
1279                        case.finish_with_outputs([arr, was_expected_success])
1280                            .unwrap();
1281                    }
1282                    conditional.finish_sub_container().unwrap().outputs_arr()
1283                };
1284                let (r, arr) = builder.add_array_get(int_ty.clone(), 2, arr, us0).unwrap();
1285                let elem_0 = builder
1286                    .build_unwrap_sum::<1>(1, option_type(int_ty.clone()), r)
1287                    .unwrap()[0];
1288                let (r, arr) = builder.add_array_get(int_ty.clone(), 2, arr, us1).unwrap();
1289                let elem_1 = builder
1290                    .build_unwrap_sum::<1>(1, option_type(int_ty.clone()), r)
1291                    .unwrap()[0];
1292                let expected_elem_0 =
1293                    builder.add_load_value(ConstInt::new_u(3, expected_arr[0]).unwrap());
1294                let elem_0_ok = builder.add_ieq(3, elem_0, expected_elem_0).unwrap();
1295                let expected_elem_1 =
1296                    builder.add_load_value(ConstInt::new_u(3, expected_arr[1]).unwrap());
1297                let elem_1_ok = builder.add_ieq(3, elem_1, expected_elem_1).unwrap();
1298                let r = builder.add_and(was_expected_success, elem_0_ok).unwrap();
1299                let r = builder.add_and(r, elem_1_ok).unwrap();
1300                let r = {
1301                    let mut conditional = builder
1302                        .conditional_builder(([type_row![], type_row![]], r), [], usize_t().into())
1303                        .unwrap();
1304                    conditional
1305                        .case_builder(0)
1306                        .unwrap()
1307                        .finish_with_outputs([us0])
1308                        .unwrap();
1309                    conditional
1310                        .case_builder(1)
1311                        .unwrap()
1312                        .finish_with_outputs([us1])
1313                        .unwrap();
1314                    conditional.finish_sub_container().unwrap().out_wire(0)
1315                };
1316                builder.add_array_discard(int_ty.clone(), 2, arr).unwrap();
1317                builder.finish_hugr_with_outputs([r]).unwrap()
1318            });
1319        exec_ctx.add_extensions(|cge| {
1320            cge.add_default_prelude_extensions()
1321                .add_default_array_extensions()
1322                .add_default_int_extensions()
1323                .add_logic_extensions()
1324        });
1325        assert_eq!(1, exec_ctx.exec_hugr_u64(hugr, "main"));
1326    }
1327
1328    #[rstest]
1329    #[case(0, 5)]
1330    #[case(1, 5)]
1331    fn exec_clone(mut exec_ctx: TestContext, #[case] index: u64, #[case] new_v: u64) {
1332        // We build a HUGR that:
1333        // - Creates an array: [1, 2]
1334        // - Clones the array
1335        // - Mutates the original at the given index
1336        // - Returns the unchanged element of the cloned array
1337
1338        let int_ty = int_type(3);
1339        let arr_ty = array_type(2, int_ty.clone());
1340        let hugr = SimpleHugrConfig::new()
1341            .with_outs(int_ty.clone())
1342            .with_extensions(exec_registry())
1343            .finish(|mut builder| {
1344                let idx = builder.add_load_value(ConstUsize::new(index));
1345                let i1 = builder.add_load_value(ConstInt::new_u(3, 1).unwrap());
1346                let i2 = builder.add_load_value(ConstInt::new_u(3, 2).unwrap());
1347                let inew = builder.add_load_value(ConstInt::new_u(3, new_v).unwrap());
1348                let arr = builder.add_new_array(int_ty.clone(), [i1, i2]).unwrap();
1349
1350                let (arr, arr_clone) = builder.add_array_clone(int_ty.clone(), 2, arr).unwrap();
1351                let r = builder
1352                    .add_array_set(int_ty.clone(), 2, arr, idx, inew)
1353                    .unwrap();
1354                let [_, arr] = builder
1355                    .build_unwrap_sum(
1356                        1,
1357                        either_type(
1358                            vec![int_ty.clone(), arr_ty.clone()],
1359                            vec![int_ty.clone(), arr_ty.clone()],
1360                        ),
1361                        r,
1362                    )
1363                    .unwrap();
1364                let (r, arr_clone) = builder
1365                    .add_array_get(int_ty.clone(), 2, arr_clone, idx)
1366                    .unwrap();
1367                let [elem] = builder
1368                    .build_unwrap_sum(1, option_type(int_ty.clone()), r)
1369                    .unwrap();
1370                builder.add_array_discard(int_ty.clone(), 2, arr).unwrap();
1371                builder
1372                    .add_array_discard(int_ty.clone(), 2, arr_clone)
1373                    .unwrap();
1374                builder.finish_hugr_with_outputs([elem]).unwrap()
1375            });
1376        exec_ctx.add_extensions(|cge| {
1377            cge.add_default_prelude_extensions()
1378                .add_default_array_extensions()
1379                .add_default_int_extensions()
1380                .add_logic_extensions()
1381        });
1382        assert_eq!([1, 2][index as usize], exec_ctx.exec_hugr_u64(hugr, "main"));
1383    }
1384
1385    #[rstest]
1386    #[case(&[], 0)]
1387    #[case(&[true], 1)]
1388    #[case(&[false], 4)]
1389    #[case(&[true, true], 3)]
1390    #[case(&[false, false], 6)]
1391    #[case(&[true, false, true], 7)]
1392    #[case(&[false, true, false], 7)]
1393    fn exec_pop(mut exec_ctx: TestContext, #[case] from_left: &[bool], #[case] expected: u64) {
1394        // We build a HUGR that:
1395        // - Creates an array: [1,2,4]
1396        // - Pops `num` elements from the left or right
1397        // - Returns the sum of the popped elements
1398
1399        let array_contents = [1, 2, 4];
1400        let int_ty = int_type(6);
1401        let hugr = SimpleHugrConfig::new()
1402            .with_outs(int_ty.clone())
1403            .with_extensions(exec_registry())
1404            .finish(|mut builder| {
1405                let mut r = builder.add_load_value(ConstInt::new_u(6, 0).unwrap());
1406                let new_array_args = array_contents
1407                    .iter()
1408                    .map(|&i| builder.add_load_value(ConstInt::new_u(6, i).unwrap()))
1409                    .collect_vec();
1410                let mut arr = builder
1411                    .add_new_array(int_ty.clone(), new_array_args)
1412                    .unwrap();
1413                for (i, left) in from_left.iter().enumerate() {
1414                    let array_size = (array_contents.len() - i) as u64;
1415                    let pop_res = if *left {
1416                        builder
1417                            .add_array_pop_left(int_ty.clone(), array_size, arr)
1418                            .unwrap()
1419                    } else {
1420                        builder
1421                            .add_array_pop_right(int_ty.clone(), array_size, arr)
1422                            .unwrap()
1423                    };
1424                    let [elem, new_arr] = builder
1425                        .build_unwrap_sum(
1426                            1,
1427                            option_type(vec![
1428                                int_ty.clone(),
1429                                array_type(array_size - 1, int_ty.clone()),
1430                            ]),
1431                            pop_res,
1432                        )
1433                        .unwrap();
1434                    arr = new_arr;
1435                    r = builder.add_iadd(6, r, elem).unwrap();
1436                }
1437                builder
1438                    .add_array_discard(
1439                        int_ty.clone(),
1440                        (array_contents.len() - from_left.len()) as u64,
1441                        arr,
1442                    )
1443                    .unwrap();
1444                builder.finish_hugr_with_outputs([r]).unwrap()
1445            });
1446        exec_ctx.add_extensions(|cge| {
1447            cge.add_default_prelude_extensions()
1448                .add_default_array_extensions()
1449                .add_default_int_extensions()
1450        });
1451        assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
1452    }
1453
1454    #[rstest]
1455    #[case(&[], 0)]
1456    #[case(&[1, 2], 3)]
1457    #[case(&[6, 6, 6], 18)]
1458    fn exec_unpack(
1459        mut exec_ctx: TestContext,
1460        #[case] array_contents: &[u64],
1461        #[case] expected: u64,
1462    ) {
1463        // We build a HUGR that:
1464        // - Loads an array with the given contents
1465        // - Unpacks all the elements
1466        // - Returns the sum of the elements
1467
1468        let int_ty = int_type(6);
1469        let hugr = SimpleHugrConfig::new()
1470            .with_outs(int_ty.clone())
1471            .with_extensions(exec_registry())
1472            .finish(|mut builder| {
1473                let array = array::ArrayValue::new(
1474                    int_ty.clone(),
1475                    array_contents
1476                        .iter()
1477                        .map(|&i| ConstInt::new_u(6, i).unwrap().into())
1478                        .collect_vec(),
1479                );
1480                let array = builder.add_load_value(array);
1481                let unpacked = builder
1482                    .add_array_unpack(int_ty.clone(), array_contents.len() as u64, array)
1483                    .unwrap();
1484                let mut r = builder.add_load_value(ConstInt::new_u(6, 0).unwrap());
1485                for elem in unpacked {
1486                    r = builder.add_iadd(6, r, elem).unwrap();
1487                }
1488
1489                builder.finish_hugr_with_outputs([r]).unwrap()
1490            });
1491        exec_ctx.add_extensions(|cge| {
1492            cge.add_default_prelude_extensions()
1493                .add_default_array_extensions()
1494                .add_default_int_extensions()
1495        });
1496        assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
1497    }
1498
1499    #[rstest]
1500    #[case(5, 42, 0)]
1501    #[case(5, 42, 1)]
1502    #[case(5, 42, 2)]
1503    #[case(5, 42, 3)]
1504    #[case(5, 42, 4)]
1505    fn exec_repeat(
1506        mut exec_ctx: TestContext,
1507        #[case] size: u64,
1508        #[case] value: u64,
1509        #[case] idx: u64,
1510    ) {
1511        // We build a HUGR that:
1512        // - Contains a nested function that returns `value`
1513        // - Creates an array of length `size` populated via this function
1514        // - Looks up the value at `idx` and returns it
1515
1516        let int_ty = int_type(6);
1517        let hugr = SimpleHugrConfig::new()
1518            .with_outs(int_ty.clone())
1519            .with_extensions(exec_registry())
1520            .finish(|mut builder| {
1521                let mut mb = builder.module_root_builder();
1522                let mut func = mb
1523                    .define_function("foo", Signature::new(vec![], vec![int_ty.clone()]))
1524                    .unwrap();
1525                let v = func.add_load_value(ConstInt::new_u(6, value).unwrap());
1526                let func_id = func.finish_with_outputs(vec![v]).unwrap();
1527                let func_v = builder.load_func(func_id.handle(), &[]).unwrap();
1528                let repeat = ArrayRepeat::new(int_ty.clone(), size);
1529                let arr = builder
1530                    .add_dataflow_op(repeat, vec![func_v])
1531                    .unwrap()
1532                    .out_wire(0);
1533                let idx_v = builder.add_load_value(ConstUsize::new(idx));
1534                let (get_res, arr) = builder
1535                    .add_array_get(int_ty.clone(), size, arr, idx_v)
1536                    .unwrap();
1537                let [elem] = builder
1538                    .build_unwrap_sum(1, option_type(vec![int_ty.clone()]), get_res)
1539                    .unwrap();
1540                builder
1541                    .add_array_discard(int_ty.clone(), size, arr)
1542                    .unwrap();
1543                builder.finish_hugr_with_outputs([elem]).unwrap()
1544            });
1545        exec_ctx.add_extensions(|cge| {
1546            cge.add_default_prelude_extensions()
1547                .add_default_array_extensions()
1548                .add_default_int_extensions()
1549        });
1550        assert_eq!(value, exec_ctx.exec_hugr_u64(hugr, "main"));
1551    }
1552
1553    #[rstest]
1554    #[case(10, 1)]
1555    #[case(10, 2)]
1556    #[case(0, 1)]
1557    fn exec_scan_map(mut exec_ctx: TestContext, #[case] size: u64, #[case] inc: u64) {
1558        // We build a HUGR that:
1559        // - Creates an array [1, 2, 3, ..., size]
1560        // - Maps a function that increments each element by `inc`
1561        // - Returns the sum of the array elements
1562
1563        let int_ty = int_type(6);
1564        let hugr = SimpleHugrConfig::new()
1565            .with_outs(int_ty.clone())
1566            .with_extensions(exec_registry())
1567            .finish(|mut builder| {
1568                let mut r = builder.add_load_value(ConstInt::new_u(6, 0).unwrap());
1569                let new_array_args = (0..size)
1570                    .map(|i| builder.add_load_value(ConstInt::new_u(6, i).unwrap()))
1571                    .collect_vec();
1572                let arr = builder
1573                    .add_new_array(int_ty.clone(), new_array_args)
1574                    .unwrap();
1575
1576                let mut mb = builder.module_root_builder();
1577                let mut func = mb
1578                    .define_function(
1579                        "foo",
1580                        Signature::new(vec![int_ty.clone()], vec![int_ty.clone()]),
1581                    )
1582                    .unwrap();
1583                let [elem] = func.input_wires_arr();
1584                let delta = func.add_load_value(ConstInt::new_u(6, inc).unwrap());
1585                let out = func.add_iadd(6, elem, delta).unwrap();
1586                let func_id = func.finish_with_outputs(vec![out]).unwrap();
1587                let func_v = builder.load_func(func_id.handle(), &[]).unwrap();
1588                let scan = ArrayScan::new(int_ty.clone(), int_ty.clone(), vec![], size);
1589                let mut arr = builder
1590                    .add_dataflow_op(scan, [arr, func_v])
1591                    .unwrap()
1592                    .out_wire(0);
1593
1594                for i in 0..size {
1595                    let array_size = size - i;
1596                    let pop_res = builder
1597                        .add_array_pop_left(int_ty.clone(), array_size, arr)
1598                        .unwrap();
1599                    let [elem, new_arr] = builder
1600                        .build_unwrap_sum(
1601                            1,
1602                            option_type(vec![
1603                                int_ty.clone(),
1604                                array_type(array_size - 1, int_ty.clone()),
1605                            ]),
1606                            pop_res,
1607                        )
1608                        .unwrap();
1609                    arr = new_arr;
1610                    r = builder.add_iadd(6, r, elem).unwrap();
1611                }
1612                builder
1613                    .add_array_discard_empty(int_ty.clone(), arr)
1614                    .unwrap();
1615                builder.finish_hugr_with_outputs([r]).unwrap()
1616            });
1617        exec_ctx.add_extensions(|cge| {
1618            cge.add_default_prelude_extensions()
1619                .add_default_array_extensions()
1620                .add_default_int_extensions()
1621        });
1622        let expected: u64 = (inc..size + inc).sum();
1623        assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
1624    }
1625
1626    #[rstest]
1627    #[case(0)]
1628    #[case(1)]
1629    #[case(10)]
1630    fn exec_scan_fold(mut exec_ctx: TestContext, #[case] size: u64) {
1631        // We build a HUGR that:
1632        // - Creates an array [1, 2, 3, ..., size]
1633        // - Sums up the elements of the array using a scan and returns that sum
1634
1635        let int_ty = int_type(6);
1636        let hugr = SimpleHugrConfig::new()
1637            .with_outs(int_ty.clone())
1638            .with_extensions(exec_registry())
1639            .finish(|mut builder| {
1640                let new_array_args = (0..size)
1641                    .map(|i| builder.add_load_value(ConstInt::new_u(6, i).unwrap()))
1642                    .collect_vec();
1643                let arr = builder
1644                    .add_new_array(int_ty.clone(), new_array_args)
1645                    .unwrap();
1646
1647                let mut mb = builder.module_root_builder();
1648                let mut func = mb
1649                    .define_function(
1650                        "foo",
1651                        Signature::new(
1652                            vec![int_ty.clone(), int_ty.clone()],
1653                            vec![Type::UNIT, int_ty.clone()],
1654                        ),
1655                    )
1656                    .unwrap();
1657                let [elem, acc] = func.input_wires_arr();
1658                let acc = func.add_iadd(6, elem, acc).unwrap();
1659                let unit = func
1660                    .add_dataflow_op(Tag::new(0, vec![type_row![]]), [])
1661                    .unwrap()
1662                    .out_wire(0);
1663                let func_id = func.finish_with_outputs(vec![unit, acc]).unwrap();
1664                let func_v = builder.load_func(func_id.handle(), &[]).unwrap();
1665                let scan = ArrayScan::new(int_ty.clone(), Type::UNIT, vec![int_ty.clone()], size);
1666                let zero = builder.add_load_value(ConstInt::new_u(6, 0).unwrap());
1667                let [arr, sum] = builder
1668                    .add_dataflow_op(scan, [arr, func_v, zero])
1669                    .unwrap()
1670                    .outputs_arr();
1671                builder.add_array_discard(Type::UNIT, size, arr).unwrap();
1672                builder.finish_hugr_with_outputs([sum]).unwrap()
1673            });
1674        exec_ctx.add_extensions(|cge| {
1675            cge.add_default_prelude_extensions()
1676                .add_default_array_extensions()
1677                .add_default_int_extensions()
1678        });
1679        let expected: u64 = (0..size).sum();
1680        assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
1681    }
1682}