hugr_llvm/extension/
float.rs

1use anyhow::{Result, anyhow};
2use hugr_core::Node;
3use hugr_core::ops::ExtensionOp;
4use hugr_core::ops::{Value, constant::CustomConst};
5use hugr_core::std_extensions::arithmetic::float_ops::FloatOps;
6use hugr_core::{
7    HugrView,
8    std_extensions::arithmetic::float_types::{self, ConstF64},
9};
10use inkwell::{
11    types::{BasicType, FloatType},
12    values::{BasicValue, BasicValueEnum},
13};
14
15use crate::emit::ops::{emit_custom_binary_op, emit_custom_unary_op};
16use crate::emit::{EmitOpArgs, func::EmitFuncContext};
17use crate::emit::{emit_value, get_intrinsic};
18
19use crate::custom::CodegenExtsBuilder;
20
21/// Emit a float comparison operation.
22fn emit_fcmp<'c, H: HugrView<Node = Node>>(
23    context: &mut EmitFuncContext<'c, '_, H>,
24    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
25    pred: inkwell::FloatPredicate,
26) -> Result<()> {
27    let true_val = emit_value(context, &Value::true_val())?;
28    let false_val = emit_value(context, &Value::false_val())?;
29
30    emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
31        // get result as an i1
32        let r = ctx.builder().build_float_compare(
33            pred,
34            lhs.into_float_value(),
35            rhs.into_float_value(),
36            "",
37        )?;
38        // convert to whatever bool_t is
39        Ok(vec![
40            ctx.builder().build_select(r, true_val, false_val, "")?,
41        ])
42    })
43}
44
45fn emit_float_op<'c, H: HugrView<Node = Node>>(
46    context: &mut EmitFuncContext<'c, '_, H>,
47    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
48    op: FloatOps,
49) -> Result<()> {
50    // We emit the float comparison variants where NaN is an absorbing value.
51    // Any comparison with NaN is always false.
52    #[allow(clippy::wildcard_in_or_patterns)]
53    match op {
54        FloatOps::feq => emit_fcmp(context, args, inkwell::FloatPredicate::OEQ),
55        FloatOps::fne => emit_fcmp(context, args, inkwell::FloatPredicate::ONE),
56        FloatOps::flt => emit_fcmp(context, args, inkwell::FloatPredicate::OLT),
57        FloatOps::fgt => emit_fcmp(context, args, inkwell::FloatPredicate::OGT),
58        FloatOps::fle => emit_fcmp(context, args, inkwell::FloatPredicate::OLE),
59        FloatOps::fge => emit_fcmp(context, args, inkwell::FloatPredicate::OGE),
60        FloatOps::fadd => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
61            Ok(vec![
62                ctx.builder()
63                    .build_float_add(lhs.into_float_value(), rhs.into_float_value(), "")?
64                    .as_basic_value_enum(),
65            ])
66        }),
67        FloatOps::fsub => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
68            Ok(vec![
69                ctx.builder()
70                    .build_float_sub(lhs.into_float_value(), rhs.into_float_value(), "")?
71                    .as_basic_value_enum(),
72            ])
73        }),
74        FloatOps::fneg => emit_custom_unary_op(context, args, |ctx, v, _| {
75            Ok(vec![
76                ctx.builder()
77                    .build_float_neg(v.into_float_value(), "")?
78                    .as_basic_value_enum(),
79            ])
80        }),
81        FloatOps::fmul => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
82            Ok(vec![
83                ctx.builder()
84                    .build_float_mul(lhs.into_float_value(), rhs.into_float_value(), "")?
85                    .as_basic_value_enum(),
86            ])
87        }),
88        FloatOps::fdiv => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
89            Ok(vec![
90                ctx.builder()
91                    .build_float_div(lhs.into_float_value(), rhs.into_float_value(), "")?
92                    .as_basic_value_enum(),
93            ])
94        }),
95        FloatOps::fpow => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
96            let float_ty = ctx.iw_context().f64_type().as_basic_type_enum();
97            let func = get_intrinsic(ctx.get_current_module(), "llvm.pow.f64", [float_ty])?;
98            Ok(vec![
99                ctx.builder()
100                    .build_call(func, &[lhs.into(), rhs.into()], "")?
101                    .try_as_basic_value()
102                    .unwrap_left()
103                    .as_basic_value_enum(),
104            ])
105        }),
106        // Missing ops, not supported by inkwell
107        FloatOps::fmax
108        | FloatOps::fmin
109        | FloatOps::fabs
110        | FloatOps::ffloor
111        | FloatOps::fceil
112        | FloatOps::ftostring
113        | _ => {
114            let name: &str = op.into();
115            Err(anyhow!("FloatOpEmitter: unimplemented op: {name}"))
116        }
117    }
118}
119
120fn emit_constf64<'c, H: HugrView<Node = Node>>(
121    context: &mut EmitFuncContext<'c, '_, H>,
122    k: &ConstF64,
123) -> Result<BasicValueEnum<'c>> {
124    let ty: FloatType = context.llvm_type(&k.get_type())?.try_into().unwrap();
125    Ok(ty.const_float(k.value()).as_basic_value_enum())
126}
127
128pub fn add_float_extensions<'a, H: HugrView<Node = Node> + 'a>(
129    cem: CodegenExtsBuilder<'a, H>,
130) -> CodegenExtsBuilder<'a, H> {
131    cem.custom_type(
132        (
133            float_types::EXTENSION_ID,
134            float_types::FLOAT_TYPE_ID.clone(),
135        ),
136        |ts, _custom_type| Ok(ts.iw_context().f64_type().as_basic_type_enum()),
137    )
138    .custom_const(emit_constf64)
139    .simple_extension_op::<FloatOps>(emit_float_op)
140}
141
142impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
143    #[must_use]
144    pub fn add_float_extensions(self) -> Self {
145        add_float_extensions(self)
146    }
147}
148
149#[cfg(test)]
150mod test {
151    use hugr_core::Hugr;
152    use hugr_core::builder::DataflowHugr;
153    use hugr_core::extension::SignatureFunc;
154    use hugr_core::extension::simple_op::MakeOpDef;
155    use hugr_core::std_extensions::STD_REG;
156    use hugr_core::std_extensions::arithmetic::float_ops::FloatOps;
157    use hugr_core::types::TypeRow;
158    use hugr_core::{
159        builder::Dataflow,
160        std_extensions::arithmetic::float_types::{ConstF64, float64_type},
161    };
162    use rstest::rstest;
163
164    use super::add_float_extensions;
165    use crate::{
166        check_emission,
167        emit::test::SimpleHugrConfig,
168        test::{TestContext, llvm_ctx},
169    };
170
171    fn test_float_op(op: FloatOps) -> Hugr {
172        let SignatureFunc::PolyFuncType(poly_sig) = op.signature() else {
173            panic!("Expected PolyFuncType");
174        };
175        let sig = poly_sig.body();
176        let inp: TypeRow = sig.input.clone().try_into().unwrap();
177        let out: TypeRow = sig.output.clone().try_into().unwrap();
178
179        SimpleHugrConfig::new()
180            .with_ins(inp)
181            .with_outs(out)
182            .with_extensions(STD_REG.to_owned())
183            .finish(|mut builder| {
184                let outputs = builder
185                    .add_dataflow_op(op, builder.input_wires())
186                    .unwrap()
187                    .outputs();
188                builder.finish_hugr_with_outputs(outputs).unwrap()
189            })
190    }
191
192    #[rstest]
193    fn const_float(mut llvm_ctx: TestContext) {
194        llvm_ctx.add_extensions(add_float_extensions);
195        let hugr = SimpleHugrConfig::new()
196            .with_outs(float64_type())
197            .with_extensions(STD_REG.to_owned())
198            .finish(|mut builder| {
199                let c = builder.add_load_value(ConstF64::new(3.12));
200                builder.finish_hugr_with_outputs([c]).unwrap()
201            });
202        check_emission!(hugr, llvm_ctx);
203    }
204
205    #[rstest]
206    #[case::feq(FloatOps::feq)]
207    #[case::fne(FloatOps::fne)]
208    #[case::flt(FloatOps::flt)]
209    #[case::fgt(FloatOps::fgt)]
210    #[case::fle(FloatOps::fle)]
211    #[case::fge(FloatOps::fge)]
212    #[case::fadd(FloatOps::fadd)]
213    #[case::fsub(FloatOps::fsub)]
214    #[case::fneg(FloatOps::fneg)]
215    #[case::fmul(FloatOps::fmul)]
216    #[case::fdiv(FloatOps::fdiv)]
217    #[case::fpow(FloatOps::fpow)]
218    fn float_operations(mut llvm_ctx: TestContext, #[case] op: FloatOps) {
219        let name: &str = op.into();
220        let hugr = test_float_op(op);
221        llvm_ctx.add_extensions(add_float_extensions);
222        check_emission!(name, hugr, llvm_ctx);
223    }
224}