hugr_core/std_extensions/arithmetic/
conversions.rs

1//! Conversions between integer and floating-point values.
2
3use std::sync::{Arc, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use crate::Extension;
8use crate::extension::prelude::sum_with_error;
9use crate::extension::prelude::{bool_t, string_type, usize_t};
10use crate::extension::simple_op::{HasConcrete, HasDef};
11use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError};
12use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc};
13use crate::ops::{ExtensionOp, OpName};
14use crate::std_extensions::arithmetic::int_ops::int_polytype;
15use crate::std_extensions::arithmetic::int_types::int_type;
16use crate::types::{TypeArg, TypeRV};
17
18use super::float_types::float64_type;
19use super::int_types::{get_log_width, int_tv};
20use lazy_static::lazy_static;
21mod const_fold;
22/// The extension identifier.
23pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");
24/// Extension version.
25pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
26
27/// Extension for conversions between floats and integers.
28#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
29#[allow(missing_docs, non_camel_case_types)]
30#[non_exhaustive]
31pub enum ConvertOpDef {
32    trunc_u,
33    trunc_s,
34    convert_u,
35    convert_s,
36    itobool,
37    ifrombool,
38    itostring_u,
39    itostring_s,
40    itousize,
41    ifromusize,
42    bytecast_int64_to_float64,
43    bytecast_float64_to_int64,
44}
45
46impl MakeOpDef for ConvertOpDef {
47    fn opdef_id(&self) -> OpName {
48        <&'static str>::from(self).into()
49    }
50
51    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
52        crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
53    }
54
55    fn extension(&self) -> ExtensionId {
56        EXTENSION_ID.clone()
57    }
58
59    fn extension_ref(&self) -> Weak<Extension> {
60        Arc::downgrade(&EXTENSION)
61    }
62
63    fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
64        use ConvertOpDef::*;
65        match self {
66            trunc_s | trunc_u => int_polytype(
67                1,
68                vec![float64_type()],
69                TypeRV::from(sum_with_error(int_tv(0))),
70            ),
71            convert_s | convert_u => int_polytype(1, vec![int_tv(0)], vec![float64_type()]),
72            itobool => int_polytype(0, vec![int_type(0)], vec![bool_t()]),
73            ifrombool => int_polytype(0, vec![bool_t()], vec![int_type(0)]),
74            itostring_u | itostring_s => int_polytype(1, vec![int_tv(0)], vec![string_type()]),
75            itousize => int_polytype(0, vec![int_type(6)], vec![usize_t()]),
76            ifromusize => int_polytype(0, vec![usize_t()], vec![int_type(6)]),
77            bytecast_int64_to_float64 => int_polytype(0, vec![int_type(6)], vec![float64_type()]),
78            bytecast_float64_to_int64 => int_polytype(0, vec![float64_type()], vec![int_type(6)]),
79        }
80        .into()
81    }
82
83    fn description(&self) -> String {
84        use ConvertOpDef::*;
85        match self {
86            trunc_u => "float to unsigned int",
87            trunc_s => "float to signed int",
88            convert_u => "unsigned int to float",
89            convert_s => "signed int to float",
90            itobool => "convert a 1-bit integer to bool (1 is true, 0 is false)",
91            ifrombool => "convert from bool into a 1-bit integer (1 is true, 0 is false)",
92            itostring_s => "convert a signed integer to its string representation",
93            itostring_u => "convert an unsigned integer to its string representation",
94            itousize => "convert a 64b unsigned integer to its usize representation",
95            ifromusize => "convert a usize to a 64b unsigned integer",
96            bytecast_int64_to_float64 => {
97                "reinterpret an int64 as a float64 based on its bytes, with the same endianness"
98            }
99            bytecast_float64_to_int64 => {
100                "reinterpret an float64 as an int based on its bytes, with the same endianness"
101            }
102        }
103        .to_string()
104    }
105
106    fn post_opdef(&self, def: &mut OpDef) {
107        const_fold::set_fold(self, def);
108    }
109}
110
111impl ConvertOpDef {
112    /// Initialize a [`ConvertOpType`] from a [`ConvertOpDef`] which requires no
113    /// integer widths set.
114    #[must_use]
115    pub fn without_log_width(self) -> ConvertOpType {
116        ConvertOpType {
117            def: self,
118            log_width: None,
119        }
120    }
121    /// Initialize a [`ConvertOpType`] from a [`ConvertOpDef`] which requires one
122    /// integer width set.
123    #[must_use]
124    pub fn with_log_width(self, log_width: u8) -> ConvertOpType {
125        ConvertOpType {
126            def: self,
127            log_width: Some(log_width),
128        }
129    }
130}
131/// Concrete convert operation with integer log width set.
132#[derive(Debug, Clone, PartialEq)]
133pub struct ConvertOpType {
134    /// The kind of conversion op.
135    def: ConvertOpDef,
136    /// The integer width parameter of the conversion op, if any. This is interpreted
137    /// differently, depending on `def`. The integer types in the inputs and
138    /// outputs of the op will have [`int_type`]s of this width.
139    log_width: Option<u8>,
140}
141
142impl ConvertOpType {
143    /// Returns the generic [`ConvertOpDef`] of this [`ConvertOpType`].
144    #[must_use]
145    pub fn def(&self) -> &ConvertOpDef {
146        &self.def
147    }
148
149    /// Returns the integer width parameters of this [`ConvertOpType`], if any.
150    #[must_use]
151    pub fn log_widths(&self) -> &[u8] {
152        self.log_width.as_slice()
153    }
154}
155
156impl MakeExtensionOp for ConvertOpType {
157    fn op_id(&self) -> OpName {
158        self.def.opdef_id()
159    }
160
161    fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
162        let def = ConvertOpDef::from_def(ext_op.def())?;
163        def.instantiate(ext_op.args())
164    }
165
166    fn type_args(&self) -> Vec<TypeArg> {
167        self.log_width
168            .iter()
169            .map(|&n| u64::from(n).into())
170            .collect()
171    }
172}
173
174lazy_static! {
175    /// Extension for conversions between integers and floats.
176    pub static ref EXTENSION: Arc<Extension> = {
177        Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
178            ConvertOpDef::load_all_ops(extension, extension_ref).unwrap();
179        })
180    };
181}
182
183impl MakeRegisteredOp for ConvertOpType {
184    fn extension_id(&self) -> ExtensionId {
185        EXTENSION_ID.clone()
186    }
187
188    fn extension_ref(&self) -> Weak<Extension> {
189        Arc::downgrade(&EXTENSION)
190    }
191}
192
193impl HasConcrete for ConvertOpDef {
194    type Concrete = ConvertOpType;
195
196    fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
197        let log_width = match type_args {
198            [] => None,
199            [arg] => Some(get_log_width(arg).map_err(|_| SignatureError::InvalidTypeArgs)?),
200            _ => return Err(SignatureError::InvalidTypeArgs.into()),
201        };
202        Ok(ConvertOpType {
203            def: *self,
204            log_width,
205        })
206    }
207}
208
209impl HasDef for ConvertOpType {
210    type Def = ConvertOpDef;
211}
212
213#[cfg(test)]
214mod test {
215    use rstest::rstest;
216
217    use crate::IncomingPort;
218    use crate::extension::prelude::ConstUsize;
219    use crate::ops::Value;
220    use crate::std_extensions::arithmetic::int_types::ConstInt;
221
222    use super::*;
223
224    #[test]
225    fn test_conversions_extension() {
226        let r = &EXTENSION;
227        assert_eq!(r.name() as &str, "arithmetic.conversions");
228        assert_eq!(r.types().count(), 0);
229    }
230
231    #[test]
232    fn test_conversions() {
233        // Initialization with an invalid number of type arguments should fail.
234        assert!(
235            ConvertOpDef::itobool
236                .with_log_width(1)
237                .to_extension_op()
238                .is_none(),
239            "type arguments invalid"
240        );
241
242        // This should work
243        let o = ConvertOpDef::itobool.without_log_width();
244        let ext_op: ExtensionOp = o.clone().to_extension_op().unwrap();
245
246        assert_eq!(ConvertOpType::from_op(&ext_op).unwrap(), o);
247        assert_eq!(
248            ConvertOpDef::from_op(&ext_op).unwrap(),
249            ConvertOpDef::itobool
250        );
251    }
252
253    #[rstest]
254    #[case::itobool_false(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 0).unwrap().into()], &[Value::false_val()])]
255    #[case::itobool_true(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 1).unwrap().into()], &[Value::true_val()])]
256    #[case::ifrombool_false(ConvertOpDef::ifrombool.without_log_width(), &[Value::false_val()], &[ConstInt::new_u(0, 0).unwrap().into()])]
257    #[case::ifrombool_true(ConvertOpDef::ifrombool.without_log_width(), &[Value::true_val()], &[ConstInt::new_u(0, 1).unwrap().into()])]
258    #[case::itousize(ConvertOpDef::itousize.without_log_width(), &[ConstInt::new_u(6, 42).unwrap().into()], &[ConstUsize::new(42).into()])]
259    #[case::ifromusize(ConvertOpDef::ifromusize.without_log_width(), &[ConstUsize::new(42).into()], &[ConstInt::new_u(6, 42).unwrap().into()])]
260    fn convert_fold(
261        #[case] op: ConvertOpType,
262        #[case] inputs: &[Value],
263        #[case] outputs: &[Value],
264    ) {
265        use crate::ops::Value;
266
267        let consts: Vec<(IncomingPort, Value)> = inputs
268            .iter()
269            .enumerate()
270            .map(|(i, v)| (i.into(), v.clone()))
271            .collect();
272
273        let res = op
274            .to_extension_op()
275            .unwrap()
276            .constant_fold(&consts)
277            .unwrap();
278
279        for (i, expected) in outputs.iter().enumerate() {
280            let res_val: &Value = &res.get(i).unwrap().1;
281
282            assert_eq!(res_val, expected);
283        }
284    }
285}