hugr_core/std_extensions/arithmetic/
conversions.rs1use std::sync::{Arc, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use crate::extension::prelude::sum_with_error;
8use crate::extension::prelude::{bool_t, string_type, usize_t};
9use crate::extension::simple_op::{HasConcrete, HasDef};
10use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError};
11use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc};
12use crate::ops::{ExtensionOp, OpName};
13use crate::std_extensions::arithmetic::int_ops::int_polytype;
14use crate::std_extensions::arithmetic::int_types::int_type;
15use crate::types::{TypeArg, TypeRV};
16use crate::Extension;
17
18use super::float_types::float64_type;
19use super::int_types::{get_log_width, int_tv};
20use lazy_static::lazy_static;
21mod const_fold;
22pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");
24pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
26
27#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
29#[allow(missing_docs, non_camel_case_types)]
30#[non_exhaustive]
31pub enum ConvertOpDef {
32 trunc_u,
33 trunc_s,
34 convert_u,
35 convert_s,
36 itobool,
37 ifrombool,
38 itostring_u,
39 itostring_s,
40 itousize,
41 ifromusize,
42 bytecast_int64_to_float64,
43 bytecast_float64_to_int64,
44}
45
46impl MakeOpDef for ConvertOpDef {
47 fn opdef_id(&self) -> OpName {
48 <&'static str>::from(self).into()
49 }
50
51 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
52 crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
53 }
54
55 fn extension(&self) -> ExtensionId {
56 EXTENSION_ID.to_owned()
57 }
58
59 fn extension_ref(&self) -> Weak<Extension> {
60 Arc::downgrade(&EXTENSION)
61 }
62
63 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
64 use ConvertOpDef::*;
65 match self {
66 trunc_s | trunc_u => int_polytype(
67 1,
68 vec![float64_type()],
69 TypeRV::from(sum_with_error(int_tv(0))),
70 ),
71 convert_s | convert_u => int_polytype(1, vec![int_tv(0)], vec![float64_type()]),
72 itobool => int_polytype(0, vec![int_type(0)], vec![bool_t()]),
73 ifrombool => int_polytype(0, vec![bool_t()], vec![int_type(0)]),
74 itostring_u | itostring_s => int_polytype(1, vec![int_tv(0)], vec![string_type()]),
75 itousize => int_polytype(0, vec![int_type(6)], vec![usize_t()]),
76 ifromusize => int_polytype(0, vec![usize_t()], vec![int_type(6)]),
77 bytecast_int64_to_float64 => int_polytype(0, vec![int_type(6)], vec![float64_type()]),
78 bytecast_float64_to_int64 => int_polytype(0, vec![float64_type()], vec![int_type(6)]),
79 }
80 .into()
81 }
82
83 fn description(&self) -> String {
84 use ConvertOpDef::*;
85 match self {
86 trunc_u => "float to unsigned int",
87 trunc_s => "float to signed int",
88 convert_u => "unsigned int to float",
89 convert_s => "signed int to float",
90 itobool => "convert a 1-bit integer to bool (1 is true, 0 is false)",
91 ifrombool => "convert from bool into a 1-bit integer (1 is true, 0 is false)",
92 itostring_s => "convert a signed integer to its string representation",
93 itostring_u => "convert an unsigned integer to its string representation",
94 itousize => "convert a 64b unsigned integer to its usize representation",
95 ifromusize => "convert a usize to a 64b unsigned integer",
96 bytecast_int64_to_float64 => {
97 "reinterpret an int64 as a float64 based on its bytes, with the same endianness"
98 }
99 bytecast_float64_to_int64 => {
100 "reinterpret an float64 as an int based on its bytes, with the same endianness"
101 }
102 }
103 .to_string()
104 }
105
106 fn post_opdef(&self, def: &mut OpDef) {
107 const_fold::set_fold(self, def)
108 }
109}
110
111impl ConvertOpDef {
112 pub fn without_log_width(self) -> ConvertOpType {
115 ConvertOpType {
116 def: self,
117 log_width: None,
118 }
119 }
120 pub fn with_log_width(self, log_width: u8) -> ConvertOpType {
123 ConvertOpType {
124 def: self,
125 log_width: Some(log_width),
126 }
127 }
128}
129#[derive(Debug, Clone, PartialEq)]
131pub struct ConvertOpType {
132 def: ConvertOpDef,
134 log_width: Option<u8>,
138}
139
140impl ConvertOpType {
141 pub fn def(&self) -> &ConvertOpDef {
143 &self.def
144 }
145
146 pub fn log_widths(&self) -> &[u8] {
148 self.log_width.as_slice()
149 }
150}
151
152impl MakeExtensionOp for ConvertOpType {
153 fn op_id(&self) -> OpName {
154 self.def.opdef_id()
155 }
156
157 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
158 let def = ConvertOpDef::from_def(ext_op.def())?;
159 def.instantiate(ext_op.args())
160 }
161
162 fn type_args(&self) -> Vec<TypeArg> {
163 self.log_width.iter().map(|&n| (n as u64).into()).collect()
164 }
165}
166
167lazy_static! {
168 pub static ref EXTENSION: Arc<Extension> = {
170 Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
171 ConvertOpDef::load_all_ops(extension, extension_ref).unwrap();
172 })
173 };
174}
175
176impl MakeRegisteredOp for ConvertOpType {
177 fn extension_id(&self) -> ExtensionId {
178 EXTENSION_ID.to_owned()
179 }
180
181 fn extension_ref(&self) -> Weak<Extension> {
182 Arc::downgrade(&EXTENSION)
183 }
184}
185
186impl HasConcrete for ConvertOpDef {
187 type Concrete = ConvertOpType;
188
189 fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
190 let log_width = match type_args {
191 [] => None,
192 [arg] => Some(get_log_width(arg).map_err(|_| SignatureError::InvalidTypeArgs)?),
193 _ => return Err(SignatureError::InvalidTypeArgs.into()),
194 };
195 Ok(ConvertOpType {
196 def: *self,
197 log_width,
198 })
199 }
200}
201
202impl HasDef for ConvertOpType {
203 type Def = ConvertOpDef;
204}
205
206#[cfg(test)]
207mod test {
208 use rstest::rstest;
209
210 use crate::extension::prelude::ConstUsize;
211 use crate::ops::Value;
212 use crate::std_extensions::arithmetic::int_types::ConstInt;
213 use crate::IncomingPort;
214
215 use super::*;
216
217 #[test]
218 fn test_conversions_extension() {
219 let r = &EXTENSION;
220 assert_eq!(r.name() as &str, "arithmetic.conversions");
221 assert_eq!(r.types().count(), 0);
222 }
223
224 #[test]
225 fn test_conversions() {
226 assert!(
228 ConvertOpDef::itobool
229 .with_log_width(1)
230 .to_extension_op()
231 .is_none(),
232 "type arguments invalid"
233 );
234
235 let o = ConvertOpDef::itobool.without_log_width();
237 let ext_op: ExtensionOp = o.clone().to_extension_op().unwrap();
238
239 assert_eq!(ConvertOpType::from_op(&ext_op).unwrap(), o);
240 assert_eq!(
241 ConvertOpDef::from_op(&ext_op).unwrap(),
242 ConvertOpDef::itobool
243 );
244 }
245
246 #[rstest]
247 #[case::itobool_false(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 0).unwrap().into()], &[Value::false_val()])]
248 #[case::itobool_true(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 1).unwrap().into()], &[Value::true_val()])]
249 #[case::ifrombool_false(ConvertOpDef::ifrombool.without_log_width(), &[Value::false_val()], &[ConstInt::new_u(0, 0).unwrap().into()])]
250 #[case::ifrombool_true(ConvertOpDef::ifrombool.without_log_width(), &[Value::true_val()], &[ConstInt::new_u(0, 1).unwrap().into()])]
251 #[case::itousize(ConvertOpDef::itousize.without_log_width(), &[ConstInt::new_u(6, 42).unwrap().into()], &[ConstUsize::new(42).into()])]
252 #[case::ifromusize(ConvertOpDef::ifromusize.without_log_width(), &[ConstUsize::new(42).into()], &[ConstInt::new_u(6, 42).unwrap().into()])]
253 fn convert_fold(
254 #[case] op: ConvertOpType,
255 #[case] inputs: &[Value],
256 #[case] outputs: &[Value],
257 ) {
258 use crate::ops::Value;
259
260 let consts: Vec<(IncomingPort, Value)> = inputs
261 .iter()
262 .enumerate()
263 .map(|(i, v)| (i.into(), v.clone()))
264 .collect();
265
266 let res = op
267 .to_extension_op()
268 .unwrap()
269 .constant_fold(&consts)
270 .unwrap();
271
272 for (i, expected) in outputs.iter().enumerate() {
273 let res_val: &Value = &res.get(i).unwrap().1;
274
275 assert_eq!(res_val, expected);
276 }
277 }
278}