hugr_llvm/extension/collections/
array.rs

1//! Codegen for prelude array operations.
2use std::iter;
3
4use anyhow::{anyhow, Ok, Result};
5use hugr_core::extension::prelude::option_type;
6use hugr_core::extension::simple_op::{MakeExtensionOp, MakeRegisteredOp};
7use hugr_core::ops::DataflowOpTrait;
8use hugr_core::std_extensions::collections::array::{
9    self, array_type, ArrayOp, ArrayOpDef, ArrayRepeat, ArrayScan,
10};
11use hugr_core::types::{TypeArg, TypeEnum};
12use hugr_core::{HugrView, Node};
13use inkwell::builder::{Builder, BuilderError};
14use inkwell::types::{BasicType, BasicTypeEnum};
15use inkwell::values::{
16    ArrayValue, BasicValue as _, BasicValueEnum, CallableValue, IntValue, PointerValue,
17};
18use inkwell::IntPredicate;
19use itertools::Itertools;
20
21use crate::emit::emit_value;
22use crate::{
23    emit::{deaggregate_call_result, EmitFuncContext, RowPromise},
24    types::{HugrType, TypingSession},
25};
26use crate::{CodegenExtension, CodegenExtsBuilder};
27
28impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
29    /// Add a [ArrayCodegenExtension] to the given [CodegenExtsBuilder] using `ccg`
30    /// as the implementation.
31    pub fn add_default_array_extensions(self) -> Self {
32        self.add_array_extensions(DefaultArrayCodegen)
33    }
34
35    /// Add a [ArrayCodegenExtension] to the given [CodegenExtsBuilder] using
36    /// [DefaultArrayCodegen] as the implementation.
37    pub fn add_array_extensions(self, ccg: impl ArrayCodegen + 'a) -> Self {
38        self.add_extension(ArrayCodegenExtension::from(ccg))
39    }
40}
41
42/// A helper trait for customising the lowering of [hugr_core::std_extensions::collections::array]
43/// types, [hugr_core::ops::constant::CustomConst]s, and ops.
44pub trait ArrayCodegen: Clone {
45    /// Return the llvm type of [hugr_core::std_extensions::collections::array::ARRAY_TYPENAME].
46    fn array_type<'c>(
47        &self,
48        _session: &TypingSession<'c, '_>,
49        elem_ty: BasicTypeEnum<'c>,
50        size: u64,
51    ) -> impl BasicType<'c> {
52        elem_ty.array_type(size as u32)
53    }
54
55    /// Emit a [hugr_core::std_extensions::collections::array::ArrayValue].
56    fn emit_array_value<'c, H: HugrView<Node = Node>>(
57        &self,
58        ctx: &mut EmitFuncContext<'c, '_, H>,
59        value: &array::ArrayValue,
60    ) -> Result<BasicValueEnum<'c>> {
61        emit_array_value(self, ctx, value)
62    }
63
64    /// Emit a [hugr_core::std_extensions::collections::array::ArrayOp].
65    fn emit_array_op<'c, H: HugrView<Node = Node>>(
66        &self,
67        ctx: &mut EmitFuncContext<'c, '_, H>,
68        op: ArrayOp,
69        inputs: Vec<BasicValueEnum<'c>>,
70        outputs: RowPromise<'c>,
71    ) -> Result<()> {
72        emit_array_op(self, ctx, op, inputs, outputs)
73    }
74
75    /// Emit a [hugr_core::std_extensions::collections::array::ArrayRepeat] op.
76    fn emit_array_repeat<'c, H: HugrView<Node = Node>>(
77        &self,
78        ctx: &mut EmitFuncContext<'c, '_, H>,
79        op: ArrayRepeat,
80        func: BasicValueEnum<'c>,
81    ) -> Result<BasicValueEnum<'c>> {
82        emit_repeat_op(ctx, op, func)
83    }
84
85    /// Emit a [hugr_core::std_extensions::collections::array::ArrayScan] op.
86    ///
87    /// Returns the resulting array and the final values of the accumulators.
88    fn emit_array_scan<'c, H: HugrView<Node = Node>>(
89        &self,
90        ctx: &mut EmitFuncContext<'c, '_, H>,
91        op: ArrayScan,
92        src_array: BasicValueEnum<'c>,
93        func: BasicValueEnum<'c>,
94        initial_accs: &[BasicValueEnum<'c>],
95    ) -> Result<(BasicValueEnum<'c>, Vec<BasicValueEnum<'c>>)> {
96        emit_scan_op(ctx, op, src_array, func, initial_accs)
97    }
98}
99
100/// A trivial implementation of [ArrayCodegen] which passes all methods
101/// through to their default implementations.
102#[derive(Default, Clone)]
103pub struct DefaultArrayCodegen;
104
105impl ArrayCodegen for DefaultArrayCodegen {}
106
107#[derive(Clone, Debug, Default)]
108pub struct ArrayCodegenExtension<CCG>(CCG);
109
110impl<CCG: ArrayCodegen> ArrayCodegenExtension<CCG> {
111    pub fn new(ccg: CCG) -> Self {
112        Self(ccg)
113    }
114}
115
116impl<CCG: ArrayCodegen> From<CCG> for ArrayCodegenExtension<CCG> {
117    fn from(ccg: CCG) -> Self {
118        Self::new(ccg)
119    }
120}
121
122impl<CCG: ArrayCodegen> CodegenExtension for ArrayCodegenExtension<CCG> {
123    fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
124        self,
125        builder: CodegenExtsBuilder<'a, H>,
126    ) -> CodegenExtsBuilder<'a, H>
127    where
128        Self: 'a,
129    {
130        builder
131            .custom_type((array::EXTENSION_ID, array::ARRAY_TYPENAME), {
132                let ccg = self.0.clone();
133                move |ts, hugr_type| {
134                    let [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] = hugr_type.args() else {
135                        return Err(anyhow!("Invalid type args for array type"));
136                    };
137                    let elem_ty = ts.llvm_type(ty)?;
138                    Ok(ccg.array_type(&ts, elem_ty, *n).as_basic_type_enum())
139                }
140            })
141            .custom_const::<array::ArrayValue>({
142                let ccg = self.0.clone();
143                move |context, k| ccg.emit_array_value(context, k)
144            })
145            .simple_extension_op::<ArrayOpDef>({
146                let ccg = self.0.clone();
147                move |context, args, _| {
148                    ccg.emit_array_op(
149                        context,
150                        ArrayOp::from_extension_op(args.node().as_ref())?,
151                        args.inputs,
152                        args.outputs,
153                    )
154                }
155            })
156            .extension_op(array::EXTENSION_ID, array::ARRAY_REPEAT_OP_ID, {
157                let ccg = self.0.clone();
158                move |context, args| {
159                    let func = args.inputs[0];
160                    let op = ArrayRepeat::from_extension_op(args.node().as_ref())?;
161                    let arr = ccg.emit_array_repeat(context, op, func)?;
162                    args.outputs.finish(context.builder(), [arr])
163                }
164            })
165            .extension_op(array::EXTENSION_ID, array::ARRAY_SCAN_OP_ID, {
166                let ccg = self.0.clone();
167                move |context, args| {
168                    let src_array = args.inputs[0];
169                    let func = args.inputs[1];
170                    let initial_accs = &args.inputs[2..];
171                    let op = ArrayScan::from_extension_op(args.node().as_ref())?;
172                    let (tgt_array, final_accs) =
173                        ccg.emit_array_scan(context, op, src_array, func, initial_accs)?;
174                    args.outputs
175                        .finish(context.builder(), iter::once(tgt_array).chain(final_accs))
176                }
177            })
178    }
179}
180
181/// Helper function to allocate an array on the stack.
182///
183/// Returns two pointers: The first one is a pointer to the first element of the
184/// array (i.e. it is of type `array.get_element_type().ptr_type()`) whereas the
185/// second one points to the whole array value, i.e. it is of type `array.ptr_type()`.
186fn build_array_alloca<'c>(
187    builder: &Builder<'c>,
188    array: ArrayValue<'c>,
189) -> Result<(PointerValue<'c>, PointerValue<'c>), BuilderError> {
190    let array_ty = array.get_type();
191    let array_len: IntValue<'c> = {
192        let ctx = builder.get_insert_block().unwrap().get_context();
193        ctx.i32_type().const_int(array_ty.len() as u64, false)
194    };
195    let ptr = builder.build_array_alloca(array_ty.get_element_type(), array_len, "")?;
196    let array_ptr = builder
197        .build_bit_cast(ptr, array_ty.ptr_type(Default::default()), "")?
198        .into_pointer_value();
199    builder.build_store(array_ptr, array)?;
200    Result::Ok((ptr, array_ptr))
201}
202
203/// Helper function to allocate an array on the stack and pass a pointer to it
204/// to a closure.
205///
206/// The pointer forwarded to the closure is a pointer to the first element of
207/// the array. I.e. it is of type `array.get_element_type().ptr_type()` not
208/// `array.ptr_type()`
209fn with_array_alloca<'c, T, E: From<BuilderError>>(
210    builder: &Builder<'c>,
211    array: ArrayValue<'c>,
212    go: impl FnOnce(PointerValue<'c>) -> Result<T, E>,
213) -> Result<T, E> {
214    let (ptr, _) = build_array_alloca(builder, array)?;
215    go(ptr)
216}
217
218/// Helper function to build a loop that repeats for a given number of iterations.
219///
220/// The provided closure is called to build the loop body. Afterwards, the builder is positioned at
221/// the end of the loop exit block.
222fn build_loop<'c, T, H: HugrView<Node = Node>>(
223    ctx: &mut EmitFuncContext<'c, '_, H>,
224    iters: IntValue<'c>,
225    go: impl FnOnce(&mut EmitFuncContext<'c, '_, H>, IntValue<'c>) -> Result<T>,
226) -> Result<T> {
227    let builder = ctx.builder();
228    let idx_ty = ctx.iw_context().i32_type();
229    let idx_ptr = builder.build_alloca(idx_ty, "")?;
230    builder.build_store(idx_ptr, idx_ty.const_zero())?;
231
232    let exit_block = ctx.new_basic_block("", None);
233
234    let (body_block, val) = ctx.build_positioned_new_block("", Some(exit_block), |ctx, bb| {
235        let idx = ctx.builder().build_load(idx_ptr, "")?.into_int_value();
236        let val = go(ctx, idx)?;
237        let builder = ctx.builder();
238        let inc_idx = builder.build_int_add(idx, idx_ty.const_int(1, false), "")?;
239        builder.build_store(idx_ptr, inc_idx)?;
240        // Branch to the head is built later
241        Ok((bb, val))
242    })?;
243
244    let head_block = ctx.build_positioned_new_block("", Some(body_block), |ctx, bb| {
245        let builder = ctx.builder();
246        let idx = builder.build_load(idx_ptr, "")?.into_int_value();
247        let cmp = builder.build_int_compare(IntPredicate::ULT, idx, iters, "")?;
248        builder.build_conditional_branch(cmp, body_block, exit_block)?;
249        Ok(bb)
250    })?;
251
252    let builder = ctx.builder();
253    builder.build_unconditional_branch(head_block)?;
254    builder.position_at_end(body_block);
255    builder.build_unconditional_branch(head_block)?;
256    ctx.builder().position_at_end(exit_block);
257    Ok(val)
258}
259
260pub fn emit_array_value<'c, H: HugrView<Node = Node>>(
261    ccg: &impl ArrayCodegen,
262    ctx: &mut EmitFuncContext<'c, '_, H>,
263    value: &array::ArrayValue,
264) -> Result<BasicValueEnum<'c>> {
265    let ts = ctx.typing_session();
266    let llvm_array_ty = ccg
267        .array_type(
268            &ts,
269            ts.llvm_type(value.get_element_type())?,
270            value.get_contents().len() as u64,
271        )
272        .as_basic_type_enum()
273        .into_array_type();
274    let mut array_v = llvm_array_ty.get_undef();
275    for (i, v) in value.get_contents().iter().enumerate() {
276        let llvm_v = emit_value(ctx, v)?;
277        array_v = ctx
278            .builder()
279            .build_insert_value(array_v, llvm_v, i as u32, "")?
280            .into_array_value();
281    }
282    Ok(array_v.into())
283}
284
285pub fn emit_array_op<'c, H: HugrView<Node = Node>>(
286    ccg: &impl ArrayCodegen,
287    ctx: &mut EmitFuncContext<'c, '_, H>,
288    op: ArrayOp,
289    inputs: Vec<BasicValueEnum<'c>>,
290    outputs: RowPromise<'c>,
291) -> Result<()> {
292    let builder = ctx.builder();
293    let ts = ctx.typing_session();
294    let sig = op
295        .clone()
296        .to_extension_op()
297        .unwrap()
298        .signature()
299        .into_owned();
300    let ArrayOp {
301        def,
302        ref elem_ty,
303        size,
304    } = op;
305    let llvm_array_ty = ccg
306        .array_type(&ts, ts.llvm_type(elem_ty)?, size)
307        .as_basic_type_enum()
308        .into_array_type();
309    match def {
310        ArrayOpDef::new_array => {
311            let mut array_v = llvm_array_ty.get_undef();
312            for (i, v) in inputs.into_iter().enumerate() {
313                array_v = builder
314                    .build_insert_value(array_v, v, i as u32, "")?
315                    .into_array_value();
316            }
317            outputs.finish(builder, [array_v.as_basic_value_enum()])
318        }
319        ArrayOpDef::get => {
320            let [array_v, index_v] = inputs
321                .try_into()
322                .map_err(|_| anyhow!("ArrayOpDef::get expects two arguments"))?;
323            let array_v = array_v.into_array_value();
324            let index_v = index_v.into_int_value();
325            let res_hugr_ty = sig
326                .output()
327                .get(0)
328                .ok_or(anyhow!("ArrayOp::get has no outputs"))?;
329
330            let res_sum_ty = {
331                let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else {
332                    Err(anyhow!("ArrayOp::get output is not a sum type"))?
333                };
334                ts.llvm_sum_type(st.clone())?
335            };
336
337            let exit_rmb = ctx.new_row_mail_box([res_hugr_ty], "")?;
338
339            let exit_block = ctx.build_positioned_new_block("", None, |ctx, bb| {
340                outputs.finish(ctx.builder(), exit_rmb.read_vec(ctx.builder(), [])?)?;
341                Ok(bb)
342            })?;
343
344            let success_block =
345                ctx.build_positioned_new_block("", Some(exit_block), |ctx, bb| {
346                    let builder = ctx.builder();
347                    let elem_v = with_array_alloca(builder, array_v, |ptr| {
348                        // inside `success_block` we know `index_v` to be in
349                        // bounds.
350                        let elem_addr =
351                            unsafe { builder.build_in_bounds_gep(ptr, &[index_v], "")? };
352                        builder.build_load(elem_addr, "")
353                    })?;
354                    let success_v = res_sum_ty.build_tag(builder, 1, vec![elem_v])?;
355                    exit_rmb.write(ctx.builder(), [success_v.into()])?;
356                    builder.build_unconditional_branch(exit_block)?;
357                    Ok(bb)
358                })?;
359
360            let failure_block =
361                ctx.build_positioned_new_block("", Some(success_block), |ctx, bb| {
362                    let builder = ctx.builder();
363                    let failure_v = res_sum_ty.build_tag(builder, 0, vec![])?;
364                    exit_rmb.write(ctx.builder(), [failure_v.into()])?;
365                    builder.build_unconditional_branch(exit_block)?;
366                    Ok(bb)
367                })?;
368
369            let builder = ctx.builder();
370            let is_success = builder.build_int_compare(
371                IntPredicate::ULT,
372                index_v,
373                index_v.get_type().const_int(size, false),
374                "",
375            )?;
376
377            builder.build_conditional_branch(is_success, success_block, failure_block)?;
378            builder.position_at_end(exit_block);
379            Ok(())
380        }
381        ArrayOpDef::set => {
382            let [array_v0, index_v, value_v] = inputs
383                .try_into()
384                .map_err(|_| anyhow!("ArrayOpDef::set expects three arguments"))?;
385            let array_v = array_v0.into_array_value();
386            let index_v = index_v.into_int_value();
387
388            let res_hugr_ty = sig
389                .output()
390                .get(0)
391                .ok_or(anyhow!("ArrayOp::set has no outputs"))?;
392
393            let res_sum_ty = {
394                let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else {
395                    Err(anyhow!("ArrayOp::set output is not a sum type"))?
396                };
397                ts.llvm_sum_type(st.clone())?
398            };
399
400            let exit_rmb = ctx.new_row_mail_box([res_hugr_ty], "")?;
401
402            let exit_block = ctx.build_positioned_new_block("", None, |ctx, bb| {
403                outputs.finish(ctx.builder(), exit_rmb.read_vec(ctx.builder(), [])?)?;
404                Ok(bb)
405            })?;
406
407            let success_block =
408                ctx.build_positioned_new_block("", Some(exit_block), |ctx, bb| {
409                    let builder = ctx.builder();
410                    let (elem_v, array_v) = with_array_alloca(builder, array_v, |ptr| {
411                        // inside `success_block` we know `index_v` to be in
412                        // bounds.
413                        let elem_addr =
414                            unsafe { builder.build_in_bounds_gep(ptr, &[index_v], "")? };
415                        let elem_v = builder.build_load(elem_addr, "")?;
416                        builder.build_store(elem_addr, value_v)?;
417                        let ptr = builder
418                            .build_bit_cast(
419                                ptr,
420                                array_v.get_type().ptr_type(Default::default()),
421                                "",
422                            )?
423                            .into_pointer_value();
424                        let array_v = builder.build_load(ptr, "")?;
425                        Ok((elem_v, array_v))
426                    })?;
427                    let success_v = res_sum_ty.build_tag(builder, 1, vec![elem_v, array_v])?;
428                    exit_rmb.write(ctx.builder(), [success_v.into()])?;
429                    builder.build_unconditional_branch(exit_block)?;
430                    Ok(bb)
431                })?;
432
433            let failure_block =
434                ctx.build_positioned_new_block("", Some(success_block), |ctx, bb| {
435                    let builder = ctx.builder();
436                    let failure_v =
437                        res_sum_ty.build_tag(builder, 0, vec![value_v, array_v.into()])?;
438                    exit_rmb.write(ctx.builder(), [failure_v.into()])?;
439                    builder.build_unconditional_branch(exit_block)?;
440                    Ok(bb)
441                })?;
442
443            let builder = ctx.builder();
444            let is_success = builder.build_int_compare(
445                IntPredicate::ULT,
446                index_v,
447                index_v.get_type().const_int(size, false),
448                "",
449            )?;
450            builder.build_conditional_branch(is_success, success_block, failure_block)?;
451            builder.position_at_end(exit_block);
452            Ok(())
453        }
454        ArrayOpDef::swap => {
455            let [array_v0, index1_v, index2_v] = inputs
456                .try_into()
457                .map_err(|_| anyhow!("ArrayOpDef::swap expects three arguments"))?;
458            let array_v = array_v0.into_array_value();
459            let index1_v = index1_v.into_int_value();
460            let index2_v = index2_v.into_int_value();
461
462            let res_hugr_ty = sig
463                .output()
464                .get(0)
465                .ok_or(anyhow!("ArrayOp::swap has no outputs"))?;
466
467            let res_sum_ty = {
468                let TypeEnum::Sum(st) = res_hugr_ty.as_type_enum() else {
469                    Err(anyhow!("ArrayOp::swap output is not a sum type"))?
470                };
471                ts.llvm_sum_type(st.clone())?
472            };
473
474            let exit_rmb = ctx.new_row_mail_box([res_hugr_ty], "")?;
475
476            let exit_block = ctx.build_positioned_new_block("", None, |ctx, bb| {
477                outputs.finish(ctx.builder(), exit_rmb.read_vec(ctx.builder(), [])?)?;
478                Ok(bb)
479            })?;
480
481            let success_block =
482                ctx.build_positioned_new_block("", Some(exit_block), |ctx, bb| {
483                    // if `index1_v` == `index2_v` then the following is a no-op.
484                    // We could check for this: either with a select instruction
485                    // here, or by branching to another case in earlier.
486                    // Doing so would generate better code in cases where the
487                    // optimiser can determine that the indices are the same, at
488                    // the cost of worse code in cases where it cannot.
489                    // For now we choose the simpler option of omitting the check.
490                    let builder = ctx.builder();
491                    let array_v = with_array_alloca(builder, array_v, |ptr| {
492                        // inside `success_block` we know `index1_v` and `index2_v`
493                        // to be in bounds.
494                        let elem1_addr =
495                            unsafe { builder.build_in_bounds_gep(ptr, &[index1_v], "")? };
496                        let elem1_v = builder.build_load(elem1_addr, "")?;
497                        let elem2_addr =
498                            unsafe { builder.build_in_bounds_gep(ptr, &[index2_v], "")? };
499                        let elem2_v = builder.build_load(elem2_addr, "")?;
500                        builder.build_store(elem1_addr, elem2_v)?;
501                        builder.build_store(elem2_addr, elem1_v)?;
502                        let ptr = builder
503                            .build_bit_cast(
504                                ptr,
505                                array_v.get_type().ptr_type(Default::default()),
506                                "",
507                            )?
508                            .into_pointer_value();
509                        builder.build_load(ptr, "")
510                    })?;
511                    let success_v = res_sum_ty.build_tag(builder, 1, vec![array_v])?;
512                    exit_rmb.write(ctx.builder(), [success_v.into()])?;
513                    builder.build_unconditional_branch(exit_block)?;
514                    Ok(bb)
515                })?;
516
517            let failure_block =
518                ctx.build_positioned_new_block("", Some(success_block), |ctx, bb| {
519                    let builder = ctx.builder();
520                    let failure_v = res_sum_ty.build_tag(builder, 0, vec![array_v.into()])?;
521                    exit_rmb.write(ctx.builder(), [failure_v.into()])?;
522                    builder.build_unconditional_branch(exit_block)?;
523                    Ok(bb)
524                })?;
525
526            let builder = ctx.builder();
527            let is_success = {
528                let index1_ok = builder.build_int_compare(
529                    IntPredicate::ULT,
530                    index1_v,
531                    index1_v.get_type().const_int(size, false),
532                    "",
533                )?;
534                let index2_ok = builder.build_int_compare(
535                    IntPredicate::ULT,
536                    index2_v,
537                    index2_v.get_type().const_int(size, false),
538                    "",
539                )?;
540                builder.build_and(index1_ok, index2_ok, "")?
541            };
542            builder.build_conditional_branch(is_success, success_block, failure_block)?;
543            builder.position_at_end(exit_block);
544            Ok(())
545        }
546        ArrayOpDef::pop_left => {
547            let [array_v] = inputs
548                .try_into()
549                .map_err(|_| anyhow!("ArrayOpDef::pop_left expects one argument"))?;
550            let r = emit_pop_op(
551                builder,
552                &ts,
553                elem_ty.clone(),
554                size,
555                array_v.into_array_value(),
556                true,
557            )?;
558            outputs.finish(ctx.builder(), [r])
559        }
560        ArrayOpDef::pop_right => {
561            let [array_v] = inputs
562                .try_into()
563                .map_err(|_| anyhow!("ArrayOpDef::pop_right expects one argument"))?;
564            let r = emit_pop_op(
565                builder,
566                &ts,
567                elem_ty.clone(),
568                size,
569                array_v.into_array_value(),
570                false,
571            )?;
572            outputs.finish(ctx.builder(), [r])
573        }
574        ArrayOpDef::discard_empty => Ok(()),
575        _ => todo!(),
576    }
577}
578
579/// Helper function to emit the pop operations.
580fn emit_pop_op<'c>(
581    builder: &Builder<'c>,
582    ts: &TypingSession<'c, '_>,
583    elem_ty: HugrType,
584    size: u64,
585    array_v: ArrayValue<'c>,
586    pop_left: bool,
587) -> Result<BasicValueEnum<'c>> {
588    let ret_ty = ts.llvm_sum_type(option_type(vec![
589        elem_ty.clone(),
590        array_type(size.saturating_add_signed(-1), elem_ty),
591    ]))?;
592    if size == 0 {
593        return Ok(ret_ty.build_tag(builder, 0, vec![])?.into());
594    }
595    let ctx = builder.get_insert_block().unwrap().get_context();
596    let (elem_v, array_v) = with_array_alloca(builder, array_v, |ptr| {
597        let (elem_ptr, ptr) = {
598            if pop_left {
599                let rest_ptr =
600                    unsafe { builder.build_gep(ptr, &[ctx.i32_type().const_int(1, false)], "") }?;
601                (ptr, rest_ptr)
602            } else {
603                let elem_ptr = unsafe {
604                    builder.build_gep(ptr, &[ctx.i32_type().const_int(size - 1, false)], "")
605                }?;
606                (elem_ptr, ptr)
607            }
608        };
609        let elem_v = builder.build_load(elem_ptr, "")?;
610        let new_array_ty = array_v
611            .get_type()
612            .get_element_type()
613            .array_type(size as u32 - 1);
614        let ptr = builder
615            .build_bit_cast(ptr, new_array_ty.ptr_type(Default::default()), "")?
616            .into_pointer_value();
617        let array_v = builder.build_load(ptr, "")?;
618        Ok((elem_v, array_v))
619    })?;
620    Ok(ret_ty.build_tag(builder, 1, vec![elem_v, array_v])?.into())
621}
622
623/// Emits an [ArrayRepeat] op.
624pub fn emit_repeat_op<'c, H: HugrView<Node = Node>>(
625    ctx: &mut EmitFuncContext<'c, '_, H>,
626    op: ArrayRepeat,
627    func: BasicValueEnum<'c>,
628) -> Result<BasicValueEnum<'c>> {
629    let builder = ctx.builder();
630    let array_len = ctx.iw_context().i32_type().const_int(op.size, false);
631    let array_ty = ctx.llvm_type(&op.elem_ty)?.array_type(op.size as u32);
632    let (ptr, array_ptr) = build_array_alloca(builder, array_ty.get_undef())?;
633    build_loop(ctx, array_len, |ctx, idx| {
634        let builder = ctx.builder();
635        let func_ptr = CallableValue::try_from(func.into_pointer_value())
636            .map_err(|_| anyhow!("ArrayOpDef::repeat expects a function pointer"))?;
637        let v = builder
638            .build_call(func_ptr, &[], "")?
639            .try_as_basic_value()
640            .left()
641            .ok_or(anyhow!("ArrayOpDef::repeat function must return a value"))?;
642        let elem_addr = unsafe { builder.build_in_bounds_gep(ptr, &[idx], "")? };
643        builder.build_store(elem_addr, v)?;
644        Ok(())
645    })?;
646
647    let builder = ctx.builder();
648    let array_v = builder.build_load(array_ptr, "")?;
649    Ok(array_v)
650}
651
652/// Emits an [ArrayScan] op.
653///
654/// Returns the resulting array and the final values of the accumulators.
655pub fn emit_scan_op<'c, H: HugrView<Node = Node>>(
656    ctx: &mut EmitFuncContext<'c, '_, H>,
657    op: ArrayScan,
658    src_array: BasicValueEnum<'c>,
659    func: BasicValueEnum<'c>,
660    initial_accs: &[BasicValueEnum<'c>],
661) -> Result<(BasicValueEnum<'c>, Vec<BasicValueEnum<'c>>)> {
662    let builder = ctx.builder();
663    let ts = ctx.typing_session();
664    let array_len = ctx.iw_context().i32_type().const_int(op.size, false);
665    let tgt_array_ty = ts.llvm_type(&op.tgt_ty)?.array_type(op.size as u32);
666    let (src_ptr, _) = build_array_alloca(builder, src_array.into_array_value())?;
667    let (tgt_ptr, tgt_array_ptr) = build_array_alloca(builder, tgt_array_ty.get_undef())?;
668
669    let acc_tys: Vec<_> = op.acc_tys.iter().map(|ty| ts.llvm_type(ty)).try_collect()?;
670    let acc_ptrs: Vec<_> = acc_tys
671        .iter()
672        .map(|ty| builder.build_alloca(*ty, ""))
673        .try_collect()?;
674    for (ptr, initial_val) in acc_ptrs.iter().zip(initial_accs) {
675        builder.build_store(*ptr, *initial_val)?;
676    }
677
678    build_loop(ctx, array_len, |ctx, idx| {
679        let builder = ctx.builder();
680        let func_ptr = CallableValue::try_from(func.into_pointer_value())
681            .map_err(|_| anyhow!("ArrayOpDef::scan expects a function pointer"))?;
682        let src_elem_addr = unsafe { builder.build_in_bounds_gep(src_ptr, &[idx], "")? };
683        let src_elem = builder.build_load(src_elem_addr, "")?;
684        let mut args = vec![src_elem.into()];
685        for ptr in acc_ptrs.iter() {
686            args.push(builder.build_load(*ptr, "")?.into());
687        }
688        let call = builder.build_call(func_ptr, args.as_slice(), "")?;
689        let call_results = deaggregate_call_result(builder, call, 1 + acc_tys.len())?;
690        let tgt_elem_addr = unsafe { builder.build_in_bounds_gep(tgt_ptr, &[idx], "")? };
691        builder.build_store(tgt_elem_addr, call_results[0])?;
692        for (ptr, next_act) in acc_ptrs.iter().zip(call_results[1..].iter()) {
693            builder.build_store(*ptr, *next_act)?;
694        }
695        Ok(())
696    })?;
697
698    let builder = ctx.builder();
699    let tgt_array_v = builder.build_load(tgt_array_ptr, "")?;
700    let final_accs = acc_ptrs
701        .into_iter()
702        .map(|ptr| builder.build_load(ptr, ""))
703        .try_collect()?;
704    Ok((tgt_array_v, final_accs))
705}
706
707#[cfg(test)]
708mod test {
709    use hugr_core::builder::Container as _;
710    use hugr_core::extension::ExtensionSet;
711    use hugr_core::ops::Tag;
712    use hugr_core::std_extensions::collections::array::{self, array_type, ArrayRepeat, ArrayScan};
713    use hugr_core::std_extensions::STD_REG;
714    use hugr_core::types::Type;
715    use hugr_core::{
716        builder::{Dataflow, DataflowSubContainer, SubContainer},
717        extension::{
718            prelude::{self, bool_t, option_type, usize_t, ConstUsize, UnwrapBuilder as _},
719            ExtensionRegistry,
720        },
721        ops::Value,
722        std_extensions::{
723            arithmetic::{
724                int_ops::{self},
725                int_types::{self, int_type, ConstInt},
726            },
727            logic,
728        },
729        type_row,
730        types::Signature,
731    };
732    use itertools::Itertools as _;
733    use rstest::rstest;
734
735    use crate::{
736        check_emission,
737        emit::test::SimpleHugrConfig,
738        test::{exec_ctx, llvm_ctx, TestContext},
739        utils::{array_op_builder, ArrayOpBuilder, IntOpBuilder, LogicOpBuilder},
740    };
741
742    #[rstest]
743    fn emit_all_ops(mut llvm_ctx: TestContext) {
744        let hugr = SimpleHugrConfig::new()
745            .with_extensions(STD_REG.to_owned())
746            .finish(|mut builder| {
747                array_op_builder::test::all_array_ops(builder.dfg_builder_endo([]).unwrap())
748                    .finish_sub_container()
749                    .unwrap();
750                builder.finish_sub_container().unwrap()
751            });
752        llvm_ctx.add_extensions(|cge| {
753            cge.add_default_prelude_extensions()
754                .add_default_array_extensions()
755        });
756        check_emission!(hugr, llvm_ctx);
757    }
758
759    #[rstest]
760    fn emit_get(mut llvm_ctx: TestContext) {
761        let hugr = SimpleHugrConfig::new()
762            .with_extensions(STD_REG.to_owned())
763            .finish(|mut builder| {
764                let us1 = builder.add_load_value(ConstUsize::new(1));
765                let us2 = builder.add_load_value(ConstUsize::new(2));
766                let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap();
767                builder.add_array_get(usize_t(), 2, arr, us1).unwrap();
768                builder.finish_with_outputs([]).unwrap()
769            });
770        llvm_ctx.add_extensions(|cge| {
771            cge.add_default_prelude_extensions()
772                .add_default_array_extensions()
773        });
774        check_emission!(hugr, llvm_ctx);
775    }
776
777    #[rstest]
778    fn emit_array_value(mut llvm_ctx: TestContext) {
779        let hugr = SimpleHugrConfig::new()
780            .with_extensions(STD_REG.to_owned())
781            .with_outs(vec![array_type(2, usize_t())])
782            .finish(|mut builder| {
783                let vs = vec![ConstUsize::new(1).into(), ConstUsize::new(2).into()];
784                let arr = builder.add_load_value(array::ArrayValue::new(usize_t(), vs));
785                builder.finish_with_outputs([arr]).unwrap()
786            });
787        llvm_ctx.add_extensions(|cge| {
788            cge.add_default_prelude_extensions()
789                .add_default_array_extensions()
790        });
791        check_emission!(hugr, llvm_ctx);
792    }
793
794    fn exec_registry() -> ExtensionRegistry {
795        ExtensionRegistry::new([
796            int_types::EXTENSION.to_owned(),
797            int_ops::EXTENSION.to_owned(),
798            logic::EXTENSION.to_owned(),
799            prelude::PRELUDE.to_owned(),
800            array::EXTENSION.to_owned(),
801        ])
802    }
803
804    fn exec_extension_set() -> ExtensionSet {
805        ExtensionSet::from_iter([
806            int_types::EXTENSION_ID,
807            int_ops::EXTENSION_ID,
808            logic::EXTENSION_ID,
809            prelude::PRELUDE_ID,
810            array::EXTENSION_ID,
811        ])
812    }
813
814    #[rstest]
815    #[case(0, 1)]
816    #[case(1, 2)]
817    #[case(3, 0)]
818    #[case(999999, 0)]
819    fn exec_get(mut exec_ctx: TestContext, #[case] index: u64, #[case] expected: u64) {
820        // We build a HUGR that:
821        // - Creates an array of [1,2]
822        // - Gets the element at the given index
823        // - Returns the element if the index is in bounds, otherwise 0
824        let hugr = SimpleHugrConfig::new()
825            .with_outs(usize_t())
826            .with_extensions(exec_registry())
827            .finish(|mut builder| {
828                let us0 = builder.add_load_value(ConstUsize::new(0));
829                let us1 = builder.add_load_value(ConstUsize::new(1));
830                let us2 = builder.add_load_value(ConstUsize::new(2));
831                let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap();
832                let i = builder.add_load_value(ConstUsize::new(index));
833                let get_r = builder.add_array_get(usize_t(), 2, arr, i).unwrap();
834                let r = {
835                    let ot = option_type(usize_t());
836                    let variants = (0..ot.num_variants())
837                        .map(|i| ot.get_variant(i).cloned().unwrap().try_into().unwrap())
838                        .collect_vec();
839                    let mut builder = builder
840                        .conditional_builder((variants, get_r), [], usize_t().into())
841                        .unwrap();
842                    {
843                        let failure_case = builder.case_builder(0).unwrap();
844                        failure_case.finish_with_outputs([us0]).unwrap();
845                    }
846                    {
847                        let success_case = builder.case_builder(1).unwrap();
848                        let inputs = success_case.input_wires();
849                        success_case.finish_with_outputs(inputs).unwrap();
850                    }
851                    builder.finish_sub_container().unwrap().out_wire(0)
852                };
853                builder.finish_with_outputs([r]).unwrap()
854            });
855        exec_ctx.add_extensions(|cge| {
856            cge.add_default_prelude_extensions()
857                .add_default_array_extensions()
858        });
859        assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
860    }
861
862    #[rstest]
863    #[case(0, 3, 1, [3,2])]
864    #[case(1, 3, 2, [1,3])]
865    #[case(2, 3, 3, [1,2])]
866    #[case(999999, 3, 3, [1,2])]
867    fn exec_set(
868        mut exec_ctx: TestContext,
869        #[case] index: u64,
870        #[case] value: u64,
871        #[case] expected_elem: u64,
872        #[case] expected_arr: [u64; 2],
873    ) {
874        // We build a HUGR that
875        // - Creates an array: [1,2]
876        // - Sets the element at the given index to the given value
877        // - Checks the following, returning 1 iff they are all true:
878        //   - The element returned from set is `expected_elem`
879        //   - The Oth element of the resulting array is `expected_arr_0`
880
881        use hugr_core::extension::prelude::either_type;
882        let int_ty = int_type(3);
883        let hugr = SimpleHugrConfig::new()
884            .with_outs(usize_t())
885            .with_extensions(exec_registry())
886            .finish(|mut builder| {
887                let us0 = builder.add_load_value(ConstUsize::new(0));
888                let us1 = builder.add_load_value(ConstUsize::new(1));
889                let i1 = builder.add_load_value(ConstInt::new_u(3, 1).unwrap());
890                let i2 = builder.add_load_value(ConstInt::new_u(3, 2).unwrap());
891                let arr = builder.add_new_array(int_ty.clone(), [i1, i2]).unwrap();
892                let index = builder.add_load_value(ConstUsize::new(index));
893                let value = builder.add_load_value(ConstInt::new_u(3, value).unwrap());
894                let get_r = builder
895                    .add_array_set(int_ty.clone(), 2, arr, index, value)
896                    .unwrap();
897                let r = {
898                    let res_sum_ty = {
899                        let row = vec![int_ty.clone(), array_type(2, int_ty.clone())];
900                        either_type(row.clone(), row)
901                    };
902                    let variants = (0..res_sum_ty.num_variants())
903                        .map(|i| {
904                            res_sum_ty
905                                .get_variant(i)
906                                .cloned()
907                                .unwrap()
908                                .try_into()
909                                .unwrap()
910                        })
911                        .collect_vec();
912                    let mut builder = builder
913                        .conditional_builder((variants, get_r), [], bool_t().into())
914                        .unwrap();
915                    for i in 0..2 {
916                        let mut builder = builder.case_builder(i).unwrap();
917                        let [elem, arr] = builder.input_wires_arr();
918                        let expected_elem =
919                            builder.add_load_value(ConstInt::new_u(3, expected_elem).unwrap());
920                        let expected_arr_0 =
921                            builder.add_load_value(ConstInt::new_u(3, expected_arr[0]).unwrap());
922                        let expected_arr_1 =
923                            builder.add_load_value(ConstInt::new_u(3, expected_arr[1]).unwrap());
924                        let [arr_0] = {
925                            let r = builder.add_array_get(int_ty.clone(), 2, arr, us0).unwrap();
926                            builder
927                                .build_unwrap_sum(1, option_type(int_ty.clone()), r)
928                                .unwrap()
929                        };
930                        let [arr_1] = {
931                            let r = builder.add_array_get(int_ty.clone(), 2, arr, us1).unwrap();
932                            builder
933                                .build_unwrap_sum(1, option_type(int_ty.clone()), r)
934                                .unwrap()
935                        };
936                        let elem_eq = builder.add_ieq(3, elem, expected_elem).unwrap();
937                        let arr_0_eq = builder.add_ieq(3, arr_0, expected_arr_0).unwrap();
938                        let arr_1_eq = builder.add_ieq(3, arr_1, expected_arr_1).unwrap();
939                        let r = builder.add_and(elem_eq, arr_0_eq).unwrap();
940                        let r = builder.add_and(r, arr_1_eq).unwrap();
941                        builder.finish_with_outputs([r]).unwrap();
942                    }
943                    builder.finish_sub_container().unwrap().out_wire(0)
944                };
945                let r = {
946                    let mut conditional = builder
947                        .conditional_builder(([type_row![], type_row![]], r), [], usize_t().into())
948                        .unwrap();
949                    conditional
950                        .case_builder(0)
951                        .unwrap()
952                        .finish_with_outputs([us0])
953                        .unwrap();
954                    conditional
955                        .case_builder(1)
956                        .unwrap()
957                        .finish_with_outputs([us1])
958                        .unwrap();
959                    conditional.finish_sub_container().unwrap().out_wire(0)
960                };
961                builder.finish_with_outputs([r]).unwrap()
962            });
963        exec_ctx.add_extensions(|cge| {
964            cge.add_default_prelude_extensions()
965                .add_default_array_extensions()
966                .add_int_extensions()
967                .add_logic_extensions()
968        });
969        assert_eq!(1, exec_ctx.exec_hugr_u64(hugr, "main"));
970    }
971
972    #[rstest]
973    #[case(0, 1, [2,1], true)]
974    #[case(0, 0, [1,2], true)]
975    #[case(0, 2, [1,2], false)]
976    #[case(2, 0, [1,2], false)]
977    #[case(9999999, 0, [1,2], false)]
978    #[case(0, 9999999, [1,2], false)]
979    fn exec_swap(
980        mut exec_ctx: TestContext,
981        #[case] index1: u64,
982        #[case] index2: u64,
983        #[case] expected_arr: [u64; 2],
984        #[case] expected_succeeded: bool,
985    ) {
986        // We build a HUGR that:
987        // - Creates an array: [1 ,2]
988        // - Swaps the elements at the given indices
989        // - Checks the following, returning 1 iff the following are all true:
990        //  - The element at index 0 is `expected_elem_0`
991        //  - The swap operation succeeded iff `expected_succeeded`
992
993        let int_ty = int_type(3);
994        let arr_ty = array_type(2, int_ty.clone());
995        let hugr = SimpleHugrConfig::new()
996            .with_outs(usize_t())
997            .with_extensions(exec_registry())
998            .finish(|mut builder| {
999                let us0 = builder.add_load_value(ConstUsize::new(0));
1000                let us1 = builder.add_load_value(ConstUsize::new(1));
1001                let i1 = builder.add_load_value(ConstInt::new_u(3, 1).unwrap());
1002                let i2 = builder.add_load_value(ConstInt::new_u(3, 2).unwrap());
1003                let arr = builder.add_new_array(int_ty.clone(), [i1, i2]).unwrap();
1004
1005                let index1 = builder.add_load_value(ConstUsize::new(index1));
1006                let index2 = builder.add_load_value(ConstUsize::new(index2));
1007                let r = builder
1008                    .add_array_swap(int_ty.clone(), 2, arr, index1, index2)
1009                    .unwrap();
1010                let [arr, was_expected_success] = {
1011                    let mut conditional = builder
1012                        .conditional_builder(
1013                            (
1014                                [vec![arr_ty.clone()].into(), vec![arr_ty.clone()].into()],
1015                                r,
1016                            ),
1017                            [],
1018                            vec![arr_ty, bool_t()].into(),
1019                        )
1020                        .unwrap();
1021                    for i in 0..2 {
1022                        let mut case = conditional.case_builder(i).unwrap();
1023                        let [arr] = case.input_wires_arr();
1024                        let was_expected_success =
1025                            case.add_load_value(if (i == 1) == expected_succeeded {
1026                                Value::true_val()
1027                            } else {
1028                                Value::false_val()
1029                            });
1030                        case.finish_with_outputs([arr, was_expected_success])
1031                            .unwrap();
1032                    }
1033                    conditional.finish_sub_container().unwrap().outputs_arr()
1034                };
1035                let elem_0 = {
1036                    let r = builder.add_array_get(int_ty.clone(), 2, arr, us0).unwrap();
1037                    builder
1038                        .build_unwrap_sum::<1>(1, option_type(int_ty.clone()), r)
1039                        .unwrap()[0]
1040                };
1041                let elem_1 = {
1042                    let r = builder.add_array_get(int_ty.clone(), 2, arr, us1).unwrap();
1043                    builder
1044                        .build_unwrap_sum::<1>(1, option_type(int_ty), r)
1045                        .unwrap()[0]
1046                };
1047                let expected_elem_0 =
1048                    builder.add_load_value(ConstInt::new_u(3, expected_arr[0]).unwrap());
1049                let elem_0_ok = builder.add_ieq(3, elem_0, expected_elem_0).unwrap();
1050                let expected_elem_1 =
1051                    builder.add_load_value(ConstInt::new_u(3, expected_arr[1]).unwrap());
1052                let elem_1_ok = builder.add_ieq(3, elem_1, expected_elem_1).unwrap();
1053                let r = builder.add_and(was_expected_success, elem_0_ok).unwrap();
1054                let r = builder.add_and(r, elem_1_ok).unwrap();
1055                let r = {
1056                    let mut conditional = builder
1057                        .conditional_builder(([type_row![], type_row![]], r), [], usize_t().into())
1058                        .unwrap();
1059                    conditional
1060                        .case_builder(0)
1061                        .unwrap()
1062                        .finish_with_outputs([us0])
1063                        .unwrap();
1064                    conditional
1065                        .case_builder(1)
1066                        .unwrap()
1067                        .finish_with_outputs([us1])
1068                        .unwrap();
1069                    conditional.finish_sub_container().unwrap().out_wire(0)
1070                };
1071                builder.finish_with_outputs([r]).unwrap()
1072            });
1073        exec_ctx.add_extensions(|cge| {
1074            cge.add_default_prelude_extensions()
1075                .add_default_array_extensions()
1076                .add_int_extensions()
1077                .add_logic_extensions()
1078        });
1079        assert_eq!(1, exec_ctx.exec_hugr_u64(hugr, "main"));
1080    }
1081
1082    #[rstest]
1083    #[case(true, 0, 0)]
1084    #[case(true, 1, 1)]
1085    #[case(true, 2, 3)]
1086    #[case(true, 3, 7)]
1087    #[case(false, 0, 0)]
1088    #[case(false, 1, 4)]
1089    #[case(false, 2, 6)]
1090    #[case(false, 3, 7)]
1091    fn exec_pop(
1092        mut exec_ctx: TestContext,
1093        #[case] from_left: bool,
1094        #[case] num: usize,
1095        #[case] expected: u64,
1096    ) {
1097        // We build a HUGR that:
1098        // - Creates an array: [1,2,4]
1099        // - Pops `num` elements from the left or right
1100        // - Returns the sum of the popped elements
1101
1102        let array_contents = [1, 2, 4];
1103        let int_ty = int_type(6);
1104        let hugr = SimpleHugrConfig::new()
1105            .with_outs(int_ty.clone())
1106            .with_extensions(exec_registry())
1107            .finish(|mut builder| {
1108                let mut r = builder.add_load_value(ConstInt::new_u(6, 0).unwrap());
1109                let new_array_args = array_contents
1110                    .iter()
1111                    .map(|&i| builder.add_load_value(ConstInt::new_u(6, i).unwrap()))
1112                    .collect_vec();
1113                let mut arr = builder
1114                    .add_new_array(int_ty.clone(), new_array_args)
1115                    .unwrap();
1116                for i in 0..num {
1117                    let array_size = (array_contents.len() - i) as u64;
1118                    let pop_res = if from_left {
1119                        builder
1120                            .add_array_pop_left(int_ty.clone(), array_size, arr)
1121                            .unwrap()
1122                    } else {
1123                        builder
1124                            .add_array_pop_right(int_ty.clone(), array_size, arr)
1125                            .unwrap()
1126                    };
1127                    let [elem, new_arr] = builder
1128                        .build_unwrap_sum(
1129                            1,
1130                            option_type(vec![
1131                                int_ty.clone(),
1132                                array_type(array_size - 1, int_ty.clone()),
1133                            ]),
1134                            pop_res,
1135                        )
1136                        .unwrap();
1137                    arr = new_arr;
1138                    r = builder.add_iadd(6, r, elem).unwrap();
1139                }
1140                builder.finish_with_outputs([r]).unwrap()
1141            });
1142        exec_ctx.add_extensions(|cge| {
1143            cge.add_default_prelude_extensions()
1144                .add_default_array_extensions()
1145                .add_int_extensions()
1146        });
1147        assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
1148    }
1149
1150    #[rstest]
1151    #[case(5, 42, 0)]
1152    #[case(5, 42, 1)]
1153    #[case(5, 42, 2)]
1154    #[case(5, 42, 3)]
1155    #[case(5, 42, 4)]
1156    fn exec_repeat(
1157        mut exec_ctx: TestContext,
1158        #[case] size: u64,
1159        #[case] value: u64,
1160        #[case] idx: u64,
1161    ) {
1162        // We build a HUGR that:
1163        // - Contains a nested function that returns `value`
1164        // - Creates an array of length `size` populated via this function
1165        // - Looks up the value at `idx` and returns it
1166
1167        let int_ty = int_type(6);
1168        let hugr = SimpleHugrConfig::new()
1169            .with_outs(int_ty.clone())
1170            .with_extensions(exec_registry())
1171            .finish(|mut builder| {
1172                let mut func = builder
1173                    .define_function(
1174                        "foo",
1175                        Signature::new(vec![], vec![int_ty.clone()])
1176                            .with_extension_delta(exec_extension_set()),
1177                    )
1178                    .unwrap();
1179                let v = func.add_load_value(ConstInt::new_u(6, value).unwrap());
1180                let func_id = func.finish_with_outputs(vec![v]).unwrap();
1181                let func_v = builder.load_func(func_id.handle(), &[]).unwrap();
1182                let repeat = ArrayRepeat::new(int_ty.clone(), size, exec_extension_set());
1183                let arr = builder
1184                    .add_dataflow_op(repeat, vec![func_v])
1185                    .unwrap()
1186                    .out_wire(0);
1187                let idx_v = builder.add_load_value(ConstUsize::new(idx));
1188                let get_res = builder
1189                    .add_array_get(int_ty.clone(), size, arr, idx_v)
1190                    .unwrap();
1191                let [elem] = builder
1192                    .build_unwrap_sum(1, option_type(vec![int_ty.clone()]), get_res)
1193                    .unwrap();
1194                builder.finish_with_outputs([elem]).unwrap()
1195            });
1196        exec_ctx.add_extensions(|cge| {
1197            cge.add_default_prelude_extensions()
1198                .add_default_array_extensions()
1199                .add_int_extensions()
1200        });
1201        assert_eq!(value, exec_ctx.exec_hugr_u64(hugr, "main"));
1202    }
1203
1204    #[rstest]
1205    #[case(10, 1)]
1206    #[case(10, 2)]
1207    #[case(0, 1)]
1208    fn exec_scan_map(mut exec_ctx: TestContext, #[case] size: u64, #[case] inc: u64) {
1209        // We build a HUGR that:
1210        // - Creates an array [1, 2, 3, ..., size]
1211        // - Maps a function that increments each element by `inc`
1212        // - Returns the sum of the array elements
1213
1214        let int_ty = int_type(6);
1215        let hugr = SimpleHugrConfig::new()
1216            .with_outs(int_ty.clone())
1217            .with_extensions(exec_registry())
1218            .finish(|mut builder| {
1219                let mut r = builder.add_load_value(ConstInt::new_u(6, 0).unwrap());
1220                let new_array_args = (0..size)
1221                    .map(|i| builder.add_load_value(ConstInt::new_u(6, i).unwrap()))
1222                    .collect_vec();
1223                let arr = builder
1224                    .add_new_array(int_ty.clone(), new_array_args)
1225                    .unwrap();
1226
1227                let mut func = builder
1228                    .define_function(
1229                        "foo",
1230                        Signature::new(vec![int_ty.clone()], vec![int_ty.clone()])
1231                            .with_extension_delta(exec_extension_set()),
1232                    )
1233                    .unwrap();
1234                let [elem] = func.input_wires_arr();
1235                let delta = func.add_load_value(ConstInt::new_u(6, inc).unwrap());
1236                let out = func.add_iadd(6, elem, delta).unwrap();
1237                let func_id = func.finish_with_outputs(vec![out]).unwrap();
1238                let func_v = builder.load_func(func_id.handle(), &[]).unwrap();
1239                let scan = ArrayScan::new(
1240                    int_ty.clone(),
1241                    int_ty.clone(),
1242                    vec![],
1243                    size,
1244                    exec_extension_set(),
1245                );
1246                let mut arr = builder
1247                    .add_dataflow_op(scan, [arr, func_v])
1248                    .unwrap()
1249                    .out_wire(0);
1250
1251                for i in 0..size {
1252                    let array_size = size - i;
1253                    let pop_res = builder
1254                        .add_array_pop_left(int_ty.clone(), array_size, arr)
1255                        .unwrap();
1256                    let [elem, new_arr] = builder
1257                        .build_unwrap_sum(
1258                            1,
1259                            option_type(vec![
1260                                int_ty.clone(),
1261                                array_type(array_size - 1, int_ty.clone()),
1262                            ]),
1263                            pop_res,
1264                        )
1265                        .unwrap();
1266                    arr = new_arr;
1267                    r = builder.add_iadd(6, r, elem).unwrap();
1268                }
1269                builder.finish_with_outputs([r]).unwrap()
1270            });
1271        exec_ctx.add_extensions(|cge| {
1272            cge.add_default_prelude_extensions()
1273                .add_default_array_extensions()
1274                .add_int_extensions()
1275        });
1276        let expected: u64 = (inc..size + inc).sum();
1277        assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
1278    }
1279
1280    #[rstest]
1281    #[case(0)]
1282    #[case(1)]
1283    #[case(10)]
1284    fn exec_scan_fold(mut exec_ctx: TestContext, #[case] size: u64) {
1285        // We build a HUGR that:
1286        // - Creates an array [1, 2, 3, ..., size]
1287        // - Sums up the elements of the array using a scan and returns that sum
1288
1289        let int_ty = int_type(6);
1290        let hugr = SimpleHugrConfig::new()
1291            .with_outs(int_ty.clone())
1292            .with_extensions(exec_registry())
1293            .finish(|mut builder| {
1294                let new_array_args = (0..size)
1295                    .map(|i| builder.add_load_value(ConstInt::new_u(6, i).unwrap()))
1296                    .collect_vec();
1297                let arr = builder
1298                    .add_new_array(int_ty.clone(), new_array_args)
1299                    .unwrap();
1300
1301                let mut func = builder
1302                    .define_function(
1303                        "foo",
1304                        Signature::new(
1305                            vec![int_ty.clone(), int_ty.clone()],
1306                            vec![Type::UNIT, int_ty.clone()],
1307                        )
1308                        .with_extension_delta(exec_extension_set()),
1309                    )
1310                    .unwrap();
1311                let [elem, acc] = func.input_wires_arr();
1312                let acc = func.add_iadd(6, elem, acc).unwrap();
1313                let unit = func
1314                    .add_dataflow_op(Tag::new(0, vec![type_row![]]), [])
1315                    .unwrap()
1316                    .out_wire(0);
1317                let func_id = func.finish_with_outputs(vec![unit, acc]).unwrap();
1318                let func_v = builder.load_func(func_id.handle(), &[]).unwrap();
1319                let scan = ArrayScan::new(
1320                    int_ty.clone(),
1321                    Type::UNIT,
1322                    vec![int_ty.clone()],
1323                    size,
1324                    exec_extension_set(),
1325                );
1326                let zero = builder.add_load_value(ConstInt::new_u(6, 0).unwrap());
1327                let sum = builder
1328                    .add_dataflow_op(scan, [arr, func_v, zero])
1329                    .unwrap()
1330                    .out_wire(1);
1331                builder.finish_with_outputs([sum]).unwrap()
1332            });
1333        exec_ctx.add_extensions(|cge| {
1334            cge.add_default_prelude_extensions()
1335                .add_default_array_extensions()
1336                .add_int_extensions()
1337        });
1338        let expected: u64 = (0..size).sum();
1339        assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
1340    }
1341}