1use anyhow::{anyhow, Result};
2use hugr_core::ops::ExtensionOp;
3use hugr_core::ops::{constant::CustomConst, Value};
4use hugr_core::std_extensions::arithmetic::float_ops::FloatOps;
5use hugr_core::Node;
6use hugr_core::{
7 std_extensions::arithmetic::float_types::{self, ConstF64},
8 HugrView,
9};
10use inkwell::{
11 types::{BasicType, FloatType},
12 values::{BasicValue, BasicValueEnum},
13};
14
15use crate::emit::ops::{emit_custom_binary_op, emit_custom_unary_op};
16use crate::emit::{emit_value, get_intrinsic};
17use crate::emit::{func::EmitFuncContext, EmitOpArgs};
18
19use crate::custom::CodegenExtsBuilder;
20
21fn emit_fcmp<'c, H: HugrView<Node = Node>>(
23 context: &mut EmitFuncContext<'c, '_, H>,
24 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
25 pred: inkwell::FloatPredicate,
26) -> Result<()> {
27 let true_val = emit_value(context, &Value::true_val())?;
28 let false_val = emit_value(context, &Value::false_val())?;
29
30 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
31 let r = ctx.builder().build_float_compare(
33 pred,
34 lhs.into_float_value(),
35 rhs.into_float_value(),
36 "",
37 )?;
38 Ok(vec![ctx
40 .builder()
41 .build_select(r, true_val, false_val, "")?])
42 })
43}
44
45fn emit_float_op<'c, H: HugrView<Node = Node>>(
46 context: &mut EmitFuncContext<'c, '_, H>,
47 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
48 op: FloatOps,
49) -> Result<()> {
50 #[allow(clippy::wildcard_in_or_patterns)]
53 match op {
54 FloatOps::feq => emit_fcmp(context, args, inkwell::FloatPredicate::OEQ),
55 FloatOps::fne => emit_fcmp(context, args, inkwell::FloatPredicate::ONE),
56 FloatOps::flt => emit_fcmp(context, args, inkwell::FloatPredicate::OLT),
57 FloatOps::fgt => emit_fcmp(context, args, inkwell::FloatPredicate::OGT),
58 FloatOps::fle => emit_fcmp(context, args, inkwell::FloatPredicate::OLE),
59 FloatOps::fge => emit_fcmp(context, args, inkwell::FloatPredicate::OGE),
60 FloatOps::fadd => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
61 Ok(vec![ctx
62 .builder()
63 .build_float_add(lhs.into_float_value(), rhs.into_float_value(), "")?
64 .as_basic_value_enum()])
65 }),
66 FloatOps::fsub => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
67 Ok(vec![ctx
68 .builder()
69 .build_float_sub(lhs.into_float_value(), rhs.into_float_value(), "")?
70 .as_basic_value_enum()])
71 }),
72 FloatOps::fneg => emit_custom_unary_op(context, args, |ctx, v, _| {
73 Ok(vec![ctx
74 .builder()
75 .build_float_neg(v.into_float_value(), "")?
76 .as_basic_value_enum()])
77 }),
78 FloatOps::fmul => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
79 Ok(vec![ctx
80 .builder()
81 .build_float_mul(lhs.into_float_value(), rhs.into_float_value(), "")?
82 .as_basic_value_enum()])
83 }),
84 FloatOps::fdiv => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
85 Ok(vec![ctx
86 .builder()
87 .build_float_div(lhs.into_float_value(), rhs.into_float_value(), "")?
88 .as_basic_value_enum()])
89 }),
90 FloatOps::fpow => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
91 let float_ty = ctx.iw_context().f64_type().as_basic_type_enum();
92 let func = get_intrinsic(ctx.get_current_module(), "llvm.pow.f64", [float_ty])?;
93 Ok(vec![ctx
94 .builder()
95 .build_call(func, &[lhs.into(), rhs.into()], "")?
96 .try_as_basic_value()
97 .unwrap_left()
98 .as_basic_value_enum()])
99 }),
100 FloatOps::fmax
102 | FloatOps::fmin
103 | FloatOps::fabs
104 | FloatOps::ffloor
105 | FloatOps::fceil
106 | FloatOps::ftostring
107 | _ => {
108 let name: &str = op.into();
109 Err(anyhow!("FloatOpEmitter: unimplemented op: {name}"))
110 }
111 }
112}
113
114fn emit_constf64<'c, H: HugrView<Node = Node>>(
115 context: &mut EmitFuncContext<'c, '_, H>,
116 k: &ConstF64,
117) -> Result<BasicValueEnum<'c>> {
118 let ty: FloatType = context.llvm_type(&k.get_type())?.try_into().unwrap();
119 Ok(ty.const_float(k.value()).as_basic_value_enum())
120}
121
122pub fn add_float_extensions<'a, H: HugrView<Node = Node> + 'a>(
123 cem: CodegenExtsBuilder<'a, H>,
124) -> CodegenExtsBuilder<'a, H> {
125 cem.custom_type(
126 (
127 float_types::EXTENSION_ID,
128 float_types::FLOAT_TYPE_ID.clone(),
129 ),
130 |ts, _custom_type| Ok(ts.iw_context().f64_type().as_basic_type_enum()),
131 )
132 .custom_const(emit_constf64)
133 .simple_extension_op::<FloatOps>(emit_float_op)
134}
135
136impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
137 pub fn add_float_extensions(self) -> Self {
138 add_float_extensions(self)
139 }
140}
141
142#[cfg(test)]
143mod test {
144 use hugr_core::extension::simple_op::MakeOpDef;
145 use hugr_core::extension::SignatureFunc;
146 use hugr_core::std_extensions::arithmetic::float_ops::FloatOps;
147 use hugr_core::std_extensions::STD_REG;
148 use hugr_core::types::TypeRow;
149 use hugr_core::Hugr;
150 use hugr_core::{
151 builder::{Dataflow, DataflowSubContainer},
152 std_extensions::arithmetic::float_types::{float64_type, ConstF64},
153 };
154 use rstest::rstest;
155
156 use super::add_float_extensions;
157 use crate::{
158 check_emission,
159 emit::test::SimpleHugrConfig,
160 test::{llvm_ctx, TestContext},
161 };
162
163 fn test_float_op(op: FloatOps) -> Hugr {
164 let SignatureFunc::PolyFuncType(poly_sig) = op.signature() else {
165 panic!("Expected PolyFuncType");
166 };
167 let sig = poly_sig.body();
168 let inp: TypeRow = sig.input.clone().try_into().unwrap();
169 let out: TypeRow = sig.output.clone().try_into().unwrap();
170
171 SimpleHugrConfig::new()
172 .with_ins(inp)
173 .with_outs(out)
174 .with_extensions(STD_REG.to_owned())
175 .finish(|mut builder| {
176 let outputs = builder
177 .add_dataflow_op(op, builder.input_wires())
178 .unwrap()
179 .outputs();
180 builder.finish_with_outputs(outputs).unwrap()
181 })
182 }
183
184 #[rstest]
185 fn const_float(mut llvm_ctx: TestContext) {
186 llvm_ctx.add_extensions(add_float_extensions);
187 let hugr = SimpleHugrConfig::new()
188 .with_outs(float64_type())
189 .with_extensions(STD_REG.to_owned())
190 .finish(|mut builder| {
191 let c = builder.add_load_value(ConstF64::new(3.12));
192 builder.finish_with_outputs([c]).unwrap()
193 });
194 check_emission!(hugr, llvm_ctx);
195 }
196
197 #[rstest]
198 #[case::feq(FloatOps::feq)]
199 #[case::fne(FloatOps::fne)]
200 #[case::flt(FloatOps::flt)]
201 #[case::fgt(FloatOps::fgt)]
202 #[case::fle(FloatOps::fle)]
203 #[case::fge(FloatOps::fge)]
204 #[case::fadd(FloatOps::fadd)]
205 #[case::fsub(FloatOps::fsub)]
206 #[case::fneg(FloatOps::fneg)]
207 #[case::fmul(FloatOps::fmul)]
208 #[case::fdiv(FloatOps::fdiv)]
209 #[case::fpow(FloatOps::fpow)]
210 fn float_operations(mut llvm_ctx: TestContext, #[case] op: FloatOps) {
211 let name: &str = op.into();
212 let hugr = test_float_op(op);
213 llvm_ctx.add_extensions(add_float_extensions);
214 check_emission!(name, hugr, llvm_ctx);
215 }
216}