1use anyhow::{bail, Ok, Result};
2use hugr_core::{
3 ops::{ExtensionOp, NamedOp},
4 std_extensions::collections::list::{self, ListOp, ListValue},
5 types::{SumType, Type, TypeArg},
6 HugrView, Node,
7};
8use inkwell::values::FunctionValue;
9use inkwell::{
10 types::{BasicType, BasicTypeEnum, FunctionType},
11 values::{BasicValueEnum, PointerValue},
12 AddressSpace,
13};
14
15use crate::emit::func::{build_ok_or_else, build_option};
16use crate::{
17 custom::{CodegenExtension, CodegenExtsBuilder},
18 emit::{emit_value, func::EmitFuncContext, EmitOpArgs},
19 types::TypingSession,
20};
21
22#[derive(Clone, Copy, Debug, PartialEq, Hash)]
24#[non_exhaustive]
25pub enum ListRtFunc {
26 New,
27 Push,
28 Pop,
29 Get,
30 Set,
31 Insert,
32 Length,
33}
34
35impl ListRtFunc {
36 pub fn signature<'c>(
40 self,
41 ts: TypingSession<'c, '_>,
42 ccg: &(impl ListCodegen + 'c),
43 ) -> FunctionType<'c> {
44 let iwc = ts.iw_context();
45 match self {
46 ListRtFunc::New => ccg.list_type(ts).fn_type(
47 &[
48 iwc.i64_type().into(), iwc.i64_type().into(), iwc.i64_type().into(), iwc.i8_type().ptr_type(AddressSpace::default()).into(),
53 ],
54 false,
55 ),
56 ListRtFunc::Push => iwc.void_type().fn_type(
57 &[
58 ccg.list_type(ts).into(),
59 iwc.i8_type().ptr_type(AddressSpace::default()).into(),
60 ],
61 false,
62 ),
63 ListRtFunc::Pop => iwc.bool_type().fn_type(
64 &[
65 ccg.list_type(ts).into(),
66 iwc.i8_type().ptr_type(AddressSpace::default()).into(),
67 ],
68 false,
69 ),
70 ListRtFunc::Get | ListRtFunc::Set | ListRtFunc::Insert => iwc.bool_type().fn_type(
71 &[
72 ccg.list_type(ts).into(),
73 iwc.i64_type().into(),
74 iwc.i8_type().ptr_type(AddressSpace::default()).into(),
75 ],
76 false,
77 ),
78 ListRtFunc::Length => iwc.i64_type().fn_type(&[ccg.list_type(ts).into()], false),
79 }
80 }
81
82 pub fn get_extern<'c, H: HugrView<Node = Node>>(
86 self,
87 ctx: &EmitFuncContext<'c, '_, H>,
88 ccg: &(impl ListCodegen + 'c),
89 ) -> Result<FunctionValue<'c>> {
90 ctx.get_extern_func(
91 ccg.rt_func_name(self),
92 self.signature(ctx.typing_session(), ccg),
93 )
94 }
95}
96
97impl From<ListOp> for ListRtFunc {
98 fn from(op: ListOp) -> Self {
99 match op {
100 ListOp::get => ListRtFunc::Get,
101 ListOp::set => ListRtFunc::Set,
102 ListOp::push => ListRtFunc::Push,
103 ListOp::pop => ListRtFunc::Pop,
104 ListOp::insert => ListRtFunc::Insert,
105 ListOp::length => ListRtFunc::Length,
106 _ => todo!(),
107 }
108 }
109}
110
111pub trait ListCodegen: Clone {
114 fn list_type<'c>(&self, session: TypingSession<'c, '_>) -> BasicTypeEnum<'c> {
116 session
117 .iw_context()
118 .i8_type()
119 .ptr_type(AddressSpace::default())
120 .into()
121 }
122
123 fn rt_func_name(&self, func: ListRtFunc) -> String {
125 match func {
126 ListRtFunc::New => "__rt__list__new",
127 ListRtFunc::Push => "__rt__list__push",
128 ListRtFunc::Pop => "__rt__list__pop",
129 ListRtFunc::Get => "__rt__list__get",
130 ListRtFunc::Set => "__rt__list__set",
131 ListRtFunc::Insert => "__rt__list__insert",
132 ListRtFunc::Length => "__rt__list__length",
133 }
134 .into()
135 }
136}
137
138#[derive(Default, Clone)]
141pub struct DefaultListCodegen;
142
143impl ListCodegen for DefaultListCodegen {}
144
145#[derive(Clone, Debug, Default)]
146pub struct ListCodegenExtension<CCG>(CCG);
147
148impl<CCG: ListCodegen> ListCodegenExtension<CCG> {
149 pub fn new(ccg: CCG) -> Self {
150 Self(ccg)
151 }
152}
153
154impl<CCG: ListCodegen> From<CCG> for ListCodegenExtension<CCG> {
155 fn from(ccg: CCG) -> Self {
156 Self::new(ccg)
157 }
158}
159
160impl<CCG: ListCodegen> CodegenExtension for ListCodegenExtension<CCG> {
161 fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
162 self,
163 builder: CodegenExtsBuilder<'a, H>,
164 ) -> CodegenExtsBuilder<'a, H>
165 where
166 Self: 'a,
167 {
168 builder
169 .custom_type((list::EXTENSION_ID, list::LIST_TYPENAME), {
170 let ccg = self.0.clone();
171 move |ts, _hugr_type| Ok(ccg.list_type(ts).as_basic_type_enum())
172 })
173 .custom_const::<ListValue>({
174 let ccg = self.0.clone();
175 move |ctx, k| emit_list_value(ctx, &ccg, k)
176 })
177 .simple_extension_op::<ListOp>(move |ctx, args, op| {
178 emit_list_op(ctx, &self.0, args, op)
179 })
180 }
181}
182
183impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
184 pub fn add_default_list_extensions(self) -> Self {
187 self.add_list_extensions(DefaultListCodegen)
188 }
189
190 pub fn add_list_extensions(self, ccg: impl ListCodegen + 'a) -> Self {
193 self.add_extension(ListCodegenExtension::from(ccg))
194 }
195}
196
197fn emit_list_op<'c, H: HugrView<Node = Node>>(
198 ctx: &mut EmitFuncContext<'c, '_, H>,
199 ccg: &(impl ListCodegen + 'c),
200 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
201 op: ListOp,
202) -> Result<()> {
203 let hugr_elem_ty = match args.node().args() {
204 [TypeArg::Type { ty }] => ty.clone(),
205 _ => {
206 bail!("Collections: invalid type args for list op");
207 }
208 };
209 let elem_ty = ctx.llvm_type(&hugr_elem_ty)?;
210 let func = ListRtFunc::get_extern(op.into(), ctx, ccg)?;
211 match op {
212 ListOp::push => {
213 let [list, elem] = args.inputs.try_into().unwrap();
214 let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?;
215 ctx.builder()
216 .build_call(func, &[list.into(), elem_ptr.into()], "")?;
217 args.outputs.finish(ctx.builder(), vec![list])?;
218 }
219 ListOp::pop => {
220 let [list] = args.inputs.try_into().unwrap();
221 let out_ptr = build_alloca_i8_ptr(ctx, elem_ty, None)?;
222 let ok = ctx
223 .builder()
224 .build_call(func, &[list.into(), out_ptr.into()], "")?
225 .try_as_basic_value()
226 .unwrap_left()
227 .into_int_value();
228 let elem = build_load_i8_ptr(ctx, out_ptr, elem_ty)?;
229 let elem_opt = build_option(ctx, ok, elem, hugr_elem_ty)?;
230 args.outputs.finish(ctx.builder(), vec![list, elem_opt])?;
231 }
232 ListOp::get => {
233 let [list, idx] = args.inputs.try_into().unwrap();
234 let out_ptr = build_alloca_i8_ptr(ctx, elem_ty, None)?;
235 let ok = ctx
236 .builder()
237 .build_call(func, &[list.into(), idx.into(), out_ptr.into()], "")?
238 .try_as_basic_value()
239 .unwrap_left()
240 .into_int_value();
241 let elem = build_load_i8_ptr(ctx, out_ptr, elem_ty)?;
242 let elem_opt = build_option(ctx, ok, elem, hugr_elem_ty)?;
243 args.outputs.finish(ctx.builder(), vec![elem_opt])?;
244 }
245 ListOp::set => {
246 let [list, idx, elem] = args.inputs.try_into().unwrap();
247 let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?;
248 let ok = ctx
249 .builder()
250 .build_call(func, &[list.into(), idx.into(), elem_ptr.into()], "")?
251 .try_as_basic_value()
252 .unwrap_left()
253 .into_int_value();
254 let old_elem = build_load_i8_ptr(ctx, elem_ptr, elem.get_type())?;
255 let ok_or =
256 build_ok_or_else(ctx, ok, elem, hugr_elem_ty.clone(), old_elem, hugr_elem_ty)?;
257 args.outputs.finish(ctx.builder(), vec![list, ok_or])?;
258 }
259 ListOp::insert => {
260 let [list, idx, elem] = args.inputs.try_into().unwrap();
261 let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?;
262 let ok = ctx
263 .builder()
264 .build_call(func, &[list.into(), idx.into(), elem_ptr.into()], "")?
265 .try_as_basic_value()
266 .unwrap_left()
267 .into_int_value();
268 let unit =
269 ctx.llvm_sum_type(SumType::new_unary(1))?
270 .build_tag(ctx.builder(), 0, vec![])?;
271 let ok_or = build_ok_or_else(ctx, ok, unit.into(), Type::UNIT, elem, hugr_elem_ty)?;
272 args.outputs.finish(ctx.builder(), vec![list, ok_or])?;
273 }
274 ListOp::length => {
275 let [list] = args.inputs.try_into().unwrap();
276 let length = ctx
277 .builder()
278 .build_call(func, &[list.into()], "")?
279 .try_as_basic_value()
280 .unwrap_left()
281 .into_int_value();
282 args.outputs
283 .finish(ctx.builder(), vec![list, length.into()])?;
284 }
285 _ => bail!("Collections: unimplemented op: {}", op.name()),
286 }
287 Ok(())
288}
289
290fn emit_list_value<'c, H: HugrView<Node = Node>>(
291 ctx: &mut EmitFuncContext<'c, '_, H>,
292 ccg: &(impl ListCodegen + 'c),
293 val: &ListValue,
294) -> Result<BasicValueEnum<'c>> {
295 let elem_ty = ctx.llvm_type(val.get_element_type())?;
296 let iwc = ctx.typing_session().iw_context();
297 let capacity = iwc
298 .i64_type()
299 .const_int(val.get_contents().len() as u64, false);
300 let elem_size = elem_ty.size_of().unwrap();
301 let alignment = iwc.i64_type().const_int(8, false);
302 let destructor = iwc.i8_type().ptr_type(AddressSpace::default()).const_null();
304 let list = ctx
305 .builder()
306 .build_call(
307 ListRtFunc::New.get_extern(ctx, ccg)?,
308 &[
309 capacity.into(),
310 elem_size.into(),
311 alignment.into(),
312 destructor.into(),
313 ],
314 "",
315 )?
316 .try_as_basic_value()
317 .unwrap_left();
318 let rt_push = ListRtFunc::Push.get_extern(ctx, ccg)?;
320 for v in val.get_contents() {
321 let elem = emit_value(ctx, v)?;
322 let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?;
323 ctx.builder()
324 .build_call(rt_push, &[list.into(), elem_ptr.into()], "")?;
325 }
326 Ok(list)
327}
328
329fn build_alloca_i8_ptr<'c, H: HugrView<Node = Node>>(
335 ctx: &mut EmitFuncContext<'c, '_, H>,
336 ty: BasicTypeEnum<'c>,
337 value: Option<BasicValueEnum<'c>>,
338) -> Result<PointerValue<'c>> {
339 let builder = ctx.builder();
340 let ptr = builder.build_alloca(ty, "")?;
341 if let Some(val) = value {
342 builder.build_store(ptr, val)?;
343 }
344 let i8_ptr = builder.build_pointer_cast(
345 ptr,
346 ctx.iw_context().i8_type().ptr_type(AddressSpace::default()),
347 "",
348 )?;
349 Ok(i8_ptr)
350}
351
352fn build_load_i8_ptr<'c, H: HugrView<Node = Node>>(
354 ctx: &mut EmitFuncContext<'c, '_, H>,
355 i8_ptr: PointerValue<'c>,
356 ty: BasicTypeEnum<'c>,
357) -> Result<BasicValueEnum<'c>> {
358 let builder = ctx.builder();
359 let ptr = builder.build_pointer_cast(i8_ptr, ty.ptr_type(AddressSpace::default()), "")?;
360 let val = builder.build_load(ptr, "")?;
361 Ok(val)
362}
363
364#[cfg(test)]
365mod test {
366 use hugr_core::{
367 builder::{Dataflow, DataflowSubContainer},
368 extension::{
369 prelude::{self, qb_t, usize_t, ConstUsize},
370 ExtensionRegistry,
371 },
372 ops::{DataflowOpTrait, NamedOp, Value},
373 std_extensions::collections::list::{self, list_type, ListOp, ListValue},
374 };
375 use rstest::rstest;
376
377 use crate::{
378 check_emission,
379 custom::CodegenExtsBuilder,
380 emit::test::SimpleHugrConfig,
381 test::{llvm_ctx, TestContext},
382 };
383
384 #[rstest]
385 #[case::push(ListOp::push)]
386 #[case::pop(ListOp::pop)]
387 #[case::get(ListOp::get)]
388 #[case::set(ListOp::set)]
389 #[case::insert(ListOp::insert)]
390 #[case::length(ListOp::length)]
391 fn test_list_emission(mut llvm_ctx: TestContext, #[case] op: ListOp) {
392 let ext_op = list::EXTENSION
393 .instantiate_extension_op(op.name().as_ref(), [qb_t().into()])
394 .unwrap();
395 let es = ExtensionRegistry::new([list::EXTENSION.to_owned(), prelude::PRELUDE.to_owned()]);
396 es.validate().unwrap();
397 let hugr = SimpleHugrConfig::new()
398 .with_ins(ext_op.signature().input().clone())
399 .with_outs(ext_op.signature().output().clone())
400 .with_extensions(es)
401 .finish(|mut hugr_builder| {
402 let outputs = hugr_builder
403 .add_dataflow_op(ext_op, hugr_builder.input_wires())
404 .unwrap()
405 .outputs();
406 hugr_builder.finish_with_outputs(outputs).unwrap()
407 });
408 llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions);
409 llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_list_extensions);
410 check_emission!(op.name().as_str(), hugr, llvm_ctx);
411 }
412
413 #[rstest]
414 fn test_const_list_emmission(mut llvm_ctx: TestContext) {
415 let elem_ty = usize_t();
416 let contents = (1..4).map(|i| Value::extension(ConstUsize::new(i)));
417 let es = ExtensionRegistry::new([list::EXTENSION.to_owned(), prelude::PRELUDE.to_owned()]);
418 es.validate().unwrap();
419
420 let hugr = SimpleHugrConfig::new()
421 .with_ins(vec![])
422 .with_outs(vec![list_type(elem_ty.clone())])
423 .with_extensions(es)
424 .finish(|mut hugr_builder| {
425 let list = hugr_builder.add_load_value(ListValue::new(elem_ty, contents));
426 hugr_builder.finish_with_outputs(vec![list]).unwrap()
427 });
428
429 llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions);
430 llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_list_extensions);
431 check_emission!("const", hugr, llvm_ctx);
432 }
433}