hugr_core/std_extensions/arithmetic/
float_ops.rs

1//! Basic floating-point operations.
2
3use std::sync::{Arc, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use super::float_types::float64_type;
8use crate::{
9    Extension,
10    extension::{
11        ExtensionId, OpDef, SignatureFunc,
12        prelude::{bool_t, string_type},
13        simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError},
14    },
15    ops::OpName,
16    types::Signature,
17};
18use lazy_static::lazy_static;
19mod const_fold;
20/// The extension identifier.
21pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float");
22/// Extension version.
23pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
24
25/// Integer extension operation definitions.
26#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
27#[allow(missing_docs, non_camel_case_types)]
28#[non_exhaustive]
29pub enum FloatOps {
30    feq,
31    fne,
32    flt,
33    fgt,
34    fle,
35    fge,
36    fmax,
37    fmin,
38    fadd,
39    fsub,
40    fneg,
41    fabs,
42    fmul,
43    fdiv,
44    fpow,
45    ffloor,
46    fceil,
47    fround,
48    ftostring,
49}
50
51impl MakeOpDef for FloatOps {
52    fn opdef_id(&self) -> OpName {
53        <&Self as Into<&'static str>>::into(self).into()
54    }
55
56    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
57        crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
58    }
59
60    fn extension(&self) -> ExtensionId {
61        EXTENSION_ID.clone()
62    }
63
64    fn extension_ref(&self) -> Weak<Extension> {
65        Arc::downgrade(&EXTENSION)
66    }
67
68    fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
69        use FloatOps::*;
70
71        match self {
72            feq | fne | flt | fgt | fle | fge => {
73                Signature::new(vec![float64_type(); 2], vec![bool_t()])
74            }
75            fmax | fmin | fadd | fsub | fmul | fdiv | fpow => {
76                Signature::new(vec![float64_type(); 2], vec![float64_type()])
77            }
78            fneg | fabs | ffloor | fceil | fround => Signature::new_endo(vec![float64_type()]),
79            ftostring => Signature::new(vec![float64_type()], string_type()),
80        }
81        .into()
82    }
83
84    fn description(&self) -> String {
85        use FloatOps::*;
86        match self {
87            feq => "equality test",
88            fne => "inequality test",
89            flt => "\"less than\"",
90            fgt => "\"greater than\"",
91            fle => "\"less than or equal\"",
92            fge => "\"greater than or equal\"",
93            fmax => "maximum",
94            fmin => "minimum",
95            fadd => "addition",
96            fsub => "subtraction",
97            fneg => "negation",
98            fabs => "absolute value",
99            fmul => "multiplication",
100            fdiv => "division",
101            fpow => "exponentiation",
102            ffloor => "floor",
103            fceil => "ceiling",
104            fround => "round",
105            ftostring => "string representation",
106        }
107        .to_string()
108    }
109
110    fn post_opdef(&self, def: &mut OpDef) {
111        const_fold::set_fold(self, def);
112    }
113}
114
115lazy_static! {
116    /// Extension for basic float operations.
117    pub static ref EXTENSION: Arc<Extension> = {
118        Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
119            FloatOps::load_all_ops(extension, extension_ref).unwrap();
120        })
121    };
122}
123
124impl MakeRegisteredOp for FloatOps {
125    fn extension_id(&self) -> ExtensionId {
126        EXTENSION_ID.clone()
127    }
128
129    fn extension_ref(&self) -> Weak<Extension> {
130        Arc::downgrade(&EXTENSION)
131    }
132}
133
134#[cfg(test)]
135mod test {
136    use cgmath::AbsDiffEq;
137    use rstest::rstest;
138
139    use super::*;
140
141    #[test]
142    fn test_float_ops_extension() {
143        let r = &EXTENSION;
144        assert_eq!(r.name() as &str, "arithmetic.float");
145        assert_eq!(r.types().count(), 0);
146        for (name, _) in r.operations() {
147            assert!(name.as_str().starts_with('f'));
148        }
149    }
150
151    #[rstest]
152    #[case::fadd(FloatOps::fadd, &[0.1, 0.2], &[0.30000000000000004])]
153    #[case::fsub(FloatOps::fsub, &[1., 2.], &[-1.])]
154    #[case::fmul(FloatOps::fmul, &[2., 3.], &[6.])]
155    #[case::fdiv(FloatOps::fdiv, &[7., 2.], &[3.5])]
156    #[case::fpow(FloatOps::fpow, &[0.5, 3.], &[0.125])]
157    #[case::ffloor(FloatOps::ffloor, &[42.42], &[42.])]
158    #[case::fceil(FloatOps::fceil, &[42.42], &[43.])]
159    #[case::fround(FloatOps::fround, &[42.42], &[42.])]
160    fn float_fold(#[case] op: FloatOps, #[case] inputs: &[f64], #[case] outputs: &[f64]) {
161        use crate::ops::Value;
162        use crate::std_extensions::arithmetic::float_types::ConstF64;
163
164        let consts: Vec<_> = inputs
165            .iter()
166            .enumerate()
167            .map(|(i, &x)| (i.into(), Value::extension(ConstF64::new(x))))
168            .collect();
169
170        let res = op
171            .to_extension_op()
172            .unwrap()
173            .constant_fold(&consts)
174            .unwrap();
175
176        for (i, expected) in outputs.iter().enumerate() {
177            let res_val: f64 = res
178                .get(i)
179                .unwrap()
180                .1
181                .get_custom_value::<ConstF64>()
182                .expect("This function assumes all incoming constants are floats.")
183                .value();
184
185            assert!(
186                res_val.abs_diff_eq(expected, f64::EPSILON),
187                "expected {expected:?}, got {res_val:?}"
188            );
189        }
190    }
191}