1use anyhow::{anyhow, Result};
2use hugr_core::ops::ExtensionOp;
3use hugr_core::ops::{constant::CustomConst, Value};
4use hugr_core::std_extensions::arithmetic::float_ops::FloatOps;
5use hugr_core::{
6 std_extensions::arithmetic::float_types::{self, ConstF64},
7 HugrView,
8};
9use inkwell::{
10 types::{BasicType, FloatType},
11 values::{BasicValue, BasicValueEnum},
12};
13
14use crate::emit::emit_value;
15use crate::emit::ops::{emit_custom_binary_op, emit_custom_unary_op};
16use crate::emit::{func::EmitFuncContext, EmitOpArgs};
17
18use crate::custom::CodegenExtsBuilder;
19
20fn emit_fcmp<'c, H: HugrView>(
22 context: &mut EmitFuncContext<'c, '_, H>,
23 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
24 pred: inkwell::FloatPredicate,
25) -> Result<()> {
26 let true_val = emit_value(context, &Value::true_val())?;
27 let false_val = emit_value(context, &Value::false_val())?;
28
29 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
30 let r = ctx.builder().build_float_compare(
32 pred,
33 lhs.into_float_value(),
34 rhs.into_float_value(),
35 "",
36 )?;
37 Ok(vec![ctx
39 .builder()
40 .build_select(r, true_val, false_val, "")?])
41 })
42}
43
44fn emit_float_op<'c, H: HugrView>(
45 context: &mut EmitFuncContext<'c, '_, H>,
46 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
47 op: FloatOps,
48) -> Result<()> {
49 #[allow(clippy::wildcard_in_or_patterns)]
52 match op {
53 FloatOps::feq => emit_fcmp(context, args, inkwell::FloatPredicate::OEQ),
54 FloatOps::fne => emit_fcmp(context, args, inkwell::FloatPredicate::ONE),
55 FloatOps::flt => emit_fcmp(context, args, inkwell::FloatPredicate::OLT),
56 FloatOps::fgt => emit_fcmp(context, args, inkwell::FloatPredicate::OGT),
57 FloatOps::fle => emit_fcmp(context, args, inkwell::FloatPredicate::OLE),
58 FloatOps::fge => emit_fcmp(context, args, inkwell::FloatPredicate::OGE),
59 FloatOps::fadd => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
60 Ok(vec![ctx
61 .builder()
62 .build_float_add(lhs.into_float_value(), rhs.into_float_value(), "")?
63 .as_basic_value_enum()])
64 }),
65 FloatOps::fsub => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
66 Ok(vec![ctx
67 .builder()
68 .build_float_sub(lhs.into_float_value(), rhs.into_float_value(), "")?
69 .as_basic_value_enum()])
70 }),
71 FloatOps::fneg => emit_custom_unary_op(context, args, |ctx, v, _| {
72 Ok(vec![ctx
73 .builder()
74 .build_float_neg(v.into_float_value(), "")?
75 .as_basic_value_enum()])
76 }),
77 FloatOps::fmul => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
78 Ok(vec![ctx
79 .builder()
80 .build_float_mul(lhs.into_float_value(), rhs.into_float_value(), "")?
81 .as_basic_value_enum()])
82 }),
83 FloatOps::fdiv => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
84 Ok(vec![ctx
85 .builder()
86 .build_float_div(lhs.into_float_value(), rhs.into_float_value(), "")?
87 .as_basic_value_enum()])
88 }),
89 FloatOps::fmax
91 | FloatOps::fmin
92 | FloatOps::fabs
93 | FloatOps::ffloor
94 | FloatOps::fceil
95 | FloatOps::ftostring
96 | _ => {
97 let name: &str = op.into();
98 Err(anyhow!("FloatOpEmitter: unimplemented op: {name}"))
99 }
100 }
101}
102
103fn emit_constf64<'c, H: HugrView>(
104 context: &mut EmitFuncContext<'c, '_, H>,
105 k: &ConstF64,
106) -> Result<BasicValueEnum<'c>> {
107 let ty: FloatType = context.llvm_type(&k.get_type())?.try_into().unwrap();
108 Ok(ty.const_float(k.value()).as_basic_value_enum())
109}
110
111pub fn add_float_extensions<'a, H: HugrView + 'a>(
112 cem: CodegenExtsBuilder<'a, H>,
113) -> CodegenExtsBuilder<'a, H> {
114 cem.custom_type(
115 (
116 float_types::EXTENSION_ID,
117 float_types::FLOAT_TYPE_ID.clone(),
118 ),
119 |ts, _custom_type| Ok(ts.iw_context().f64_type().as_basic_type_enum()),
120 )
121 .custom_const(emit_constf64)
122 .simple_extension_op::<FloatOps>(emit_float_op)
123}
124
125impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> {
126 pub fn add_float_extensions(self) -> Self {
127 add_float_extensions(self)
128 }
129}
130
131#[cfg(test)]
132mod test {
133 use hugr_core::extension::simple_op::MakeOpDef;
134 use hugr_core::extension::SignatureFunc;
135 use hugr_core::std_extensions::arithmetic::float_ops::FloatOps;
136 use hugr_core::std_extensions::STD_REG;
137 use hugr_core::types::TypeRow;
138 use hugr_core::Hugr;
139 use hugr_core::{
140 builder::{Dataflow, DataflowSubContainer},
141 std_extensions::arithmetic::float_types::{float64_type, ConstF64},
142 };
143 use rstest::rstest;
144
145 use super::add_float_extensions;
146 use crate::{
147 check_emission,
148 emit::test::SimpleHugrConfig,
149 test::{llvm_ctx, TestContext},
150 };
151
152 fn test_float_op(op: FloatOps) -> Hugr {
153 let SignatureFunc::PolyFuncType(poly_sig) = op.signature() else {
154 panic!("Expected PolyFuncType");
155 };
156 let sig = poly_sig.body();
157 let inp: TypeRow = sig.input.clone().try_into().unwrap();
158 let out: TypeRow = sig.output.clone().try_into().unwrap();
159
160 SimpleHugrConfig::new()
161 .with_ins(inp)
162 .with_outs(out)
163 .with_extensions(STD_REG.to_owned())
164 .finish(|mut builder| {
165 let outputs = builder
166 .add_dataflow_op(op, builder.input_wires())
167 .unwrap()
168 .outputs();
169 builder.finish_with_outputs(outputs).unwrap()
170 })
171 }
172
173 #[rstest]
174 fn const_float(mut llvm_ctx: TestContext) {
175 llvm_ctx.add_extensions(add_float_extensions);
176 let hugr = SimpleHugrConfig::new()
177 .with_outs(float64_type())
178 .with_extensions(STD_REG.to_owned())
179 .finish(|mut builder| {
180 let c = builder.add_load_value(ConstF64::new(3.12));
181 builder.finish_with_outputs([c]).unwrap()
182 });
183 check_emission!(hugr, llvm_ctx);
184 }
185
186 #[rstest]
187 #[case::feq(FloatOps::feq)]
188 #[case::fne(FloatOps::fne)]
189 #[case::flt(FloatOps::flt)]
190 #[case::fgt(FloatOps::fgt)]
191 #[case::fle(FloatOps::fle)]
192 #[case::fge(FloatOps::fge)]
193 #[case::fadd(FloatOps::fadd)]
194 #[case::fsub(FloatOps::fsub)]
195 #[case::fneg(FloatOps::fneg)]
196 #[case::fmul(FloatOps::fmul)]
197 #[case::fdiv(FloatOps::fdiv)]
198 fn float_operations(mut llvm_ctx: TestContext, #[case] op: FloatOps) {
199 let name: &str = op.into();
200 let hugr = test_float_op(op);
201 llvm_ctx.add_extensions(add_float_extensions);
202 check_emission!(name, hugr, llvm_ctx);
203 }
204}