hugr_llvm/extension/
float.rs

1use anyhow::{anyhow, Result};
2use hugr_core::ops::ExtensionOp;
3use hugr_core::ops::{constant::CustomConst, Value};
4use hugr_core::std_extensions::arithmetic::float_ops::FloatOps;
5use hugr_core::{
6    std_extensions::arithmetic::float_types::{self, ConstF64},
7    HugrView,
8};
9use inkwell::{
10    types::{BasicType, FloatType},
11    values::{BasicValue, BasicValueEnum},
12};
13
14use crate::emit::emit_value;
15use crate::emit::ops::{emit_custom_binary_op, emit_custom_unary_op};
16use crate::emit::{func::EmitFuncContext, EmitOpArgs};
17
18use crate::custom::CodegenExtsBuilder;
19
20/// Emit a float comparison operation.
21fn emit_fcmp<'c, H: HugrView>(
22    context: &mut EmitFuncContext<'c, '_, H>,
23    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
24    pred: inkwell::FloatPredicate,
25) -> Result<()> {
26    let true_val = emit_value(context, &Value::true_val())?;
27    let false_val = emit_value(context, &Value::false_val())?;
28
29    emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
30        // get result as an i1
31        let r = ctx.builder().build_float_compare(
32            pred,
33            lhs.into_float_value(),
34            rhs.into_float_value(),
35            "",
36        )?;
37        // convert to whatever bool_t is
38        Ok(vec![ctx
39            .builder()
40            .build_select(r, true_val, false_val, "")?])
41    })
42}
43
44fn emit_float_op<'c, H: HugrView>(
45    context: &mut EmitFuncContext<'c, '_, H>,
46    args: EmitOpArgs<'c, '_, ExtensionOp, H>,
47    op: FloatOps,
48) -> Result<()> {
49    // We emit the float comparison variants where NaN is an absorbing value.
50    // Any comparison with NaN is always false.
51    #[allow(clippy::wildcard_in_or_patterns)]
52    match op {
53        FloatOps::feq => emit_fcmp(context, args, inkwell::FloatPredicate::OEQ),
54        FloatOps::fne => emit_fcmp(context, args, inkwell::FloatPredicate::ONE),
55        FloatOps::flt => emit_fcmp(context, args, inkwell::FloatPredicate::OLT),
56        FloatOps::fgt => emit_fcmp(context, args, inkwell::FloatPredicate::OGT),
57        FloatOps::fle => emit_fcmp(context, args, inkwell::FloatPredicate::OLE),
58        FloatOps::fge => emit_fcmp(context, args, inkwell::FloatPredicate::OGE),
59        FloatOps::fadd => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
60            Ok(vec![ctx
61                .builder()
62                .build_float_add(lhs.into_float_value(), rhs.into_float_value(), "")?
63                .as_basic_value_enum()])
64        }),
65        FloatOps::fsub => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
66            Ok(vec![ctx
67                .builder()
68                .build_float_sub(lhs.into_float_value(), rhs.into_float_value(), "")?
69                .as_basic_value_enum()])
70        }),
71        FloatOps::fneg => emit_custom_unary_op(context, args, |ctx, v, _| {
72            Ok(vec![ctx
73                .builder()
74                .build_float_neg(v.into_float_value(), "")?
75                .as_basic_value_enum()])
76        }),
77        FloatOps::fmul => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
78            Ok(vec![ctx
79                .builder()
80                .build_float_mul(lhs.into_float_value(), rhs.into_float_value(), "")?
81                .as_basic_value_enum()])
82        }),
83        FloatOps::fdiv => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
84            Ok(vec![ctx
85                .builder()
86                .build_float_div(lhs.into_float_value(), rhs.into_float_value(), "")?
87                .as_basic_value_enum()])
88        }),
89        // Missing ops, not supported by inkwell
90        FloatOps::fmax
91        | FloatOps::fmin
92        | FloatOps::fabs
93        | FloatOps::ffloor
94        | FloatOps::fceil
95        | FloatOps::ftostring
96        | _ => {
97            let name: &str = op.into();
98            Err(anyhow!("FloatOpEmitter: unimplemented op: {name}"))
99        }
100    }
101}
102
103fn emit_constf64<'c, H: HugrView>(
104    context: &mut EmitFuncContext<'c, '_, H>,
105    k: &ConstF64,
106) -> Result<BasicValueEnum<'c>> {
107    let ty: FloatType = context.llvm_type(&k.get_type())?.try_into().unwrap();
108    Ok(ty.const_float(k.value()).as_basic_value_enum())
109}
110
111pub fn add_float_extensions<'a, H: HugrView + 'a>(
112    cem: CodegenExtsBuilder<'a, H>,
113) -> CodegenExtsBuilder<'a, H> {
114    cem.custom_type(
115        (
116            float_types::EXTENSION_ID,
117            float_types::FLOAT_TYPE_ID.clone(),
118        ),
119        |ts, _custom_type| Ok(ts.iw_context().f64_type().as_basic_type_enum()),
120    )
121    .custom_const(emit_constf64)
122    .simple_extension_op::<FloatOps>(emit_float_op)
123}
124
125impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> {
126    pub fn add_float_extensions(self) -> Self {
127        add_float_extensions(self)
128    }
129}
130
131#[cfg(test)]
132mod test {
133    use hugr_core::extension::simple_op::MakeOpDef;
134    use hugr_core::extension::SignatureFunc;
135    use hugr_core::std_extensions::arithmetic::float_ops::FloatOps;
136    use hugr_core::std_extensions::STD_REG;
137    use hugr_core::types::TypeRow;
138    use hugr_core::Hugr;
139    use hugr_core::{
140        builder::{Dataflow, DataflowSubContainer},
141        std_extensions::arithmetic::float_types::{float64_type, ConstF64},
142    };
143    use rstest::rstest;
144
145    use super::add_float_extensions;
146    use crate::{
147        check_emission,
148        emit::test::SimpleHugrConfig,
149        test::{llvm_ctx, TestContext},
150    };
151
152    fn test_float_op(op: FloatOps) -> Hugr {
153        let SignatureFunc::PolyFuncType(poly_sig) = op.signature() else {
154            panic!("Expected PolyFuncType");
155        };
156        let sig = poly_sig.body();
157        let inp: TypeRow = sig.input.clone().try_into().unwrap();
158        let out: TypeRow = sig.output.clone().try_into().unwrap();
159
160        SimpleHugrConfig::new()
161            .with_ins(inp)
162            .with_outs(out)
163            .with_extensions(STD_REG.to_owned())
164            .finish(|mut builder| {
165                let outputs = builder
166                    .add_dataflow_op(op, builder.input_wires())
167                    .unwrap()
168                    .outputs();
169                builder.finish_with_outputs(outputs).unwrap()
170            })
171    }
172
173    #[rstest]
174    fn const_float(mut llvm_ctx: TestContext) {
175        llvm_ctx.add_extensions(add_float_extensions);
176        let hugr = SimpleHugrConfig::new()
177            .with_outs(float64_type())
178            .with_extensions(STD_REG.to_owned())
179            .finish(|mut builder| {
180                let c = builder.add_load_value(ConstF64::new(3.12));
181                builder.finish_with_outputs([c]).unwrap()
182            });
183        check_emission!(hugr, llvm_ctx);
184    }
185
186    #[rstest]
187    #[case::feq(FloatOps::feq)]
188    #[case::fne(FloatOps::fne)]
189    #[case::flt(FloatOps::flt)]
190    #[case::fgt(FloatOps::fgt)]
191    #[case::fle(FloatOps::fle)]
192    #[case::fge(FloatOps::fge)]
193    #[case::fadd(FloatOps::fadd)]
194    #[case::fsub(FloatOps::fsub)]
195    #[case::fneg(FloatOps::fneg)]
196    #[case::fmul(FloatOps::fmul)]
197    #[case::fdiv(FloatOps::fdiv)]
198    fn float_operations(mut llvm_ctx: TestContext, #[case] op: FloatOps) {
199        let name: &str = op.into();
200        let hugr = test_float_op(op);
201        llvm_ctx.add_extensions(add_float_extensions);
202        check_emission!(name, hugr, llvm_ctx);
203    }
204}