1use anyhow::{Result, anyhow};
2use hugr_core::Node;
3use hugr_core::ops::ExtensionOp;
4use hugr_core::ops::{Value, constant::CustomConst};
5use hugr_core::std_extensions::arithmetic::float_ops::FloatOps;
6use hugr_core::{
7 HugrView,
8 std_extensions::arithmetic::float_types::{self, ConstF64},
9};
10use inkwell::{
11 types::{BasicType, FloatType},
12 values::{BasicValue, BasicValueEnum},
13};
14
15use crate::emit::ops::{emit_custom_binary_op, emit_custom_unary_op};
16use crate::emit::{EmitOpArgs, func::EmitFuncContext};
17use crate::emit::{emit_value, get_intrinsic};
18
19use crate::custom::CodegenExtsBuilder;
20
21fn emit_fcmp<'c, H: HugrView<Node = Node>>(
23 context: &mut EmitFuncContext<'c, '_, H>,
24 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
25 pred: inkwell::FloatPredicate,
26) -> Result<()> {
27 let true_val = emit_value(context, &Value::true_val())?;
28 let false_val = emit_value(context, &Value::false_val())?;
29
30 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
31 let r = ctx.builder().build_float_compare(
33 pred,
34 lhs.into_float_value(),
35 rhs.into_float_value(),
36 "",
37 )?;
38 Ok(vec![
40 ctx.builder().build_select(r, true_val, false_val, "")?,
41 ])
42 })
43}
44
45fn emit_float_op<'c, H: HugrView<Node = Node>>(
46 context: &mut EmitFuncContext<'c, '_, H>,
47 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
48 op: FloatOps,
49) -> Result<()> {
50 #[allow(clippy::wildcard_in_or_patterns)]
53 match op {
54 FloatOps::feq => emit_fcmp(context, args, inkwell::FloatPredicate::OEQ),
55 FloatOps::fne => emit_fcmp(context, args, inkwell::FloatPredicate::ONE),
56 FloatOps::flt => emit_fcmp(context, args, inkwell::FloatPredicate::OLT),
57 FloatOps::fgt => emit_fcmp(context, args, inkwell::FloatPredicate::OGT),
58 FloatOps::fle => emit_fcmp(context, args, inkwell::FloatPredicate::OLE),
59 FloatOps::fge => emit_fcmp(context, args, inkwell::FloatPredicate::OGE),
60 FloatOps::fadd => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
61 Ok(vec![
62 ctx.builder()
63 .build_float_add(lhs.into_float_value(), rhs.into_float_value(), "")?
64 .as_basic_value_enum(),
65 ])
66 }),
67 FloatOps::fsub => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
68 Ok(vec![
69 ctx.builder()
70 .build_float_sub(lhs.into_float_value(), rhs.into_float_value(), "")?
71 .as_basic_value_enum(),
72 ])
73 }),
74 FloatOps::fneg => emit_custom_unary_op(context, args, |ctx, v, _| {
75 Ok(vec![
76 ctx.builder()
77 .build_float_neg(v.into_float_value(), "")?
78 .as_basic_value_enum(),
79 ])
80 }),
81 FloatOps::fmul => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
82 Ok(vec![
83 ctx.builder()
84 .build_float_mul(lhs.into_float_value(), rhs.into_float_value(), "")?
85 .as_basic_value_enum(),
86 ])
87 }),
88 FloatOps::fdiv => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
89 Ok(vec![
90 ctx.builder()
91 .build_float_div(lhs.into_float_value(), rhs.into_float_value(), "")?
92 .as_basic_value_enum(),
93 ])
94 }),
95 FloatOps::fpow => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
96 let float_ty = ctx.iw_context().f64_type().as_basic_type_enum();
97 let func = get_intrinsic(ctx.get_current_module(), "llvm.pow.f64", [float_ty])?;
98 Ok(vec![
99 ctx.builder()
100 .build_call(func, &[lhs.into(), rhs.into()], "")?
101 .try_as_basic_value()
102 .unwrap_left()
103 .as_basic_value_enum(),
104 ])
105 }),
106 FloatOps::fmax
108 | FloatOps::fmin
109 | FloatOps::fabs
110 | FloatOps::ffloor
111 | FloatOps::fceil
112 | FloatOps::ftostring
113 | _ => {
114 let name: &str = op.into();
115 Err(anyhow!("FloatOpEmitter: unimplemented op: {name}"))
116 }
117 }
118}
119
120fn emit_constf64<'c, H: HugrView<Node = Node>>(
121 context: &mut EmitFuncContext<'c, '_, H>,
122 k: &ConstF64,
123) -> Result<BasicValueEnum<'c>> {
124 let ty: FloatType = context.llvm_type(&k.get_type())?.try_into().unwrap();
125 Ok(ty.const_float(k.value()).as_basic_value_enum())
126}
127
128pub fn add_float_extensions<'a, H: HugrView<Node = Node> + 'a>(
129 cem: CodegenExtsBuilder<'a, H>,
130) -> CodegenExtsBuilder<'a, H> {
131 cem.custom_type(
132 (
133 float_types::EXTENSION_ID,
134 float_types::FLOAT_TYPE_ID.clone(),
135 ),
136 |ts, _custom_type| Ok(ts.iw_context().f64_type().as_basic_type_enum()),
137 )
138 .custom_const(emit_constf64)
139 .simple_extension_op::<FloatOps>(emit_float_op)
140}
141
142impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
143 #[must_use]
144 pub fn add_float_extensions(self) -> Self {
145 add_float_extensions(self)
146 }
147}
148
149#[cfg(test)]
150mod test {
151 use hugr_core::Hugr;
152 use hugr_core::builder::DataflowHugr;
153 use hugr_core::extension::SignatureFunc;
154 use hugr_core::extension::simple_op::MakeOpDef;
155 use hugr_core::std_extensions::STD_REG;
156 use hugr_core::std_extensions::arithmetic::float_ops::FloatOps;
157 use hugr_core::types::TypeRow;
158 use hugr_core::{
159 builder::Dataflow,
160 std_extensions::arithmetic::float_types::{ConstF64, float64_type},
161 };
162 use rstest::rstest;
163
164 use super::add_float_extensions;
165 use crate::{
166 check_emission,
167 emit::test::SimpleHugrConfig,
168 test::{TestContext, llvm_ctx},
169 };
170
171 fn test_float_op(op: FloatOps) -> Hugr {
172 let SignatureFunc::PolyFuncType(poly_sig) = op.signature() else {
173 panic!("Expected PolyFuncType");
174 };
175 let sig = poly_sig.body();
176 let inp: TypeRow = sig.input.clone().try_into().unwrap();
177 let out: TypeRow = sig.output.clone().try_into().unwrap();
178
179 SimpleHugrConfig::new()
180 .with_ins(inp)
181 .with_outs(out)
182 .with_extensions(STD_REG.to_owned())
183 .finish(|mut builder| {
184 let outputs = builder
185 .add_dataflow_op(op, builder.input_wires())
186 .unwrap()
187 .outputs();
188 builder.finish_hugr_with_outputs(outputs).unwrap()
189 })
190 }
191
192 #[rstest]
193 fn const_float(mut llvm_ctx: TestContext) {
194 llvm_ctx.add_extensions(add_float_extensions);
195 let hugr = SimpleHugrConfig::new()
196 .with_outs(float64_type())
197 .with_extensions(STD_REG.to_owned())
198 .finish(|mut builder| {
199 let c = builder.add_load_value(ConstF64::new(3.12));
200 builder.finish_hugr_with_outputs([c]).unwrap()
201 });
202 check_emission!(hugr, llvm_ctx);
203 }
204
205 #[rstest]
206 #[case::feq(FloatOps::feq)]
207 #[case::fne(FloatOps::fne)]
208 #[case::flt(FloatOps::flt)]
209 #[case::fgt(FloatOps::fgt)]
210 #[case::fle(FloatOps::fle)]
211 #[case::fge(FloatOps::fge)]
212 #[case::fadd(FloatOps::fadd)]
213 #[case::fsub(FloatOps::fsub)]
214 #[case::fneg(FloatOps::fneg)]
215 #[case::fmul(FloatOps::fmul)]
216 #[case::fdiv(FloatOps::fdiv)]
217 #[case::fpow(FloatOps::fpow)]
218 fn float_operations(mut llvm_ctx: TestContext, #[case] op: FloatOps) {
219 let name: &str = op.into();
220 let hugr = test_float_op(op);
221 llvm_ctx.add_extensions(add_float_extensions);
222 check_emission!(name, hugr, llvm_ctx);
223 }
224}