1use anyhow::{anyhow, bail, ensure, Ok, Result};
2use hugr_core::extension::prelude::generic::LoadNat;
3use hugr_core::extension::prelude::{
4 self, error_type, generic, ConstError, ConstExternalSymbol, ConstString, ConstUsize, MakeTuple,
5 TupleOpDef, UnpackTuple,
6};
7use hugr_core::extension::prelude::{ERROR_TYPE_NAME, STRING_TYPE_NAME};
8use hugr_core::ops::ExtensionOp;
9use hugr_core::types::TypeArg;
10use hugr_core::Node;
11use hugr_core::{
12 extension::simple_op::MakeExtensionOp as _, ops::constant::CustomConst, types::SumType,
13 HugrView,
14};
15use inkwell::{
16 types::{BasicType, IntType, PointerType},
17 values::{BasicValue as _, BasicValueEnum, StructValue},
18 AddressSpace,
19};
20use itertools::Itertools;
21
22use crate::emit::EmitOpArgs;
23use crate::{
24 custom::{CodegenExtension, CodegenExtsBuilder},
25 emit::{
26 func::EmitFuncContext,
27 libc::{emit_libc_abort, emit_libc_printf},
28 },
29 sum::LLVMSumValue,
30 types::TypingSession,
31};
32
33pub trait PreludeCodegen: Clone {
40 fn usize_type<'c>(&self, session: &TypingSession<'c, '_>) -> IntType<'c> {
43 session.iw_context().i64_type()
44 }
45
46 fn qubit_type<'c>(&self, session: &TypingSession<'c, '_>) -> impl BasicType<'c> {
48 session.iw_context().i16_type()
49 }
50
51 fn error_type<'c>(&self, session: &TypingSession<'c, '_>) -> Result<impl BasicType<'c>> {
59 let ctx = session.iw_context();
60 Ok(session.iw_context().struct_type(
61 &[
62 ctx.i32_type().into(),
63 ctx.i8_type().ptr_type(AddressSpace::default()).into(),
64 ],
65 false,
66 ))
67 }
68
69 fn string_type<'c>(&self, session: &TypingSession<'c, '_>) -> Result<impl BasicType<'c>> {
76 Ok(session
77 .iw_context()
78 .i8_type()
79 .ptr_type(AddressSpace::default()))
80 }
81
82 fn emit_print<H: HugrView<Node = Node>>(
84 &self,
85 ctx: &mut EmitFuncContext<H>,
86 text: BasicValueEnum,
87 ) -> Result<()> {
88 let format_str = ctx
89 .builder()
90 .build_global_string_ptr("%s\n", "prelude.print_template")?
91 .as_basic_value_enum();
92 emit_libc_printf(ctx, &[format_str.into(), text.into()])
93 }
94
95 fn emit_const_error<'c, H: HugrView<Node = Node>>(
102 &self,
103 ctx: &mut EmitFuncContext<'c, '_, H>,
104 err: &ConstError,
105 ) -> Result<BasicValueEnum<'c>> {
106 let builder = ctx.builder();
107 let err_ty = ctx.llvm_type(&error_type())?.into_struct_type();
108 let signal = err_ty
109 .get_field_type_at_index(0)
110 .unwrap()
111 .into_int_type()
112 .const_int(err.signal as u64, false);
113 let message = builder
114 .build_global_string_ptr(&err.message, "")?
115 .as_basic_value_enum();
116 let err = err_ty.const_named_struct(&[signal.into(), message]);
117 Ok(err.into())
118 }
119
120 fn emit_panic<H: HugrView<Node = Node>>(
129 &self,
130 ctx: &mut EmitFuncContext<H>,
131 err: BasicValueEnum,
132 ) -> Result<()> {
133 let format_str = ctx
134 .builder()
135 .build_global_string_ptr(
136 "Program panicked (signal %i): %s\n",
137 "prelude.panic_template",
138 )?
139 .as_basic_value_enum();
140 let Some(err) = StructValue::try_from(err).ok() else {
141 bail!("emit_panic: Expected err value to be a struct type")
142 };
143 ensure!(err.get_type().count_fields() == 2);
144 let signal = ctx.builder().build_extract_value(err, 0, "")?;
145 ensure!(signal.get_type() == ctx.iw_context().i32_type().as_basic_type_enum());
146 let msg = ctx.builder().build_extract_value(err, 1, "")?;
147 ensure!(PointerType::try_from(msg.get_type()).is_ok());
148 emit_libc_printf(ctx, &[format_str.into(), signal.into(), msg.into()])?;
149 emit_libc_abort(ctx)
150 }
151
152 fn emit_const_string<'c, H: HugrView<Node = Node>>(
158 &self,
159 ctx: &mut EmitFuncContext<'c, '_, H>,
160 str: &ConstString,
161 ) -> Result<BasicValueEnum<'c>> {
162 let default_str_type = ctx
163 .iw_context()
164 .i8_type()
165 .ptr_type(AddressSpace::default())
166 .as_basic_type_enum();
167 let str_type = ctx.llvm_type(&str.get_type())?.as_basic_type_enum();
168 ensure!(str_type == default_str_type, "The default implementation of PreludeCodegen::string_type was overridden, but the default implementation of emit_const_string was not. String type is: {str_type}");
169 let s = ctx.builder().build_global_string_ptr(str.value(), "")?;
170 Ok(s.as_basic_value_enum())
171 }
172
173 fn emit_barrier<'c, H: HugrView<Node = Node>>(
174 &self,
175 ctx: &mut EmitFuncContext<'c, '_, H>,
176 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
177 ) -> Result<()> {
178 args.outputs.finish(ctx.builder(), args.inputs)
180 }
181}
182
183#[derive(Default, Clone)]
186pub struct DefaultPreludeCodegen;
187
188impl PreludeCodegen for DefaultPreludeCodegen {}
189
190#[derive(Clone, Debug, Default)]
191pub struct PreludeCodegenExtension<PCG>(PCG);
192
193impl<PCG: PreludeCodegen> PreludeCodegenExtension<PCG> {
194 pub fn new(pcg: PCG) -> Self {
195 Self(pcg)
196 }
197}
198
199impl<PCG: PreludeCodegen> From<PCG> for PreludeCodegenExtension<PCG> {
200 fn from(pcg: PCG) -> Self {
201 Self::new(pcg)
202 }
203}
204
205impl<PCG: PreludeCodegen> CodegenExtension for PreludeCodegenExtension<PCG> {
206 fn add_extension<'a, H: HugrView<Node = Node> + 'a>(
207 self,
208 builder: CodegenExtsBuilder<'a, H>,
209 ) -> CodegenExtsBuilder<'a, H>
210 where
211 Self: 'a,
212 {
213 add_prelude_extensions(builder, self.0)
214 }
215}
216
217impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
218 pub fn add_default_prelude_extensions(self) -> Self {
221 self.add_prelude_extensions(DefaultPreludeCodegen)
222 }
223
224 pub fn add_prelude_extensions(self, pcg: impl PreludeCodegen + 'a) -> Self {
227 self.add_extension(PreludeCodegenExtension::from(pcg))
228 }
229}
230
231pub fn add_prelude_extensions<'a, H: HugrView<Node = Node> + 'a>(
234 cem: CodegenExtsBuilder<'a, H>,
235 pcg: impl PreludeCodegen + 'a,
236) -> CodegenExtsBuilder<'a, H> {
237 cem.custom_type((prelude::PRELUDE_ID, "qubit".into()), {
238 let pcg = pcg.clone();
239 move |ts, _| Ok(pcg.qubit_type(&ts).as_basic_type_enum())
240 })
241 .custom_type((prelude::PRELUDE_ID, "usize".into()), {
242 let pcg = pcg.clone();
243 move |ts, _| Ok(pcg.usize_type(&ts).as_basic_type_enum())
244 })
245 .custom_type((prelude::PRELUDE_ID, ERROR_TYPE_NAME.clone()), {
246 let pcg = pcg.clone();
247 move |ts, _| Ok(pcg.error_type(&ts)?.as_basic_type_enum())
248 })
249 .custom_type((prelude::PRELUDE_ID, STRING_TYPE_NAME.clone()), {
250 let pcg = pcg.clone();
251 move |ts, _| Ok(pcg.string_type(&ts)?.as_basic_type_enum())
252 })
253 .custom_const::<ConstUsize>(|context, k| {
254 let ty: IntType = context
255 .llvm_type(&k.get_type())?
256 .try_into()
257 .map_err(|_| anyhow!("Failed to get ConstUsize as IntType"))?;
258 Ok(ty.const_int(k.value(), false).into())
259 })
260 .custom_const::<ConstExternalSymbol>(|context, k| {
261 let llvm_type = context.llvm_type(&k.get_type())?;
264 let global = context.get_global(&k.symbol, llvm_type, k.constant)?;
265 Ok(context
266 .builder()
267 .build_load(global.as_pointer_value(), &k.symbol)?)
268 })
269 .custom_const::<ConstString>({
270 let pcg = pcg.clone();
271 move |context, k| {
272 let err = pcg.emit_const_string(context, k)?;
273 ensure!(
274 err.get_type()
275 == pcg
276 .string_type(&context.typing_session())?
277 .as_basic_type_enum()
278 );
279 Ok(err)
280 }
281 })
282 .custom_const::<ConstError>({
283 let pcg = pcg.clone();
284 move |context, k| {
285 let err = pcg.emit_const_error(context, k)?;
286 ensure!(
287 err.get_type()
288 == pcg
289 .error_type(&context.typing_session())?
290 .as_basic_type_enum()
291 );
292 Ok(err)
293 }
294 })
295 .simple_extension_op::<TupleOpDef>(|context, args, op| match op {
296 TupleOpDef::UnpackTuple => {
297 let unpack_tuple = UnpackTuple::from_extension_op(args.node().as_ref())?;
298 let llvm_sum_type = context.llvm_sum_type(SumType::new([unpack_tuple.0]))?;
299 let llvm_sum_value = args
300 .inputs
301 .into_iter()
302 .exactly_one()
303 .map_err(|_| anyhow!("UnpackTuple does not have exactly one input"))
304 .and_then(|v| LLVMSumValue::try_new(v, llvm_sum_type))?;
305 let rs = llvm_sum_value.build_untag(context.builder(), 0)?;
306 args.outputs.finish(context.builder(), rs)
307 }
308 TupleOpDef::MakeTuple => {
309 let make_tuple = MakeTuple::from_extension_op(args.node().as_ref())?;
310 let llvm_sum_type = context.llvm_sum_type(SumType::new([make_tuple.0]))?;
311 let r = llvm_sum_type.build_tag(context.builder(), 0, args.inputs)?;
312 args.outputs.finish(context.builder(), [r.into()])
313 }
314 _ => Err(anyhow!("Unsupported TupleOpDef")),
315 })
316 .extension_op(prelude::PRELUDE_ID, prelude::PRINT_OP_ID, {
317 let pcg = pcg.clone();
318 move |context, args| {
319 let text = args.inputs[0];
320 pcg.emit_print(context, text)?;
321 args.outputs.finish(context.builder(), [])
322 }
323 })
324 .extension_op(prelude::PRELUDE_ID, prelude::PANIC_OP_ID, {
325 let pcg = pcg.clone();
326 move |context, args| {
327 let err = args.inputs[0];
328 ensure!(
329 err.get_type()
330 == pcg
331 .error_type(&context.typing_session())?
332 .as_basic_type_enum()
333 );
334 pcg.emit_panic(context, err)?;
335 let returns = args
336 .outputs
337 .get_types()
338 .map(|ty| ty.const_zero())
339 .collect_vec();
340 args.outputs.finish(context.builder(), returns)
341 }
342 })
343 .extension_op(prelude::PRELUDE_ID, generic::LOAD_NAT_OP_ID, {
344 let pcg = pcg.clone();
345 move |context, args| {
346 let load_nat = LoadNat::from_extension_op(args.node().as_ref())?;
347 let v = match load_nat.get_nat() {
348 TypeArg::BoundedNat { n } => pcg
349 .usize_type(&context.typing_session())
350 .const_int(n, false),
351 arg => bail!("Unexpected type arg for LoadNat: {}", arg),
352 };
353 args.outputs.finish(context.builder(), vec![v.into()])
354 }
355 })
356 .extension_op(prelude::PRELUDE_ID, prelude::BARRIER_OP_ID, {
357 let pcg = pcg.clone();
358 move |context, args| pcg.emit_barrier(context, args)
359 })
360}
361
362#[cfg(test)]
363mod test {
364 use hugr_core::builder::{Dataflow, DataflowSubContainer};
365 use hugr_core::extension::PRELUDE;
366 use hugr_core::types::{Type, TypeArg};
367 use hugr_core::{type_row, Hugr};
368 use prelude::{bool_t, qb_t, usize_t, PANIC_OP_ID, PRINT_OP_ID};
369 use rstest::{fixture, rstest};
370
371 use crate::check_emission;
372 use crate::custom::CodegenExtsBuilder;
373 use crate::emit::test::SimpleHugrConfig;
374 use crate::test::{exec_ctx, llvm_ctx, TestContext};
375 use crate::types::HugrType;
376
377 use super::*;
378
379 #[derive(Clone)]
380 struct TestPreludeCodegen;
381 impl PreludeCodegen for TestPreludeCodegen {
382 fn usize_type<'c>(&self, session: &TypingSession<'c, '_>) -> IntType<'c> {
383 session.iw_context().i32_type()
384 }
385
386 fn qubit_type<'c>(&self, session: &TypingSession<'c, '_>) -> impl BasicType<'c> {
387 session.iw_context().f64_type()
388 }
389 }
390
391 #[rstest]
392 fn prelude_extension_types(llvm_ctx: TestContext) {
393 let iw_context = llvm_ctx.iw_context();
394 let type_converter = CodegenExtsBuilder::<Hugr>::default()
395 .add_prelude_extensions(TestPreludeCodegen)
396 .finish()
397 .type_converter;
398 let session = type_converter.session(iw_context);
399
400 assert_eq!(
401 iw_context.i32_type().as_basic_type_enum(),
402 session.llvm_type(&usize_t()).unwrap()
403 );
404 assert_eq!(
405 iw_context.f64_type().as_basic_type_enum(),
406 session.llvm_type(&qb_t()).unwrap()
407 );
408 }
409
410 #[rstest]
411 fn prelude_extension_types_in_test_context(mut llvm_ctx: TestContext) {
412 llvm_ctx.add_extensions(|x| x.add_prelude_extensions(TestPreludeCodegen));
413 let tc = llvm_ctx.get_typing_session();
414 assert_eq!(
415 llvm_ctx.iw_context().i32_type().as_basic_type_enum(),
416 tc.llvm_type(&usize_t()).unwrap()
417 );
418 assert_eq!(
419 llvm_ctx.iw_context().f64_type().as_basic_type_enum(),
420 tc.llvm_type(&qb_t()).unwrap()
421 );
422 }
423
424 #[rstest::fixture]
425 fn prelude_llvm_ctx(mut llvm_ctx: TestContext) -> TestContext {
426 llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_prelude_extensions);
427 llvm_ctx
428 }
429
430 #[rstest]
431 fn prelude_const_usize(prelude_llvm_ctx: TestContext) {
432 let hugr = SimpleHugrConfig::new()
433 .with_outs(usize_t())
434 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
435 .finish(|mut builder| {
436 let k = builder.add_load_value(ConstUsize::new(17));
437 builder.finish_with_outputs([k]).unwrap()
438 });
439 check_emission!(hugr, prelude_llvm_ctx);
440 }
441
442 #[rstest]
443 fn prelude_const_external_symbol(prelude_llvm_ctx: TestContext) {
444 let konst1 = ConstExternalSymbol::new("sym1", usize_t(), true);
445 let konst2 = ConstExternalSymbol::new(
446 "sym2",
447 HugrType::new_sum([
448 vec![usize_t(), HugrType::new_unit_sum(3)].into(),
449 type_row![],
450 ]),
451 false,
452 );
453
454 let hugr = SimpleHugrConfig::new()
455 .with_outs(vec![konst1.get_type(), konst2.get_type()])
456 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
457 .finish(|mut builder| {
458 let k1 = builder.add_load_value(konst1);
459 let k2 = builder.add_load_value(konst2);
460 builder.finish_with_outputs([k1, k2]).unwrap()
461 });
462 check_emission!(hugr, prelude_llvm_ctx);
463 }
464
465 #[rstest]
466 fn prelude_make_tuple(prelude_llvm_ctx: TestContext) {
467 let hugr = SimpleHugrConfig::new()
468 .with_ins(vec![bool_t(), bool_t()])
469 .with_outs(Type::new_tuple(vec![bool_t(), bool_t()]))
470 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
471 .finish(|mut builder| {
472 let in_wires = builder.input_wires();
473 let r = builder.make_tuple(in_wires).unwrap();
474 builder.finish_with_outputs([r]).unwrap()
475 });
476 check_emission!(hugr, prelude_llvm_ctx);
477 }
478
479 #[rstest]
480 fn prelude_unpack_tuple(prelude_llvm_ctx: TestContext) {
481 let hugr = SimpleHugrConfig::new()
482 .with_ins(Type::new_tuple(vec![bool_t(), bool_t()]))
483 .with_outs(vec![bool_t(), bool_t()])
484 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
485 .finish(|mut builder| {
486 let unpack = builder
487 .add_dataflow_op(
488 UnpackTuple::new(vec![bool_t(), bool_t()].into()),
489 builder.input_wires(),
490 )
491 .unwrap();
492 builder.finish_with_outputs(unpack.outputs()).unwrap()
493 });
494 check_emission!(hugr, prelude_llvm_ctx);
495 }
496
497 #[rstest]
498 fn prelude_panic(prelude_llvm_ctx: TestContext) {
499 let error_val = ConstError::new(42, "PANIC");
500 let type_arg_q: TypeArg = TypeArg::Type { ty: qb_t() };
501 let type_arg_2q: TypeArg = TypeArg::Sequence {
502 elems: vec![type_arg_q.clone(), type_arg_q],
503 };
504 let panic_op = PRELUDE
505 .instantiate_extension_op(&PANIC_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()])
506 .unwrap();
507
508 let hugr = SimpleHugrConfig::new()
509 .with_ins(vec![qb_t(), qb_t()])
510 .with_outs(vec![qb_t(), qb_t()])
511 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
512 .finish(|mut builder| {
513 let [q0, q1] = builder.input_wires_arr();
514 let err = builder.add_load_value(error_val);
515 let [q0, q1] = builder
516 .add_dataflow_op(panic_op, [err, q0, q1])
517 .unwrap()
518 .outputs_arr();
519 builder.finish_with_outputs([q0, q1]).unwrap()
520 });
521
522 check_emission!(hugr, prelude_llvm_ctx);
523 }
524
525 #[rstest]
526 fn prelude_print(prelude_llvm_ctx: TestContext) {
527 let greeting: ConstString = ConstString::new("Hello, world!".into());
528 let print_op = PRELUDE.instantiate_extension_op(&PRINT_OP_ID, []).unwrap();
529
530 let hugr = SimpleHugrConfig::new()
531 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
532 .finish(|mut builder| {
533 let greeting_out = builder.add_load_value(greeting);
534 builder.add_dataflow_op(print_op, [greeting_out]).unwrap();
535 builder.finish_with_outputs([]).unwrap()
536 });
537
538 check_emission!(hugr, prelude_llvm_ctx);
539 }
540
541 #[rstest]
542 fn prelude_load_nat(prelude_llvm_ctx: TestContext) {
543 let hugr = SimpleHugrConfig::new()
544 .with_outs(usize_t())
545 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
546 .finish(|mut builder| {
547 let v = builder
548 .add_dataflow_op(LoadNat::new(TypeArg::BoundedNat { n: 42 }), vec![])
549 .unwrap()
550 .out_wire(0);
551 builder.finish_with_outputs([v]).unwrap()
552 });
553 check_emission!(hugr, prelude_llvm_ctx);
554 }
555
556 #[fixture]
557 fn barrier_hugr() -> Hugr {
558 SimpleHugrConfig::new()
559 .with_outs(vec![usize_t()])
560 .with_extensions(prelude::PRELUDE_REGISTRY.to_owned())
561 .finish(|mut builder| {
562 let i = builder.add_load_value(ConstUsize::new(42));
563 let [w1, _w2] = builder.add_barrier([i, i]).unwrap().outputs_arr();
564 builder.finish_with_outputs([w1]).unwrap()
565 })
566 }
567
568 #[rstest]
569 fn prelude_barrier(prelude_llvm_ctx: TestContext, barrier_hugr: Hugr) {
570 check_emission!(barrier_hugr, prelude_llvm_ctx);
571 }
572 #[rstest]
573 fn prelude_barrier_exec(mut exec_ctx: TestContext, barrier_hugr: Hugr) {
574 exec_ctx.add_extensions(|cem| add_prelude_extensions(cem, TestPreludeCodegen));
575 assert_eq!(exec_ctx.exec_hugr_u64(barrier_hugr, "main"), 42);
576 }
577}