hugr_llvm/extension/collections/
list.rs

1use anyhow::{bail, Ok, Result};
2use hugr_core::{
3    ops::{ExtensionOp, NamedOp},
4    std_extensions::collections::list::{self, ListOp, ListValue},
5    types::{SumType, Type, TypeArg},
6    HugrView, Node,
7};
8use inkwell::values::FunctionValue;
9use inkwell::{
10    types::{BasicType, BasicTypeEnum, FunctionType},
11    values::{BasicValueEnum, PointerValue},
12    AddressSpace,
13};
14
15use crate::emit::func::{build_ok_or_else, build_option};
16use crate::{
17    custom::{CodegenExtension, CodegenExtsBuilder},
18    emit::{emit_value, func::EmitFuncContext, EmitOpArgs},
19    types::TypingSession,
20};
21
22/// Runtime functions that implement operations on lists.
23#[derive(Clone, Copy, Debug, PartialEq, Hash)]
24#[non_exhaustive]
25pub enum ListRtFunc {
26    New,
27    Push,
28    Pop,
29    Get,
30    Set,
31    Insert,
32    Length,
33}
34
35impl ListRtFunc {
36    /// The signature of a given [ListRtFunc].
37    ///
38    /// Requires a [ListCodegen] to determine the type of lists.
39    pub fn signature<'c>(
40        self,
41        ts: TypingSession<'c, '_>,
42        ccg: &(impl ListCodegen + 'c),
43    ) -> FunctionType<'c> {
44        let iwc = ts.iw_context();
45        match self {
46            ListRtFunc::New => ccg.list_type(ts).fn_type(
47                &[
48                    iwc.i64_type().into(), // Capacity
49                    iwc.i64_type().into(), // Single element size in bytes
50                    iwc.i64_type().into(), // Element alignment
51                    // Pointer to element destructor
52                    iwc.i8_type().ptr_type(AddressSpace::default()).into(),
53                ],
54                false,
55            ),
56            ListRtFunc::Push => iwc.void_type().fn_type(
57                &[
58                    ccg.list_type(ts).into(),
59                    iwc.i8_type().ptr_type(AddressSpace::default()).into(),
60                ],
61                false,
62            ),
63            ListRtFunc::Pop => iwc.bool_type().fn_type(
64                &[
65                    ccg.list_type(ts).into(),
66                    iwc.i8_type().ptr_type(AddressSpace::default()).into(),
67                ],
68                false,
69            ),
70            ListRtFunc::Get | ListRtFunc::Set | ListRtFunc::Insert => iwc.bool_type().fn_type(
71                &[
72                    ccg.list_type(ts).into(),
73                    iwc.i64_type().into(),
74                    iwc.i8_type().ptr_type(AddressSpace::default()).into(),
75                ],
76                false,
77            ),
78            ListRtFunc::Length => iwc.i64_type().fn_type(&[ccg.list_type(ts).into()], false),
79        }
80    }
81
82    /// Returns the extern function corresponding to this [ListRtFunc].
83    ///
84    /// Requires a [ListCodegen] to determine the function signature.
85    pub fn get_extern<'c, H: HugrView<Node = Node>>(
86        self,
87        ctx: &EmitFuncContext<'c, '_, H>,
88        ccg: &(impl ListCodegen + 'c),
89    ) -> Result<FunctionValue<'c>> {
90        ctx.get_extern_func(
91            ccg.rt_func_name(self),
92            self.signature(ctx.typing_session(), ccg),
93        )
94    }
95}
96
97impl From<ListOp> for ListRtFunc {
98    fn from(op: ListOp) -> Self {
99        match op {
100            ListOp::get => ListRtFunc::Get,
101            ListOp::set => ListRtFunc::Set,
102            ListOp::push => ListRtFunc::Push,
103            ListOp::pop => ListRtFunc::Pop,
104            ListOp::insert => ListRtFunc::Insert,
105            ListOp::length => ListRtFunc::Length,
106            _ => todo!(),
107        }
108    }
109}
110
111/// A helper trait for customising the lowering of [hugr_core::std_extensions::collections::list]
112/// types, [hugr_core::ops::constant::CustomConst]s, and ops.
113pub trait ListCodegen: Clone {
114    /// Return the llvm type of [hugr_core::std_extensions::collections::list::LIST_TYPENAME].
115    fn list_type<'c>(&self, session: TypingSession<'c, '_>) -> BasicTypeEnum<'c> {
116        session
117            .iw_context()
118            .i8_type()
119            .ptr_type(AddressSpace::default())
120            .into()
121    }
122
123    /// Return the name of a given [ListRtFunc].
124    fn rt_func_name(&self, func: ListRtFunc) -> String {
125        match func {
126            ListRtFunc::New => "__rt__list__new",
127            ListRtFunc::Push => "__rt__list__push",
128            ListRtFunc::Pop => "__rt__list__pop",
129            ListRtFunc::Get => "__rt__list__get",
130            ListRtFunc::Set => "__rt__list__set",
131            ListRtFunc::Insert => "__rt__list__insert",
132            ListRtFunc::Length => "__rt__list__length",
133        }
134        .into()
135    }
136}
137
138/// A trivial implementation of [ListCodegen] which passes all methods
139/// through to their default implementations.
140#[derive(Default, Clone)]
141pub struct DefaultListCodegen;
142
143impl ListCodegen for DefaultListCodegen {}
144
145#[derive(Clone, Debug, Default)]
146pub struct ListCodegenExtension<CCG>(CCG);
147
148impl<CCG: ListCodegen> ListCodegenExtension<CCG> {
149    pub fn new(ccg: CCG) -> Self {
150        Self(ccg)
151    }
152}
153
154impl<CCG: ListCodegen> From<CCG> for ListCodegenExtension<CCG> {
155    fn from(ccg: CCG) -> Self {
156        Self::new(ccg)
157    }
158}
159
160impl<CCG: ListCodegen> CodegenExtension for ListCodegenExtension<CCG> {
161    fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
162        self,
163        builder: CodegenExtsBuilder<'a, H>,
164    ) -> CodegenExtsBuilder<'a, H>
165    where
166        Self: 'a,
167    {
168        builder
169            .custom_type((list::EXTENSION_ID, list::LIST_TYPENAME), {
170                let ccg = self.0.clone();
171                move |ts, _hugr_type| Ok(ccg.list_type(ts).as_basic_type_enum())
172            })
173            .custom_const::<ListValue>({
174                let ccg = self.0.clone();
175                move |ctx, k| emit_list_value(ctx, &ccg, k)
176            })
177            .simple_extension_op::<ListOp>(move |ctx, args, op| {
178                emit_list_op(ctx, &self.0, args, op)
179            })
180    }
181}
182
183impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
184    /// Add a [ListCodegenExtension] to the given [CodegenExtsBuilder] using `ccg`
185    /// as the implementation.
186    pub fn add_default_list_extensions(self) -> Self {
187        self.add_list_extensions(DefaultListCodegen)
188    }
189
190    /// Add a [ListCodegenExtension] to the given [CodegenExtsBuilder] using
191    /// [DefaultListCodegen] as the implementation.
192    pub fn add_list_extensions(self, ccg: impl ListCodegen + 'a) -> Self {
193        self.add_extension(ListCodegenExtension::from(ccg))
194    }
195}
196
197fn emit_list_op<'c, H: HugrView<Node = Node>>(
198    ctx: &mut EmitFuncContext<'c, '_, H>,
199    ccg: &(impl ListCodegen + 'c),
200    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
201    op: ListOp,
202) -> Result<()> {
203    let hugr_elem_ty = match args.node().args() {
204        [TypeArg::Type { ty }] => ty.clone(),
205        _ => {
206            bail!("Collections: invalid type args for list op");
207        }
208    };
209    let elem_ty = ctx.llvm_type(&hugr_elem_ty)?;
210    let func = ListRtFunc::get_extern(op.into(), ctx, ccg)?;
211    match op {
212        ListOp::push => {
213            let [list, elem] = args.inputs.try_into().unwrap();
214            let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?;
215            ctx.builder()
216                .build_call(func, &[list.into(), elem_ptr.into()], "")?;
217            args.outputs.finish(ctx.builder(), vec![list])?;
218        }
219        ListOp::pop => {
220            let [list] = args.inputs.try_into().unwrap();
221            let out_ptr = build_alloca_i8_ptr(ctx, elem_ty, None)?;
222            let ok = ctx
223                .builder()
224                .build_call(func, &[list.into(), out_ptr.into()], "")?
225                .try_as_basic_value()
226                .unwrap_left()
227                .into_int_value();
228            let elem = build_load_i8_ptr(ctx, out_ptr, elem_ty)?;
229            let elem_opt = build_option(ctx, ok, elem, hugr_elem_ty)?;
230            args.outputs.finish(ctx.builder(), vec![list, elem_opt])?;
231        }
232        ListOp::get => {
233            let [list, idx] = args.inputs.try_into().unwrap();
234            let out_ptr = build_alloca_i8_ptr(ctx, elem_ty, None)?;
235            let ok = ctx
236                .builder()
237                .build_call(func, &[list.into(), idx.into(), out_ptr.into()], "")?
238                .try_as_basic_value()
239                .unwrap_left()
240                .into_int_value();
241            let elem = build_load_i8_ptr(ctx, out_ptr, elem_ty)?;
242            let elem_opt = build_option(ctx, ok, elem, hugr_elem_ty)?;
243            args.outputs.finish(ctx.builder(), vec![elem_opt])?;
244        }
245        ListOp::set => {
246            let [list, idx, elem] = args.inputs.try_into().unwrap();
247            let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?;
248            let ok = ctx
249                .builder()
250                .build_call(func, &[list.into(), idx.into(), elem_ptr.into()], "")?
251                .try_as_basic_value()
252                .unwrap_left()
253                .into_int_value();
254            let old_elem = build_load_i8_ptr(ctx, elem_ptr, elem.get_type())?;
255            let ok_or =
256                build_ok_or_else(ctx, ok, elem, hugr_elem_ty.clone(), old_elem, hugr_elem_ty)?;
257            args.outputs.finish(ctx.builder(), vec![list, ok_or])?;
258        }
259        ListOp::insert => {
260            let [list, idx, elem] = args.inputs.try_into().unwrap();
261            let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?;
262            let ok = ctx
263                .builder()
264                .build_call(func, &[list.into(), idx.into(), elem_ptr.into()], "")?
265                .try_as_basic_value()
266                .unwrap_left()
267                .into_int_value();
268            let unit =
269                ctx.llvm_sum_type(SumType::new_unary(1))?
270                    .build_tag(ctx.builder(), 0, vec![])?;
271            let ok_or = build_ok_or_else(ctx, ok, unit.into(), Type::UNIT, elem, hugr_elem_ty)?;
272            args.outputs.finish(ctx.builder(), vec![list, ok_or])?;
273        }
274        ListOp::length => {
275            let [list] = args.inputs.try_into().unwrap();
276            let length = ctx
277                .builder()
278                .build_call(func, &[list.into()], "")?
279                .try_as_basic_value()
280                .unwrap_left()
281                .into_int_value();
282            args.outputs
283                .finish(ctx.builder(), vec![list, length.into()])?;
284        }
285        _ => bail!("Collections: unimplemented op: {}", op.name()),
286    }
287    Ok(())
288}
289
290fn emit_list_value<'c, H: HugrView<Node = Node>>(
291    ctx: &mut EmitFuncContext<'c, '_, H>,
292    ccg: &(impl ListCodegen + 'c),
293    val: &ListValue,
294) -> Result<BasicValueEnum<'c>> {
295    let elem_ty = ctx.llvm_type(val.get_element_type())?;
296    let iwc = ctx.typing_session().iw_context();
297    let capacity = iwc
298        .i64_type()
299        .const_int(val.get_contents().len() as u64, false);
300    let elem_size = elem_ty.size_of().unwrap();
301    let alignment = iwc.i64_type().const_int(8, false);
302    // TODO: Lookup destructor for elem_ty
303    let destructor = iwc.i8_type().ptr_type(AddressSpace::default()).const_null();
304    let list = ctx
305        .builder()
306        .build_call(
307            ListRtFunc::New.get_extern(ctx, ccg)?,
308            &[
309                capacity.into(),
310                elem_size.into(),
311                alignment.into(),
312                destructor.into(),
313            ],
314            "",
315        )?
316        .try_as_basic_value()
317        .unwrap_left();
318    // Push elements onto the list
319    let rt_push = ListRtFunc::Push.get_extern(ctx, ccg)?;
320    for v in val.get_contents() {
321        let elem = emit_value(ctx, v)?;
322        let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?;
323        ctx.builder()
324            .build_call(rt_push, &[list.into(), elem_ptr.into()], "")?;
325    }
326    Ok(list)
327}
328
329/// Helper function to allocate space on the stack for a given type.
330///
331/// Optionally also stores a value at that location.
332///
333/// Returns an i8 pointer to the allocated memory.
334fn build_alloca_i8_ptr<'c, H: HugrView<Node = Node>>(
335    ctx: &mut EmitFuncContext<'c, '_, H>,
336    ty: BasicTypeEnum<'c>,
337    value: Option<BasicValueEnum<'c>>,
338) -> Result<PointerValue<'c>> {
339    let builder = ctx.builder();
340    let ptr = builder.build_alloca(ty, "")?;
341    if let Some(val) = value {
342        builder.build_store(ptr, val)?;
343    }
344    let i8_ptr = builder.build_pointer_cast(
345        ptr,
346        ctx.iw_context().i8_type().ptr_type(AddressSpace::default()),
347        "",
348    )?;
349    Ok(i8_ptr)
350}
351
352/// Helper function to load a value from an i8 pointer.
353fn build_load_i8_ptr<'c, H: HugrView<Node = Node>>(
354    ctx: &mut EmitFuncContext<'c, '_, H>,
355    i8_ptr: PointerValue<'c>,
356    ty: BasicTypeEnum<'c>,
357) -> Result<BasicValueEnum<'c>> {
358    let builder = ctx.builder();
359    let ptr = builder.build_pointer_cast(i8_ptr, ty.ptr_type(AddressSpace::default()), "")?;
360    let val = builder.build_load(ptr, "")?;
361    Ok(val)
362}
363
364#[cfg(test)]
365mod test {
366    use hugr_core::{
367        builder::{Dataflow, DataflowSubContainer},
368        extension::{
369            prelude::{self, qb_t, usize_t, ConstUsize},
370            ExtensionRegistry,
371        },
372        ops::{DataflowOpTrait, NamedOp, Value},
373        std_extensions::collections::list::{self, list_type, ListOp, ListValue},
374    };
375    use rstest::rstest;
376
377    use crate::{
378        check_emission,
379        custom::CodegenExtsBuilder,
380        emit::test::SimpleHugrConfig,
381        test::{llvm_ctx, TestContext},
382    };
383
384    #[rstest]
385    #[case::push(ListOp::push)]
386    #[case::pop(ListOp::pop)]
387    #[case::get(ListOp::get)]
388    #[case::set(ListOp::set)]
389    #[case::insert(ListOp::insert)]
390    #[case::length(ListOp::length)]
391    fn test_list_emission(mut llvm_ctx: TestContext, #[case] op: ListOp) {
392        let ext_op = list::EXTENSION
393            .instantiate_extension_op(op.name().as_ref(), [qb_t().into()])
394            .unwrap();
395        let es = ExtensionRegistry::new([list::EXTENSION.to_owned(), prelude::PRELUDE.to_owned()]);
396        es.validate().unwrap();
397        let hugr = SimpleHugrConfig::new()
398            .with_ins(ext_op.signature().input().clone())
399            .with_outs(ext_op.signature().output().clone())
400            .with_extensions(es)
401            .finish(|mut hugr_builder| {
402                let outputs = hugr_builder
403                    .add_dataflow_op(ext_op, hugr_builder.input_wires())
404                    .unwrap()
405                    .outputs();
406                hugr_builder.finish_with_outputs(outputs).unwrap()
407            });
408        llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions);
409        llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_list_extensions);
410        check_emission!(op.name().as_str(), hugr, llvm_ctx);
411    }
412
413    #[rstest]
414    fn test_const_list_emmission(mut llvm_ctx: TestContext) {
415        let elem_ty = usize_t();
416        let contents = (1..4).map(|i| Value::extension(ConstUsize::new(i)));
417        let es = ExtensionRegistry::new([list::EXTENSION.to_owned(), prelude::PRELUDE.to_owned()]);
418        es.validate().unwrap();
419
420        let hugr = SimpleHugrConfig::new()
421            .with_ins(vec![])
422            .with_outs(vec![list_type(elem_ty.clone())])
423            .with_extensions(es)
424            .finish(|mut hugr_builder| {
425                let list = hugr_builder.add_load_value(ListValue::new(elem_ty, contents));
426                hugr_builder.finish_with_outputs(vec![list]).unwrap()
427            });
428
429        llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions);
430        llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_list_extensions);
431        check_emission!("const", hugr, llvm_ctx);
432    }
433}