hugr_core/std_extensions/arithmetic/
float_ops.rs1use std::sync::{Arc, LazyLock, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use super::float_types::float64_type;
8use crate::{
9 Extension,
10 extension::{
11 ExtensionId, OpDef, SignatureFunc,
12 prelude::{bool_t, string_type},
13 simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError},
14 },
15 ops::OpName,
16 types::Signature,
17};
18mod const_fold;
19pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float");
21pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
23
24#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
26#[allow(missing_docs, non_camel_case_types)]
27#[non_exhaustive]
28pub enum FloatOps {
29 feq,
30 fne,
31 flt,
32 fgt,
33 fle,
34 fge,
35 fmax,
36 fmin,
37 fadd,
38 fsub,
39 fneg,
40 fabs,
41 fmul,
42 fdiv,
43 fpow,
44 ffloor,
45 fceil,
46 fround,
47 ftostring,
48}
49
50impl MakeOpDef for FloatOps {
51 fn opdef_id(&self) -> OpName {
52 <&Self as Into<&'static str>>::into(self).into()
53 }
54
55 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
56 crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
57 }
58
59 fn extension(&self) -> ExtensionId {
60 EXTENSION_ID.clone()
61 }
62
63 fn extension_ref(&self) -> Weak<Extension> {
64 Arc::downgrade(&EXTENSION)
65 }
66
67 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
68 use FloatOps::*;
69
70 match self {
71 feq | fne | flt | fgt | fle | fge => {
72 Signature::new(vec![float64_type(); 2], vec![bool_t()])
73 }
74 fmax | fmin | fadd | fsub | fmul | fdiv | fpow => {
75 Signature::new(vec![float64_type(); 2], vec![float64_type()])
76 }
77 fneg | fabs | ffloor | fceil | fround => Signature::new_endo(vec![float64_type()]),
78 ftostring => Signature::new(vec![float64_type()], string_type()),
79 }
80 .into()
81 }
82
83 fn description(&self) -> String {
84 use FloatOps::*;
85 match self {
86 feq => "equality test",
87 fne => "inequality test",
88 flt => "\"less than\"",
89 fgt => "\"greater than\"",
90 fle => "\"less than or equal\"",
91 fge => "\"greater than or equal\"",
92 fmax => "maximum",
93 fmin => "minimum",
94 fadd => "addition",
95 fsub => "subtraction",
96 fneg => "negation",
97 fabs => "absolute value",
98 fmul => "multiplication",
99 fdiv => "division",
100 fpow => "exponentiation",
101 ffloor => "floor",
102 fceil => "ceiling",
103 fround => "round",
104 ftostring => "string representation",
105 }
106 .to_string()
107 }
108
109 fn post_opdef(&self, def: &mut OpDef) {
110 const_fold::set_fold(self, def);
111 }
112}
113
114pub static EXTENSION: LazyLock<Arc<Extension>> = LazyLock::new(|| {
116 Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
117 FloatOps::load_all_ops(extension, extension_ref).unwrap();
118 })
119});
120
121impl MakeRegisteredOp for FloatOps {
122 fn extension_id(&self) -> ExtensionId {
123 EXTENSION_ID.clone()
124 }
125
126 fn extension_ref(&self) -> Weak<Extension> {
127 Arc::downgrade(&EXTENSION)
128 }
129}
130
131#[cfg(test)]
132mod test {
133 use cgmath::AbsDiffEq;
134 use rstest::rstest;
135
136 use super::*;
137
138 #[test]
139 fn test_float_ops_extension() {
140 let r = &EXTENSION;
141 assert_eq!(r.name() as &str, "arithmetic.float");
142 assert_eq!(r.types().count(), 0);
143 for (name, _) in r.operations() {
144 assert!(name.as_str().starts_with('f'));
145 }
146 }
147
148 #[rstest]
149 #[case::fadd(FloatOps::fadd, &[0.1, 0.2], &[0.30000000000000004])]
150 #[case::fsub(FloatOps::fsub, &[1., 2.], &[-1.])]
151 #[case::fmul(FloatOps::fmul, &[2., 3.], &[6.])]
152 #[case::fdiv(FloatOps::fdiv, &[7., 2.], &[3.5])]
153 #[case::fpow(FloatOps::fpow, &[0.5, 3.], &[0.125])]
154 #[case::ffloor(FloatOps::ffloor, &[42.42], &[42.])]
155 #[case::fceil(FloatOps::fceil, &[42.42], &[43.])]
156 #[case::fround(FloatOps::fround, &[42.42], &[42.])]
157 fn float_fold(#[case] op: FloatOps, #[case] inputs: &[f64], #[case] outputs: &[f64]) {
158 use crate::ops::Value;
159 use crate::std_extensions::arithmetic::float_types::ConstF64;
160
161 let consts: Vec<_> = inputs
162 .iter()
163 .enumerate()
164 .map(|(i, &x)| (i.into(), Value::extension(ConstF64::new(x))))
165 .collect();
166
167 let res = op
168 .to_extension_op()
169 .unwrap()
170 .constant_fold(&consts)
171 .unwrap();
172
173 for (i, expected) in outputs.iter().enumerate() {
174 let res_val: f64 = res
175 .get(i)
176 .unwrap()
177 .1
178 .get_custom_value::<ConstF64>()
179 .expect("This function assumes all incoming constants are floats.")
180 .value();
181
182 assert!(
183 res_val.abs_diff_eq(expected, f64::EPSILON),
184 "expected {expected:?}, got {res_val:?}"
185 );
186 }
187 }
188}