hugr_core/std_extensions/arithmetic/
conversions.rs

1//! Conversions between integer and floating-point values.
2
3use std::sync::{Arc, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use crate::extension::prelude::sum_with_error;
8use crate::extension::prelude::{bool_t, string_type, usize_t};
9use crate::extension::simple_op::{HasConcrete, HasDef};
10use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError};
11use crate::extension::{ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc};
12use crate::ops::OpName;
13use crate::ops::{custom::ExtensionOp, NamedOp};
14use crate::std_extensions::arithmetic::int_ops::int_polytype;
15use crate::std_extensions::arithmetic::int_types::int_type;
16use crate::types::{TypeArg, TypeRV};
17use crate::Extension;
18
19use super::float_types::float64_type;
20use super::int_types::{get_log_width, int_tv};
21use lazy_static::lazy_static;
22mod const_fold;
23/// The extension identifier.
24pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");
25/// Extension version.
26pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
27
28/// Extension for conversions between floats and integers.
29#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
30#[allow(missing_docs, non_camel_case_types)]
31#[non_exhaustive]
32pub enum ConvertOpDef {
33    trunc_u,
34    trunc_s,
35    convert_u,
36    convert_s,
37    itobool,
38    ifrombool,
39    itostring_u,
40    itostring_s,
41    itousize,
42    ifromusize,
43    bytecast_int64_to_float64,
44    bytecast_float64_to_int64,
45}
46
47impl MakeOpDef for ConvertOpDef {
48    fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
49        crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
50    }
51
52    fn extension(&self) -> ExtensionId {
53        EXTENSION_ID.to_owned()
54    }
55
56    fn extension_ref(&self) -> Weak<Extension> {
57        Arc::downgrade(&EXTENSION)
58    }
59
60    fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
61        use ConvertOpDef::*;
62        match self {
63            trunc_s | trunc_u => int_polytype(
64                1,
65                vec![float64_type()],
66                TypeRV::from(sum_with_error(int_tv(0))),
67            ),
68            convert_s | convert_u => int_polytype(1, vec![int_tv(0)], vec![float64_type()]),
69            itobool => int_polytype(0, vec![int_type(0)], vec![bool_t()]),
70            ifrombool => int_polytype(0, vec![bool_t()], vec![int_type(0)]),
71            itostring_u | itostring_s => int_polytype(1, vec![int_tv(0)], vec![string_type()]),
72            itousize => int_polytype(0, vec![int_type(6)], vec![usize_t()]),
73            ifromusize => int_polytype(0, vec![usize_t()], vec![int_type(6)]),
74            bytecast_int64_to_float64 => int_polytype(0, vec![int_type(6)], vec![float64_type()]),
75            bytecast_float64_to_int64 => int_polytype(0, vec![float64_type()], vec![int_type(6)]),
76        }
77        .into()
78    }
79
80    fn description(&self) -> String {
81        use ConvertOpDef::*;
82        match self {
83            trunc_u => "float to unsigned int",
84            trunc_s => "float to signed int",
85            convert_u => "unsigned int to float",
86            convert_s => "signed int to float",
87            itobool => "convert a 1-bit integer to bool (1 is true, 0 is false)",
88            ifrombool => "convert from bool into a 1-bit integer (1 is true, 0 is false)",
89            itostring_s => "convert a signed integer to its string representation",
90            itostring_u => "convert an unsigned integer to its string representation",
91            itousize => "convert a 64b unsigned integer to its usize representation",
92            ifromusize => "convert a usize to a 64b unsigned integer",
93            bytecast_int64_to_float64 => {
94                "reinterpret an int64 as a float64 based on its bytes, with the same endianness"
95            }
96            bytecast_float64_to_int64 => {
97                "reinterpret an float64 as an int based on its bytes, with the same endianness"
98            }
99        }
100        .to_string()
101    }
102
103    fn post_opdef(&self, def: &mut OpDef) {
104        const_fold::set_fold(self, def)
105    }
106}
107
108impl ConvertOpDef {
109    /// Initialize a [ConvertOpType] from a [ConvertOpDef] which requires no
110    /// integer widths set.
111    pub fn without_log_width(self) -> ConvertOpType {
112        ConvertOpType {
113            def: self,
114            log_width: None,
115        }
116    }
117    /// Initialize a [ConvertOpType] from a [ConvertOpDef] which requires one
118    /// integer width set.
119    pub fn with_log_width(self, log_width: u8) -> ConvertOpType {
120        ConvertOpType {
121            def: self,
122            log_width: Some(log_width),
123        }
124    }
125}
126/// Concrete convert operation with integer log width set.
127#[derive(Debug, Clone, PartialEq)]
128pub struct ConvertOpType {
129    /// The kind of conversion op.
130    def: ConvertOpDef,
131    /// The integer width parameter of the conversion op, if any. This is interpreted
132    /// differently, depending on `def`. The integer types in the inputs and
133    /// outputs of the op will have [int_type]s of this width.
134    log_width: Option<u8>,
135}
136
137impl ConvertOpType {
138    /// Returns the generic [ConvertOpDef] of this [ConvertOpType].
139    pub fn def(&self) -> &ConvertOpDef {
140        &self.def
141    }
142
143    /// Returns the integer width parameters of this [ConvertOpType], if any.
144    pub fn log_widths(&self) -> &[u8] {
145        self.log_width.as_slice()
146    }
147}
148
149impl NamedOp for ConvertOpType {
150    fn name(&self) -> OpName {
151        self.def.name()
152    }
153}
154
155impl MakeExtensionOp for ConvertOpType {
156    fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
157        let def = ConvertOpDef::from_def(ext_op.def())?;
158        def.instantiate(ext_op.args())
159    }
160
161    fn type_args(&self) -> Vec<TypeArg> {
162        self.log_width.iter().map(|&n| (n as u64).into()).collect()
163    }
164}
165
166lazy_static! {
167    /// Extension for conversions between integers and floats.
168    pub static ref EXTENSION: Arc<Extension> = {
169        Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
170            extension.add_requirements(
171                ExtensionSet::from_iter(vec![
172                    super::int_types::EXTENSION_ID,
173                    super::float_types::EXTENSION_ID,
174            ]));
175
176            ConvertOpDef::load_all_ops(extension, extension_ref).unwrap();
177        })
178    };
179}
180
181impl MakeRegisteredOp for ConvertOpType {
182    fn extension_id(&self) -> ExtensionId {
183        EXTENSION_ID.to_owned()
184    }
185
186    fn extension_ref(&self) -> Weak<Extension> {
187        Arc::downgrade(&EXTENSION)
188    }
189}
190
191impl HasConcrete for ConvertOpDef {
192    type Concrete = ConvertOpType;
193
194    fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
195        let log_width = match type_args {
196            [] => None,
197            [arg] => Some(get_log_width(arg).map_err(|_| SignatureError::InvalidTypeArgs)?),
198            _ => return Err(SignatureError::InvalidTypeArgs.into()),
199        };
200        Ok(ConvertOpType {
201            def: *self,
202            log_width,
203        })
204    }
205}
206
207impl HasDef for ConvertOpType {
208    type Def = ConvertOpDef;
209}
210
211#[cfg(test)]
212mod test {
213    use rstest::rstest;
214
215    use crate::extension::prelude::ConstUsize;
216    use crate::ops::Value;
217    use crate::std_extensions::arithmetic::int_types::ConstInt;
218    use crate::IncomingPort;
219
220    use super::*;
221
222    #[test]
223    fn test_conversions_extension() {
224        let r = &EXTENSION;
225        assert_eq!(r.name() as &str, "arithmetic.conversions");
226        assert_eq!(r.types().count(), 0);
227    }
228
229    #[test]
230    fn test_conversions() {
231        // Initialization with an invalid number of type arguments should fail.
232        assert!(
233            ConvertOpDef::itobool
234                .with_log_width(1)
235                .to_extension_op()
236                .is_none(),
237            "type arguments invalid"
238        );
239
240        // This should work
241        let o = ConvertOpDef::itobool.without_log_width();
242        let ext_op: ExtensionOp = o.clone().to_extension_op().unwrap();
243
244        assert_eq!(ConvertOpType::from_op(&ext_op).unwrap(), o);
245        assert_eq!(
246            ConvertOpDef::from_op(&ext_op).unwrap(),
247            ConvertOpDef::itobool
248        );
249    }
250
251    #[rstest]
252    #[case::itobool_false(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 0).unwrap().into()], &[Value::false_val()])]
253    #[case::itobool_true(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 1).unwrap().into()], &[Value::true_val()])]
254    #[case::ifrombool_false(ConvertOpDef::ifrombool.without_log_width(), &[Value::false_val()], &[ConstInt::new_u(0, 0).unwrap().into()])]
255    #[case::ifrombool_true(ConvertOpDef::ifrombool.without_log_width(), &[Value::true_val()], &[ConstInt::new_u(0, 1).unwrap().into()])]
256    #[case::itousize(ConvertOpDef::itousize.without_log_width(), &[ConstInt::new_u(6, 42).unwrap().into()], &[ConstUsize::new(42).into()])]
257    #[case::ifromusize(ConvertOpDef::ifromusize.without_log_width(), &[ConstUsize::new(42).into()], &[ConstInt::new_u(6, 42).unwrap().into()])]
258    fn convert_fold(
259        #[case] op: ConvertOpType,
260        #[case] inputs: &[Value],
261        #[case] outputs: &[Value],
262    ) {
263        use crate::ops::Value;
264
265        let consts: Vec<(IncomingPort, Value)> = inputs
266            .iter()
267            .enumerate()
268            .map(|(i, v)| (i.into(), v.clone()))
269            .collect();
270
271        let res = op
272            .to_extension_op()
273            .unwrap()
274            .constant_fold(&consts)
275            .unwrap();
276
277        for (i, expected) in outputs.iter().enumerate() {
278            let res_val: &Value = &res.get(i).unwrap().1;
279
280            assert_eq!(res_val, expected);
281        }
282    }
283}