hugr_core/std_extensions/arithmetic/
float_ops.rs1use std::sync::{Arc, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use super::float_types::float64_type;
8use crate::{
9 extension::{
10 prelude::{bool_t, string_type},
11 simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError},
12 ExtensionId, ExtensionSet, OpDef, SignatureFunc,
13 },
14 types::Signature,
15 Extension,
16};
17use lazy_static::lazy_static;
18mod const_fold;
19pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float");
21pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
23
24#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
26#[allow(missing_docs, non_camel_case_types)]
27#[non_exhaustive]
28pub enum FloatOps {
29 feq,
30 fne,
31 flt,
32 fgt,
33 fle,
34 fge,
35 fmax,
36 fmin,
37 fadd,
38 fsub,
39 fneg,
40 fabs,
41 fmul,
42 fdiv,
43 fpow,
44 ffloor,
45 fceil,
46 fround,
47 ftostring,
48}
49
50impl MakeOpDef for FloatOps {
51 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
52 crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
53 }
54
55 fn extension(&self) -> ExtensionId {
56 EXTENSION_ID.to_owned()
57 }
58
59 fn extension_ref(&self) -> Weak<Extension> {
60 Arc::downgrade(&EXTENSION)
61 }
62
63 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
64 use FloatOps::*;
65
66 match self {
67 feq | fne | flt | fgt | fle | fge => {
68 Signature::new(vec![float64_type(); 2], vec![bool_t()])
69 }
70 fmax | fmin | fadd | fsub | fmul | fdiv | fpow => {
71 Signature::new(vec![float64_type(); 2], vec![float64_type()])
72 }
73 fneg | fabs | ffloor | fceil | fround => Signature::new_endo(vec![float64_type()]),
74 ftostring => Signature::new(vec![float64_type()], string_type()),
75 }
76 .into()
77 }
78
79 fn description(&self) -> String {
80 use FloatOps::*;
81 match self {
82 feq => "equality test",
83 fne => "inequality test",
84 flt => "\"less than\"",
85 fgt => "\"greater than\"",
86 fle => "\"less than or equal\"",
87 fge => "\"greater than or equal\"",
88 fmax => "maximum",
89 fmin => "minimum",
90 fadd => "addition",
91 fsub => "subtraction",
92 fneg => "negation",
93 fabs => "absolute value",
94 fmul => "multiplication",
95 fdiv => "division",
96 fpow => "exponentiation",
97 ffloor => "floor",
98 fceil => "ceiling",
99 fround => "round",
100 ftostring => "string representation",
101 }
102 .to_string()
103 }
104
105 fn post_opdef(&self, def: &mut OpDef) {
106 const_fold::set_fold(self, def)
107 }
108}
109
110lazy_static! {
111 pub static ref EXTENSION: Arc<Extension> = {
113 Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
114 extension.add_requirements(ExtensionSet::singleton(super::int_types::EXTENSION_ID));
115 FloatOps::load_all_ops(extension, extension_ref).unwrap();
116 })
117 };
118}
119
120impl MakeRegisteredOp for FloatOps {
121 fn extension_id(&self) -> ExtensionId {
122 EXTENSION_ID.to_owned()
123 }
124
125 fn extension_ref(&self) -> Weak<Extension> {
126 Arc::downgrade(&EXTENSION)
127 }
128}
129
130#[cfg(test)]
131mod test {
132 use cgmath::AbsDiffEq;
133 use rstest::rstest;
134
135 use super::*;
136
137 #[test]
138 fn test_float_ops_extension() {
139 let r = &EXTENSION;
140 assert_eq!(r.name() as &str, "arithmetic.float");
141 assert_eq!(r.types().count(), 0);
142 for (name, _) in r.operations() {
143 assert!(name.as_str().starts_with('f'));
144 }
145 }
146
147 #[rstest]
148 #[case::fadd(FloatOps::fadd, &[0.1, 0.2], &[0.30000000000000004])]
149 #[case::fsub(FloatOps::fsub, &[1., 2.], &[-1.])]
150 #[case::fmul(FloatOps::fmul, &[2., 3.], &[6.])]
151 #[case::fdiv(FloatOps::fdiv, &[7., 2.], &[3.5])]
152 #[case::fpow(FloatOps::fpow, &[0.5, 3.], &[0.125])]
153 #[case::ffloor(FloatOps::ffloor, &[42.42], &[42.])]
154 #[case::fceil(FloatOps::fceil, &[42.42], &[43.])]
155 #[case::fround(FloatOps::fround, &[42.42], &[42.])]
156 fn float_fold(#[case] op: FloatOps, #[case] inputs: &[f64], #[case] outputs: &[f64]) {
157 use crate::ops::Value;
158 use crate::std_extensions::arithmetic::float_types::ConstF64;
159
160 let consts: Vec<_> = inputs
161 .iter()
162 .enumerate()
163 .map(|(i, &x)| (i.into(), Value::extension(ConstF64::new(x))))
164 .collect();
165
166 let res = op
167 .to_extension_op()
168 .unwrap()
169 .constant_fold(&consts)
170 .unwrap();
171
172 for (i, expected) in outputs.iter().enumerate() {
173 let res_val: f64 = res
174 .get(i)
175 .unwrap()
176 .1
177 .get_custom_value::<ConstF64>()
178 .expect("This function assumes all incoming constants are floats.")
179 .value();
180
181 assert!(
182 res_val.abs_diff_eq(expected, f64::EPSILON),
183 "expected {:?}, got {:?}",
184 expected,
185 res_val
186 );
187 }
188 }
189}