hugr_core/std_extensions/arithmetic/
float_ops.rs

1//! Basic floating-point operations.
2
3use std::sync::{Arc, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use super::float_types::float64_type;
8use crate::{
9    extension::{
10        prelude::{bool_t, string_type},
11        simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError},
12        ExtensionId, ExtensionSet, OpDef, SignatureFunc,
13    },
14    types::Signature,
15    Extension,
16};
17use lazy_static::lazy_static;
18mod const_fold;
19/// The extension identifier.
20pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float");
21/// Extension version.
22pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
23
24/// Integer extension operation definitions.
25#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
26#[allow(missing_docs, non_camel_case_types)]
27#[non_exhaustive]
28pub enum FloatOps {
29    feq,
30    fne,
31    flt,
32    fgt,
33    fle,
34    fge,
35    fmax,
36    fmin,
37    fadd,
38    fsub,
39    fneg,
40    fabs,
41    fmul,
42    fdiv,
43    fpow,
44    ffloor,
45    fceil,
46    fround,
47    ftostring,
48}
49
50impl MakeOpDef for FloatOps {
51    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
52        crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
53    }
54
55    fn extension(&self) -> ExtensionId {
56        EXTENSION_ID.to_owned()
57    }
58
59    fn extension_ref(&self) -> Weak<Extension> {
60        Arc::downgrade(&EXTENSION)
61    }
62
63    fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
64        use FloatOps::*;
65
66        match self {
67            feq | fne | flt | fgt | fle | fge => {
68                Signature::new(vec![float64_type(); 2], vec![bool_t()])
69            }
70            fmax | fmin | fadd | fsub | fmul | fdiv | fpow => {
71                Signature::new(vec![float64_type(); 2], vec![float64_type()])
72            }
73            fneg | fabs | ffloor | fceil | fround => Signature::new_endo(vec![float64_type()]),
74            ftostring => Signature::new(vec![float64_type()], string_type()),
75        }
76        .into()
77    }
78
79    fn description(&self) -> String {
80        use FloatOps::*;
81        match self {
82            feq => "equality test",
83            fne => "inequality test",
84            flt => "\"less than\"",
85            fgt => "\"greater than\"",
86            fle => "\"less than or equal\"",
87            fge => "\"greater than or equal\"",
88            fmax => "maximum",
89            fmin => "minimum",
90            fadd => "addition",
91            fsub => "subtraction",
92            fneg => "negation",
93            fabs => "absolute value",
94            fmul => "multiplication",
95            fdiv => "division",
96            fpow => "exponentiation",
97            ffloor => "floor",
98            fceil => "ceiling",
99            fround => "round",
100            ftostring => "string representation",
101        }
102        .to_string()
103    }
104
105    fn post_opdef(&self, def: &mut OpDef) {
106        const_fold::set_fold(self, def)
107    }
108}
109
110lazy_static! {
111    /// Extension for basic float operations.
112    pub static ref EXTENSION: Arc<Extension> = {
113        Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
114            extension.add_requirements(ExtensionSet::singleton(super::int_types::EXTENSION_ID));
115            FloatOps::load_all_ops(extension, extension_ref).unwrap();
116        })
117    };
118}
119
120impl MakeRegisteredOp for FloatOps {
121    fn extension_id(&self) -> ExtensionId {
122        EXTENSION_ID.to_owned()
123    }
124
125    fn extension_ref(&self) -> Weak<Extension> {
126        Arc::downgrade(&EXTENSION)
127    }
128}
129
130#[cfg(test)]
131mod test {
132    use cgmath::AbsDiffEq;
133    use rstest::rstest;
134
135    use super::*;
136
137    #[test]
138    fn test_float_ops_extension() {
139        let r = &EXTENSION;
140        assert_eq!(r.name() as &str, "arithmetic.float");
141        assert_eq!(r.types().count(), 0);
142        for (name, _) in r.operations() {
143            assert!(name.as_str().starts_with('f'));
144        }
145    }
146
147    #[rstest]
148    #[case::fadd(FloatOps::fadd, &[0.1, 0.2], &[0.30000000000000004])]
149    #[case::fsub(FloatOps::fsub, &[1., 2.], &[-1.])]
150    #[case::fmul(FloatOps::fmul, &[2., 3.], &[6.])]
151    #[case::fdiv(FloatOps::fdiv, &[7., 2.], &[3.5])]
152    #[case::fpow(FloatOps::fpow, &[0.5, 3.], &[0.125])]
153    #[case::ffloor(FloatOps::ffloor, &[42.42], &[42.])]
154    #[case::fceil(FloatOps::fceil, &[42.42], &[43.])]
155    #[case::fround(FloatOps::fround, &[42.42], &[42.])]
156    fn float_fold(#[case] op: FloatOps, #[case] inputs: &[f64], #[case] outputs: &[f64]) {
157        use crate::ops::Value;
158        use crate::std_extensions::arithmetic::float_types::ConstF64;
159
160        let consts: Vec<_> = inputs
161            .iter()
162            .enumerate()
163            .map(|(i, &x)| (i.into(), Value::extension(ConstF64::new(x))))
164            .collect();
165
166        let res = op
167            .to_extension_op()
168            .unwrap()
169            .constant_fold(&consts)
170            .unwrap();
171
172        for (i, expected) in outputs.iter().enumerate() {
173            let res_val: f64 = res
174                .get(i)
175                .unwrap()
176                .1
177                .get_custom_value::<ConstF64>()
178                .expect("This function assumes all incoming constants are floats.")
179                .value();
180
181            assert!(
182                res_val.abs_diff_eq(expected, f64::EPSILON),
183                "expected {:?}, got {:?}",
184                expected,
185                res_val
186            );
187        }
188    }
189}