1use std::hash::Hasher as _;
2
3use hugr_core::{
4 HugrView, Node,
5 extension::{
6 prelude::{option_type, usize_t},
7 simple_op::HasConcrete as _,
8 },
9 ops::{ExtensionOp, constant::TryHash},
10 std_extensions::collections::static_array::{
11 self, StaticArrayOp, StaticArrayOpDef, StaticArrayValue,
12 },
13};
14use inkwell::{
15 AddressSpace, IntPredicate,
16 builder::Builder,
17 context::Context,
18 types::{BasicType, BasicTypeEnum, StructType},
19 values::{ArrayValue, BasicValue, BasicValueEnum, IntValue, PointerValue},
20};
21use itertools::Itertools as _;
22
23use crate::{
24 CodegenExtension, CodegenExtsBuilder,
25 emit::{EmitFuncContext, EmitOpArgs, emit_value},
26 types::{HugrType, TypingSession},
27};
28
29use anyhow::{Result, bail};
30
31#[derive(Debug, Clone, derive_more::From)]
32pub struct StaticArrayCodegenExtension<SACG>(SACG);
37
38impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
39 pub fn add_static_array_extensions(self, ccg: impl StaticArrayCodegen + 'static) -> Self {
42 self.add_extension(StaticArrayCodegenExtension::from(ccg))
43 }
44
45 #[must_use]
48 pub fn add_default_static_array_extensions(self) -> Self {
49 self.add_static_array_extensions(DefaultStaticArrayCodegen)
50 }
51}
52
53fn value_is_const<'c>(value: impl BasicValue<'c>) -> bool {
55 match value.as_basic_value_enum() {
56 BasicValueEnum::ArrayValue(v) => v.is_const(),
57 BasicValueEnum::IntValue(v) => v.is_const(),
58 BasicValueEnum::FloatValue(v) => v.is_const(),
59 BasicValueEnum::PointerValue(v) => v.is_const(),
60 BasicValueEnum::StructValue(v) => v.is_const(),
61 BasicValueEnum::VectorValue(v) => v.is_const(),
62 BasicValueEnum::ScalableVectorValue(v) => v.is_const(),
63 }
64}
65
66fn const_array<'c>(
68 ty: impl BasicType<'c>,
69 values: impl IntoIterator<Item = impl BasicValue<'c>>,
70) -> ArrayValue<'c> {
71 match ty.as_basic_type_enum() {
72 BasicTypeEnum::ArrayType(t) => t.const_array(
73 values
74 .into_iter()
75 .map(|x| x.as_basic_value_enum().into_array_value())
76 .collect_vec()
77 .as_slice(),
78 ),
79 BasicTypeEnum::FloatType(t) => t.const_array(
80 values
81 .into_iter()
82 .map(|x| x.as_basic_value_enum().into_float_value())
83 .collect_vec()
84 .as_slice(),
85 ),
86 BasicTypeEnum::IntType(t) => t.const_array(
87 values
88 .into_iter()
89 .map(|x| x.as_basic_value_enum().into_int_value())
90 .collect_vec()
91 .as_slice(),
92 ),
93 BasicTypeEnum::PointerType(t) => t.const_array(
94 values
95 .into_iter()
96 .map(|x| x.as_basic_value_enum().into_pointer_value())
97 .collect_vec()
98 .as_slice(),
99 ),
100 BasicTypeEnum::StructType(t) => t.const_array(
101 values
102 .into_iter()
103 .map(|x| x.as_basic_value_enum().into_struct_value())
104 .collect_vec()
105 .as_slice(),
106 ),
107 BasicTypeEnum::VectorType(t) => t.const_array(
108 values
109 .into_iter()
110 .map(|x| x.as_basic_value_enum().into_vector_value())
111 .collect_vec()
112 .as_slice(),
113 ),
114 BasicTypeEnum::ScalableVectorType(t) => t.const_array(
115 values
116 .into_iter()
117 .map(|x| x.as_basic_value_enum().into_scalable_vector_value())
118 .collect_vec()
119 .as_slice(),
120 ),
121 }
122}
123
124fn static_array_struct_type<'c>(
125 context: &'c Context,
126 index_type: impl BasicType<'c>,
127 element_type: impl BasicType<'c>,
128 len: u32,
129) -> StructType<'c> {
130 context.struct_type(
131 &[
132 index_type.as_basic_type_enum(),
133 element_type.array_type(len).into(),
134 ],
135 false,
136 )
137}
138
139fn build_read_len<'c>(
140 context: &'c Context,
141 builder: &Builder<'c>,
142 struct_ty: StructType<'c>,
143 mut ptr: PointerValue<'c>,
144) -> Result<IntValue<'c>> {
145 let canonical_ptr_ty = struct_ty.ptr_type(AddressSpace::default());
146 if ptr.get_type() != canonical_ptr_ty {
147 ptr = builder.build_pointer_cast(ptr, canonical_ptr_ty, "")?;
148 }
149 let i32_ty = context.i32_type();
150 let indices = [i32_ty.const_zero(), i32_ty.const_zero()];
151 let len_ptr = unsafe { builder.build_in_bounds_gep(ptr, &indices, "") }?;
152 Ok(builder.build_load(len_ptr, "")?.into_int_value())
153}
154
155pub trait StaticArrayCodegen: Clone {
158 fn static_array_type<'c>(
173 &self,
174 session: TypingSession<'c, '_>,
175 element_type: &HugrType,
176 ) -> Result<BasicTypeEnum<'c>> {
177 let index_type = session.llvm_type(&usize_t())?;
178 let element_type = session.llvm_type(element_type)?;
179 Ok(
180 static_array_struct_type(session.iw_context(), index_type, element_type, 0)
181 .ptr_type(AddressSpace::default())
182 .into(),
183 )
184 }
185
186 fn static_array_value<'c, H: HugrView<Node = Node>>(
194 &self,
195 context: &mut EmitFuncContext<'c, '_, H>,
196 value: &StaticArrayValue,
197 ) -> Result<BasicValueEnum<'c>> {
198 let element_type = value.get_element_type();
199 let llvm_element_type = context.llvm_type(element_type)?;
200 let index_type = context.llvm_type(&usize_t())?.into_int_type();
201 let array_elements = value.get_contents().iter().map(|v| {
202 let value = emit_value(context, v)?;
203 if !value_is_const(value) {
204 anyhow::bail!("Static array value must be constant. HUGR value '{v:?}' was codegened as non-const");
205 }
206 Ok(value)
207 }).collect::<Result<Vec<_>>>()?;
208 let len = array_elements.len();
209 let struct_ty = static_array_struct_type(
210 context.iw_context(),
211 index_type,
212 llvm_element_type,
213 len as u32,
214 );
215 let array_value = struct_ty.const_named_struct(&[
216 index_type.const_int(len as u64, false).into(),
217 const_array(llvm_element_type, array_elements).into(),
218 ]);
219
220 let gv = {
221 let module = context.get_current_module();
222 let hash = {
223 let mut hasher = std::collections::hash_map::DefaultHasher::new();
224 let _ = value.try_hash(&mut hasher);
225 hasher.finish() as u32 };
227 let prefix = format!("sa.{}.{hash:x}.", value.name);
228 (0..)
229 .find_map(|i| {
230 let sym = format!("{prefix}{i}");
231 if let Some(global) = module.get_global(&sym) {
232 if global.get_initializer().is_some_and(|x| x == array_value) {
237 Some(global)
238 } else {
239 None
240 }
241 } else {
242 let global = module.add_global(struct_ty, None, &sym);
243 global.set_constant(true);
244 global.set_initializer(&array_value);
245 Some(global)
246 }
247 })
248 .unwrap()
249 };
250 let canonical_type = self
251 .static_array_type(context.typing_session(), value.get_element_type())?
252 .into_pointer_type();
253 Ok(gv.as_pointer_value().const_cast(canonical_type).into())
254 }
255
256 fn static_array_op<'c, H: HugrView<Node = Node>>(
258 &self,
259 context: &mut EmitFuncContext<'c, '_, H>,
260 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
261 op: StaticArrayOp,
262 ) -> Result<()> {
263 match op.def {
264 StaticArrayOpDef::get => {
265 let ptr = args.inputs[0].into_pointer_value();
266 let index = args.inputs[1].into_int_value();
267 let index_ty = index.get_type();
268 let element_llvm_ty = context.llvm_type(&op.elem_ty)?;
269 let struct_ty =
270 static_array_struct_type(context.iw_context(), index_ty, element_llvm_ty, 0);
271
272 let len = build_read_len(context.iw_context(), context.builder(), struct_ty, ptr)?;
273
274 let result_sum_ty = option_type(op.elem_ty);
275 let rmb = context.new_row_mail_box([&result_sum_ty.clone().into()], "")?;
276 let result_llvm_sum_ty = context.llvm_sum_type(result_sum_ty)?;
277
278 let exit_block = context.build_positioned_new_block(
279 "static_array_get_exit",
280 context.builder().get_insert_block(),
281 |context, bb| {
282 args.outputs
283 .finish(context.builder(), rmb.read_vec(context.builder(), [])?)?;
284 anyhow::Ok(bb)
285 },
286 )?;
287
288 let fail_block = context.build_positioned_new_block(
289 "static_array_get_out_of_bounds",
290 Some(exit_block),
291 |context, bb| {
292 rmb.write(
293 context.builder(),
294 [result_llvm_sum_ty
295 .build_tag(context.builder(), 0, vec![])?
296 .into()],
297 )?;
298 context.builder().build_unconditional_branch(exit_block)?;
299 anyhow::Ok(bb)
300 },
301 )?;
302
303 let success_block = context.build_positioned_new_block(
304 "static_array_get_in_bounds",
305 Some(exit_block),
306 |context, bb| {
307 let i32_ty = context.iw_context().i32_type();
308 let indices = [i32_ty.const_zero(), i32_ty.const_int(1, false), index];
309 let element_ptr =
310 unsafe { context.builder().build_in_bounds_gep(ptr, &indices, "") }?;
311 let element = context.builder().build_load(element_ptr, "")?;
312 rmb.write(
313 context.builder(),
314 [result_llvm_sum_ty
315 .build_tag(context.builder(), 1, vec![element])?
316 .into()],
317 )?;
318 context.builder().build_unconditional_branch(exit_block)?;
319 anyhow::Ok(bb)
320 },
321 )?;
322
323 let inbounds =
324 context
325 .builder()
326 .build_int_compare(IntPredicate::ULT, index, len, "")?;
327 context
328 .builder()
329 .build_conditional_branch(inbounds, success_block, fail_block)?;
330
331 context.builder().position_at_end(exit_block);
332 Ok(())
333 }
334 StaticArrayOpDef::len => {
335 let ptr = args.inputs[0].into_pointer_value();
336 let element_llvm_ty = context.llvm_type(&op.elem_ty)?;
337 let index_ty = args.outputs.get_types().next().unwrap().into_int_type();
338 let struct_ty =
339 static_array_struct_type(context.iw_context(), index_ty, element_llvm_ty, 0);
340 let len = build_read_len(context.iw_context(), context.builder(), struct_ty, ptr)?;
341 args.outputs.finish(context.builder(), [len.into()])
342 }
343 op => bail!("StaticArrayCodegen: Unsupported op: {op:?}"),
344 }
345 }
346}
347
348#[derive(Debug, Clone)]
349pub struct DefaultStaticArrayCodegen;
352
353impl StaticArrayCodegen for DefaultStaticArrayCodegen {}
354
355impl<SAC: StaticArrayCodegen + 'static> CodegenExtension for StaticArrayCodegenExtension<SAC> {
356 fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
357 self,
358 builder: CodegenExtsBuilder<'a, H>,
359 ) -> CodegenExtsBuilder<'a, H>
360 where
361 Self: 'a,
362 {
363 builder
364 .custom_type(
365 (
366 static_array::EXTENSION_ID,
367 static_array::STATIC_ARRAY_TYPENAME,
368 ),
369 {
370 let sac = self.0.clone();
371 move |ts, custom_type| {
372 let element_type = custom_type.args()[0]
373 .as_runtime()
374 .expect("Type argument for static array must be a type");
375 sac.static_array_type(ts, &element_type)
376 }
377 },
378 )
379 .custom_const::<StaticArrayValue>({
380 let sac = self.0.clone();
381 move |context, sav| sac.static_array_value(context, sav)
382 })
383 .simple_extension_op::<StaticArrayOpDef>({
384 let sac = self.0.clone();
385 move |context, args, op| {
386 let op = op.instantiate(args.node().args())?;
387 sac.static_array_op(context, args, op)
388 }
389 })
390 }
391}
392
393#[cfg(test)]
394mod test {
395 use super::*;
396 use float_types::float64_type;
397 use hugr_core::builder::DataflowHugr;
398 use hugr_core::extension::prelude::ConstUsize;
399 use hugr_core::ops::OpType;
400 use hugr_core::ops::Value;
401 use hugr_core::ops::constant::CustomConst;
402 use hugr_core::std_extensions::arithmetic::float_types::{self, ConstF64};
403 use rstest::rstest;
404
405 use hugr_core::extension::simple_op::MakeRegisteredOp;
406 use hugr_core::extension::{ExtensionRegistry, prelude::bool_t};
407 use hugr_core::{builder::SubContainer as _, type_row};
408 use static_array::StaticArrayOpBuilder as _;
409
410 use crate::check_emission;
411 use crate::test::single_op_hugr;
412 use crate::{
413 emit::test::SimpleHugrConfig,
414 test::{TestContext, exec_ctx, llvm_ctx},
415 };
416 use hugr_core::builder::{Dataflow as _, DataflowSubContainer as _};
417
418 #[rstest]
419 #[case(0, StaticArrayOpDef::get, usize_t())]
420 #[case(1, StaticArrayOpDef::get, bool_t())]
421 #[case(2, StaticArrayOpDef::len, usize_t())]
422 #[case(3, StaticArrayOpDef::len, bool_t())]
423 fn static_array_op_codegen(
424 #[case] _i: i32,
425 #[with(_i)] mut llvm_ctx: TestContext,
426 #[case] op: StaticArrayOpDef,
427 #[case] ty: HugrType,
428 ) {
429 let op = op.instantiate(&[ty.clone().into()]).unwrap();
430 let op = OpType::from(op.to_extension_op().unwrap());
431 llvm_ctx.add_extensions(|ceb| {
432 ceb.add_default_static_array_extensions()
433 .add_default_prelude_extensions()
434 });
435 let hugr = single_op_hugr(op);
436 check_emission!(hugr, llvm_ctx);
437 }
438
439 #[rstest]
440 #[case(0, StaticArrayValue::try_new("a", usize_t(), (0..10).map(|x| ConstUsize::new(x).into())).unwrap())]
441 #[case(1, StaticArrayValue::try_new("b", float64_type(), (0..10).map(|x| ConstF64::new(f64::from(x)).into())).unwrap())]
442 #[case(2, StaticArrayValue::try_new("c", bool_t(), (0..10).map(|x| Value::from_bool(x % 2 == 0))).unwrap())]
443 #[case(3, StaticArrayValue::try_new("d", option_type(usize_t()).into(), (0..10).map(|x| Value::some([ConstUsize::new(x)]))).unwrap())]
444 fn static_array_const_codegen(
445 #[case] _i: i32,
446 #[with(_i)] mut llvm_ctx: TestContext,
447 #[case] value: StaticArrayValue,
448 ) {
449 llvm_ctx.add_extensions(|ceb| {
450 ceb.add_default_static_array_extensions()
451 .add_default_prelude_extensions()
452 .add_float_extensions()
453 });
454
455 let hugr = SimpleHugrConfig::new()
456 .with_outs(value.get_type())
457 .with_extensions(ExtensionRegistry::new(vec![
458 static_array::EXTENSION.to_owned(),
459 float_types::EXTENSION.to_owned(),
460 ]))
461 .finish(|mut builder| {
462 let a = builder.add_load_value(value);
463 builder.finish_hugr_with_outputs([a]).unwrap()
464 });
465 check_emission!(hugr, llvm_ctx);
466 }
467
468 #[rstest]
469 #[case(0, 0, 999)]
470 #[case(1, 1, 998)]
471 #[case(2, 1000, u64::MAX)]
472 fn static_array_exec(
473 #[case] _i: i32,
474 #[with(_i)] mut exec_ctx: TestContext,
475 #[case] index: u64,
476 #[case] expected: u64,
477 ) {
478 let hugr = SimpleHugrConfig::new()
479 .with_outs(usize_t())
480 .with_extensions(ExtensionRegistry::new(vec![
481 static_array::EXTENSION.to_owned(),
482 ]))
483 .finish(|mut builder| {
484 let arr = builder.add_load_value(
485 StaticArrayValue::try_new(
486 "exec_arr",
487 usize_t(),
488 (0..1000)
489 .map(|x| ConstUsize::new(999 - x).into())
490 .collect_vec(),
491 )
492 .unwrap(),
493 );
494 let index = builder.add_load_value(ConstUsize::new(index));
495 let get_r = builder.add_static_array_get(usize_t(), arr, index).unwrap();
496 let [out] = {
497 let mut cond = builder
498 .conditional_builder(
499 ([type_row!(), usize_t().into()], get_r),
500 [],
501 usize_t().into(),
502 )
503 .unwrap();
504 {
505 let mut oob_case = cond.case_builder(0).unwrap();
506 let err = oob_case.add_load_value(ConstUsize::new(u64::MAX));
507 oob_case.finish_with_outputs([err]).unwrap();
508 }
509 {
510 let inbounds_case = cond.case_builder(1).unwrap();
511 let [out] = inbounds_case.input_wires_arr();
512 inbounds_case.finish_with_outputs([out]).unwrap();
513 }
514 cond.finish_sub_container().unwrap().outputs_arr()
515 };
516 builder.finish_hugr_with_outputs([out]).unwrap()
517 });
518
519 exec_ctx.add_extensions(|ceb| {
520 ceb.add_default_static_array_extensions()
521 .add_default_prelude_extensions()
522 .add_float_extensions()
523 });
524 assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
525 }
526
527 #[rstest]
528 fn len_0_array(mut exec_ctx: TestContext) {
529 let hugr = SimpleHugrConfig::new()
530 .with_outs(usize_t())
531 .with_extensions(ExtensionRegistry::new(vec![
532 static_array::EXTENSION.to_owned(),
533 ]))
534 .finish(|mut builder| {
535 let arr = builder
536 .add_load_value(StaticArrayValue::try_new("empty", usize_t(), vec![]).unwrap());
537 let len = builder.add_static_array_len(usize_t(), arr).unwrap();
538 builder.finish_hugr_with_outputs([len]).unwrap()
539 });
540
541 exec_ctx.add_extensions(|ceb| {
542 ceb.add_default_static_array_extensions()
543 .add_default_prelude_extensions()
544 });
545 assert_eq!(0, exec_ctx.exec_hugr_u64(hugr, "main"));
546 }
547
548 #[rstest]
549 fn emit_static_array_of_static_array(mut llvm_ctx: TestContext) {
550 llvm_ctx.add_extensions(|ceb| {
551 ceb.add_default_static_array_extensions()
552 .add_default_prelude_extensions()
553 });
554 let hugr = SimpleHugrConfig::new()
555 .with_outs(usize_t())
556 .with_extensions(ExtensionRegistry::new(vec![
557 static_array::EXTENSION.to_owned(),
558 ]))
559 .finish(|mut builder| {
560 let inner_arrs: Vec<Value> = (0..10)
561 .map(|i| {
562 StaticArrayValue::try_new(
563 "inner",
564 usize_t(),
565 vec![Value::from(ConstUsize::new(i)); i as usize],
566 )
567 .unwrap()
568 .into()
569 })
570 .collect_vec();
571 let inner_arr_ty = inner_arrs[0].get_type();
572 let outer_arr = builder.add_load_value(
573 StaticArrayValue::try_new("outer", inner_arr_ty.clone(), inner_arrs).unwrap(),
574 );
575 let len = builder
576 .add_static_array_len(inner_arr_ty, outer_arr)
577 .unwrap();
578 builder.finish_hugr_with_outputs([len]).unwrap()
579 });
580 check_emission!(hugr, llvm_ctx);
581 }
582}