1use hugr_core::{
2 extension::simple_op::MakeExtensionOp,
3 ops::{ExtensionOp, NamedOp, Value},
4 std_extensions::logic::{self, LogicOp},
5 types::SumType,
6 HugrView,
7};
8use inkwell::IntPredicate;
9
10use crate::{
11 custom::CodegenExtsBuilder,
12 emit::{emit_value, func::EmitFuncContext, EmitOpArgs},
13 sum::LLVMSumValue,
14};
15
16use anyhow::{anyhow, Result};
17
18fn emit_logic_op<'c, H: HugrView>(
19 context: &mut EmitFuncContext<'c, '_, H>,
20 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
21) -> Result<()> {
22 let lot = LogicOp::from_optype(&args.node().generalise()).ok_or(anyhow!(
23 "LogicOpEmitter: from_optype_failed: {:?}",
24 args.node().as_ref()
25 ))?;
26 let builder = context.builder();
27 let mut inputs = vec![];
29 for inp in args.inputs {
30 let bool_ty = context.llvm_sum_type(SumType::new_unary(2))?;
31 let bool_val = LLVMSumValue::try_new(inp, bool_ty)?;
32 inputs.push(bool_val.build_get_tag(builder)?);
33 }
34 let res = match lot {
35 LogicOp::And => builder.build_and(inputs[0], inputs[1], "")?,
36 LogicOp::Or => builder.build_or(inputs[0], inputs[1], "")?,
37 LogicOp::Xor => builder.build_xor(inputs[0], inputs[1], "")?,
38 LogicOp::Eq => builder.build_int_compare(IntPredicate::EQ, inputs[0], inputs[1], "")?,
39 LogicOp::Not => builder.build_not(inputs[0], "")?,
40 op => {
41 return Err(anyhow!("LogicOpEmitter: Unknown op: {op:?}"));
42 }
43 };
44 let res = builder.build_int_cast(res, context.iw_context().bool_type(), "")?;
46 let true_val = emit_value(context, &Value::true_val())?;
47 let false_val = emit_value(context, &Value::false_val())?;
48 let res = context
49 .builder()
50 .build_select(res, true_val, false_val, "")?;
51 args.outputs.finish(context.builder(), vec![res])
52}
53
54pub fn add_logic_extensions<'a, H: HugrView + 'a>(
57 cem: CodegenExtsBuilder<'a, H>,
58) -> CodegenExtsBuilder<'a, H> {
59 cem.extension_op(logic::EXTENSION_ID, LogicOp::Eq.name(), emit_logic_op)
60 .extension_op(logic::EXTENSION_ID, LogicOp::And.name(), emit_logic_op)
61 .extension_op(logic::EXTENSION_ID, LogicOp::Or.name(), emit_logic_op)
62 .extension_op(logic::EXTENSION_ID, LogicOp::Not.name(), emit_logic_op)
63 .extension_op(logic::EXTENSION_ID, LogicOp::Xor.name(), emit_logic_op) }
65
66impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> {
67 pub fn add_logic_extensions(self) -> Self {
70 add_logic_extensions(self)
71 }
72}
73
74#[cfg(test)]
75mod test {
76 use hugr_core::{
77 builder::{Dataflow, DataflowSubContainer},
78 extension::{prelude::bool_t, ExtensionRegistry},
79 std_extensions::logic::{self, LogicOp},
80 Hugr,
81 };
82 use rstest::rstest;
83
84 use crate::{
85 check_emission,
86 emit::test::SimpleHugrConfig,
87 extension::logic::add_logic_extensions,
88 test::{llvm_ctx, TestContext},
89 };
90
91 fn test_logic_op(op: LogicOp, arity: usize) -> Hugr {
92 SimpleHugrConfig::new()
93 .with_ins(vec![bool_t(); arity])
94 .with_outs(vec![bool_t()])
95 .with_extensions(ExtensionRegistry::new(vec![logic::EXTENSION.to_owned()]))
96 .finish(|mut builder| {
97 let outputs = builder
98 .add_dataflow_op(op, builder.input_wires())
99 .unwrap()
100 .outputs();
101 builder.finish_with_outputs(outputs).unwrap()
102 })
103 }
104
105 #[rstest]
106 fn and(mut llvm_ctx: TestContext) {
107 llvm_ctx.add_extensions(add_logic_extensions);
108 let hugr = test_logic_op(LogicOp::And, 2);
109 check_emission!(hugr, llvm_ctx);
110 }
111
112 #[rstest]
113 fn or(mut llvm_ctx: TestContext) {
114 llvm_ctx.add_extensions(add_logic_extensions);
115 let hugr = test_logic_op(LogicOp::Or, 2);
116 check_emission!(hugr, llvm_ctx);
117 }
118
119 #[rstest]
120 fn eq(mut llvm_ctx: TestContext) {
121 llvm_ctx.add_extensions(add_logic_extensions);
122 let hugr = test_logic_op(LogicOp::Eq, 2);
123 check_emission!(hugr, llvm_ctx);
124 }
125
126 #[rstest]
127 fn not(mut llvm_ctx: TestContext) {
128 llvm_ctx.add_extensions(add_logic_extensions);
129 let hugr = test_logic_op(LogicOp::Not, 1);
130 check_emission!(hugr, llvm_ctx);
131 }
132
133 #[rstest]
134 fn xor(mut llvm_ctx: TestContext) {
135 llvm_ctx.add_extensions(add_logic_extensions);
136 let hugr = test_logic_op(LogicOp::Xor, 2);
137 check_emission!(hugr, llvm_ctx);
138 }
139}