hugr_llvm/extension/
prelude.rs

1use anyhow::{anyhow, bail, ensure, Ok, Result};
2use hugr_core::extension::prelude::generic::LoadNat;
3use hugr_core::extension::prelude::{
4    self, error_type, generic, ConstError, ConstExternalSymbol, ConstString, ConstUsize, MakeTuple,
5    TupleOpDef, UnpackTuple,
6};
7use hugr_core::extension::prelude::{ERROR_TYPE_NAME, STRING_TYPE_NAME};
8use hugr_core::ops::ExtensionOp;
9use hugr_core::types::TypeArg;
10use hugr_core::Node;
11use hugr_core::{
12    extension::simple_op::MakeExtensionOp as _, ops::constant::CustomConst, types::SumType,
13    HugrView,
14};
15use inkwell::{
16    types::{BasicType, IntType, PointerType},
17    values::{BasicValue as _, BasicValueEnum, StructValue},
18    AddressSpace,
19};
20use itertools::Itertools;
21
22use crate::emit::EmitOpArgs;
23use crate::{
24    custom::{CodegenExtension, CodegenExtsBuilder},
25    emit::{
26        func::EmitFuncContext,
27        libc::{emit_libc_abort, emit_libc_printf},
28    },
29    sum::LLVMSumValue,
30    types::TypingSession,
31};
32
33/// A helper trait for customising the lowering [hugr_core::extension::prelude]
34/// types, [CustomConst]s, and ops.
35///
36/// All methods have sensible defaults provided, and [DefaultPreludeCodegen] is
37/// a trivial implementation of this trait which delegates everything to those
38/// default implementations.
39pub trait PreludeCodegen: Clone {
40    /// Return the llvm type of [hugr_core::extension::prelude::usize_t]. That type
41    /// must be an [IntType].
42    fn usize_type<'c>(&self, session: &TypingSession<'c, '_>) -> IntType<'c> {
43        session.iw_context().i64_type()
44    }
45
46    /// Return the llvm type of [hugr_core::extension::prelude::qb_t].
47    fn qubit_type<'c>(&self, session: &TypingSession<'c, '_>) -> impl BasicType<'c> {
48        session.iw_context().i16_type()
49    }
50
51    /// Return the llvm type of [hugr_core::extension::prelude::error_type()].
52    ///
53    /// The returned type must always match the type of the returned value of
54    /// [Self::emit_const_error], and the `err` argument of [Self::emit_panic].
55    ///
56    /// The default implementation is a struct type with an i32 field and an i8*
57    /// field for the code and message.
58    fn error_type<'c>(&self, session: &TypingSession<'c, '_>) -> Result<impl BasicType<'c>> {
59        let ctx = session.iw_context();
60        Ok(session.iw_context().struct_type(
61            &[
62                ctx.i32_type().into(),
63                ctx.i8_type().ptr_type(AddressSpace::default()).into(),
64            ],
65            false,
66        ))
67    }
68
69    /// Return the llvm type of [hugr_core::extension::prelude::string_type()].
70    ///
71    /// The returned type must always match the type of the returned value of
72    /// [Self::emit_const_string], and the `text` argument of [Self::emit_print].
73    ///
74    /// The default implementation is i8*.
75    fn string_type<'c>(&self, session: &TypingSession<'c, '_>) -> Result<impl BasicType<'c>> {
76        Ok(session
77            .iw_context()
78            .i8_type()
79            .ptr_type(AddressSpace::default()))
80    }
81
82    /// Emit a [hugr_core::extension::prelude::PRINT_OP_ID] node.
83    fn emit_print<H: HugrView<Node = Node>>(
84        &self,
85        ctx: &mut EmitFuncContext<H>,
86        text: BasicValueEnum,
87    ) -> Result<()> {
88        let format_str = ctx
89            .builder()
90            .build_global_string_ptr("%s\n", "prelude.print_template")?
91            .as_basic_value_enum();
92        emit_libc_printf(ctx, &[format_str.into(), text.into()])
93    }
94
95    /// Emit instructions to materialise an LLVM value representing `err`.
96    ///
97    /// The type of the returned value must match [Self::error_type].
98    ///
99    /// The default implementation materialises an LLVM struct with the
100    /// [ConstError::signal] and [ConstError::message] of `err`.
101    fn emit_const_error<'c, H: HugrView<Node = Node>>(
102        &self,
103        ctx: &mut EmitFuncContext<'c, '_, H>,
104        err: &ConstError,
105    ) -> Result<BasicValueEnum<'c>> {
106        let builder = ctx.builder();
107        let err_ty = ctx.llvm_type(&error_type())?.into_struct_type();
108        let signal = err_ty
109            .get_field_type_at_index(0)
110            .unwrap()
111            .into_int_type()
112            .const_int(err.signal as u64, false);
113        let message = builder
114            .build_global_string_ptr(&err.message, "")?
115            .as_basic_value_enum();
116        let err = err_ty.const_named_struct(&[signal.into(), message]);
117        Ok(err.into())
118    }
119
120    /// Emit instructions to halt execution with the error `err`.
121    ///
122    /// The type of `err` must match that returned from [Self::error_type].
123    ///
124    /// The default implementation emits calls to libc's `printf` and `abort`.
125    ///
126    /// Note that implementations of `emit_panic` must not emit `unreachable`
127    /// terminators, that, if appropriate, is the responsibility of the caller.
128    fn emit_panic<H: HugrView<Node = Node>>(
129        &self,
130        ctx: &mut EmitFuncContext<H>,
131        err: BasicValueEnum,
132    ) -> Result<()> {
133        let format_str = ctx
134            .builder()
135            .build_global_string_ptr(
136                "Program panicked (signal %i): %s\n",
137                "prelude.panic_template",
138            )?
139            .as_basic_value_enum();
140        let Some(err) = StructValue::try_from(err).ok() else {
141            bail!("emit_panic: Expected err value to be a struct type")
142        };
143        ensure!(err.get_type().count_fields() == 2);
144        let signal = ctx.builder().build_extract_value(err, 0, "")?;
145        ensure!(signal.get_type() == ctx.iw_context().i32_type().as_basic_type_enum());
146        let msg = ctx.builder().build_extract_value(err, 1, "")?;
147        ensure!(PointerType::try_from(msg.get_type()).is_ok());
148        emit_libc_printf(ctx, &[format_str.into(), signal.into(), msg.into()])?;
149        emit_libc_abort(ctx)
150    }
151
152    /// Emit instructions to materialise an LLVM value representing `str`.
153    ///
154    /// The type of the returned value must match [Self::string_type].
155    ///
156    /// The default implementation creates a global C string.
157    fn emit_const_string<'c, H: HugrView<Node = Node>>(
158        &self,
159        ctx: &mut EmitFuncContext<'c, '_, H>,
160        str: &ConstString,
161    ) -> Result<BasicValueEnum<'c>> {
162        let default_str_type = ctx
163            .iw_context()
164            .i8_type()
165            .ptr_type(AddressSpace::default())
166            .as_basic_type_enum();
167        let str_type = ctx.llvm_type(&str.get_type())?.as_basic_type_enum();
168        ensure!(str_type == default_str_type, "The default implementation of PreludeCodegen::string_type was overridden, but the default implementation of emit_const_string was not. String type is: {str_type}");
169        let s = ctx.builder().build_global_string_ptr(str.value(), "")?;
170        Ok(s.as_basic_value_enum())
171    }
172
173    fn emit_barrier<'c, H: HugrView<Node = Node>>(
174        &self,
175        ctx: &mut EmitFuncContext<'c, '_, H>,
176        args: EmitOpArgs<'c, '_, ExtensionOp, H>,
177    ) -> Result<()> {
178        // By default, treat barriers as no-ops.
179        args.outputs.finish(ctx.builder(), args.inputs)
180    }
181}
182
183/// A trivial implementation of [PreludeCodegen] which passes all methods
184/// through to their default implementations.
185#[derive(Default, Clone)]
186pub struct DefaultPreludeCodegen;
187
188impl PreludeCodegen for DefaultPreludeCodegen {}
189
190#[derive(Clone, Debug, Default)]
191pub struct PreludeCodegenExtension<PCG>(PCG);
192
193impl<PCG: PreludeCodegen> PreludeCodegenExtension<PCG> {
194    pub fn new(pcg: PCG) -> Self {
195        Self(pcg)
196    }
197}
198
199impl<PCG: PreludeCodegen> From<PCG> for PreludeCodegenExtension<PCG> {
200    fn from(pcg: PCG) -> Self {
201        Self::new(pcg)
202    }
203}
204
205impl<PCG: PreludeCodegen> CodegenExtension for PreludeCodegenExtension<PCG> {
206    fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
207        self,
208        builder: CodegenExtsBuilder<'a, H>,
209    ) -> CodegenExtsBuilder<'a, H>
210    where
211        Self: 'a,
212    {
213        add_prelude_extensions(builder, self.0)
214    }
215}
216
217impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
218    /// Add a [PreludeCodegenExtension] to the given [CodegenExtsBuilder] using `pcg`
219    /// as the implementation.
220    pub fn add_default_prelude_extensions(self) -> Self {
221        self.add_prelude_extensions(DefaultPreludeCodegen)
222    }
223
224    /// Add a [PreludeCodegenExtension] to the given [CodegenExtsBuilder] using
225    /// [DefaultPreludeCodegen] as the implementation.
226    pub fn add_prelude_extensions(self, pcg: impl PreludeCodegen + 'a) -> Self {
227        self.add_extension(PreludeCodegenExtension::from(pcg))
228    }
229}
230
231/// Add a [PreludeCodegenExtension] to the given [CodegenExtsBuilder] using `pcg`
232/// as the implementation.
233pub fn add_prelude_extensions<'a, H: HugrView<Node = Node> + 'a>(
234    cem: CodegenExtsBuilder<'a, H>,
235    pcg: impl PreludeCodegen + 'a,
236) -> CodegenExtsBuilder<'a, H> {
237    cem.custom_type((prelude::PRELUDE_ID, "qubit".into()), {
238        let pcg = pcg.clone();
239        move |ts, _| Ok(pcg.qubit_type(&ts).as_basic_type_enum())
240    })
241    .custom_type((prelude::PRELUDE_ID, "usize".into()), {
242        let pcg = pcg.clone();
243        move |ts, _| Ok(pcg.usize_type(&ts).as_basic_type_enum())
244    })
245    .custom_type((prelude::PRELUDE_ID, ERROR_TYPE_NAME.clone()), {
246        let pcg = pcg.clone();
247        move |ts, _| Ok(pcg.error_type(&ts)?.as_basic_type_enum())
248    })
249    .custom_type((prelude::PRELUDE_ID, STRING_TYPE_NAME.clone()), {
250        let pcg = pcg.clone();
251        move |ts, _| Ok(pcg.string_type(&ts)?.as_basic_type_enum())
252    })
253    .custom_const::<ConstUsize>(|context, k| {
254        let ty: IntType = context
255            .llvm_type(&k.get_type())?
256            .try_into()
257            .map_err(|_| anyhow!("Failed to get ConstUsize as IntType"))?;
258        Ok(ty.const_int(k.value(), false).into())
259    })
260    .custom_const::<ConstExternalSymbol>(|context, k| {
261        // TODO we should namespace these symbols
262        // https://github.com/CQCL/hugr-llvm/issues/120
263        let llvm_type = context.llvm_type(&k.get_type())?;
264        let global = context.get_global(&k.symbol, llvm_type, k.constant)?;
265        Ok(context
266            .builder()
267            .build_load(global.as_pointer_value(), &k.symbol)?)
268    })
269    .custom_const::<ConstString>({
270        let pcg = pcg.clone();
271        move |context, k| {
272            let err = pcg.emit_const_string(context, k)?;
273            ensure!(
274                err.get_type()
275                    == pcg
276                        .string_type(&context.typing_session())?
277                        .as_basic_type_enum()
278            );
279            Ok(err)
280        }
281    })
282    .custom_const::<ConstError>({
283        let pcg = pcg.clone();
284        move |context, k| {
285            let err = pcg.emit_const_error(context, k)?;
286            ensure!(
287                err.get_type()
288                    == pcg
289                        .error_type(&context.typing_session())?
290                        .as_basic_type_enum()
291            );
292            Ok(err)
293        }
294    })
295    .simple_extension_op::<TupleOpDef>(|context, args, op| match op {
296        TupleOpDef::UnpackTuple => {
297            let unpack_tuple = UnpackTuple::from_extension_op(args.node().as_ref())?;
298            let llvm_sum_type = context.llvm_sum_type(SumType::new([unpack_tuple.0]))?;
299            let llvm_sum_value = args
300                .inputs
301                .into_iter()
302                .exactly_one()
303                .map_err(|_| anyhow!("UnpackTuple does not have exactly one input"))
304                .and_then(|v| LLVMSumValue::try_new(v, llvm_sum_type))?;
305            let rs = llvm_sum_value.build_untag(context.builder(), 0)?;
306            args.outputs.finish(context.builder(), rs)
307        }
308        TupleOpDef::MakeTuple => {
309            let make_tuple = MakeTuple::from_extension_op(args.node().as_ref())?;
310            let llvm_sum_type = context.llvm_sum_type(SumType::new([make_tuple.0]))?;
311            let r = llvm_sum_type.build_tag(context.builder(), 0, args.inputs)?;
312            args.outputs.finish(context.builder(), [r.into()])
313        }
314        _ => Err(anyhow!("Unsupported TupleOpDef")),
315    })
316    .extension_op(prelude::PRELUDE_ID, prelude::PRINT_OP_ID, {
317        let pcg = pcg.clone();
318        move |context, args| {
319            let text = args.inputs[0];
320            pcg.emit_print(context, text)?;
321            args.outputs.finish(context.builder(), [])
322        }
323    })
324    .extension_op(prelude::PRELUDE_ID, prelude::PANIC_OP_ID, {
325        let pcg = pcg.clone();
326        move |context, args| {
327            let err = args.inputs[0];
328            ensure!(
329                err.get_type()
330                    == pcg
331                        .error_type(&context.typing_session())?
332                        .as_basic_type_enum()
333            );
334            pcg.emit_panic(context, err)?;
335            let returns = args
336                .outputs
337                .get_types()
338                .map(|ty| ty.const_zero())
339                .collect_vec();
340            args.outputs.finish(context.builder(), returns)
341        }
342    })
343    .extension_op(prelude::PRELUDE_ID, generic::LOAD_NAT_OP_ID, {
344        let pcg = pcg.clone();
345        move |context, args| {
346            let load_nat = LoadNat::from_extension_op(args.node().as_ref())?;
347            let v = match load_nat.get_nat() {
348                TypeArg::BoundedNat { n } => pcg
349                    .usize_type(&context.typing_session())
350                    .const_int(n, false),
351                arg => bail!("Unexpected type arg for LoadNat: {}", arg),
352            };
353            args.outputs.finish(context.builder(), vec![v.into()])
354        }
355    })
356    .extension_op(prelude::PRELUDE_ID, prelude::BARRIER_OP_ID, {
357        let pcg = pcg.clone();
358        move |context, args| pcg.emit_barrier(context, args)
359    })
360}
361
362#[cfg(test)]
363mod test {
364    use hugr_core::builder::{Dataflow, DataflowSubContainer};
365    use hugr_core::extension::PRELUDE;
366    use hugr_core::types::{Type, TypeArg};
367    use hugr_core::{type_row, Hugr};
368    use prelude::{bool_t, qb_t, usize_t, PANIC_OP_ID, PRINT_OP_ID};
369    use rstest::{fixture, rstest};
370
371    use crate::check_emission;
372    use crate::custom::CodegenExtsBuilder;
373    use crate::emit::test::SimpleHugrConfig;
374    use crate::test::{exec_ctx, llvm_ctx, TestContext};
375    use crate::types::HugrType;
376
377    use super::*;
378
379    #[derive(Clone)]
380    struct TestPreludeCodegen;
381    impl PreludeCodegen for TestPreludeCodegen {
382        fn usize_type<'c>(&self, session: &TypingSession<'c, '_>) -> IntType<'c> {
383            session.iw_context().i32_type()
384        }
385
386        fn qubit_type<'c>(&self, session: &TypingSession<'c, '_>) -> impl BasicType<'c> {
387            session.iw_context().f64_type()
388        }
389    }
390
391    #[rstest]
392    fn prelude_extension_types(llvm_ctx: TestContext) {
393        let iw_context = llvm_ctx.iw_context();
394        let type_converter = CodegenExtsBuilder::<Hugr>::default()
395            .add_prelude_extensions(TestPreludeCodegen)
396            .finish()
397            .type_converter;
398        let session = type_converter.session(iw_context);
399
400        assert_eq!(
401            iw_context.i32_type().as_basic_type_enum(),
402            session.llvm_type(&usize_t()).unwrap()
403        );
404        assert_eq!(
405            iw_context.f64_type().as_basic_type_enum(),
406            session.llvm_type(&qb_t()).unwrap()
407        );
408    }
409
410    #[rstest]
411    fn prelude_extension_types_in_test_context(mut llvm_ctx: TestContext) {
412        llvm_ctx.add_extensions(|x| x.add_prelude_extensions(TestPreludeCodegen));
413        let tc = llvm_ctx.get_typing_session();
414        assert_eq!(
415            llvm_ctx.iw_context().i32_type().as_basic_type_enum(),
416            tc.llvm_type(&usize_t()).unwrap()
417        );
418        assert_eq!(
419            llvm_ctx.iw_context().f64_type().as_basic_type_enum(),
420            tc.llvm_type(&qb_t()).unwrap()
421        );
422    }
423
424    #[rstest::fixture]
425    fn prelude_llvm_ctx(mut llvm_ctx: TestContext) -> TestContext {
426        llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions);
427        llvm_ctx
428    }
429
430    #[rstest]
431    fn prelude_const_usize(prelude_llvm_ctx: TestContext) {
432        let hugr = SimpleHugrConfig::new()
433            .with_outs(usize_t())
434            .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
435            .finish(|mut builder| {
436                let k = builder.add_load_value(ConstUsize::new(17));
437                builder.finish_with_outputs([k]).unwrap()
438            });
439        check_emission!(hugr, prelude_llvm_ctx);
440    }
441
442    #[rstest]
443    fn prelude_const_external_symbol(prelude_llvm_ctx: TestContext) {
444        let konst1 = ConstExternalSymbol::new("sym1", usize_t(), true);
445        let konst2 = ConstExternalSymbol::new(
446            "sym2",
447            HugrType::new_sum([
448                vec![usize_t(), HugrType::new_unit_sum(3)].into(),
449                type_row![],
450            ]),
451            false,
452        );
453
454        let hugr = SimpleHugrConfig::new()
455            .with_outs(vec![konst1.get_type(), konst2.get_type()])
456            .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
457            .finish(|mut builder| {
458                let k1 = builder.add_load_value(konst1);
459                let k2 = builder.add_load_value(konst2);
460                builder.finish_with_outputs([k1, k2]).unwrap()
461            });
462        check_emission!(hugr, prelude_llvm_ctx);
463    }
464
465    #[rstest]
466    fn prelude_make_tuple(prelude_llvm_ctx: TestContext) {
467        let hugr = SimpleHugrConfig::new()
468            .with_ins(vec![bool_t(), bool_t()])
469            .with_outs(Type::new_tuple(vec![bool_t(), bool_t()]))
470            .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
471            .finish(|mut builder| {
472                let in_wires = builder.input_wires();
473                let r = builder.make_tuple(in_wires).unwrap();
474                builder.finish_with_outputs([r]).unwrap()
475            });
476        check_emission!(hugr, prelude_llvm_ctx);
477    }
478
479    #[rstest]
480    fn prelude_unpack_tuple(prelude_llvm_ctx: TestContext) {
481        let hugr = SimpleHugrConfig::new()
482            .with_ins(Type::new_tuple(vec![bool_t(), bool_t()]))
483            .with_outs(vec![bool_t(), bool_t()])
484            .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
485            .finish(|mut builder| {
486                let unpack = builder
487                    .add_dataflow_op(
488                        UnpackTuple::new(vec![bool_t(), bool_t()].into()),
489                        builder.input_wires(),
490                    )
491                    .unwrap();
492                builder.finish_with_outputs(unpack.outputs()).unwrap()
493            });
494        check_emission!(hugr, prelude_llvm_ctx);
495    }
496
497    #[rstest]
498    fn prelude_panic(prelude_llvm_ctx: TestContext) {
499        let error_val = ConstError::new(42, "PANIC");
500        let type_arg_q: TypeArg = TypeArg::Type { ty: qb_t() };
501        let type_arg_2q: TypeArg = TypeArg::Sequence {
502            elems: vec![type_arg_q.clone(), type_arg_q],
503        };
504        let panic_op = PRELUDE
505            .instantiate_extension_op(&PANIC_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()])
506            .unwrap();
507
508        let hugr = SimpleHugrConfig::new()
509            .with_ins(vec![qb_t(), qb_t()])
510            .with_outs(vec![qb_t(), qb_t()])
511            .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
512            .finish(|mut builder| {
513                let [q0, q1] = builder.input_wires_arr();
514                let err = builder.add_load_value(error_val);
515                let [q0, q1] = builder
516                    .add_dataflow_op(panic_op, [err, q0, q1])
517                    .unwrap()
518                    .outputs_arr();
519                builder.finish_with_outputs([q0, q1]).unwrap()
520            });
521
522        check_emission!(hugr, prelude_llvm_ctx);
523    }
524
525    #[rstest]
526    fn prelude_print(prelude_llvm_ctx: TestContext) {
527        let greeting: ConstString = ConstString::new("Hello, world!".into());
528        let print_op = PRELUDE.instantiate_extension_op(&PRINT_OP_ID, []).unwrap();
529
530        let hugr = SimpleHugrConfig::new()
531            .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
532            .finish(|mut builder| {
533                let greeting_out = builder.add_load_value(greeting);
534                builder.add_dataflow_op(print_op, [greeting_out]).unwrap();
535                builder.finish_with_outputs([]).unwrap()
536            });
537
538        check_emission!(hugr, prelude_llvm_ctx);
539    }
540
541    #[rstest]
542    fn prelude_load_nat(prelude_llvm_ctx: TestContext) {
543        let hugr = SimpleHugrConfig::new()
544            .with_outs(usize_t())
545            .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
546            .finish(|mut builder| {
547                let v = builder
548                    .add_dataflow_op(LoadNat::new(TypeArg::BoundedNat { n: 42 }), vec![])
549                    .unwrap()
550                    .out_wire(0);
551                builder.finish_with_outputs([v]).unwrap()
552            });
553        check_emission!(hugr, prelude_llvm_ctx);
554    }
555
556    #[fixture]
557    fn barrier_hugr() -> Hugr {
558        SimpleHugrConfig::new()
559            .with_outs(vec![usize_t()])
560            .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
561            .finish(|mut builder| {
562                let i = builder.add_load_value(ConstUsize::new(42));
563                let [w1, _w2] = builder.add_barrier([i, i]).unwrap().outputs_arr();
564                builder.finish_with_outputs([w1]).unwrap()
565            })
566    }
567
568    #[rstest]
569    fn prelude_barrier(prelude_llvm_ctx: TestContext, barrier_hugr: Hugr) {
570        check_emission!(barrier_hugr, prelude_llvm_ctx);
571    }
572    #[rstest]
573    fn prelude_barrier_exec(mut exec_ctx: TestContext, barrier_hugr: Hugr) {
574        exec_ctx.add_extensions(|cem| add_prelude_extensions(cem, TestPreludeCodegen));
575        assert_eq!(exec_ctx.exec_hugr_u64(barrier_hugr, "main"), 42);
576    }
577}