hugr_llvm/extension/
float.rs

1use anyhow::{anyhow, Result};
2use hugr_core::ops::ExtensionOp;
3use hugr_core::ops::{constant::CustomConst, Value};
4use hugr_core::std_extensions::arithmetic::float_ops::FloatOps;
5use hugr_core::Node;
6use hugr_core::{
7    std_extensions::arithmetic::float_types::{self, ConstF64},
8    HugrView,
9};
10use inkwell::{
11    types::{BasicType, FloatType},
12    values::{BasicValue, BasicValueEnum},
13};
14
15use crate::emit::ops::{emit_custom_binary_op, emit_custom_unary_op};
16use crate::emit::{emit_value, get_intrinsic};
17use crate::emit::{func::EmitFuncContext, EmitOpArgs};
18
19use crate::custom::CodegenExtsBuilder;
20
21/// Emit a float comparison operation.
22fn emit_fcmp<'c, H: HugrView<Node = Node>>(
23    context: &mut EmitFuncContext<'c, '_, H>,
24    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
25    pred: inkwell::FloatPredicate,
26) -> Result<()> {
27    let true_val = emit_value(context, &Value::true_val())?;
28    let false_val = emit_value(context, &Value::false_val())?;
29
30    emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
31        // get result as an i1
32        let r = ctx.builder().build_float_compare(
33            pred,
34            lhs.into_float_value(),
35            rhs.into_float_value(),
36            "",
37        )?;
38        // convert to whatever bool_t is
39        Ok(vec![ctx
40            .builder()
41            .build_select(r, true_val, false_val, "")?])
42    })
43}
44
45fn emit_float_op<'c, H: HugrView<Node = Node>>(
46    context: &mut EmitFuncContext<'c, '_, H>,
47    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
48    op: FloatOps,
49) -> Result<()> {
50    // We emit the float comparison variants where NaN is an absorbing value.
51    // Any comparison with NaN is always false.
52    #[allow(clippy::wildcard_in_or_patterns)]
53    match op {
54        FloatOps::feq => emit_fcmp(context, args, inkwell::FloatPredicate::OEQ),
55        FloatOps::fne => emit_fcmp(context, args, inkwell::FloatPredicate::ONE),
56        FloatOps::flt => emit_fcmp(context, args, inkwell::FloatPredicate::OLT),
57        FloatOps::fgt => emit_fcmp(context, args, inkwell::FloatPredicate::OGT),
58        FloatOps::fle => emit_fcmp(context, args, inkwell::FloatPredicate::OLE),
59        FloatOps::fge => emit_fcmp(context, args, inkwell::FloatPredicate::OGE),
60        FloatOps::fadd => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
61            Ok(vec![ctx
62                .builder()
63                .build_float_add(lhs.into_float_value(), rhs.into_float_value(), "")?
64                .as_basic_value_enum()])
65        }),
66        FloatOps::fsub => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
67            Ok(vec![ctx
68                .builder()
69                .build_float_sub(lhs.into_float_value(), rhs.into_float_value(), "")?
70                .as_basic_value_enum()])
71        }),
72        FloatOps::fneg => emit_custom_unary_op(context, args, |ctx, v, _| {
73            Ok(vec![ctx
74                .builder()
75                .build_float_neg(v.into_float_value(), "")?
76                .as_basic_value_enum()])
77        }),
78        FloatOps::fmul => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
79            Ok(vec![ctx
80                .builder()
81                .build_float_mul(lhs.into_float_value(), rhs.into_float_value(), "")?
82                .as_basic_value_enum()])
83        }),
84        FloatOps::fdiv => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
85            Ok(vec![ctx
86                .builder()
87                .build_float_div(lhs.into_float_value(), rhs.into_float_value(), "")?
88                .as_basic_value_enum()])
89        }),
90        FloatOps::fpow => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
91            let float_ty = ctx.iw_context().f64_type().as_basic_type_enum();
92            let func = get_intrinsic(ctx.get_current_module(), "llvm.pow.f64", [float_ty])?;
93            Ok(vec![ctx
94                .builder()
95                .build_call(func, &[lhs.into(), rhs.into()], "")?
96                .try_as_basic_value()
97                .unwrap_left()
98                .as_basic_value_enum()])
99        }),
100        // Missing ops, not supported by inkwell
101        FloatOps::fmax
102        | FloatOps::fmin
103        | FloatOps::fabs
104        | FloatOps::ffloor
105        | FloatOps::fceil
106        | FloatOps::ftostring
107        | _ => {
108            let name: &str = op.into();
109            Err(anyhow!("FloatOpEmitter: unimplemented op: {name}"))
110        }
111    }
112}
113
114fn emit_constf64<'c, H: HugrView<Node = Node>>(
115    context: &mut EmitFuncContext<'c, '_, H>,
116    k: &ConstF64,
117) -> Result<BasicValueEnum<'c>> {
118    let ty: FloatType = context.llvm_type(&k.get_type())?.try_into().unwrap();
119    Ok(ty.const_float(k.value()).as_basic_value_enum())
120}
121
122pub fn add_float_extensions<'a, H: HugrView<Node = Node> + 'a>(
123    cem: CodegenExtsBuilder<'a, H>,
124) -> CodegenExtsBuilder<'a, H> {
125    cem.custom_type(
126        (
127            float_types::EXTENSION_ID,
128            float_types::FLOAT_TYPE_ID.clone(),
129        ),
130        |ts, _custom_type| Ok(ts.iw_context().f64_type().as_basic_type_enum()),
131    )
132    .custom_const(emit_constf64)
133    .simple_extension_op::<FloatOps>(emit_float_op)
134}
135
136impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
137    pub fn add_float_extensions(self) -> Self {
138        add_float_extensions(self)
139    }
140}
141
142#[cfg(test)]
143mod test {
144    use hugr_core::extension::simple_op::MakeOpDef;
145    use hugr_core::extension::SignatureFunc;
146    use hugr_core::std_extensions::arithmetic::float_ops::FloatOps;
147    use hugr_core::std_extensions::STD_REG;
148    use hugr_core::types::TypeRow;
149    use hugr_core::Hugr;
150    use hugr_core::{
151        builder::{Dataflow, DataflowSubContainer},
152        std_extensions::arithmetic::float_types::{float64_type, ConstF64},
153    };
154    use rstest::rstest;
155
156    use super::add_float_extensions;
157    use crate::{
158        check_emission,
159        emit::test::SimpleHugrConfig,
160        test::{llvm_ctx, TestContext},
161    };
162
163    fn test_float_op(op: FloatOps) -> Hugr {
164        let SignatureFunc::PolyFuncType(poly_sig) = op.signature() else {
165            panic!("Expected PolyFuncType");
166        };
167        let sig = poly_sig.body();
168        let inp: TypeRow = sig.input.clone().try_into().unwrap();
169        let out: TypeRow = sig.output.clone().try_into().unwrap();
170
171        SimpleHugrConfig::new()
172            .with_ins(inp)
173            .with_outs(out)
174            .with_extensions(STD_REG.to_owned())
175            .finish(|mut builder| {
176                let outputs = builder
177                    .add_dataflow_op(op, builder.input_wires())
178                    .unwrap()
179                    .outputs();
180                builder.finish_with_outputs(outputs).unwrap()
181            })
182    }
183
184    #[rstest]
185    fn const_float(mut llvm_ctx: TestContext) {
186        llvm_ctx.add_extensions(add_float_extensions);
187        let hugr = SimpleHugrConfig::new()
188            .with_outs(float64_type())
189            .with_extensions(STD_REG.to_owned())
190            .finish(|mut builder| {
191                let c = builder.add_load_value(ConstF64::new(3.12));
192                builder.finish_with_outputs([c]).unwrap()
193            });
194        check_emission!(hugr, llvm_ctx);
195    }
196
197    #[rstest]
198    #[case::feq(FloatOps::feq)]
199    #[case::fne(FloatOps::fne)]
200    #[case::flt(FloatOps::flt)]
201    #[case::fgt(FloatOps::fgt)]
202    #[case::fle(FloatOps::fle)]
203    #[case::fge(FloatOps::fge)]
204    #[case::fadd(FloatOps::fadd)]
205    #[case::fsub(FloatOps::fsub)]
206    #[case::fneg(FloatOps::fneg)]
207    #[case::fmul(FloatOps::fmul)]
208    #[case::fdiv(FloatOps::fdiv)]
209    #[case::fpow(FloatOps::fpow)]
210    fn float_operations(mut llvm_ctx: TestContext, #[case] op: FloatOps) {
211        let name: &str = op.into();
212        let hugr = test_float_op(op);
213        llvm_ctx.add_extensions(add_float_extensions);
214        check_emission!(name, hugr, llvm_ctx);
215    }
216}