1use hugr_core::{
2 HugrView, Node,
3 extension::simple_op::MakeExtensionOp,
4 ops::{ExtensionOp, Value},
5 std_extensions::logic::{self, LogicOp},
6 types::SumType,
7};
8use inkwell::IntPredicate;
9
10use crate::{
11 custom::CodegenExtsBuilder,
12 emit::{EmitOpArgs, emit_value, func::EmitFuncContext},
13 sum::LLVMSumValue,
14};
15
16use anyhow::{Result, anyhow};
17
18fn emit_logic_op<'c, H: HugrView<Node = Node>>(
19 context: &mut EmitFuncContext<'c, '_, H>,
20 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
21) -> Result<()> {
22 let lot = LogicOp::from_optype(&args.node().generalise()).ok_or(anyhow!(
23 "LogicOpEmitter: from_optype_failed: {:?}",
24 args.node().as_ref()
25 ))?;
26 let builder = context.builder();
27 let mut inputs = vec![];
29 for inp in args.inputs {
30 let bool_ty = context.llvm_sum_type(SumType::new_unary(2))?;
31 let bool_val = LLVMSumValue::try_new(inp, bool_ty)?;
32 inputs.push(bool_val.build_get_tag(builder)?);
33 }
34 let res = match lot {
35 LogicOp::And => builder.build_and(inputs[0], inputs[1], "")?,
36 LogicOp::Or => builder.build_or(inputs[0], inputs[1], "")?,
37 LogicOp::Xor => builder.build_xor(inputs[0], inputs[1], "")?,
38 LogicOp::Eq => builder.build_int_compare(IntPredicate::EQ, inputs[0], inputs[1], "")?,
39 LogicOp::Not => builder.build_not(inputs[0], "")?,
40 op => {
41 return Err(anyhow!("LogicOpEmitter: Unknown op: {op:?}"));
42 }
43 };
44 let res = builder.build_int_cast(res, context.iw_context().bool_type(), "")?;
46 let true_val = emit_value(context, &Value::true_val())?;
47 let false_val = emit_value(context, &Value::false_val())?;
48 let res = context
49 .builder()
50 .build_select(res, true_val, false_val, "")?;
51 args.outputs.finish(context.builder(), vec![res])
52}
53
54pub fn add_logic_extensions<'a, H: HugrView<Node = Node> + 'a>(
57 cem: CodegenExtsBuilder<'a, H>,
58) -> CodegenExtsBuilder<'a, H> {
59 cem.extension_op(logic::EXTENSION_ID, LogicOp::Eq.op_id(), emit_logic_op)
60 .extension_op(logic::EXTENSION_ID, LogicOp::And.op_id(), emit_logic_op)
61 .extension_op(logic::EXTENSION_ID, LogicOp::Or.op_id(), emit_logic_op)
62 .extension_op(logic::EXTENSION_ID, LogicOp::Not.op_id(), emit_logic_op)
63 .extension_op(logic::EXTENSION_ID, LogicOp::Xor.op_id(), emit_logic_op) }
65
66impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
67 #[must_use]
70 pub fn add_logic_extensions(self) -> Self {
71 add_logic_extensions(self)
72 }
73}
74
75#[cfg(test)]
76mod test {
77 use hugr_core::{
78 Hugr,
79 builder::{Dataflow, DataflowHugr},
80 extension::{ExtensionRegistry, prelude::bool_t},
81 std_extensions::logic::{self, LogicOp},
82 };
83 use rstest::rstest;
84
85 use crate::{
86 check_emission,
87 emit::test::SimpleHugrConfig,
88 extension::logic::add_logic_extensions,
89 test::{TestContext, llvm_ctx},
90 };
91
92 fn test_logic_op(op: LogicOp, arity: usize) -> Hugr {
93 SimpleHugrConfig::new()
94 .with_ins(vec![bool_t(); arity])
95 .with_outs(vec![bool_t()])
96 .with_extensions(ExtensionRegistry::new(vec![logic::EXTENSION.to_owned()]))
97 .finish(|mut builder| {
98 let outputs = builder
99 .add_dataflow_op(op, builder.input_wires())
100 .unwrap()
101 .outputs();
102 builder.finish_hugr_with_outputs(outputs).unwrap()
103 })
104 }
105
106 #[rstest]
107 fn and(mut llvm_ctx: TestContext) {
108 llvm_ctx.add_extensions(add_logic_extensions);
109 let hugr = test_logic_op(LogicOp::And, 2);
110 check_emission!(hugr, llvm_ctx);
111 }
112
113 #[rstest]
114 fn or(mut llvm_ctx: TestContext) {
115 llvm_ctx.add_extensions(add_logic_extensions);
116 let hugr = test_logic_op(LogicOp::Or, 2);
117 check_emission!(hugr, llvm_ctx);
118 }
119
120 #[rstest]
121 fn eq(mut llvm_ctx: TestContext) {
122 llvm_ctx.add_extensions(add_logic_extensions);
123 let hugr = test_logic_op(LogicOp::Eq, 2);
124 check_emission!(hugr, llvm_ctx);
125 }
126
127 #[rstest]
128 fn not(mut llvm_ctx: TestContext) {
129 llvm_ctx.add_extensions(add_logic_extensions);
130 let hugr = test_logic_op(LogicOp::Not, 1);
131 check_emission!(hugr, llvm_ctx);
132 }
133
134 #[rstest]
135 fn xor(mut llvm_ctx: TestContext) {
136 llvm_ctx.add_extensions(add_logic_extensions);
137 let hugr = test_logic_op(LogicOp::Xor, 2);
138 check_emission!(hugr, llvm_ctx);
139 }
140}