hugr_llvm/extension/
float.rs

1use anyhow::{anyhow, Result};
2use hugr_core::ops::ExtensionOp;
3use hugr_core::ops::{constant::CustomConst, Value};
4use hugr_core::std_extensions::arithmetic::float_ops::FloatOps;
5use hugr_core::Node;
6use hugr_core::{
7    std_extensions::arithmetic::float_types::{self, ConstF64},
8    HugrView,
9};
10use inkwell::{
11    types::{BasicType, FloatType},
12    values::{BasicValue, BasicValueEnum},
13};
14
15use crate::emit::emit_value;
16use crate::emit::ops::{emit_custom_binary_op, emit_custom_unary_op};
17use crate::emit::{func::EmitFuncContext, EmitOpArgs};
18
19use crate::custom::CodegenExtsBuilder;
20
21/// Emit a float comparison operation.
22fn emit_fcmp<'c, H: HugrView<Node = Node>>(
23    context: &mut EmitFuncContext<'c, '_, H>,
24    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
25    pred: inkwell::FloatPredicate,
26) -> Result<()> {
27    let true_val = emit_value(context, &Value::true_val())?;
28    let false_val = emit_value(context, &Value::false_val())?;
29
30    emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
31        // get result as an i1
32        let r = ctx.builder().build_float_compare(
33            pred,
34            lhs.into_float_value(),
35            rhs.into_float_value(),
36            "",
37        )?;
38        // convert to whatever bool_t is
39        Ok(vec![ctx
40            .builder()
41            .build_select(r, true_val, false_val, "")?])
42    })
43}
44
45fn emit_float_op<'c, H: HugrView<Node = Node>>(
46    context: &mut EmitFuncContext<'c, '_, H>,
47    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
48    op: FloatOps,
49) -> Result<()> {
50    // We emit the float comparison variants where NaN is an absorbing value.
51    // Any comparison with NaN is always false.
52    #[allow(clippy::wildcard_in_or_patterns)]
53    match op {
54        FloatOps::feq => emit_fcmp(context, args, inkwell::FloatPredicate::OEQ),
55        FloatOps::fne => emit_fcmp(context, args, inkwell::FloatPredicate::ONE),
56        FloatOps::flt => emit_fcmp(context, args, inkwell::FloatPredicate::OLT),
57        FloatOps::fgt => emit_fcmp(context, args, inkwell::FloatPredicate::OGT),
58        FloatOps::fle => emit_fcmp(context, args, inkwell::FloatPredicate::OLE),
59        FloatOps::fge => emit_fcmp(context, args, inkwell::FloatPredicate::OGE),
60        FloatOps::fadd => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
61            Ok(vec![ctx
62                .builder()
63                .build_float_add(lhs.into_float_value(), rhs.into_float_value(), "")?
64                .as_basic_value_enum()])
65        }),
66        FloatOps::fsub => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
67            Ok(vec![ctx
68                .builder()
69                .build_float_sub(lhs.into_float_value(), rhs.into_float_value(), "")?
70                .as_basic_value_enum()])
71        }),
72        FloatOps::fneg => emit_custom_unary_op(context, args, |ctx, v, _| {
73            Ok(vec![ctx
74                .builder()
75                .build_float_neg(v.into_float_value(), "")?
76                .as_basic_value_enum()])
77        }),
78        FloatOps::fmul => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
79            Ok(vec![ctx
80                .builder()
81                .build_float_mul(lhs.into_float_value(), rhs.into_float_value(), "")?
82                .as_basic_value_enum()])
83        }),
84        FloatOps::fdiv => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
85            Ok(vec![ctx
86                .builder()
87                .build_float_div(lhs.into_float_value(), rhs.into_float_value(), "")?
88                .as_basic_value_enum()])
89        }),
90        // Missing ops, not supported by inkwell
91        FloatOps::fmax
92        | FloatOps::fmin
93        | FloatOps::fabs
94        | FloatOps::ffloor
95        | FloatOps::fceil
96        | FloatOps::ftostring
97        | _ => {
98            let name: &str = op.into();
99            Err(anyhow!("FloatOpEmitter: unimplemented op: {name}"))
100        }
101    }
102}
103
104fn emit_constf64<'c, H: HugrView<Node = Node>>(
105    context: &mut EmitFuncContext<'c, '_, H>,
106    k: &ConstF64,
107) -> Result<BasicValueEnum<'c>> {
108    let ty: FloatType = context.llvm_type(&k.get_type())?.try_into().unwrap();
109    Ok(ty.const_float(k.value()).as_basic_value_enum())
110}
111
112pub fn add_float_extensions<'a, H: HugrView<Node = Node> + 'a>(
113    cem: CodegenExtsBuilder<'a, H>,
114) -> CodegenExtsBuilder<'a, H> {
115    cem.custom_type(
116        (
117            float_types::EXTENSION_ID,
118            float_types::FLOAT_TYPE_ID.clone(),
119        ),
120        |ts, _custom_type| Ok(ts.iw_context().f64_type().as_basic_type_enum()),
121    )
122    .custom_const(emit_constf64)
123    .simple_extension_op::<FloatOps>(emit_float_op)
124}
125
126impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
127    pub fn add_float_extensions(self) -> Self {
128        add_float_extensions(self)
129    }
130}
131
132#[cfg(test)]
133mod test {
134    use hugr_core::extension::simple_op::MakeOpDef;
135    use hugr_core::extension::SignatureFunc;
136    use hugr_core::std_extensions::arithmetic::float_ops::FloatOps;
137    use hugr_core::std_extensions::STD_REG;
138    use hugr_core::types::TypeRow;
139    use hugr_core::Hugr;
140    use hugr_core::{
141        builder::{Dataflow, DataflowSubContainer},
142        std_extensions::arithmetic::float_types::{float64_type, ConstF64},
143    };
144    use rstest::rstest;
145
146    use super::add_float_extensions;
147    use crate::{
148        check_emission,
149        emit::test::SimpleHugrConfig,
150        test::{llvm_ctx, TestContext},
151    };
152
153    fn test_float_op(op: FloatOps) -> Hugr {
154        let SignatureFunc::PolyFuncType(poly_sig) = op.signature() else {
155            panic!("Expected PolyFuncType");
156        };
157        let sig = poly_sig.body();
158        let inp: TypeRow = sig.input.clone().try_into().unwrap();
159        let out: TypeRow = sig.output.clone().try_into().unwrap();
160
161        SimpleHugrConfig::new()
162            .with_ins(inp)
163            .with_outs(out)
164            .with_extensions(STD_REG.to_owned())
165            .finish(|mut builder| {
166                let outputs = builder
167                    .add_dataflow_op(op, builder.input_wires())
168                    .unwrap()
169                    .outputs();
170                builder.finish_with_outputs(outputs).unwrap()
171            })
172    }
173
174    #[rstest]
175    fn const_float(mut llvm_ctx: TestContext) {
176        llvm_ctx.add_extensions(add_float_extensions);
177        let hugr = SimpleHugrConfig::new()
178            .with_outs(float64_type())
179            .with_extensions(STD_REG.to_owned())
180            .finish(|mut builder| {
181                let c = builder.add_load_value(ConstF64::new(3.12));
182                builder.finish_with_outputs([c]).unwrap()
183            });
184        check_emission!(hugr, llvm_ctx);
185    }
186
187    #[rstest]
188    #[case::feq(FloatOps::feq)]
189    #[case::fne(FloatOps::fne)]
190    #[case::flt(FloatOps::flt)]
191    #[case::fgt(FloatOps::fgt)]
192    #[case::fle(FloatOps::fle)]
193    #[case::fge(FloatOps::fge)]
194    #[case::fadd(FloatOps::fadd)]
195    #[case::fsub(FloatOps::fsub)]
196    #[case::fneg(FloatOps::fneg)]
197    #[case::fmul(FloatOps::fmul)]
198    #[case::fdiv(FloatOps::fdiv)]
199    fn float_operations(mut llvm_ctx: TestContext, #[case] op: FloatOps) {
200        let name: &str = op.into();
201        let hugr = test_float_op(op);
202        llvm_ctx.add_extensions(add_float_extensions);
203        check_emission!(name, hugr, llvm_ctx);
204    }
205}