hugr_llvm/extension/
logic.rs

1use hugr_core::{
2    extension::simple_op::MakeExtensionOp,
3    ops::{ExtensionOp, NamedOp, Value},
4    std_extensions::logic::{self, LogicOp},
5    types::SumType,
6    HugrView,
7};
8use inkwell::IntPredicate;
9
10use crate::{
11    custom::CodegenExtsBuilder,
12    emit::{emit_value, func::EmitFuncContext, EmitOpArgs},
13    sum::LLVMSumValue,
14};
15
16use anyhow::{anyhow, Result};
17
18fn emit_logic_op<'c, H: HugrView>(
19    context: &mut EmitFuncContext<'c, '_, H>,
20    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
21) -> Result<()> {
22    let lot = LogicOp::from_optype(&args.node().generalise()).ok_or(anyhow!(
23        "LogicOpEmitter: from_optype_failed: {:?}",
24        args.node().as_ref()
25    ))?;
26    let builder = context.builder();
27    // Turn bool sum inputs into i1's
28    let mut inputs = vec![];
29    for inp in args.inputs {
30        let bool_ty = context.llvm_sum_type(SumType::new_unary(2))?;
31        let bool_val = LLVMSumValue::try_new(inp, bool_ty)?;
32        inputs.push(bool_val.build_get_tag(builder)?);
33    }
34    let res = match lot {
35        LogicOp::And => builder.build_and(inputs[0], inputs[1], "")?,
36        LogicOp::Or => builder.build_or(inputs[0], inputs[1], "")?,
37        LogicOp::Xor => builder.build_xor(inputs[0], inputs[1], "")?,
38        LogicOp::Eq => builder.build_int_compare(IntPredicate::EQ, inputs[0], inputs[1], "")?,
39        LogicOp::Not => builder.build_not(inputs[0], "")?,
40        op => {
41            return Err(anyhow!("LogicOpEmitter: Unknown op: {op:?}"));
42        }
43    };
44    // Turn result back into sum
45    let res = builder.build_int_cast(res, context.iw_context().bool_type(), "")?;
46    let true_val = emit_value(context, &Value::true_val())?;
47    let false_val = emit_value(context, &Value::false_val())?;
48    let res = context
49        .builder()
50        .build_select(res, true_val, false_val, "")?;
51    args.outputs.finish(context.builder(), vec![res])
52}
53
54/// Populates a [CodegenExtsBuilder] with all extensions needed to lower logic
55/// ops.
56pub fn add_logic_extensions<'a, H: HugrView + 'a>(
57    cem: CodegenExtsBuilder<'a, H>,
58) -> CodegenExtsBuilder<'a, H> {
59    cem.extension_op(logic::EXTENSION_ID, LogicOp::Eq.name(), emit_logic_op)
60        .extension_op(logic::EXTENSION_ID, LogicOp::And.name(), emit_logic_op)
61        .extension_op(logic::EXTENSION_ID, LogicOp::Or.name(), emit_logic_op)
62        .extension_op(logic::EXTENSION_ID, LogicOp::Not.name(), emit_logic_op)
63        .extension_op(logic::EXTENSION_ID, LogicOp::Xor.name(), emit_logic_op) // Added Xor
64}
65
66impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> {
67    /// Populates a [CodegenExtsBuilder] with all extensions needed to lower
68    /// logic ops.
69    pub fn add_logic_extensions(self) -> Self {
70        add_logic_extensions(self)
71    }
72}
73
74#[cfg(test)]
75mod test {
76    use hugr_core::{
77        builder::{Dataflow, DataflowSubContainer},
78        extension::{prelude::bool_t, ExtensionRegistry},
79        std_extensions::logic::{self, LogicOp},
80        Hugr,
81    };
82    use rstest::rstest;
83
84    use crate::{
85        check_emission,
86        emit::test::SimpleHugrConfig,
87        extension::logic::add_logic_extensions,
88        test::{llvm_ctx, TestContext},
89    };
90
91    fn test_logic_op(op: LogicOp, arity: usize) -> Hugr {
92        SimpleHugrConfig::new()
93            .with_ins(vec![bool_t(); arity])
94            .with_outs(vec![bool_t()])
95            .with_extensions(ExtensionRegistry::new(vec![logic::EXTENSION.to_owned()]))
96            .finish(|mut builder| {
97                let outputs = builder
98                    .add_dataflow_op(op, builder.input_wires())
99                    .unwrap()
100                    .outputs();
101                builder.finish_with_outputs(outputs).unwrap()
102            })
103    }
104
105    #[rstest]
106    fn and(mut llvm_ctx: TestContext) {
107        llvm_ctx.add_extensions(add_logic_extensions);
108        let hugr = test_logic_op(LogicOp::And, 2);
109        check_emission!(hugr, llvm_ctx);
110    }
111
112    #[rstest]
113    fn or(mut llvm_ctx: TestContext) {
114        llvm_ctx.add_extensions(add_logic_extensions);
115        let hugr = test_logic_op(LogicOp::Or, 2);
116        check_emission!(hugr, llvm_ctx);
117    }
118
119    #[rstest]
120    fn eq(mut llvm_ctx: TestContext) {
121        llvm_ctx.add_extensions(add_logic_extensions);
122        let hugr = test_logic_op(LogicOp::Eq, 2);
123        check_emission!(hugr, llvm_ctx);
124    }
125
126    #[rstest]
127    fn not(mut llvm_ctx: TestContext) {
128        llvm_ctx.add_extensions(add_logic_extensions);
129        let hugr = test_logic_op(LogicOp::Not, 1);
130        check_emission!(hugr, llvm_ctx);
131    }
132
133    #[rstest]
134    fn xor(mut llvm_ctx: TestContext) {
135        llvm_ctx.add_extensions(add_logic_extensions);
136        let hugr = test_logic_op(LogicOp::Xor, 2);
137        check_emission!(hugr, llvm_ctx);
138    }
139}