hugr_llvm/extension/
logic.rs

1use hugr_core::{
2    HugrView, Node,
3    extension::simple_op::MakeExtensionOp,
4    ops::{ExtensionOp, Value},
5    std_extensions::logic::{self, LogicOp},
6    types::SumType,
7};
8use inkwell::IntPredicate;
9
10use crate::{
11    custom::CodegenExtsBuilder,
12    emit::{EmitOpArgs, emit_value, func::EmitFuncContext},
13    sum::LLVMSumValue,
14};
15
16use anyhow::{Result, anyhow};
17
18fn emit_logic_op<'c, H: HugrView<Node = Node>>(
19    context: &mut EmitFuncContext<'c, '_, H>,
20    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
21) -> Result<()> {
22    let lot = LogicOp::from_optype(&args.node().generalise()).ok_or(anyhow!(
23        "LogicOpEmitter: from_optype_failed: {:?}",
24        args.node().as_ref()
25    ))?;
26    let builder = context.builder();
27    // Turn bool sum inputs into i1's
28    let mut inputs = vec![];
29    for inp in args.inputs {
30        let bool_ty = context.llvm_sum_type(SumType::new_unary(2))?;
31        let bool_val = LLVMSumValue::try_new(inp, bool_ty)?;
32        inputs.push(bool_val.build_get_tag(builder)?);
33    }
34    let res = match lot {
35        LogicOp::And => builder.build_and(inputs[0], inputs[1], "")?,
36        LogicOp::Or => builder.build_or(inputs[0], inputs[1], "")?,
37        LogicOp::Xor => builder.build_xor(inputs[0], inputs[1], "")?,
38        LogicOp::Eq => builder.build_int_compare(IntPredicate::EQ, inputs[0], inputs[1], "")?,
39        LogicOp::Not => builder.build_not(inputs[0], "")?,
40        op => {
41            return Err(anyhow!("LogicOpEmitter: Unknown op: {op:?}"));
42        }
43    };
44    // Turn result back into sum
45    let res = builder.build_int_cast(res, context.iw_context().bool_type(), "")?;
46    let true_val = emit_value(context, &Value::true_val())?;
47    let false_val = emit_value(context, &Value::false_val())?;
48    let res = context
49        .builder()
50        .build_select(res, true_val, false_val, "")?;
51    args.outputs.finish(context.builder(), vec![res])
52}
53
54/// Populates a [`CodegenExtsBuilder`] with all extensions needed to lower logic
55/// ops.
56pub fn add_logic_extensions<'a, H: HugrView<Node = Node> + 'a>(
57    cem: CodegenExtsBuilder<'a, H>,
58) -> CodegenExtsBuilder<'a, H> {
59    cem.extension_op(logic::EXTENSION_ID, LogicOp::Eq.op_id(), emit_logic_op)
60        .extension_op(logic::EXTENSION_ID, LogicOp::And.op_id(), emit_logic_op)
61        .extension_op(logic::EXTENSION_ID, LogicOp::Or.op_id(), emit_logic_op)
62        .extension_op(logic::EXTENSION_ID, LogicOp::Not.op_id(), emit_logic_op)
63        .extension_op(logic::EXTENSION_ID, LogicOp::Xor.op_id(), emit_logic_op) // Added Xor
64}
65
66impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
67    /// Populates a [`CodegenExtsBuilder`] with all extensions needed to lower
68    /// logic ops.
69    #[must_use]
70    pub fn add_logic_extensions(self) -> Self {
71        add_logic_extensions(self)
72    }
73}
74
75#[cfg(test)]
76mod test {
77    use hugr_core::{
78        Hugr,
79        builder::{Dataflow, DataflowHugr},
80        extension::{ExtensionRegistry, prelude::bool_t},
81        std_extensions::logic::{self, LogicOp},
82    };
83    use rstest::rstest;
84
85    use crate::{
86        check_emission,
87        emit::test::SimpleHugrConfig,
88        extension::logic::add_logic_extensions,
89        test::{TestContext, llvm_ctx},
90    };
91
92    fn test_logic_op(op: LogicOp, arity: usize) -> Hugr {
93        SimpleHugrConfig::new()
94            .with_ins(vec![bool_t(); arity])
95            .with_outs(vec![bool_t()])
96            .with_extensions(ExtensionRegistry::new(vec![logic::EXTENSION.to_owned()]))
97            .finish(|mut builder| {
98                let outputs = builder
99                    .add_dataflow_op(op, builder.input_wires())
100                    .unwrap()
101                    .outputs();
102                builder.finish_hugr_with_outputs(outputs).unwrap()
103            })
104    }
105
106    #[rstest]
107    fn and(mut llvm_ctx: TestContext) {
108        llvm_ctx.add_extensions(add_logic_extensions);
109        let hugr = test_logic_op(LogicOp::And, 2);
110        check_emission!(hugr, llvm_ctx);
111    }
112
113    #[rstest]
114    fn or(mut llvm_ctx: TestContext) {
115        llvm_ctx.add_extensions(add_logic_extensions);
116        let hugr = test_logic_op(LogicOp::Or, 2);
117        check_emission!(hugr, llvm_ctx);
118    }
119
120    #[rstest]
121    fn eq(mut llvm_ctx: TestContext) {
122        llvm_ctx.add_extensions(add_logic_extensions);
123        let hugr = test_logic_op(LogicOp::Eq, 2);
124        check_emission!(hugr, llvm_ctx);
125    }
126
127    #[rstest]
128    fn not(mut llvm_ctx: TestContext) {
129        llvm_ctx.add_extensions(add_logic_extensions);
130        let hugr = test_logic_op(LogicOp::Not, 1);
131        check_emission!(hugr, llvm_ctx);
132    }
133
134    #[rstest]
135    fn xor(mut llvm_ctx: TestContext) {
136        llvm_ctx.add_extensions(add_logic_extensions);
137        let hugr = test_logic_op(LogicOp::Xor, 2);
138        check_emission!(hugr, llvm_ctx);
139    }
140}