hugr_llvm/extension/collections/
static_array.rs

1use std::hash::Hasher as _;
2
3use hugr_core::{
4    HugrView, Node,
5    extension::{
6        prelude::{option_type, usize_t},
7        simple_op::HasConcrete as _,
8    },
9    ops::{ExtensionOp, constant::TryHash},
10    std_extensions::collections::static_array::{
11        self, StaticArrayOp, StaticArrayOpDef, StaticArrayValue,
12    },
13};
14use inkwell::{
15    AddressSpace, IntPredicate,
16    builder::Builder,
17    context::Context,
18    types::{BasicType, BasicTypeEnum, StructType},
19    values::{ArrayValue, BasicValue, BasicValueEnum, IntValue, PointerValue},
20};
21use itertools::Itertools as _;
22
23use crate::{
24    CodegenExtension, CodegenExtsBuilder,
25    emit::{EmitFuncContext, EmitOpArgs, emit_value},
26    types::{HugrType, TypingSession},
27};
28
29use anyhow::{Result, bail};
30
31#[derive(Debug, Clone, derive_more::From)]
32/// A [`CodegenExtension`] that lowers the
33/// [`hugr_core::std_extensions::collections::static_array`].
34///
35/// All behaviour is delegated to `SACG`.
36pub struct StaticArrayCodegenExtension<SACG>(SACG);
37
38impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
39    /// Add a [`StaticArrayCodegenExtension`] to the given [`CodegenExtsBuilder`] using `ccg`
40    /// as the implementation.
41    pub fn add_static_array_extensions(self, ccg: impl StaticArrayCodegen + 'static) -> Self {
42        self.add_extension(StaticArrayCodegenExtension::from(ccg))
43    }
44
45    /// Add a [`StaticArrayCodegenExtension`] to the given [`CodegenExtsBuilder`] using
46    /// [`DefaultStaticArrayCodegen`] as the implementation.
47    #[must_use]
48    pub fn add_default_static_array_extensions(self) -> Self {
49        self.add_static_array_extensions(DefaultStaticArrayCodegen)
50    }
51}
52
53// This is not provided by inkwell, it seems like it should be
54fn value_is_const<'c>(value: impl BasicValue<'c>) -> bool {
55    match value.as_basic_value_enum() {
56        BasicValueEnum::ArrayValue(v) => v.is_const(),
57        BasicValueEnum::IntValue(v) => v.is_const(),
58        BasicValueEnum::FloatValue(v) => v.is_const(),
59        BasicValueEnum::PointerValue(v) => v.is_const(),
60        BasicValueEnum::StructValue(v) => v.is_const(),
61        BasicValueEnum::VectorValue(v) => v.is_const(),
62        BasicValueEnum::ScalableVectorValue(v) => v.is_const(),
63    }
64}
65
66// This is not provided by inkwell, it seems like it should be
67fn const_array<'c>(
68    ty: impl BasicType<'c>,
69    values: impl IntoIterator<Item = impl BasicValue<'c>>,
70) -> ArrayValue<'c> {
71    match ty.as_basic_type_enum() {
72        BasicTypeEnum::ArrayType(t) => t.const_array(
73            values
74                .into_iter()
75                .map(|x| x.as_basic_value_enum().into_array_value())
76                .collect_vec()
77                .as_slice(),
78        ),
79        BasicTypeEnum::FloatType(t) => t.const_array(
80            values
81                .into_iter()
82                .map(|x| x.as_basic_value_enum().into_float_value())
83                .collect_vec()
84                .as_slice(),
85        ),
86        BasicTypeEnum::IntType(t) => t.const_array(
87            values
88                .into_iter()
89                .map(|x| x.as_basic_value_enum().into_int_value())
90                .collect_vec()
91                .as_slice(),
92        ),
93        BasicTypeEnum::PointerType(t) => t.const_array(
94            values
95                .into_iter()
96                .map(|x| x.as_basic_value_enum().into_pointer_value())
97                .collect_vec()
98                .as_slice(),
99        ),
100        BasicTypeEnum::StructType(t) => t.const_array(
101            values
102                .into_iter()
103                .map(|x| x.as_basic_value_enum().into_struct_value())
104                .collect_vec()
105                .as_slice(),
106        ),
107        BasicTypeEnum::VectorType(t) => t.const_array(
108            values
109                .into_iter()
110                .map(|x| x.as_basic_value_enum().into_vector_value())
111                .collect_vec()
112                .as_slice(),
113        ),
114        BasicTypeEnum::ScalableVectorType(t) => t.const_array(
115            values
116                .into_iter()
117                .map(|x| x.as_basic_value_enum().into_scalable_vector_value())
118                .collect_vec()
119                .as_slice(),
120        ),
121    }
122}
123
124fn static_array_struct_type<'c>(
125    context: &'c Context,
126    index_type: impl BasicType<'c>,
127    element_type: impl BasicType<'c>,
128    len: u32,
129) -> StructType<'c> {
130    context.struct_type(
131        &[
132            index_type.as_basic_type_enum(),
133            element_type.array_type(len).into(),
134        ],
135        false,
136    )
137}
138
139fn build_read_len<'c>(
140    context: &'c Context,
141    builder: &Builder<'c>,
142    struct_ty: StructType<'c>,
143    mut ptr: PointerValue<'c>,
144) -> Result<IntValue<'c>> {
145    let canonical_ptr_ty = struct_ty.ptr_type(AddressSpace::default());
146    if ptr.get_type() != canonical_ptr_ty {
147        ptr = builder.build_pointer_cast(ptr, canonical_ptr_ty, "")?;
148    }
149    let i32_ty = context.i32_type();
150    let indices = [i32_ty.const_zero(), i32_ty.const_zero()];
151    let len_ptr = unsafe { builder.build_in_bounds_gep(ptr, &indices, "") }?;
152    Ok(builder.build_load(len_ptr, "")?.into_int_value())
153}
154
155/// A helper trait for customising the lowering of [`hugr_core::std_extensions::collections::static_array`]
156/// types, [`hugr_core::ops::constant::CustomConst`]s, and ops.
157pub trait StaticArrayCodegen: Clone {
158    /// Return the llvm type of
159    /// [`hugr_core::std_extensions::collections::static_array::STATIC_ARRAY_TYPENAME`].
160    ///
161    /// By default a static array of llvm type `t` and length `l` is stored in a
162    /// global of type `struct { i64, [t * l] }``
163    ///
164    /// The `i64` stores the length of the array.
165    ///
166    /// However a `static_array` `HugrType` is represented by an llvm pointer type
167    /// `struct {i64, [t * 0]}`;  i.e. the array is zero length. This gives all
168    /// static arrays of the same element type a uniform llvm type.
169    ///
170    /// It is legal to index past the end of an array (it is only undefined behaviour
171    /// to index past the allocation).
172    fn static_array_type<'c>(
173        &self,
174        session: TypingSession<'c, '_>,
175        element_type: &HugrType,
176    ) -> Result<BasicTypeEnum<'c>> {
177        let index_type = session.llvm_type(&usize_t())?;
178        let element_type = session.llvm_type(element_type)?;
179        Ok(
180            static_array_struct_type(session.iw_context(), index_type, element_type, 0)
181                .ptr_type(AddressSpace::default())
182                .into(),
183        )
184    }
185
186    /// Emit a
187    /// [`hugr_core::std_extensions::collections::static_array::StaticArrayValue`].
188    ///
189    /// Note that the type of the return value must match the type returned by
190    /// [`Self::static_array_type`].
191    ///
192    /// By default a global is created and we return a pointer to it.
193    fn static_array_value<'c, H: HugrView<Node = Node>>(
194        &self,
195        context: &mut EmitFuncContext<'c, '_, H>,
196        value: &StaticArrayValue,
197    ) -> Result<BasicValueEnum<'c>> {
198        let element_type = value.get_element_type();
199        let llvm_element_type = context.llvm_type(element_type)?;
200        let index_type = context.llvm_type(&usize_t())?.into_int_type();
201        let array_elements = value.get_contents().iter().map(|v| {
202            let value = emit_value(context, v)?;
203            if !value_is_const(value) {
204                anyhow::bail!("Static array value must be constant. HUGR value '{v:?}' was codegened as non-const");
205            }
206            Ok(value)
207        }).collect::<Result<Vec<_>>>()?;
208        let len = array_elements.len();
209        let struct_ty = static_array_struct_type(
210            context.iw_context(),
211            index_type,
212            llvm_element_type,
213            len as u32,
214        );
215        let array_value = struct_ty.const_named_struct(&[
216            index_type.const_int(len as u64, false).into(),
217            const_array(llvm_element_type, array_elements).into(),
218        ]);
219
220        let gv = {
221            let module = context.get_current_module();
222            let hash = {
223                let mut hasher = std::collections::hash_map::DefaultHasher::new();
224                let _ = value.try_hash(&mut hasher);
225                hasher.finish() as u32 // a bit shorter than u64
226            };
227            let prefix = format!("sa.{}.{hash:x}.", value.name);
228            (0..)
229                .find_map(|i| {
230                    let sym = format!("{prefix}{i}");
231                    if let Some(global) = module.get_global(&sym) {
232                        // Note this comparison may be expensive for large
233                        // values.  We could avoid it(and therefore avoid
234                        // creating array_value in this branch) if we had
235                        // https://github.com/CQCL/hugr/issues/2004
236                        if global.get_initializer().is_some_and(|x| x == array_value) {
237                            Some(global)
238                        } else {
239                            None
240                        }
241                    } else {
242                        let global = module.add_global(struct_ty, None, &sym);
243                        global.set_constant(true);
244                        global.set_initializer(&array_value);
245                        Some(global)
246                    }
247                })
248                .unwrap()
249        };
250        let canonical_type = self
251            .static_array_type(context.typing_session(), value.get_element_type())?
252            .into_pointer_type();
253        Ok(gv.as_pointer_value().const_cast(canonical_type).into())
254    }
255
256    /// Emit a [`hugr_core::std_extensions::collections::static_array::StaticArrayOp`].
257    fn static_array_op<'c, H: HugrView<Node = Node>>(
258        &self,
259        context: &mut EmitFuncContext<'c, '_, H>,
260        args: EmitOpArgs<'c, '_, ExtensionOp, H>,
261        op: StaticArrayOp,
262    ) -> Result<()> {
263        match op.def {
264            StaticArrayOpDef::get => {
265                let ptr = args.inputs[0].into_pointer_value();
266                let index = args.inputs[1].into_int_value();
267                let index_ty = index.get_type();
268                let element_llvm_ty = context.llvm_type(&op.elem_ty)?;
269                let struct_ty =
270                    static_array_struct_type(context.iw_context(), index_ty, element_llvm_ty, 0);
271
272                let len = build_read_len(context.iw_context(), context.builder(), struct_ty, ptr)?;
273
274                let result_sum_ty = option_type(op.elem_ty);
275                let rmb = context.new_row_mail_box([&result_sum_ty.clone().into()], "")?;
276                let result_llvm_sum_ty = context.llvm_sum_type(result_sum_ty)?;
277
278                let exit_block = context.build_positioned_new_block(
279                    "static_array_get_exit",
280                    context.builder().get_insert_block(),
281                    |context, bb| {
282                        args.outputs
283                            .finish(context.builder(), rmb.read_vec(context.builder(), [])?)?;
284                        anyhow::Ok(bb)
285                    },
286                )?;
287
288                let fail_block = context.build_positioned_new_block(
289                    "static_array_get_out_of_bounds",
290                    Some(exit_block),
291                    |context, bb| {
292                        rmb.write(
293                            context.builder(),
294                            [result_llvm_sum_ty
295                                .build_tag(context.builder(), 0, vec![])?
296                                .into()],
297                        )?;
298                        context.builder().build_unconditional_branch(exit_block)?;
299                        anyhow::Ok(bb)
300                    },
301                )?;
302
303                let success_block = context.build_positioned_new_block(
304                    "static_array_get_in_bounds",
305                    Some(exit_block),
306                    |context, bb| {
307                        let i32_ty = context.iw_context().i32_type();
308                        let indices = [i32_ty.const_zero(), i32_ty.const_int(1, false), index];
309                        let element_ptr =
310                            unsafe { context.builder().build_in_bounds_gep(ptr, &indices, "") }?;
311                        let element = context.builder().build_load(element_ptr, "")?;
312                        rmb.write(
313                            context.builder(),
314                            [result_llvm_sum_ty
315                                .build_tag(context.builder(), 1, vec![element])?
316                                .into()],
317                        )?;
318                        context.builder().build_unconditional_branch(exit_block)?;
319                        anyhow::Ok(bb)
320                    },
321                )?;
322
323                let inbounds =
324                    context
325                        .builder()
326                        .build_int_compare(IntPredicate::ULT, index, len, "")?;
327                context
328                    .builder()
329                    .build_conditional_branch(inbounds, success_block, fail_block)?;
330
331                context.builder().position_at_end(exit_block);
332                Ok(())
333            }
334            StaticArrayOpDef::len => {
335                let ptr = args.inputs[0].into_pointer_value();
336                let element_llvm_ty = context.llvm_type(&op.elem_ty)?;
337                let index_ty = args.outputs.get_types().next().unwrap().into_int_type();
338                let struct_ty =
339                    static_array_struct_type(context.iw_context(), index_ty, element_llvm_ty, 0);
340                let len = build_read_len(context.iw_context(), context.builder(), struct_ty, ptr)?;
341                args.outputs.finish(context.builder(), [len.into()])
342            }
343            op => bail!("StaticArrayCodegen: Unsupported op: {op:?}"),
344        }
345    }
346}
347
348#[derive(Debug, Clone)]
349/// An implementation of [`StaticArrayCodegen`] that uses all default
350/// implementations.
351pub struct DefaultStaticArrayCodegen;
352
353impl StaticArrayCodegen for DefaultStaticArrayCodegen {}
354
355impl<SAC: StaticArrayCodegen + 'static> CodegenExtension for StaticArrayCodegenExtension<SAC> {
356    fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
357        self,
358        builder: CodegenExtsBuilder<'a, H>,
359    ) -> CodegenExtsBuilder<'a, H>
360    where
361        Self: 'a,
362    {
363        builder
364            .custom_type(
365                (
366                    static_array::EXTENSION_ID,
367                    static_array::STATIC_ARRAY_TYPENAME,
368                ),
369                {
370                    let sac = self.0.clone();
371                    move |ts, custom_type| {
372                        let element_type = custom_type.args()[0]
373                            .as_runtime()
374                            .expect("Type argument for static array must be a type");
375                        sac.static_array_type(ts, &element_type)
376                    }
377                },
378            )
379            .custom_const::<StaticArrayValue>({
380                let sac = self.0.clone();
381                move |context, sav| sac.static_array_value(context, sav)
382            })
383            .simple_extension_op::<StaticArrayOpDef>({
384                let sac = self.0.clone();
385                move |context, args, op| {
386                    let op = op.instantiate(args.node().args())?;
387                    sac.static_array_op(context, args, op)
388                }
389            })
390    }
391}
392
393#[cfg(test)]
394mod test {
395    use super::*;
396    use float_types::float64_type;
397    use hugr_core::builder::DataflowHugr;
398    use hugr_core::extension::prelude::ConstUsize;
399    use hugr_core::ops::OpType;
400    use hugr_core::ops::Value;
401    use hugr_core::ops::constant::CustomConst;
402    use hugr_core::std_extensions::arithmetic::float_types::{self, ConstF64};
403    use rstest::rstest;
404
405    use hugr_core::extension::simple_op::MakeRegisteredOp;
406    use hugr_core::extension::{ExtensionRegistry, prelude::bool_t};
407    use hugr_core::{builder::SubContainer as _, type_row};
408    use static_array::StaticArrayOpBuilder as _;
409
410    use crate::check_emission;
411    use crate::test::single_op_hugr;
412    use crate::{
413        emit::test::SimpleHugrConfig,
414        test::{TestContext, exec_ctx, llvm_ctx},
415    };
416    use hugr_core::builder::{Dataflow as _, DataflowSubContainer as _};
417
418    #[rstest]
419    #[case(0, StaticArrayOpDef::get, usize_t())]
420    #[case(1, StaticArrayOpDef::get, bool_t())]
421    #[case(2, StaticArrayOpDef::len, usize_t())]
422    #[case(3, StaticArrayOpDef::len, bool_t())]
423    fn static_array_op_codegen(
424        #[case] _i: i32,
425        #[with(_i)] mut llvm_ctx: TestContext,
426        #[case] op: StaticArrayOpDef,
427        #[case] ty: HugrType,
428    ) {
429        let op = op.instantiate(&[ty.clone().into()]).unwrap();
430        let op = OpType::from(op.to_extension_op().unwrap());
431        llvm_ctx.add_extensions(|ceb| {
432            ceb.add_default_static_array_extensions()
433                .add_default_prelude_extensions()
434        });
435        let hugr = single_op_hugr(op);
436        check_emission!(hugr, llvm_ctx);
437    }
438
439    #[rstest]
440    #[case(0, StaticArrayValue::try_new("a", usize_t(), (0..10).map(|x| ConstUsize::new(x).into())).unwrap())]
441    #[case(1, StaticArrayValue::try_new("b", float64_type(), (0..10).map(|x| ConstF64::new(f64::from(x)).into())).unwrap())]
442    #[case(2, StaticArrayValue::try_new("c", bool_t(), (0..10).map(|x| Value::from_bool(x % 2 == 0))).unwrap())]
443    #[case(3, StaticArrayValue::try_new("d", option_type(usize_t()).into(), (0..10).map(|x| Value::some([ConstUsize::new(x)]))).unwrap())]
444    fn static_array_const_codegen(
445        #[case] _i: i32,
446        #[with(_i)] mut llvm_ctx: TestContext,
447        #[case] value: StaticArrayValue,
448    ) {
449        llvm_ctx.add_extensions(|ceb| {
450            ceb.add_default_static_array_extensions()
451                .add_default_prelude_extensions()
452                .add_float_extensions()
453        });
454
455        let hugr = SimpleHugrConfig::new()
456            .with_outs(value.get_type())
457            .with_extensions(ExtensionRegistry::new(vec![
458                static_array::EXTENSION.to_owned(),
459                float_types::EXTENSION.to_owned(),
460            ]))
461            .finish(|mut builder| {
462                let a = builder.add_load_value(value);
463                builder.finish_hugr_with_outputs([a]).unwrap()
464            });
465        check_emission!(hugr, llvm_ctx);
466    }
467
468    #[rstest]
469    #[case(0, 0, 999)]
470    #[case(1, 1, 998)]
471    #[case(2, 1000, u64::MAX)]
472    fn static_array_exec(
473        #[case] _i: i32,
474        #[with(_i)] mut exec_ctx: TestContext,
475        #[case] index: u64,
476        #[case] expected: u64,
477    ) {
478        let hugr = SimpleHugrConfig::new()
479            .with_outs(usize_t())
480            .with_extensions(ExtensionRegistry::new(vec![
481                static_array::EXTENSION.to_owned(),
482            ]))
483            .finish(|mut builder| {
484                let arr = builder.add_load_value(
485                    StaticArrayValue::try_new(
486                        "exec_arr",
487                        usize_t(),
488                        (0..1000)
489                            .map(|x| ConstUsize::new(999 - x).into())
490                            .collect_vec(),
491                    )
492                    .unwrap(),
493                );
494                let index = builder.add_load_value(ConstUsize::new(index));
495                let get_r = builder.add_static_array_get(usize_t(), arr, index).unwrap();
496                let [out] = {
497                    let mut cond = builder
498                        .conditional_builder(
499                            ([type_row!(), usize_t().into()], get_r),
500                            [],
501                            usize_t().into(),
502                        )
503                        .unwrap();
504                    {
505                        let mut oob_case = cond.case_builder(0).unwrap();
506                        let err = oob_case.add_load_value(ConstUsize::new(u64::MAX));
507                        oob_case.finish_with_outputs([err]).unwrap();
508                    }
509                    {
510                        let inbounds_case = cond.case_builder(1).unwrap();
511                        let [out] = inbounds_case.input_wires_arr();
512                        inbounds_case.finish_with_outputs([out]).unwrap();
513                    }
514                    cond.finish_sub_container().unwrap().outputs_arr()
515                };
516                builder.finish_hugr_with_outputs([out]).unwrap()
517            });
518
519        exec_ctx.add_extensions(|ceb| {
520            ceb.add_default_static_array_extensions()
521                .add_default_prelude_extensions()
522                .add_float_extensions()
523        });
524        assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
525    }
526
527    #[rstest]
528    fn len_0_array(mut exec_ctx: TestContext) {
529        let hugr = SimpleHugrConfig::new()
530            .with_outs(usize_t())
531            .with_extensions(ExtensionRegistry::new(vec![
532                static_array::EXTENSION.to_owned(),
533            ]))
534            .finish(|mut builder| {
535                let arr = builder
536                    .add_load_value(StaticArrayValue::try_new("empty", usize_t(), vec![]).unwrap());
537                let len = builder.add_static_array_len(usize_t(), arr).unwrap();
538                builder.finish_hugr_with_outputs([len]).unwrap()
539            });
540
541        exec_ctx.add_extensions(|ceb| {
542            ceb.add_default_static_array_extensions()
543                .add_default_prelude_extensions()
544        });
545        assert_eq!(0, exec_ctx.exec_hugr_u64(hugr, "main"));
546    }
547
548    #[rstest]
549    fn emit_static_array_of_static_array(mut llvm_ctx: TestContext) {
550        llvm_ctx.add_extensions(|ceb| {
551            ceb.add_default_static_array_extensions()
552                .add_default_prelude_extensions()
553        });
554        let hugr = SimpleHugrConfig::new()
555            .with_outs(usize_t())
556            .with_extensions(ExtensionRegistry::new(vec![
557                static_array::EXTENSION.to_owned(),
558            ]))
559            .finish(|mut builder| {
560                let inner_arrs: Vec<Value> = (0..10)
561                    .map(|i| {
562                        StaticArrayValue::try_new(
563                            "inner",
564                            usize_t(),
565                            vec![Value::from(ConstUsize::new(i)); i as usize],
566                        )
567                        .unwrap()
568                        .into()
569                    })
570                    .collect_vec();
571                let inner_arr_ty = inner_arrs[0].get_type();
572                let outer_arr = builder.add_load_value(
573                    StaticArrayValue::try_new("outer", inner_arr_ty.clone(), inner_arrs).unwrap(),
574                );
575                let len = builder
576                    .add_static_array_len(inner_arr_ty, outer_arr)
577                    .unwrap();
578                builder.finish_hugr_with_outputs([len]).unwrap()
579            });
580        check_emission!(hugr, llvm_ctx);
581    }
582}