1use anyhow::{anyhow, bail, ensure, Ok, Result};
2use hugr_core::extension::prelude::generic::LoadNat;
3use hugr_core::extension::prelude::{
4 self, error_type, generic, ConstError, ConstExternalSymbol, ConstString, ConstUsize, MakeTuple,
5 TupleOpDef, UnpackTuple,
6};
7use hugr_core::extension::prelude::{ERROR_TYPE_NAME, STRING_TYPE_NAME};
8use hugr_core::ops::ExtensionOp;
9use hugr_core::types::TypeArg;
10use hugr_core::Node;
11use hugr_core::{
12 extension::simple_op::MakeExtensionOp as _, ops::constant::CustomConst, types::SumType,
13 HugrView,
14};
15use inkwell::{
16 types::{BasicType, IntType, PointerType},
17 values::{BasicValue as _, BasicValueEnum, StructValue},
18 AddressSpace,
19};
20use itertools::Itertools;
21
22use crate::emit::EmitOpArgs;
23use crate::{
24 custom::{CodegenExtension, CodegenExtsBuilder},
25 emit::{
26 func::EmitFuncContext,
27 libc::{emit_libc_abort, emit_libc_printf},
28 },
29 sum::LLVMSumValue,
30 types::TypingSession,
31};
32
33pub trait PreludeCodegen: Clone {
40 fn usize_type<'c>(&self, session: &TypingSession<'c, '_>) -> IntType<'c> {
43 session.iw_context().i64_type()
44 }
45
46 fn qubit_type<'c>(&self, session: &TypingSession<'c, '_>) -> impl BasicType<'c> {
48 session.iw_context().i16_type()
49 }
50
51 fn error_type<'c>(&self, session: &TypingSession<'c, '_>) -> Result<impl BasicType<'c>> {
59 let ctx = session.iw_context();
60 Ok(session.iw_context().struct_type(
61 &[
62 ctx.i32_type().into(),
63 ctx.i8_type().ptr_type(AddressSpace::default()).into(),
64 ],
65 false,
66 ))
67 }
68
69 fn string_type<'c>(&self, session: &TypingSession<'c, '_>) -> Result<impl BasicType<'c>> {
76 Ok(session
77 .iw_context()
78 .i8_type()
79 .ptr_type(AddressSpace::default()))
80 }
81
82 fn emit_print<H: HugrView<Node = Node>>(
84 &self,
85 ctx: &mut EmitFuncContext<H>,
86 text: BasicValueEnum,
87 ) -> Result<()> {
88 let format_str = ctx
89 .builder()
90 .build_global_string_ptr("%s\n", "prelude.print_template")?
91 .as_basic_value_enum();
92 emit_libc_printf(ctx, &[format_str.into(), text.into()])
93 }
94
95 fn emit_const_error<'c, H: HugrView<Node = Node>>(
102 &self,
103 ctx: &mut EmitFuncContext<'c, '_, H>,
104 err: &ConstError,
105 ) -> Result<BasicValueEnum<'c>> {
106 let builder = ctx.builder();
107 let err_ty = ctx.llvm_type(&error_type())?.into_struct_type();
108 let signal = err_ty
109 .get_field_type_at_index(0)
110 .unwrap()
111 .into_int_type()
112 .const_int(err.signal as u64, false);
113 let message = builder
114 .build_global_string_ptr(&err.message, "")?
115 .as_basic_value_enum();
116 let err = err_ty.const_named_struct(&[signal.into(), message]);
117 Ok(err.into())
118 }
119
120 fn emit_panic<H: HugrView<Node = Node>>(
129 &self,
130 ctx: &mut EmitFuncContext<H>,
131 err: BasicValueEnum,
132 ) -> Result<()> {
133 let format_str = ctx
134 .builder()
135 .build_global_string_ptr(
136 "Program panicked (signal %i): %s\n",
137 "prelude.panic_template",
138 )?
139 .as_basic_value_enum();
140 let Some(err) = StructValue::try_from(err).ok() else {
141 bail!("emit_panic: Expected err value to be a struct type")
142 };
143 ensure!(err.get_type().count_fields() == 2);
144 let signal = ctx.builder().build_extract_value(err, 0, "")?;
145 ensure!(signal.get_type() == ctx.iw_context().i32_type().as_basic_type_enum());
146 let msg = ctx.builder().build_extract_value(err, 1, "")?;
147 ensure!(PointerType::try_from(msg.get_type()).is_ok());
148 emit_libc_printf(ctx, &[format_str.into(), signal.into(), msg.into()])?;
149 emit_libc_abort(ctx)
150 }
151
152 fn emit_exit<H: HugrView<Node = Node>>(
162 &self,
163 ctx: &mut EmitFuncContext<H>,
164 err: BasicValueEnum,
165 ) -> Result<()> {
166 self.emit_panic(ctx, err)
167 }
168
169 fn emit_const_string<'c, H: HugrView<Node = Node>>(
175 &self,
176 ctx: &mut EmitFuncContext<'c, '_, H>,
177 str: &ConstString,
178 ) -> Result<BasicValueEnum<'c>> {
179 let default_str_type = ctx
180 .iw_context()
181 .i8_type()
182 .ptr_type(AddressSpace::default())
183 .as_basic_type_enum();
184 let str_type = ctx.llvm_type(&str.get_type())?.as_basic_type_enum();
185 ensure!(str_type == default_str_type, "The default implementation of PreludeCodegen::string_type was overridden, but the default implementation of emit_const_string was not. String type is: {str_type}");
186 let s = ctx.builder().build_global_string_ptr(str.value(), "")?;
187 Ok(s.as_basic_value_enum())
188 }
189
190 fn emit_barrier<'c, H: HugrView<Node = Node>>(
191 &self,
192 ctx: &mut EmitFuncContext<'c, '_, H>,
193 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
194 ) -> Result<()> {
195 args.outputs.finish(ctx.builder(), args.inputs)
197 }
198}
199
200#[derive(Default, Clone)]
203pub struct DefaultPreludeCodegen;
204
205impl PreludeCodegen for DefaultPreludeCodegen {}
206
207#[derive(Clone, Debug, Default)]
208pub struct PreludeCodegenExtension<PCG>(PCG);
209
210impl<PCG: PreludeCodegen> PreludeCodegenExtension<PCG> {
211 pub fn new(pcg: PCG) -> Self {
212 Self(pcg)
213 }
214}
215
216impl<PCG: PreludeCodegen> From<PCG> for PreludeCodegenExtension<PCG> {
217 fn from(pcg: PCG) -> Self {
218 Self::new(pcg)
219 }
220}
221
222impl<PCG: PreludeCodegen> CodegenExtension for PreludeCodegenExtension<PCG> {
223 fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
224 self,
225 builder: CodegenExtsBuilder<'a, H>,
226 ) -> CodegenExtsBuilder<'a, H>
227 where
228 Self: 'a,
229 {
230 add_prelude_extensions(builder, self.0)
231 }
232}
233
234impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
235 pub fn add_default_prelude_extensions(self) -> Self {
238 self.add_prelude_extensions(DefaultPreludeCodegen)
239 }
240
241 pub fn add_prelude_extensions(self, pcg: impl PreludeCodegen + 'a) -> Self {
244 self.add_extension(PreludeCodegenExtension::from(pcg))
245 }
246}
247
248pub fn add_prelude_extensions<'a, H: HugrView<Node = Node> + 'a>(
251 cem: CodegenExtsBuilder<'a, H>,
252 pcg: impl PreludeCodegen + 'a,
253) -> CodegenExtsBuilder<'a, H> {
254 cem.custom_type((prelude::PRELUDE_ID, "qubit".into()), {
255 let pcg = pcg.clone();
256 move |ts, _| Ok(pcg.qubit_type(&ts).as_basic_type_enum())
257 })
258 .custom_type((prelude::PRELUDE_ID, "usize".into()), {
259 let pcg = pcg.clone();
260 move |ts, _| Ok(pcg.usize_type(&ts).as_basic_type_enum())
261 })
262 .custom_type((prelude::PRELUDE_ID, ERROR_TYPE_NAME.clone()), {
263 let pcg = pcg.clone();
264 move |ts, _| Ok(pcg.error_type(&ts)?.as_basic_type_enum())
265 })
266 .custom_type((prelude::PRELUDE_ID, STRING_TYPE_NAME.clone()), {
267 let pcg = pcg.clone();
268 move |ts, _| Ok(pcg.string_type(&ts)?.as_basic_type_enum())
269 })
270 .custom_const::<ConstUsize>(|context, k| {
271 let ty: IntType = context
272 .llvm_type(&k.get_type())?
273 .try_into()
274 .map_err(|_| anyhow!("Failed to get ConstUsize as IntType"))?;
275 Ok(ty.const_int(k.value(), false).into())
276 })
277 .custom_const::<ConstExternalSymbol>(|context, k| {
278 let llvm_type = context.llvm_type(&k.get_type())?;
281 let global = context.get_global(&k.symbol, llvm_type, k.constant)?;
282 Ok(context
283 .builder()
284 .build_load(global.as_pointer_value(), &k.symbol)?)
285 })
286 .custom_const::<ConstString>({
287 let pcg = pcg.clone();
288 move |context, k| {
289 let err = pcg.emit_const_string(context, k)?;
290 ensure!(
291 err.get_type()
292 == pcg
293 .string_type(&context.typing_session())?
294 .as_basic_type_enum()
295 );
296 Ok(err)
297 }
298 })
299 .custom_const::<ConstError>({
300 let pcg = pcg.clone();
301 move |context, k| {
302 let err = pcg.emit_const_error(context, k)?;
303 ensure!(
304 err.get_type()
305 == pcg
306 .error_type(&context.typing_session())?
307 .as_basic_type_enum()
308 );
309 Ok(err)
310 }
311 })
312 .simple_extension_op::<TupleOpDef>(|context, args, op| match op {
313 TupleOpDef::UnpackTuple => {
314 let unpack_tuple = UnpackTuple::from_extension_op(args.node().as_ref())?;
315 let llvm_sum_type = context.llvm_sum_type(SumType::new([unpack_tuple.0]))?;
316 let llvm_sum_value = args
317 .inputs
318 .into_iter()
319 .exactly_one()
320 .map_err(|_| anyhow!("UnpackTuple does not have exactly one input"))
321 .and_then(|v| LLVMSumValue::try_new(v, llvm_sum_type))?;
322 let rs = llvm_sum_value.build_untag(context.builder(), 0)?;
323 args.outputs.finish(context.builder(), rs)
324 }
325 TupleOpDef::MakeTuple => {
326 let make_tuple = MakeTuple::from_extension_op(args.node().as_ref())?;
327 let llvm_sum_type = context.llvm_sum_type(SumType::new([make_tuple.0]))?;
328 let r = llvm_sum_type.build_tag(context.builder(), 0, args.inputs)?;
329 args.outputs.finish(context.builder(), [r.into()])
330 }
331 _ => Err(anyhow!("Unsupported TupleOpDef")),
332 })
333 .extension_op(prelude::PRELUDE_ID, prelude::PRINT_OP_ID, {
334 let pcg = pcg.clone();
335 move |context, args| {
336 let text = args.inputs[0];
337 pcg.emit_print(context, text)?;
338 args.outputs.finish(context.builder(), [])
339 }
340 })
341 .extension_op(prelude::PRELUDE_ID, prelude::PANIC_OP_ID, {
342 let pcg = pcg.clone();
343 move |context, args| {
344 let err = args.inputs[0];
345 ensure!(
346 err.get_type()
347 == pcg
348 .error_type(&context.typing_session())?
349 .as_basic_type_enum()
350 );
351 pcg.emit_panic(context, err)?;
352 let returns = args
353 .outputs
354 .get_types()
355 .map(|ty| ty.const_zero())
356 .collect_vec();
357 args.outputs.finish(context.builder(), returns)
358 }
359 })
360 .extension_op(prelude::PRELUDE_ID, prelude::EXIT_OP_ID, {
361 let pcg = pcg.clone();
363 move |context, args| {
364 let err = args.inputs[0];
365 ensure!(
366 err.get_type()
367 == pcg
368 .error_type(&context.typing_session())?
369 .as_basic_type_enum()
370 );
371 pcg.emit_exit(context, err)?;
372 let returns = args
373 .outputs
374 .get_types()
375 .map(|ty| ty.const_zero())
376 .collect_vec();
377 args.outputs.finish(context.builder(), returns)
378 }
379 })
380 .extension_op(prelude::PRELUDE_ID, generic::LOAD_NAT_OP_ID, {
381 let pcg = pcg.clone();
382 move |context, args| {
383 let load_nat = LoadNat::from_extension_op(args.node().as_ref())?;
384 let v = match load_nat.get_nat() {
385 TypeArg::BoundedNat { n } => pcg
386 .usize_type(&context.typing_session())
387 .const_int(n, false),
388 arg => bail!("Unexpected type arg for LoadNat: {}", arg),
389 };
390 args.outputs.finish(context.builder(), vec![v.into()])
391 }
392 })
393 .extension_op(prelude::PRELUDE_ID, prelude::BARRIER_OP_ID, {
394 let pcg = pcg.clone();
395 move |context, args| pcg.emit_barrier(context, args)
396 })
397}
398
399#[cfg(test)]
400mod test {
401 use hugr_core::builder::{Dataflow, DataflowSubContainer};
402 use hugr_core::extension::prelude::EXIT_OP_ID;
403 use hugr_core::extension::PRELUDE;
404 use hugr_core::types::{Type, TypeArg};
405 use hugr_core::{type_row, Hugr};
406 use prelude::{bool_t, qb_t, usize_t, PANIC_OP_ID, PRINT_OP_ID};
407 use rstest::{fixture, rstest};
408
409 use crate::check_emission;
410 use crate::custom::CodegenExtsBuilder;
411 use crate::emit::test::SimpleHugrConfig;
412 use crate::test::{exec_ctx, llvm_ctx, TestContext};
413 use crate::types::HugrType;
414
415 use super::*;
416
417 #[derive(Clone)]
418 struct TestPreludeCodegen;
419 impl PreludeCodegen for TestPreludeCodegen {
420 fn usize_type<'c>(&self, session: &TypingSession<'c, '_>) -> IntType<'c> {
421 session.iw_context().i32_type()
422 }
423
424 fn qubit_type<'c>(&self, session: &TypingSession<'c, '_>) -> impl BasicType<'c> {
425 session.iw_context().f64_type()
426 }
427 }
428
429 #[rstest]
430 fn prelude_extension_types(llvm_ctx: TestContext) {
431 let iw_context = llvm_ctx.iw_context();
432 let type_converter = CodegenExtsBuilder::<Hugr>::default()
433 .add_prelude_extensions(TestPreludeCodegen)
434 .finish()
435 .type_converter;
436 let session = type_converter.session(iw_context);
437
438 assert_eq!(
439 iw_context.i32_type().as_basic_type_enum(),
440 session.llvm_type(&usize_t()).unwrap()
441 );
442 assert_eq!(
443 iw_context.f64_type().as_basic_type_enum(),
444 session.llvm_type(&qb_t()).unwrap()
445 );
446 }
447
448 #[rstest]
449 fn prelude_extension_types_in_test_context(mut llvm_ctx: TestContext) {
450 llvm_ctx.add_extensions(|x| x.add_prelude_extensions(TestPreludeCodegen));
451 let tc = llvm_ctx.get_typing_session();
452 assert_eq!(
453 llvm_ctx.iw_context().i32_type().as_basic_type_enum(),
454 tc.llvm_type(&usize_t()).unwrap()
455 );
456 assert_eq!(
457 llvm_ctx.iw_context().f64_type().as_basic_type_enum(),
458 tc.llvm_type(&qb_t()).unwrap()
459 );
460 }
461
462 #[rstest::fixture]
463 fn prelude_llvm_ctx(mut llvm_ctx: TestContext) -> TestContext {
464 llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions);
465 llvm_ctx
466 }
467
468 #[rstest]
469 fn prelude_const_usize(prelude_llvm_ctx: TestContext) {
470 let hugr = SimpleHugrConfig::new()
471 .with_outs(usize_t())
472 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
473 .finish(|mut builder| {
474 let k = builder.add_load_value(ConstUsize::new(17));
475 builder.finish_with_outputs([k]).unwrap()
476 });
477 check_emission!(hugr, prelude_llvm_ctx);
478 }
479
480 #[rstest]
481 fn prelude_const_external_symbol(prelude_llvm_ctx: TestContext) {
482 let konst1 = ConstExternalSymbol::new("sym1", usize_t(), true);
483 let konst2 = ConstExternalSymbol::new(
484 "sym2",
485 HugrType::new_sum([
486 vec![usize_t(), HugrType::new_unit_sum(3)].into(),
487 type_row![],
488 ]),
489 false,
490 );
491
492 let hugr = SimpleHugrConfig::new()
493 .with_outs(vec![konst1.get_type(), konst2.get_type()])
494 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
495 .finish(|mut builder| {
496 let k1 = builder.add_load_value(konst1);
497 let k2 = builder.add_load_value(konst2);
498 builder.finish_with_outputs([k1, k2]).unwrap()
499 });
500 check_emission!(hugr, prelude_llvm_ctx);
501 }
502
503 #[rstest]
504 fn prelude_make_tuple(prelude_llvm_ctx: TestContext) {
505 let hugr = SimpleHugrConfig::new()
506 .with_ins(vec![bool_t(), bool_t()])
507 .with_outs(Type::new_tuple(vec![bool_t(), bool_t()]))
508 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
509 .finish(|mut builder| {
510 let in_wires = builder.input_wires();
511 let r = builder.make_tuple(in_wires).unwrap();
512 builder.finish_with_outputs([r]).unwrap()
513 });
514 check_emission!(hugr, prelude_llvm_ctx);
515 }
516
517 #[rstest]
518 fn prelude_unpack_tuple(prelude_llvm_ctx: TestContext) {
519 let hugr = SimpleHugrConfig::new()
520 .with_ins(Type::new_tuple(vec![bool_t(), bool_t()]))
521 .with_outs(vec![bool_t(), bool_t()])
522 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
523 .finish(|mut builder| {
524 let unpack = builder
525 .add_dataflow_op(
526 UnpackTuple::new(vec![bool_t(), bool_t()].into()),
527 builder.input_wires(),
528 )
529 .unwrap();
530 builder.finish_with_outputs(unpack.outputs()).unwrap()
531 });
532 check_emission!(hugr, prelude_llvm_ctx);
533 }
534
535 #[rstest]
536 fn prelude_panic(prelude_llvm_ctx: TestContext) {
537 let error_val = ConstError::new(42, "PANIC");
538 let type_arg_q: TypeArg = TypeArg::Type { ty: qb_t() };
539 let type_arg_2q: TypeArg = TypeArg::Sequence {
540 elems: vec![type_arg_q.clone(), type_arg_q],
541 };
542 let panic_op = PRELUDE
543 .instantiate_extension_op(&PANIC_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()])
544 .unwrap();
545
546 let hugr = SimpleHugrConfig::new()
547 .with_ins(vec![qb_t(), qb_t()])
548 .with_outs(vec![qb_t(), qb_t()])
549 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
550 .finish(|mut builder| {
551 let [q0, q1] = builder.input_wires_arr();
552 let err = builder.add_load_value(error_val);
553 let [q0, q1] = builder
554 .add_dataflow_op(panic_op, [err, q0, q1])
555 .unwrap()
556 .outputs_arr();
557 builder.finish_with_outputs([q0, q1]).unwrap()
558 });
559
560 check_emission!(hugr, prelude_llvm_ctx);
561 }
562
563 #[rstest]
564 fn prelude_exit(prelude_llvm_ctx: TestContext) {
565 let error_val = ConstError::new(42, "EXIT");
566 let type_arg_q: TypeArg = TypeArg::Type { ty: qb_t() };
567 let type_arg_2q: TypeArg = TypeArg::Sequence {
568 elems: vec![type_arg_q.clone(), type_arg_q],
569 };
570 let exit_op = PRELUDE
571 .instantiate_extension_op(&EXIT_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()])
572 .unwrap();
573
574 let hugr = SimpleHugrConfig::new()
575 .with_ins(vec![qb_t(), qb_t()])
576 .with_outs(vec![qb_t(), qb_t()])
577 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
578 .finish(|mut builder| {
579 let [q0, q1] = builder.input_wires_arr();
580 let err = builder.add_load_value(error_val);
581 let [q0, q1] = builder
582 .add_dataflow_op(exit_op, [err, q0, q1])
583 .unwrap()
584 .outputs_arr();
585 builder.finish_with_outputs([q0, q1]).unwrap()
586 });
587
588 check_emission!(hugr, prelude_llvm_ctx);
589 }
590
591 #[rstest]
592 fn prelude_print(prelude_llvm_ctx: TestContext) {
593 let greeting: ConstString = ConstString::new("Hello, world!".into());
594 let print_op = PRELUDE.instantiate_extension_op(&PRINT_OP_ID, []).unwrap();
595
596 let hugr = SimpleHugrConfig::new()
597 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
598 .finish(|mut builder| {
599 let greeting_out = builder.add_load_value(greeting);
600 builder.add_dataflow_op(print_op, [greeting_out]).unwrap();
601 builder.finish_with_outputs([]).unwrap()
602 });
603
604 check_emission!(hugr, prelude_llvm_ctx);
605 }
606
607 #[rstest]
608 fn prelude_load_nat(prelude_llvm_ctx: TestContext) {
609 let hugr = SimpleHugrConfig::new()
610 .with_outs(usize_t())
611 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
612 .finish(|mut builder| {
613 let v = builder
614 .add_dataflow_op(LoadNat::new(TypeArg::BoundedNat { n: 42 }), vec![])
615 .unwrap()
616 .out_wire(0);
617 builder.finish_with_outputs([v]).unwrap()
618 });
619 check_emission!(hugr, prelude_llvm_ctx);
620 }
621
622 #[fixture]
623 fn barrier_hugr() -> Hugr {
624 SimpleHugrConfig::new()
625 .with_outs(vec![usize_t()])
626 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
627 .finish(|mut builder| {
628 let i = builder.add_load_value(ConstUsize::new(42));
629 let [w1, _w2] = builder.add_barrier([i, i]).unwrap().outputs_arr();
630 builder.finish_with_outputs([w1]).unwrap()
631 })
632 }
633
634 #[rstest]
635 fn prelude_barrier(prelude_llvm_ctx: TestContext, barrier_hugr: Hugr) {
636 check_emission!(barrier_hugr, prelude_llvm_ctx);
637 }
638 #[rstest]
639 fn prelude_barrier_exec(mut exec_ctx: TestContext, barrier_hugr: Hugr) {
640 exec_ctx.add_extensions(|cem| add_prelude_extensions(cem, TestPreludeCodegen));
641 assert_eq!(exec_ctx.exec_hugr_u64(barrier_hugr, "main"), 42);
642 }
643}