hugr_core/std_extensions/arithmetic/
float_ops.rs1use std::sync::{Arc, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use super::float_types::float64_type;
8use crate::{
9 Extension,
10 extension::{
11 ExtensionId, OpDef, SignatureFunc,
12 prelude::{bool_t, string_type},
13 simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError},
14 },
15 ops::OpName,
16 types::Signature,
17};
18use lazy_static::lazy_static;
19mod const_fold;
20pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float");
22pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
24
25#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
27#[allow(missing_docs, non_camel_case_types)]
28#[non_exhaustive]
29pub enum FloatOps {
30 feq,
31 fne,
32 flt,
33 fgt,
34 fle,
35 fge,
36 fmax,
37 fmin,
38 fadd,
39 fsub,
40 fneg,
41 fabs,
42 fmul,
43 fdiv,
44 fpow,
45 ffloor,
46 fceil,
47 fround,
48 ftostring,
49}
50
51impl MakeOpDef for FloatOps {
52 fn opdef_id(&self) -> OpName {
53 <&Self as Into<&'static str>>::into(self).into()
54 }
55
56 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
57 crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
58 }
59
60 fn extension(&self) -> ExtensionId {
61 EXTENSION_ID.clone()
62 }
63
64 fn extension_ref(&self) -> Weak<Extension> {
65 Arc::downgrade(&EXTENSION)
66 }
67
68 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
69 use FloatOps::*;
70
71 match self {
72 feq | fne | flt | fgt | fle | fge => {
73 Signature::new(vec![float64_type(); 2], vec![bool_t()])
74 }
75 fmax | fmin | fadd | fsub | fmul | fdiv | fpow => {
76 Signature::new(vec![float64_type(); 2], vec![float64_type()])
77 }
78 fneg | fabs | ffloor | fceil | fround => Signature::new_endo(vec![float64_type()]),
79 ftostring => Signature::new(vec![float64_type()], string_type()),
80 }
81 .into()
82 }
83
84 fn description(&self) -> String {
85 use FloatOps::*;
86 match self {
87 feq => "equality test",
88 fne => "inequality test",
89 flt => "\"less than\"",
90 fgt => "\"greater than\"",
91 fle => "\"less than or equal\"",
92 fge => "\"greater than or equal\"",
93 fmax => "maximum",
94 fmin => "minimum",
95 fadd => "addition",
96 fsub => "subtraction",
97 fneg => "negation",
98 fabs => "absolute value",
99 fmul => "multiplication",
100 fdiv => "division",
101 fpow => "exponentiation",
102 ffloor => "floor",
103 fceil => "ceiling",
104 fround => "round",
105 ftostring => "string representation",
106 }
107 .to_string()
108 }
109
110 fn post_opdef(&self, def: &mut OpDef) {
111 const_fold::set_fold(self, def);
112 }
113}
114
115lazy_static! {
116 pub static ref EXTENSION: Arc<Extension> = {
118 Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
119 FloatOps::load_all_ops(extension, extension_ref).unwrap();
120 })
121 };
122}
123
124impl MakeRegisteredOp for FloatOps {
125 fn extension_id(&self) -> ExtensionId {
126 EXTENSION_ID.clone()
127 }
128
129 fn extension_ref(&self) -> Weak<Extension> {
130 Arc::downgrade(&EXTENSION)
131 }
132}
133
134#[cfg(test)]
135mod test {
136 use cgmath::AbsDiffEq;
137 use rstest::rstest;
138
139 use super::*;
140
141 #[test]
142 fn test_float_ops_extension() {
143 let r = &EXTENSION;
144 assert_eq!(r.name() as &str, "arithmetic.float");
145 assert_eq!(r.types().count(), 0);
146 for (name, _) in r.operations() {
147 assert!(name.as_str().starts_with('f'));
148 }
149 }
150
151 #[rstest]
152 #[case::fadd(FloatOps::fadd, &[0.1, 0.2], &[0.30000000000000004])]
153 #[case::fsub(FloatOps::fsub, &[1., 2.], &[-1.])]
154 #[case::fmul(FloatOps::fmul, &[2., 3.], &[6.])]
155 #[case::fdiv(FloatOps::fdiv, &[7., 2.], &[3.5])]
156 #[case::fpow(FloatOps::fpow, &[0.5, 3.], &[0.125])]
157 #[case::ffloor(FloatOps::ffloor, &[42.42], &[42.])]
158 #[case::fceil(FloatOps::fceil, &[42.42], &[43.])]
159 #[case::fround(FloatOps::fround, &[42.42], &[42.])]
160 fn float_fold(#[case] op: FloatOps, #[case] inputs: &[f64], #[case] outputs: &[f64]) {
161 use crate::ops::Value;
162 use crate::std_extensions::arithmetic::float_types::ConstF64;
163
164 let consts: Vec<_> = inputs
165 .iter()
166 .enumerate()
167 .map(|(i, &x)| (i.into(), Value::extension(ConstF64::new(x))))
168 .collect();
169
170 let res = op
171 .to_extension_op()
172 .unwrap()
173 .constant_fold(&consts)
174 .unwrap();
175
176 for (i, expected) in outputs.iter().enumerate() {
177 let res_val: f64 = res
178 .get(i)
179 .unwrap()
180 .1
181 .get_custom_value::<ConstF64>()
182 .expect("This function assumes all incoming constants are floats.")
183 .value();
184
185 assert!(
186 res_val.abs_diff_eq(expected, f64::EPSILON),
187 "expected {expected:?}, got {res_val:?}"
188 );
189 }
190 }
191}