1use anyhow::{anyhow, Result};
2use hugr_core::ops::ExtensionOp;
3use hugr_core::ops::{constant::CustomConst, Value};
4use hugr_core::std_extensions::arithmetic::float_ops::FloatOps;
5use hugr_core::Node;
6use hugr_core::{
7 std_extensions::arithmetic::float_types::{self, ConstF64},
8 HugrView,
9};
10use inkwell::{
11 types::{BasicType, FloatType},
12 values::{BasicValue, BasicValueEnum},
13};
14
15use crate::emit::emit_value;
16use crate::emit::ops::{emit_custom_binary_op, emit_custom_unary_op};
17use crate::emit::{func::EmitFuncContext, EmitOpArgs};
18
19use crate::custom::CodegenExtsBuilder;
20
21fn emit_fcmp<'c, H: HugrView<Node = Node>>(
23 context: &mut EmitFuncContext<'c, '_, H>,
24 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
25 pred: inkwell::FloatPredicate,
26) -> Result<()> {
27 let true_val = emit_value(context, &Value::true_val())?;
28 let false_val = emit_value(context, &Value::false_val())?;
29
30 emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
31 let r = ctx.builder().build_float_compare(
33 pred,
34 lhs.into_float_value(),
35 rhs.into_float_value(),
36 "",
37 )?;
38 Ok(vec![ctx
40 .builder()
41 .build_select(r, true_val, false_val, "")?])
42 })
43}
44
45fn emit_float_op<'c, H: HugrView<Node = Node>>(
46 context: &mut EmitFuncContext<'c, '_, H>,
47 args: EmitOpArgs<'c, '_, ExtensionOp, H>,
48 op: FloatOps,
49) -> Result<()> {
50 #[allow(clippy::wildcard_in_or_patterns)]
53 match op {
54 FloatOps::feq => emit_fcmp(context, args, inkwell::FloatPredicate::OEQ),
55 FloatOps::fne => emit_fcmp(context, args, inkwell::FloatPredicate::ONE),
56 FloatOps::flt => emit_fcmp(context, args, inkwell::FloatPredicate::OLT),
57 FloatOps::fgt => emit_fcmp(context, args, inkwell::FloatPredicate::OGT),
58 FloatOps::fle => emit_fcmp(context, args, inkwell::FloatPredicate::OLE),
59 FloatOps::fge => emit_fcmp(context, args, inkwell::FloatPredicate::OGE),
60 FloatOps::fadd => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
61 Ok(vec![ctx
62 .builder()
63 .build_float_add(lhs.into_float_value(), rhs.into_float_value(), "")?
64 .as_basic_value_enum()])
65 }),
66 FloatOps::fsub => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
67 Ok(vec![ctx
68 .builder()
69 .build_float_sub(lhs.into_float_value(), rhs.into_float_value(), "")?
70 .as_basic_value_enum()])
71 }),
72 FloatOps::fneg => emit_custom_unary_op(context, args, |ctx, v, _| {
73 Ok(vec![ctx
74 .builder()
75 .build_float_neg(v.into_float_value(), "")?
76 .as_basic_value_enum()])
77 }),
78 FloatOps::fmul => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
79 Ok(vec![ctx
80 .builder()
81 .build_float_mul(lhs.into_float_value(), rhs.into_float_value(), "")?
82 .as_basic_value_enum()])
83 }),
84 FloatOps::fdiv => emit_custom_binary_op(context, args, |ctx, (lhs, rhs), _| {
85 Ok(vec![ctx
86 .builder()
87 .build_float_div(lhs.into_float_value(), rhs.into_float_value(), "")?
88 .as_basic_value_enum()])
89 }),
90 FloatOps::fmax
92 | FloatOps::fmin
93 | FloatOps::fabs
94 | FloatOps::ffloor
95 | FloatOps::fceil
96 | FloatOps::ftostring
97 | _ => {
98 let name: &str = op.into();
99 Err(anyhow!("FloatOpEmitter: unimplemented op: {name}"))
100 }
101 }
102}
103
104fn emit_constf64<'c, H: HugrView<Node = Node>>(
105 context: &mut EmitFuncContext<'c, '_, H>,
106 k: &ConstF64,
107) -> Result<BasicValueEnum<'c>> {
108 let ty: FloatType = context.llvm_type(&k.get_type())?.try_into().unwrap();
109 Ok(ty.const_float(k.value()).as_basic_value_enum())
110}
111
112pub fn add_float_extensions<'a, H: HugrView<Node = Node> + 'a>(
113 cem: CodegenExtsBuilder<'a, H>,
114) -> CodegenExtsBuilder<'a, H> {
115 cem.custom_type(
116 (
117 float_types::EXTENSION_ID,
118 float_types::FLOAT_TYPE_ID.clone(),
119 ),
120 |ts, _custom_type| Ok(ts.iw_context().f64_type().as_basic_type_enum()),
121 )
122 .custom_const(emit_constf64)
123 .simple_extension_op::<FloatOps>(emit_float_op)
124}
125
126impl<'a, H: HugrView<Node = Node> + 'a> CodegenExtsBuilder<'a, H> {
127 pub fn add_float_extensions(self) -> Self {
128 add_float_extensions(self)
129 }
130}
131
132#[cfg(test)]
133mod test {
134 use hugr_core::extension::simple_op::MakeOpDef;
135 use hugr_core::extension::SignatureFunc;
136 use hugr_core::std_extensions::arithmetic::float_ops::FloatOps;
137 use hugr_core::std_extensions::STD_REG;
138 use hugr_core::types::TypeRow;
139 use hugr_core::Hugr;
140 use hugr_core::{
141 builder::{Dataflow, DataflowSubContainer},
142 std_extensions::arithmetic::float_types::{float64_type, ConstF64},
143 };
144 use rstest::rstest;
145
146 use super::add_float_extensions;
147 use crate::{
148 check_emission,
149 emit::test::SimpleHugrConfig,
150 test::{llvm_ctx, TestContext},
151 };
152
153 fn test_float_op(op: FloatOps) -> Hugr {
154 let SignatureFunc::PolyFuncType(poly_sig) = op.signature() else {
155 panic!("Expected PolyFuncType");
156 };
157 let sig = poly_sig.body();
158 let inp: TypeRow = sig.input.clone().try_into().unwrap();
159 let out: TypeRow = sig.output.clone().try_into().unwrap();
160
161 SimpleHugrConfig::new()
162 .with_ins(inp)
163 .with_outs(out)
164 .with_extensions(STD_REG.to_owned())
165 .finish(|mut builder| {
166 let outputs = builder
167 .add_dataflow_op(op, builder.input_wires())
168 .unwrap()
169 .outputs();
170 builder.finish_with_outputs(outputs).unwrap()
171 })
172 }
173
174 #[rstest]
175 fn const_float(mut llvm_ctx: TestContext) {
176 llvm_ctx.add_extensions(add_float_extensions);
177 let hugr = SimpleHugrConfig::new()
178 .with_outs(float64_type())
179 .with_extensions(STD_REG.to_owned())
180 .finish(|mut builder| {
181 let c = builder.add_load_value(ConstF64::new(3.12));
182 builder.finish_with_outputs([c]).unwrap()
183 });
184 check_emission!(hugr, llvm_ctx);
185 }
186
187 #[rstest]
188 #[case::feq(FloatOps::feq)]
189 #[case::fne(FloatOps::fne)]
190 #[case::flt(FloatOps::flt)]
191 #[case::fgt(FloatOps::fgt)]
192 #[case::fle(FloatOps::fle)]
193 #[case::fge(FloatOps::fge)]
194 #[case::fadd(FloatOps::fadd)]
195 #[case::fsub(FloatOps::fsub)]
196 #[case::fneg(FloatOps::fneg)]
197 #[case::fmul(FloatOps::fmul)]
198 #[case::fdiv(FloatOps::fdiv)]
199 fn float_operations(mut llvm_ctx: TestContext, #[case] op: FloatOps) {
200 let name: &str = op.into();
201 let hugr = test_float_op(op);
202 llvm_ctx.add_extensions(add_float_extensions);
203 check_emission!(name, hugr, llvm_ctx);
204 }
205}