hugr_core/std_extensions/arithmetic/
conversions.rs1use std::sync::{Arc, LazyLock, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use crate::Extension;
8use crate::extension::prelude::sum_with_error;
9use crate::extension::prelude::{bool_t, string_type, usize_t};
10use crate::extension::simple_op::{HasConcrete, HasDef};
11use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError};
12use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc};
13use crate::ops::{ExtensionOp, OpName};
14use crate::std_extensions::arithmetic::int_ops::int_polytype;
15use crate::std_extensions::arithmetic::int_types::int_type;
16use crate::types::{TypeArg, TypeRV};
17
18use super::float_types::float64_type;
19use super::int_types::{get_log_width, int_tv};
20mod const_fold;
21pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");
23pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
25
26#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
28#[allow(missing_docs, non_camel_case_types)]
29#[non_exhaustive]
30pub enum ConvertOpDef {
31 trunc_u,
32 trunc_s,
33 convert_u,
34 convert_s,
35 itobool,
36 ifrombool,
37 itostring_u,
38 itostring_s,
39 itousize,
40 ifromusize,
41 bytecast_int64_to_float64,
42 bytecast_float64_to_int64,
43}
44
45impl MakeOpDef for ConvertOpDef {
46 fn opdef_id(&self) -> OpName {
47 <&'static str>::from(self).into()
48 }
49
50 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
51 crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
52 }
53
54 fn extension(&self) -> ExtensionId {
55 EXTENSION_ID.clone()
56 }
57
58 fn extension_ref(&self) -> Weak<Extension> {
59 Arc::downgrade(&EXTENSION)
60 }
61
62 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
63 use ConvertOpDef::*;
64 match self {
65 trunc_s | trunc_u => int_polytype(
66 1,
67 vec![float64_type()],
68 TypeRV::from(sum_with_error(int_tv(0))),
69 ),
70 convert_s | convert_u => int_polytype(1, vec![int_tv(0)], vec![float64_type()]),
71 itobool => int_polytype(0, vec![int_type(0)], vec![bool_t()]),
72 ifrombool => int_polytype(0, vec![bool_t()], vec![int_type(0)]),
73 itostring_u | itostring_s => int_polytype(1, vec![int_tv(0)], vec![string_type()]),
74 itousize => int_polytype(0, vec![int_type(6)], vec![usize_t()]),
75 ifromusize => int_polytype(0, vec![usize_t()], vec![int_type(6)]),
76 bytecast_int64_to_float64 => int_polytype(0, vec![int_type(6)], vec![float64_type()]),
77 bytecast_float64_to_int64 => int_polytype(0, vec![float64_type()], vec![int_type(6)]),
78 }
79 .into()
80 }
81
82 fn description(&self) -> String {
83 use ConvertOpDef::*;
84 match self {
85 trunc_u => "float to unsigned int",
86 trunc_s => "float to signed int",
87 convert_u => "unsigned int to float",
88 convert_s => "signed int to float",
89 itobool => "convert a 1-bit integer to bool (1 is true, 0 is false)",
90 ifrombool => "convert from bool into a 1-bit integer (1 is true, 0 is false)",
91 itostring_s => "convert a signed integer to its string representation",
92 itostring_u => "convert an unsigned integer to its string representation",
93 itousize => "convert a 64b unsigned integer to its usize representation",
94 ifromusize => "convert a usize to a 64b unsigned integer",
95 bytecast_int64_to_float64 => {
96 "reinterpret an int64 as a float64 based on its bytes, with the same endianness"
97 }
98 bytecast_float64_to_int64 => {
99 "reinterpret an float64 as an int based on its bytes, with the same endianness"
100 }
101 }
102 .to_string()
103 }
104
105 fn post_opdef(&self, def: &mut OpDef) {
106 const_fold::set_fold(self, def);
107 }
108}
109
110impl ConvertOpDef {
111 #[must_use]
114 pub fn without_log_width(self) -> ConvertOpType {
115 ConvertOpType {
116 def: self,
117 log_width: None,
118 }
119 }
120 #[must_use]
123 pub fn with_log_width(self, log_width: u8) -> ConvertOpType {
124 ConvertOpType {
125 def: self,
126 log_width: Some(log_width),
127 }
128 }
129}
130#[derive(Debug, Clone, PartialEq)]
132pub struct ConvertOpType {
133 def: ConvertOpDef,
135 log_width: Option<u8>,
139}
140
141impl ConvertOpType {
142 #[must_use]
144 pub fn def(&self) -> &ConvertOpDef {
145 &self.def
146 }
147
148 #[must_use]
150 pub fn log_widths(&self) -> &[u8] {
151 self.log_width.as_slice()
152 }
153}
154
155impl MakeExtensionOp for ConvertOpType {
156 fn op_id(&self) -> OpName {
157 self.def.opdef_id()
158 }
159
160 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
161 let def = ConvertOpDef::from_def(ext_op.def())?;
162 def.instantiate(ext_op.args())
163 }
164
165 fn type_args(&self) -> Vec<TypeArg> {
166 self.log_width
167 .iter()
168 .map(|&n| u64::from(n).into())
169 .collect()
170 }
171}
172
173pub static EXTENSION: LazyLock<Arc<Extension>> = LazyLock::new(|| {
175 Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
176 ConvertOpDef::load_all_ops(extension, extension_ref).unwrap();
177 })
178});
179
180impl MakeRegisteredOp for ConvertOpType {
181 fn extension_id(&self) -> ExtensionId {
182 EXTENSION_ID.clone()
183 }
184
185 fn extension_ref(&self) -> Weak<Extension> {
186 Arc::downgrade(&EXTENSION)
187 }
188}
189
190impl HasConcrete for ConvertOpDef {
191 type Concrete = ConvertOpType;
192
193 fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
194 let log_width = match type_args {
195 [] => None,
196 [arg] => Some(get_log_width(arg).map_err(|_| SignatureError::InvalidTypeArgs)?),
197 _ => return Err(SignatureError::InvalidTypeArgs.into()),
198 };
199 Ok(ConvertOpType {
200 def: *self,
201 log_width,
202 })
203 }
204}
205
206impl HasDef for ConvertOpType {
207 type Def = ConvertOpDef;
208}
209
210#[cfg(test)]
211mod test {
212 use rstest::rstest;
213
214 use crate::IncomingPort;
215 use crate::extension::prelude::ConstUsize;
216 use crate::ops::Value;
217 use crate::std_extensions::arithmetic::int_types::ConstInt;
218
219 use super::*;
220
221 #[test]
222 fn test_conversions_extension() {
223 let r = &EXTENSION;
224 assert_eq!(r.name() as &str, "arithmetic.conversions");
225 assert_eq!(r.types().count(), 0);
226 }
227
228 #[test]
229 fn test_conversions() {
230 assert!(
232 ConvertOpDef::itobool
233 .with_log_width(1)
234 .to_extension_op()
235 .is_none(),
236 "type arguments invalid"
237 );
238
239 let o = ConvertOpDef::itobool.without_log_width();
241 let ext_op: ExtensionOp = o.clone().to_extension_op().unwrap();
242
243 assert_eq!(ConvertOpType::from_op(&ext_op).unwrap(), o);
244 assert_eq!(
245 ConvertOpDef::from_op(&ext_op).unwrap(),
246 ConvertOpDef::itobool
247 );
248 }
249
250 #[rstest]
251 #[case::itobool_false(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 0).unwrap().into()], &[Value::false_val()])]
252 #[case::itobool_true(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 1).unwrap().into()], &[Value::true_val()])]
253 #[case::ifrombool_false(ConvertOpDef::ifrombool.without_log_width(), &[Value::false_val()], &[ConstInt::new_u(0, 0).unwrap().into()])]
254 #[case::ifrombool_true(ConvertOpDef::ifrombool.without_log_width(), &[Value::true_val()], &[ConstInt::new_u(0, 1).unwrap().into()])]
255 #[case::itousize(ConvertOpDef::itousize.without_log_width(), &[ConstInt::new_u(6, 42).unwrap().into()], &[ConstUsize::new(42).into()])]
256 #[case::ifromusize(ConvertOpDef::ifromusize.without_log_width(), &[ConstUsize::new(42).into()], &[ConstInt::new_u(6, 42).unwrap().into()])]
257 fn convert_fold(
258 #[case] op: ConvertOpType,
259 #[case] inputs: &[Value],
260 #[case] outputs: &[Value],
261 ) {
262 use crate::ops::Value;
263
264 let consts: Vec<(IncomingPort, Value)> = inputs
265 .iter()
266 .enumerate()
267 .map(|(i, v)| (i.into(), v.clone()))
268 .collect();
269
270 let res = op
271 .to_extension_op()
272 .unwrap()
273 .constant_fold(&consts)
274 .unwrap();
275
276 for (i, expected) in outputs.iter().enumerate() {
277 let res_val: &Value = &res.get(i).unwrap().1;
278
279 assert_eq!(res_val, expected);
280 }
281 }
282}