1use anyhow::{Ok, Result, bail};
2use hugr_core::{
3 HugrView, Node,
4 extension::simple_op::MakeExtensionOp as _,
5 ops::ExtensionOp,
6 std_extensions::collections::list::{self, ListOp, ListValue},
7 types::{SumType, Type, TypeArg},
8};
9use inkwell::values::FunctionValue;
10use inkwell::{
11 AddressSpace,
12 types::{BasicType, BasicTypeEnum, FunctionType},
13 values::{BasicValueEnum, PointerValue},
14};
15
16use crate::emit::func::{build_ok_or_else, build_option};
17use crate::{
18 custom::{CodegenExtension, CodegenExtsBuilder},
19 emit::{EmitOpArgs, emit_value, func::EmitFuncContext},
20 types::TypingSession,
21};
22
23#[derive(Clone, Copy, Debug, PartialEq, Hash)]
25#[non_exhaustive]
26pub enum ListRtFunc {
27 New,
28 Push,
29 Pop,
30 Get,
31 Set,
32 Insert,
33 Length,
34}
35
36impl ListRtFunc {
37 pub fn signature<'c>(
41 self,
42 ts: TypingSession<'c, '_>,
43 ccg: &(impl ListCodegen + 'c),
44 ) -> FunctionType<'c> {
45 let iwc = ts.iw_context();
46 match self {
47 ListRtFunc::New => ccg.list_type(ts).fn_type(
48 &[
49 iwc.i64_type().into(), iwc.i64_type().into(), iwc.i64_type().into(), iwc.i8_type().ptr_type(AddressSpace::default()).into(),
54 ],
55 false,
56 ),
57 ListRtFunc::Push => iwc.void_type().fn_type(
58 &[
59 ccg.list_type(ts).into(),
60 iwc.i8_type().ptr_type(AddressSpace::default()).into(),
61 ],
62 false,
63 ),
64 ListRtFunc::Pop => iwc.bool_type().fn_type(
65 &[
66 ccg.list_type(ts).into(),
67 iwc.i8_type().ptr_type(AddressSpace::default()).into(),
68 ],
69 false,
70 ),
71 ListRtFunc::Get | ListRtFunc::Set | ListRtFunc::Insert => iwc.bool_type().fn_type(
72 &[
73 ccg.list_type(ts).into(),
74 iwc.i64_type().into(),
75 iwc.i8_type().ptr_type(AddressSpace::default()).into(),
76 ],
77 false,
78 ),
79 ListRtFunc::Length => iwc.i64_type().fn_type(&[ccg.list_type(ts).into()], false),
80 }
81 }
82
83 pub fn get_extern<'c, H: HugrView<Node = Node>>(
87 self,
88 ctx: &EmitFuncContext<'c, '_, H>,
89 ccg: &(impl ListCodegen + 'c),
90 ) -> Result<FunctionValue<'c>> {
91 ctx.get_extern_func(
92 ccg.rt_func_name(self),
93 self.signature(ctx.typing_session(), ccg),
94 )
95 }
96}
97
98impl From<ListOp> for ListRtFunc {
99 fn from(op: ListOp) -> Self {
100 match op {
101 ListOp::get => ListRtFunc::Get,
102 ListOp::set => ListRtFunc::Set,
103 ListOp::push => ListRtFunc::Push,
104 ListOp::pop => ListRtFunc::Pop,
105 ListOp::insert => ListRtFunc::Insert,
106 ListOp::length => ListRtFunc::Length,
107 _ => todo!(),
108 }
109 }
110}
111
112pub trait ListCodegen: Clone {
115 fn list_type<'c>(&self, session: TypingSession<'c, '_>) -> BasicTypeEnum<'c> {
117 session
118 .iw_context()
119 .i8_type()
120 .ptr_type(AddressSpace::default())
121 .into()
122 }
123
124 fn rt_func_name(&self, func: ListRtFunc) -> String {
126 match func {
127 ListRtFunc::New => "__rt__list__new",
128 ListRtFunc::Push => "__rt__list__push",
129 ListRtFunc::Pop => "__rt__list__pop",
130 ListRtFunc::Get => "__rt__list__get",
131 ListRtFunc::Set => "__rt__list__set",
132 ListRtFunc::Insert => "__rt__list__insert",
133 ListRtFunc::Length => "__rt__list__length",
134 }
135 .into()
136 }
137}
138
139#[derive(Default, Clone)]
142pub struct DefaultListCodegen;
143
144impl ListCodegen for DefaultListCodegen {}
145
146#[derive(Clone, Debug, Default)]
147pub struct ListCodegenExtension<CCG>(CCG);
148
149impl<CCG: ListCodegen> ListCodegenExtension<CCG> {
150 pub fn new(ccg: CCG) -> Self {
151 Self(ccg)
152 }
153}
154
155impl<CCG: ListCodegen> From<CCG> for ListCodegenExtension<CCG> {
156 fn from(ccg: CCG) -> Self {
157 Self::new(ccg)
158 }
159}
160
161impl<CCG: ListCodegen> CodegenExtension for ListCodegenExtension<CCG> {
162 fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
163 self,
164 builder: CodegenExtsBuilder<'a, H>,
165 ) -> CodegenExtsBuilder<'a, H>
166 where
167 Self: 'a,
168 {
169 builder
170 .custom_type((list::EXTENSION_ID, list::LIST_TYPENAME), {
171 let ccg = self.0.clone();
172 move |ts, _hugr_type| Ok(ccg.list_type(ts).as_basic_type_enum())
173 })
174 .custom_const::<ListValue>({
175 let ccg = self.0.clone();
176 move |ctx, k| emit_list_value(ctx, &ccg, k)
177 })
178 .simple_extension_op::<ListOp>(move |ctx, args, op| {
179 emit_list_op(ctx, &self.0, args, op)
180 })
181 }
182}
183
184impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
185 #[must_use]
188 pub fn add_default_list_extensions(self) -> Self {
189 self.add_list_extensions(DefaultListCodegen)
190 }
191
192 pub fn add_list_extensions(self, ccg: impl ListCodegen + 'a) -> Self {
195 self.add_extension(ListCodegenExtension::from(ccg))
196 }
197}
198
199fn emit_list_op<'c, H: HugrView<Node = Node>>(
200 ctx: &mut EmitFuncContext<'c, '_, H>,
201 ccg: &(impl ListCodegen + 'c),
202 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
203 op: ListOp,
204) -> Result<()> {
205 let hugr_elem_ty = match args.node().args() {
206 [TypeArg::Runtime(ty)] => ty.clone(),
207 _ => {
208 bail!("Collections: invalid type args for list op");
209 }
210 };
211 let elem_ty = ctx.llvm_type(&hugr_elem_ty)?;
212 let func = ListRtFunc::get_extern(op.into(), ctx, ccg)?;
213 match op {
214 ListOp::push => {
215 let [list, elem] = args.inputs.try_into().unwrap();
216 let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?;
217 ctx.builder()
218 .build_call(func, &[list.into(), elem_ptr.into()], "")?;
219 args.outputs.finish(ctx.builder(), vec![list])?;
220 }
221 ListOp::pop => {
222 let [list] = args.inputs.try_into().unwrap();
223 let out_ptr = build_alloca_i8_ptr(ctx, elem_ty, None)?;
224 let ok = ctx
225 .builder()
226 .build_call(func, &[list.into(), out_ptr.into()], "")?
227 .try_as_basic_value()
228 .unwrap_left()
229 .into_int_value();
230 let elem = build_load_i8_ptr(ctx, out_ptr, elem_ty)?;
231 let elem_opt = build_option(ctx, ok, elem, hugr_elem_ty)?;
232 args.outputs.finish(ctx.builder(), vec![list, elem_opt])?;
233 }
234 ListOp::get => {
235 let [list, idx] = args.inputs.try_into().unwrap();
236 let out_ptr = build_alloca_i8_ptr(ctx, elem_ty, None)?;
237 let ok = ctx
238 .builder()
239 .build_call(func, &[list.into(), idx.into(), out_ptr.into()], "")?
240 .try_as_basic_value()
241 .unwrap_left()
242 .into_int_value();
243 let elem = build_load_i8_ptr(ctx, out_ptr, elem_ty)?;
244 let elem_opt = build_option(ctx, ok, elem, hugr_elem_ty)?;
245 args.outputs.finish(ctx.builder(), vec![elem_opt])?;
246 }
247 ListOp::set => {
248 let [list, idx, elem] = args.inputs.try_into().unwrap();
249 let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?;
250 let ok = ctx
251 .builder()
252 .build_call(func, &[list.into(), idx.into(), elem_ptr.into()], "")?
253 .try_as_basic_value()
254 .unwrap_left()
255 .into_int_value();
256 let old_elem = build_load_i8_ptr(ctx, elem_ptr, elem.get_type())?;
257 let ok_or =
258 build_ok_or_else(ctx, ok, elem, hugr_elem_ty.clone(), old_elem, hugr_elem_ty)?;
259 args.outputs.finish(ctx.builder(), vec![list, ok_or])?;
260 }
261 ListOp::insert => {
262 let [list, idx, elem] = args.inputs.try_into().unwrap();
263 let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?;
264 let ok = ctx
265 .builder()
266 .build_call(func, &[list.into(), idx.into(), elem_ptr.into()], "")?
267 .try_as_basic_value()
268 .unwrap_left()
269 .into_int_value();
270 let unit =
271 ctx.llvm_sum_type(SumType::new_unary(1))?
272 .build_tag(ctx.builder(), 0, vec![])?;
273 let ok_or = build_ok_or_else(ctx, ok, unit.into(), Type::UNIT, elem, hugr_elem_ty)?;
274 args.outputs.finish(ctx.builder(), vec![list, ok_or])?;
275 }
276 ListOp::length => {
277 let [list] = args.inputs.try_into().unwrap();
278 let length = ctx
279 .builder()
280 .build_call(func, &[list.into()], "")?
281 .try_as_basic_value()
282 .unwrap_left()
283 .into_int_value();
284 args.outputs
285 .finish(ctx.builder(), vec![list, length.into()])?;
286 }
287 _ => bail!("Collections: unimplemented op: {}", op.op_id()),
288 }
289 Ok(())
290}
291
292fn emit_list_value<'c, H: HugrView<Node = Node>>(
293 ctx: &mut EmitFuncContext<'c, '_, H>,
294 ccg: &(impl ListCodegen + 'c),
295 val: &ListValue,
296) -> Result<BasicValueEnum<'c>> {
297 let elem_ty = ctx.llvm_type(val.get_element_type())?;
298 let iwc = ctx.typing_session().iw_context();
299 let capacity = iwc
300 .i64_type()
301 .const_int(val.get_contents().len() as u64, false);
302 let elem_size = elem_ty.size_of().unwrap();
303 let alignment = iwc.i64_type().const_int(8, false);
304 let destructor = iwc.i8_type().ptr_type(AddressSpace::default()).const_null();
306 let list = ctx
307 .builder()
308 .build_call(
309 ListRtFunc::New.get_extern(ctx, ccg)?,
310 &[
311 capacity.into(),
312 elem_size.into(),
313 alignment.into(),
314 destructor.into(),
315 ],
316 "",
317 )?
318 .try_as_basic_value()
319 .unwrap_left();
320 let rt_push = ListRtFunc::Push.get_extern(ctx, ccg)?;
322 for v in val.get_contents() {
323 let elem = emit_value(ctx, v)?;
324 let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?;
325 ctx.builder()
326 .build_call(rt_push, &[list.into(), elem_ptr.into()], "")?;
327 }
328 Ok(list)
329}
330
331fn build_alloca_i8_ptr<'c, H: HugrView<Node = Node>>(
337 ctx: &mut EmitFuncContext<'c, '_, H>,
338 ty: BasicTypeEnum<'c>,
339 value: Option<BasicValueEnum<'c>>,
340) -> Result<PointerValue<'c>> {
341 let builder = ctx.builder();
342 let ptr = builder.build_alloca(ty, "")?;
343 if let Some(val) = value {
344 builder.build_store(ptr, val)?;
345 }
346 let i8_ptr = builder.build_pointer_cast(
347 ptr,
348 ctx.iw_context().i8_type().ptr_type(AddressSpace::default()),
349 "",
350 )?;
351 Ok(i8_ptr)
352}
353
354fn build_load_i8_ptr<'c, H: HugrView<Node = Node>>(
356 ctx: &mut EmitFuncContext<'c, '_, H>,
357 i8_ptr: PointerValue<'c>,
358 ty: BasicTypeEnum<'c>,
359) -> Result<BasicValueEnum<'c>> {
360 let builder = ctx.builder();
361 let ptr = builder.build_pointer_cast(i8_ptr, ty.ptr_type(AddressSpace::default()), "")?;
362 let val = builder.build_load(ptr, "")?;
363 Ok(val)
364}
365
366#[cfg(test)]
367mod test {
368 use hugr_core::{
369 builder::{Dataflow, DataflowHugr},
370 extension::{
371 ExtensionRegistry,
372 prelude::{self, ConstUsize, qb_t, usize_t},
373 },
374 ops::{DataflowOpTrait, Value},
375 std_extensions::collections::list::{self, ListOp, ListValue, list_type},
376 };
377 use rstest::rstest;
378
379 use crate::{
380 check_emission,
381 custom::CodegenExtsBuilder,
382 emit::test::SimpleHugrConfig,
383 test::{TestContext, llvm_ctx},
384 };
385
386 #[rstest]
387 #[case::push(ListOp::push)]
388 #[case::pop(ListOp::pop)]
389 #[case::get(ListOp::get)]
390 #[case::set(ListOp::set)]
391 #[case::insert(ListOp::insert)]
392 #[case::length(ListOp::length)]
393 fn test_list_emission(mut llvm_ctx: TestContext, #[case] op: ListOp) {
394 use hugr_core::extension::simple_op::MakeExtensionOp as _;
395
396 let ext_op = list::EXTENSION
397 .instantiate_extension_op(op.op_id().as_ref(), [qb_t().into()])
398 .unwrap();
399 let es = ExtensionRegistry::new([list::EXTENSION.to_owned(), prelude::PRELUDE.to_owned()]);
400 es.validate().unwrap();
401 let hugr = SimpleHugrConfig::new()
402 .with_ins(ext_op.signature().input().clone())
403 .with_outs(ext_op.signature().output().clone())
404 .with_extensions(es)
405 .finish(|mut hugr_builder| {
406 let outputs = hugr_builder
407 .add_dataflow_op(ext_op, hugr_builder.input_wires())
408 .unwrap()
409 .outputs();
410 hugr_builder.finish_hugr_with_outputs(outputs).unwrap()
411 });
412 llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions);
413 llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_list_extensions);
414 check_emission!(op.op_id().as_str(), hugr, llvm_ctx);
415 }
416
417 #[rstest]
418 fn test_const_list_emmission(mut llvm_ctx: TestContext) {
419 let elem_ty = usize_t();
420 let contents = (1..4).map(|i| Value::extension(ConstUsize::new(i)));
421 let es = ExtensionRegistry::new([list::EXTENSION.to_owned(), prelude::PRELUDE.to_owned()]);
422 es.validate().unwrap();
423
424 let hugr = SimpleHugrConfig::new()
425 .with_ins(vec![])
426 .with_outs(vec![list_type(elem_ty.clone())])
427 .with_extensions(es)
428 .finish(|mut hugr_builder| {
429 let list = hugr_builder.add_load_value(ListValue::new(elem_ty, contents));
430 hugr_builder.finish_hugr_with_outputs(vec![list]).unwrap()
431 });
432
433 llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions);
434 llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_list_extensions);
435 check_emission!("const", hugr, llvm_ctx);
436 }
437}