hugr_core/std_extensions/arithmetic/
float_ops.rs

1//! Basic floating-point operations.
2
3use std::sync::{Arc, LazyLock, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use super::float_types::float64_type;
8use crate::{
9    Extension,
10    extension::{
11        ExtensionId, OpDef, SignatureFunc,
12        prelude::{bool_t, string_type},
13        simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError},
14    },
15    ops::OpName,
16    types::Signature,
17};
18mod const_fold;
19/// The extension identifier.
20pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float");
21/// Extension version.
22pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
23
24/// Integer extension operation definitions.
25#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
26#[allow(missing_docs, non_camel_case_types)]
27#[non_exhaustive]
28pub enum FloatOps {
29    feq,
30    fne,
31    flt,
32    fgt,
33    fle,
34    fge,
35    fmax,
36    fmin,
37    fadd,
38    fsub,
39    fneg,
40    fabs,
41    fmul,
42    fdiv,
43    fpow,
44    ffloor,
45    fceil,
46    fround,
47    ftostring,
48}
49
50impl MakeOpDef for FloatOps {
51    fn opdef_id(&self) -> OpName {
52        <&Self as Into<&'static str>>::into(self).into()
53    }
54
55    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
56        crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
57    }
58
59    fn extension(&self) -> ExtensionId {
60        EXTENSION_ID.clone()
61    }
62
63    fn extension_ref(&self) -> Weak<Extension> {
64        Arc::downgrade(&EXTENSION)
65    }
66
67    fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
68        use FloatOps::*;
69
70        match self {
71            feq | fne | flt | fgt | fle | fge => {
72                Signature::new(vec![float64_type(); 2], vec![bool_t()])
73            }
74            fmax | fmin | fadd | fsub | fmul | fdiv | fpow => {
75                Signature::new(vec![float64_type(); 2], vec![float64_type()])
76            }
77            fneg | fabs | ffloor | fceil | fround => Signature::new_endo(vec![float64_type()]),
78            ftostring => Signature::new(vec![float64_type()], string_type()),
79        }
80        .into()
81    }
82
83    fn description(&self) -> String {
84        use FloatOps::*;
85        match self {
86            feq => "equality test",
87            fne => "inequality test",
88            flt => "\"less than\"",
89            fgt => "\"greater than\"",
90            fle => "\"less than or equal\"",
91            fge => "\"greater than or equal\"",
92            fmax => "maximum",
93            fmin => "minimum",
94            fadd => "addition",
95            fsub => "subtraction",
96            fneg => "negation",
97            fabs => "absolute value",
98            fmul => "multiplication",
99            fdiv => "division",
100            fpow => "exponentiation",
101            ffloor => "floor",
102            fceil => "ceiling",
103            fround => "round",
104            ftostring => "string representation",
105        }
106        .to_string()
107    }
108
109    fn post_opdef(&self, def: &mut OpDef) {
110        const_fold::set_fold(self, def);
111    }
112}
113
114/// Extension for basic float operations.
115pub static EXTENSION: LazyLock<Arc<Extension>> = LazyLock::new(|| {
116    Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
117        FloatOps::load_all_ops(extension, extension_ref).unwrap();
118    })
119});
120
121impl MakeRegisteredOp for FloatOps {
122    fn extension_id(&self) -> ExtensionId {
123        EXTENSION_ID.clone()
124    }
125
126    fn extension_ref(&self) -> Weak<Extension> {
127        Arc::downgrade(&EXTENSION)
128    }
129}
130
131#[cfg(test)]
132mod test {
133    use cgmath::AbsDiffEq;
134    use rstest::rstest;
135
136    use super::*;
137
138    #[test]
139    fn test_float_ops_extension() {
140        let r = &EXTENSION;
141        assert_eq!(r.name() as &str, "arithmetic.float");
142        assert_eq!(r.types().count(), 0);
143        for (name, _) in r.operations() {
144            assert!(name.as_str().starts_with('f'));
145        }
146    }
147
148    #[rstest]
149    #[case::fadd(FloatOps::fadd, &[0.1, 0.2], &[0.30000000000000004])]
150    #[case::fsub(FloatOps::fsub, &[1., 2.], &[-1.])]
151    #[case::fmul(FloatOps::fmul, &[2., 3.], &[6.])]
152    #[case::fdiv(FloatOps::fdiv, &[7., 2.], &[3.5])]
153    #[case::fpow(FloatOps::fpow, &[0.5, 3.], &[0.125])]
154    #[case::ffloor(FloatOps::ffloor, &[42.42], &[42.])]
155    #[case::fceil(FloatOps::fceil, &[42.42], &[43.])]
156    #[case::fround(FloatOps::fround, &[42.42], &[42.])]
157    fn float_fold(#[case] op: FloatOps, #[case] inputs: &[f64], #[case] outputs: &[f64]) {
158        use crate::ops::Value;
159        use crate::std_extensions::arithmetic::float_types::ConstF64;
160
161        let consts: Vec<_> = inputs
162            .iter()
163            .enumerate()
164            .map(|(i, &x)| (i.into(), Value::extension(ConstF64::new(x))))
165            .collect();
166
167        let res = op
168            .to_extension_op()
169            .unwrap()
170            .constant_fold(&consts)
171            .unwrap();
172
173        for (i, expected) in outputs.iter().enumerate() {
174            let res_val: f64 = res
175                .get(i)
176                .unwrap()
177                .1
178                .get_custom_value::<ConstF64>()
179                .expect("This function assumes all incoming constants are floats.")
180                .value();
181
182            assert!(
183                res_val.abs_diff_eq(expected, f64::EPSILON),
184                "expected {expected:?}, got {res_val:?}"
185            );
186        }
187    }
188}