hugr_llvm/extension/collections/
list.rs

1use anyhow::{Ok, Result, bail};
2use hugr_core::{
3    HugrView, Node,
4    extension::simple_op::MakeExtensionOp as _,
5    ops::ExtensionOp,
6    std_extensions::collections::list::{self, ListOp, ListValue},
7    types::{SumType, Type, TypeArg},
8};
9use inkwell::values::FunctionValue;
10use inkwell::{
11    AddressSpace,
12    types::{BasicType, BasicTypeEnum, FunctionType},
13    values::{BasicValueEnum, PointerValue},
14};
15
16use crate::emit::func::{build_ok_or_else, build_option};
17use crate::{
18    custom::{CodegenExtension, CodegenExtsBuilder},
19    emit::{EmitOpArgs, emit_value, func::EmitFuncContext},
20    types::TypingSession,
21};
22
23/// Runtime functions that implement operations on lists.
24#[derive(Clone, Copy, Debug, PartialEq, Hash)]
25#[non_exhaustive]
26pub enum ListRtFunc {
27    New,
28    Push,
29    Pop,
30    Get,
31    Set,
32    Insert,
33    Length,
34}
35
36impl ListRtFunc {
37    /// The signature of a given [`ListRtFunc`].
38    ///
39    /// Requires a [`ListCodegen`] to determine the type of lists.
40    pub fn signature<'c>(
41        self,
42        ts: TypingSession<'c, '_>,
43        ccg: &(impl ListCodegen + 'c),
44    ) -> FunctionType<'c> {
45        let iwc = ts.iw_context();
46        match self {
47            ListRtFunc::New => ccg.list_type(ts).fn_type(
48                &[
49                    iwc.i64_type().into(), // Capacity
50                    iwc.i64_type().into(), // Single element size in bytes
51                    iwc.i64_type().into(), // Element alignment
52                    // Pointer to element destructor
53                    iwc.i8_type().ptr_type(AddressSpace::default()).into(),
54                ],
55                false,
56            ),
57            ListRtFunc::Push => iwc.void_type().fn_type(
58                &[
59                    ccg.list_type(ts).into(),
60                    iwc.i8_type().ptr_type(AddressSpace::default()).into(),
61                ],
62                false,
63            ),
64            ListRtFunc::Pop => iwc.bool_type().fn_type(
65                &[
66                    ccg.list_type(ts).into(),
67                    iwc.i8_type().ptr_type(AddressSpace::default()).into(),
68                ],
69                false,
70            ),
71            ListRtFunc::Get | ListRtFunc::Set | ListRtFunc::Insert => iwc.bool_type().fn_type(
72                &[
73                    ccg.list_type(ts).into(),
74                    iwc.i64_type().into(),
75                    iwc.i8_type().ptr_type(AddressSpace::default()).into(),
76                ],
77                false,
78            ),
79            ListRtFunc::Length => iwc.i64_type().fn_type(&[ccg.list_type(ts).into()], false),
80        }
81    }
82
83    /// Returns the extern function corresponding to this [`ListRtFunc`].
84    ///
85    /// Requires a [`ListCodegen`] to determine the function signature.
86    pub fn get_extern<'c, H: HugrView<Node = Node>>(
87        self,
88        ctx: &EmitFuncContext<'c, '_, H>,
89        ccg: &(impl ListCodegen + 'c),
90    ) -> Result<FunctionValue<'c>> {
91        ctx.get_extern_func(
92            ccg.rt_func_name(self),
93            self.signature(ctx.typing_session(), ccg),
94        )
95    }
96}
97
98impl From<ListOp> for ListRtFunc {
99    fn from(op: ListOp) -> Self {
100        match op {
101            ListOp::get => ListRtFunc::Get,
102            ListOp::set => ListRtFunc::Set,
103            ListOp::push => ListRtFunc::Push,
104            ListOp::pop => ListRtFunc::Pop,
105            ListOp::insert => ListRtFunc::Insert,
106            ListOp::length => ListRtFunc::Length,
107            _ => todo!(),
108        }
109    }
110}
111
112/// A helper trait for customising the lowering of [`hugr_core::std_extensions::collections::list`]
113/// types, [`hugr_core::ops::constant::CustomConst`]s, and ops.
114pub trait ListCodegen: Clone {
115    /// Return the llvm type of [`hugr_core::std_extensions::collections::list::LIST_TYPENAME`].
116    fn list_type<'c>(&self, session: TypingSession<'c, '_>) -> BasicTypeEnum<'c> {
117        session
118            .iw_context()
119            .i8_type()
120            .ptr_type(AddressSpace::default())
121            .into()
122    }
123
124    /// Return the name of a given [`ListRtFunc`].
125    fn rt_func_name(&self, func: ListRtFunc) -> String {
126        match func {
127            ListRtFunc::New => "__rt__list__new",
128            ListRtFunc::Push => "__rt__list__push",
129            ListRtFunc::Pop => "__rt__list__pop",
130            ListRtFunc::Get => "__rt__list__get",
131            ListRtFunc::Set => "__rt__list__set",
132            ListRtFunc::Insert => "__rt__list__insert",
133            ListRtFunc::Length => "__rt__list__length",
134        }
135        .into()
136    }
137}
138
139/// A trivial implementation of [`ListCodegen`] which passes all methods
140/// through to their default implementations.
141#[derive(Default, Clone)]
142pub struct DefaultListCodegen;
143
144impl ListCodegen for DefaultListCodegen {}
145
146#[derive(Clone, Debug, Default)]
147pub struct ListCodegenExtension<CCG>(CCG);
148
149impl<CCG: ListCodegen> ListCodegenExtension<CCG> {
150    pub fn new(ccg: CCG) -> Self {
151        Self(ccg)
152    }
153}
154
155impl<CCG: ListCodegen> From<CCG> for ListCodegenExtension<CCG> {
156    fn from(ccg: CCG) -> Self {
157        Self::new(ccg)
158    }
159}
160
161impl<CCG: ListCodegen> CodegenExtension for ListCodegenExtension<CCG> {
162    fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
163        self,
164        builder: CodegenExtsBuilder<'a, H>,
165    ) -> CodegenExtsBuilder<'a, H>
166    where
167        Self: 'a,
168    {
169        builder
170            .custom_type((list::EXTENSION_ID, list::LIST_TYPENAME), {
171                let ccg = self.0.clone();
172                move |ts, _hugr_type| Ok(ccg.list_type(ts).as_basic_type_enum())
173            })
174            .custom_const::<ListValue>({
175                let ccg = self.0.clone();
176                move |ctx, k| emit_list_value(ctx, &ccg, k)
177            })
178            .simple_extension_op::<ListOp>(move |ctx, args, op| {
179                emit_list_op(ctx, &self.0, args, op)
180            })
181    }
182}
183
184impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
185    /// Add a [`ListCodegenExtension`] to the given [`CodegenExtsBuilder`] using `ccg`
186    /// as the implementation.
187    #[must_use]
188    pub fn add_default_list_extensions(self) -> Self {
189        self.add_list_extensions(DefaultListCodegen)
190    }
191
192    /// Add a [`ListCodegenExtension`] to the given [`CodegenExtsBuilder`] using
193    /// [`DefaultListCodegen`] as the implementation.
194    pub fn add_list_extensions(self, ccg: impl ListCodegen + 'a) -> Self {
195        self.add_extension(ListCodegenExtension::from(ccg))
196    }
197}
198
199fn emit_list_op<'c, H: HugrView<Node = Node>>(
200    ctx: &mut EmitFuncContext<'c, '_, H>,
201    ccg: &(impl ListCodegen + 'c),
202    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
203    op: ListOp,
204) -> Result<()> {
205    let hugr_elem_ty = match args.node().args() {
206        [TypeArg::Runtime(ty)] => ty.clone(),
207        _ => {
208            bail!("Collections: invalid type args for list op");
209        }
210    };
211    let elem_ty = ctx.llvm_type(&hugr_elem_ty)?;
212    let func = ListRtFunc::get_extern(op.into(), ctx, ccg)?;
213    match op {
214        ListOp::push => {
215            let [list, elem] = args.inputs.try_into().unwrap();
216            let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?;
217            ctx.builder()
218                .build_call(func, &[list.into(), elem_ptr.into()], "")?;
219            args.outputs.finish(ctx.builder(), vec![list])?;
220        }
221        ListOp::pop => {
222            let [list] = args.inputs.try_into().unwrap();
223            let out_ptr = build_alloca_i8_ptr(ctx, elem_ty, None)?;
224            let ok = ctx
225                .builder()
226                .build_call(func, &[list.into(), out_ptr.into()], "")?
227                .try_as_basic_value()
228                .unwrap_left()
229                .into_int_value();
230            let elem = build_load_i8_ptr(ctx, out_ptr, elem_ty)?;
231            let elem_opt = build_option(ctx, ok, elem, hugr_elem_ty)?;
232            args.outputs.finish(ctx.builder(), vec![list, elem_opt])?;
233        }
234        ListOp::get => {
235            let [list, idx] = args.inputs.try_into().unwrap();
236            let out_ptr = build_alloca_i8_ptr(ctx, elem_ty, None)?;
237            let ok = ctx
238                .builder()
239                .build_call(func, &[list.into(), idx.into(), out_ptr.into()], "")?
240                .try_as_basic_value()
241                .unwrap_left()
242                .into_int_value();
243            let elem = build_load_i8_ptr(ctx, out_ptr, elem_ty)?;
244            let elem_opt = build_option(ctx, ok, elem, hugr_elem_ty)?;
245            args.outputs.finish(ctx.builder(), vec![elem_opt])?;
246        }
247        ListOp::set => {
248            let [list, idx, elem] = args.inputs.try_into().unwrap();
249            let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?;
250            let ok = ctx
251                .builder()
252                .build_call(func, &[list.into(), idx.into(), elem_ptr.into()], "")?
253                .try_as_basic_value()
254                .unwrap_left()
255                .into_int_value();
256            let old_elem = build_load_i8_ptr(ctx, elem_ptr, elem.get_type())?;
257            let ok_or =
258                build_ok_or_else(ctx, ok, elem, hugr_elem_ty.clone(), old_elem, hugr_elem_ty)?;
259            args.outputs.finish(ctx.builder(), vec![list, ok_or])?;
260        }
261        ListOp::insert => {
262            let [list, idx, elem] = args.inputs.try_into().unwrap();
263            let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?;
264            let ok = ctx
265                .builder()
266                .build_call(func, &[list.into(), idx.into(), elem_ptr.into()], "")?
267                .try_as_basic_value()
268                .unwrap_left()
269                .into_int_value();
270            let unit =
271                ctx.llvm_sum_type(SumType::new_unary(1))?
272                    .build_tag(ctx.builder(), 0, vec![])?;
273            let ok_or = build_ok_or_else(ctx, ok, unit.into(), Type::UNIT, elem, hugr_elem_ty)?;
274            args.outputs.finish(ctx.builder(), vec![list, ok_or])?;
275        }
276        ListOp::length => {
277            let [list] = args.inputs.try_into().unwrap();
278            let length = ctx
279                .builder()
280                .build_call(func, &[list.into()], "")?
281                .try_as_basic_value()
282                .unwrap_left()
283                .into_int_value();
284            args.outputs
285                .finish(ctx.builder(), vec![list, length.into()])?;
286        }
287        _ => bail!("Collections: unimplemented op: {}", op.op_id()),
288    }
289    Ok(())
290}
291
292fn emit_list_value<'c, H: HugrView<Node = Node>>(
293    ctx: &mut EmitFuncContext<'c, '_, H>,
294    ccg: &(impl ListCodegen + 'c),
295    val: &ListValue,
296) -> Result<BasicValueEnum<'c>> {
297    let elem_ty = ctx.llvm_type(val.get_element_type())?;
298    let iwc = ctx.typing_session().iw_context();
299    let capacity = iwc
300        .i64_type()
301        .const_int(val.get_contents().len() as u64, false);
302    let elem_size = elem_ty.size_of().unwrap();
303    let alignment = iwc.i64_type().const_int(8, false);
304    // TODO: Lookup destructor for elem_ty
305    let destructor = iwc.i8_type().ptr_type(AddressSpace::default()).const_null();
306    let list = ctx
307        .builder()
308        .build_call(
309            ListRtFunc::New.get_extern(ctx, ccg)?,
310            &[
311                capacity.into(),
312                elem_size.into(),
313                alignment.into(),
314                destructor.into(),
315            ],
316            "",
317        )?
318        .try_as_basic_value()
319        .unwrap_left();
320    // Push elements onto the list
321    let rt_push = ListRtFunc::Push.get_extern(ctx, ccg)?;
322    for v in val.get_contents() {
323        let elem = emit_value(ctx, v)?;
324        let elem_ptr = build_alloca_i8_ptr(ctx, elem_ty, Some(elem))?;
325        ctx.builder()
326            .build_call(rt_push, &[list.into(), elem_ptr.into()], "")?;
327    }
328    Ok(list)
329}
330
331/// Helper function to allocate space on the stack for a given type.
332///
333/// Optionally also stores a value at that location.
334///
335/// Returns an i8 pointer to the allocated memory.
336fn build_alloca_i8_ptr<'c, H: HugrView<Node = Node>>(
337    ctx: &mut EmitFuncContext<'c, '_, H>,
338    ty: BasicTypeEnum<'c>,
339    value: Option<BasicValueEnum<'c>>,
340) -> Result<PointerValue<'c>> {
341    let builder = ctx.builder();
342    let ptr = builder.build_alloca(ty, "")?;
343    if let Some(val) = value {
344        builder.build_store(ptr, val)?;
345    }
346    let i8_ptr = builder.build_pointer_cast(
347        ptr,
348        ctx.iw_context().i8_type().ptr_type(AddressSpace::default()),
349        "",
350    )?;
351    Ok(i8_ptr)
352}
353
354/// Helper function to load a value from an i8 pointer.
355fn build_load_i8_ptr<'c, H: HugrView<Node = Node>>(
356    ctx: &mut EmitFuncContext<'c, '_, H>,
357    i8_ptr: PointerValue<'c>,
358    ty: BasicTypeEnum<'c>,
359) -> Result<BasicValueEnum<'c>> {
360    let builder = ctx.builder();
361    let ptr = builder.build_pointer_cast(i8_ptr, ty.ptr_type(AddressSpace::default()), "")?;
362    let val = builder.build_load(ptr, "")?;
363    Ok(val)
364}
365
366#[cfg(test)]
367mod test {
368    use hugr_core::{
369        builder::{Dataflow, DataflowHugr},
370        extension::{
371            ExtensionRegistry,
372            prelude::{self, ConstUsize, qb_t, usize_t},
373        },
374        ops::{DataflowOpTrait, Value},
375        std_extensions::collections::list::{self, ListOp, ListValue, list_type},
376    };
377    use rstest::rstest;
378
379    use crate::{
380        check_emission,
381        custom::CodegenExtsBuilder,
382        emit::test::SimpleHugrConfig,
383        test::{TestContext, llvm_ctx},
384    };
385
386    #[rstest]
387    #[case::push(ListOp::push)]
388    #[case::pop(ListOp::pop)]
389    #[case::get(ListOp::get)]
390    #[case::set(ListOp::set)]
391    #[case::insert(ListOp::insert)]
392    #[case::length(ListOp::length)]
393    fn test_list_emission(mut llvm_ctx: TestContext, #[case] op: ListOp) {
394        use hugr_core::extension::simple_op::MakeExtensionOp as _;
395
396        let ext_op = list::EXTENSION
397            .instantiate_extension_op(op.op_id().as_ref(), [qb_t().into()])
398            .unwrap();
399        let es = ExtensionRegistry::new([list::EXTENSION.to_owned(), prelude::PRELUDE.to_owned()]);
400        es.validate().unwrap();
401        let hugr = SimpleHugrConfig::new()
402            .with_ins(ext_op.signature().input().clone())
403            .with_outs(ext_op.signature().output().clone())
404            .with_extensions(es)
405            .finish(|mut hugr_builder| {
406                let outputs = hugr_builder
407                    .add_dataflow_op(ext_op, hugr_builder.input_wires())
408                    .unwrap()
409                    .outputs();
410                hugr_builder.finish_hugr_with_outputs(outputs).unwrap()
411            });
412        llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions);
413        llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_list_extensions);
414        check_emission!(op.op_id().as_str(), hugr, llvm_ctx);
415    }
416
417    #[rstest]
418    fn test_const_list_emmission(mut llvm_ctx: TestContext) {
419        let elem_ty = usize_t();
420        let contents = (1..4).map(|i| Value::extension(ConstUsize::new(i)));
421        let es = ExtensionRegistry::new([list::EXTENSION.to_owned(), prelude::PRELUDE.to_owned()]);
422        es.validate().unwrap();
423
424        let hugr = SimpleHugrConfig::new()
425            .with_ins(vec![])
426            .with_outs(vec![list_type(elem_ty.clone())])
427            .with_extensions(es)
428            .finish(|mut hugr_builder| {
429                let list = hugr_builder.add_load_value(ListValue::new(elem_ty, contents));
430                hugr_builder.finish_hugr_with_outputs(vec![list]).unwrap()
431            });
432
433        llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions);
434        llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_list_extensions);
435        check_emission!("const", hugr, llvm_ctx);
436    }
437}