hugr_core/std_extensions/arithmetic/
conversions.rs

1//! Conversions between integer and floating-point values.
2
3use std::sync::{Arc, LazyLock, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use crate::Extension;
8use crate::extension::prelude::sum_with_error;
9use crate::extension::prelude::{bool_t, string_type, usize_t};
10use crate::extension::simple_op::{HasConcrete, HasDef};
11use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError};
12use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc};
13use crate::ops::{ExtensionOp, OpName};
14use crate::std_extensions::arithmetic::int_ops::int_polytype;
15use crate::std_extensions::arithmetic::int_types::int_type;
16use crate::types::{TypeArg, TypeRV};
17
18use super::float_types::float64_type;
19use super::int_types::{get_log_width, int_tv};
20mod const_fold;
21/// The extension identifier.
22pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");
23/// Extension version.
24pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
25
26/// Extension for conversions between floats and integers.
27#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
28#[allow(missing_docs, non_camel_case_types)]
29#[non_exhaustive]
30pub enum ConvertOpDef {
31    trunc_u,
32    trunc_s,
33    convert_u,
34    convert_s,
35    itobool,
36    ifrombool,
37    itostring_u,
38    itostring_s,
39    itousize,
40    ifromusize,
41    bytecast_int64_to_float64,
42    bytecast_float64_to_int64,
43}
44
45impl MakeOpDef for ConvertOpDef {
46    fn opdef_id(&self) -> OpName {
47        <&'static str>::from(self).into()
48    }
49
50    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
51        crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
52    }
53
54    fn extension(&self) -> ExtensionId {
55        EXTENSION_ID.clone()
56    }
57
58    fn extension_ref(&self) -> Weak<Extension> {
59        Arc::downgrade(&EXTENSION)
60    }
61
62    fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
63        use ConvertOpDef::*;
64        match self {
65            trunc_s | trunc_u => int_polytype(
66                1,
67                vec![float64_type()],
68                TypeRV::from(sum_with_error(int_tv(0))),
69            ),
70            convert_s | convert_u => int_polytype(1, vec![int_tv(0)], vec![float64_type()]),
71            itobool => int_polytype(0, vec![int_type(0)], vec![bool_t()]),
72            ifrombool => int_polytype(0, vec![bool_t()], vec![int_type(0)]),
73            itostring_u | itostring_s => int_polytype(1, vec![int_tv(0)], vec![string_type()]),
74            itousize => int_polytype(0, vec![int_type(6)], vec![usize_t()]),
75            ifromusize => int_polytype(0, vec![usize_t()], vec![int_type(6)]),
76            bytecast_int64_to_float64 => int_polytype(0, vec![int_type(6)], vec![float64_type()]),
77            bytecast_float64_to_int64 => int_polytype(0, vec![float64_type()], vec![int_type(6)]),
78        }
79        .into()
80    }
81
82    fn description(&self) -> String {
83        use ConvertOpDef::*;
84        match self {
85            trunc_u => "float to unsigned int",
86            trunc_s => "float to signed int",
87            convert_u => "unsigned int to float",
88            convert_s => "signed int to float",
89            itobool => "convert a 1-bit integer to bool (1 is true, 0 is false)",
90            ifrombool => "convert from bool into a 1-bit integer (1 is true, 0 is false)",
91            itostring_s => "convert a signed integer to its string representation",
92            itostring_u => "convert an unsigned integer to its string representation",
93            itousize => "convert a 64b unsigned integer to its usize representation",
94            ifromusize => "convert a usize to a 64b unsigned integer",
95            bytecast_int64_to_float64 => {
96                "reinterpret an int64 as a float64 based on its bytes, with the same endianness"
97            }
98            bytecast_float64_to_int64 => {
99                "reinterpret an float64 as an int based on its bytes, with the same endianness"
100            }
101        }
102        .to_string()
103    }
104
105    fn post_opdef(&self, def: &mut OpDef) {
106        const_fold::set_fold(self, def);
107    }
108}
109
110impl ConvertOpDef {
111    /// Initialize a [`ConvertOpType`] from a [`ConvertOpDef`] which requires no
112    /// integer widths set.
113    #[must_use]
114    pub fn without_log_width(self) -> ConvertOpType {
115        ConvertOpType {
116            def: self,
117            log_width: None,
118        }
119    }
120    /// Initialize a [`ConvertOpType`] from a [`ConvertOpDef`] which requires one
121    /// integer width set.
122    #[must_use]
123    pub fn with_log_width(self, log_width: u8) -> ConvertOpType {
124        ConvertOpType {
125            def: self,
126            log_width: Some(log_width),
127        }
128    }
129}
130/// Concrete convert operation with integer log width set.
131#[derive(Debug, Clone, PartialEq)]
132pub struct ConvertOpType {
133    /// The kind of conversion op.
134    def: ConvertOpDef,
135    /// The integer width parameter of the conversion op, if any. This is interpreted
136    /// differently, depending on `def`. The integer types in the inputs and
137    /// outputs of the op will have [`int_type`]s of this width.
138    log_width: Option<u8>,
139}
140
141impl ConvertOpType {
142    /// Returns the generic [`ConvertOpDef`] of this [`ConvertOpType`].
143    #[must_use]
144    pub fn def(&self) -> &ConvertOpDef {
145        &self.def
146    }
147
148    /// Returns the integer width parameters of this [`ConvertOpType`], if any.
149    #[must_use]
150    pub fn log_widths(&self) -> &[u8] {
151        self.log_width.as_slice()
152    }
153}
154
155impl MakeExtensionOp for ConvertOpType {
156    fn op_id(&self) -> OpName {
157        self.def.opdef_id()
158    }
159
160    fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
161        let def = ConvertOpDef::from_def(ext_op.def())?;
162        def.instantiate(ext_op.args())
163    }
164
165    fn type_args(&self) -> Vec<TypeArg> {
166        self.log_width
167            .iter()
168            .map(|&n| u64::from(n).into())
169            .collect()
170    }
171}
172
173/// Extension for conversions between integers and floats.
174pub static EXTENSION: LazyLock<Arc<Extension>> = LazyLock::new(|| {
175    Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
176        ConvertOpDef::load_all_ops(extension, extension_ref).unwrap();
177    })
178});
179
180impl MakeRegisteredOp for ConvertOpType {
181    fn extension_id(&self) -> ExtensionId {
182        EXTENSION_ID.clone()
183    }
184
185    fn extension_ref(&self) -> Weak<Extension> {
186        Arc::downgrade(&EXTENSION)
187    }
188}
189
190impl HasConcrete for ConvertOpDef {
191    type Concrete = ConvertOpType;
192
193    fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
194        let log_width = match type_args {
195            [] => None,
196            [arg] => Some(get_log_width(arg).map_err(|_| SignatureError::InvalidTypeArgs)?),
197            _ => return Err(SignatureError::InvalidTypeArgs.into()),
198        };
199        Ok(ConvertOpType {
200            def: *self,
201            log_width,
202        })
203    }
204}
205
206impl HasDef for ConvertOpType {
207    type Def = ConvertOpDef;
208}
209
210#[cfg(test)]
211mod test {
212    use rstest::rstest;
213
214    use crate::IncomingPort;
215    use crate::extension::prelude::ConstUsize;
216    use crate::ops::Value;
217    use crate::std_extensions::arithmetic::int_types::ConstInt;
218
219    use super::*;
220
221    #[test]
222    fn test_conversions_extension() {
223        let r = &EXTENSION;
224        assert_eq!(r.name() as &str, "arithmetic.conversions");
225        assert_eq!(r.types().count(), 0);
226    }
227
228    #[test]
229    fn test_conversions() {
230        // Initialization with an invalid number of type arguments should fail.
231        assert!(
232            ConvertOpDef::itobool
233                .with_log_width(1)
234                .to_extension_op()
235                .is_none(),
236            "type arguments invalid"
237        );
238
239        // This should work
240        let o = ConvertOpDef::itobool.without_log_width();
241        let ext_op: ExtensionOp = o.clone().to_extension_op().unwrap();
242
243        assert_eq!(ConvertOpType::from_op(&ext_op).unwrap(), o);
244        assert_eq!(
245            ConvertOpDef::from_op(&ext_op).unwrap(),
246            ConvertOpDef::itobool
247        );
248    }
249
250    #[rstest]
251    #[case::itobool_false(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 0).unwrap().into()], &[Value::false_val()])]
252    #[case::itobool_true(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 1).unwrap().into()], &[Value::true_val()])]
253    #[case::ifrombool_false(ConvertOpDef::ifrombool.without_log_width(), &[Value::false_val()], &[ConstInt::new_u(0, 0).unwrap().into()])]
254    #[case::ifrombool_true(ConvertOpDef::ifrombool.without_log_width(), &[Value::true_val()], &[ConstInt::new_u(0, 1).unwrap().into()])]
255    #[case::itousize(ConvertOpDef::itousize.without_log_width(), &[ConstInt::new_u(6, 42).unwrap().into()], &[ConstUsize::new(42).into()])]
256    #[case::ifromusize(ConvertOpDef::ifromusize.without_log_width(), &[ConstUsize::new(42).into()], &[ConstInt::new_u(6, 42).unwrap().into()])]
257    fn convert_fold(
258        #[case] op: ConvertOpType,
259        #[case] inputs: &[Value],
260        #[case] outputs: &[Value],
261    ) {
262        use crate::ops::Value;
263
264        let consts: Vec<(IncomingPort, Value)> = inputs
265            .iter()
266            .enumerate()
267            .map(|(i, v)| (i.into(), v.clone()))
268            .collect();
269
270        let res = op
271            .to_extension_op()
272            .unwrap()
273            .constant_fold(&consts)
274            .unwrap();
275
276        for (i, expected) in outputs.iter().enumerate() {
277            let res_val: &Value = &res.get(i).unwrap().1;
278
279            assert_eq!(res_val, expected);
280        }
281    }
282}