hugr_core/std_extensions/arithmetic/
conversions.rs1use std::sync::{Arc, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use crate::extension::prelude::sum_with_error;
8use crate::extension::prelude::{bool_t, string_type, usize_t};
9use crate::extension::simple_op::{HasConcrete, HasDef};
10use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError};
11use crate::extension::{ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc};
12use crate::ops::OpName;
13use crate::ops::{custom::ExtensionOp, NamedOp};
14use crate::std_extensions::arithmetic::int_ops::int_polytype;
15use crate::std_extensions::arithmetic::int_types::int_type;
16use crate::types::{TypeArg, TypeRV};
17use crate::Extension;
18
19use super::float_types::float64_type;
20use super::int_types::{get_log_width, int_tv};
21use lazy_static::lazy_static;
22mod const_fold;
23pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");
25pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
27
28#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
30#[allow(missing_docs, non_camel_case_types)]
31#[non_exhaustive]
32pub enum ConvertOpDef {
33 trunc_u,
34 trunc_s,
35 convert_u,
36 convert_s,
37 itobool,
38 ifrombool,
39 itostring_u,
40 itostring_s,
41 itousize,
42 ifromusize,
43 bytecast_int64_to_float64,
44 bytecast_float64_to_int64,
45}
46
47impl MakeOpDef for ConvertOpDef {
48 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
49 crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
50 }
51
52 fn extension(&self) -> ExtensionId {
53 EXTENSION_ID.to_owned()
54 }
55
56 fn extension_ref(&self) -> Weak<Extension> {
57 Arc::downgrade(&EXTENSION)
58 }
59
60 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
61 use ConvertOpDef::*;
62 match self {
63 trunc_s | trunc_u => int_polytype(
64 1,
65 vec![float64_type()],
66 TypeRV::from(sum_with_error(int_tv(0))),
67 ),
68 convert_s | convert_u => int_polytype(1, vec![int_tv(0)], vec![float64_type()]),
69 itobool => int_polytype(0, vec![int_type(0)], vec![bool_t()]),
70 ifrombool => int_polytype(0, vec![bool_t()], vec![int_type(0)]),
71 itostring_u | itostring_s => int_polytype(1, vec![int_tv(0)], vec![string_type()]),
72 itousize => int_polytype(0, vec![int_type(6)], vec![usize_t()]),
73 ifromusize => int_polytype(0, vec![usize_t()], vec![int_type(6)]),
74 bytecast_int64_to_float64 => int_polytype(0, vec![int_type(6)], vec![float64_type()]),
75 bytecast_float64_to_int64 => int_polytype(0, vec![float64_type()], vec![int_type(6)]),
76 }
77 .into()
78 }
79
80 fn description(&self) -> String {
81 use ConvertOpDef::*;
82 match self {
83 trunc_u => "float to unsigned int",
84 trunc_s => "float to signed int",
85 convert_u => "unsigned int to float",
86 convert_s => "signed int to float",
87 itobool => "convert a 1-bit integer to bool (1 is true, 0 is false)",
88 ifrombool => "convert from bool into a 1-bit integer (1 is true, 0 is false)",
89 itostring_s => "convert a signed integer to its string representation",
90 itostring_u => "convert an unsigned integer to its string representation",
91 itousize => "convert a 64b unsigned integer to its usize representation",
92 ifromusize => "convert a usize to a 64b unsigned integer",
93 bytecast_int64_to_float64 => {
94 "reinterpret an int64 as a float64 based on its bytes, with the same endianness"
95 }
96 bytecast_float64_to_int64 => {
97 "reinterpret an float64 as an int based on its bytes, with the same endianness"
98 }
99 }
100 .to_string()
101 }
102
103 fn post_opdef(&self, def: &mut OpDef) {
104 const_fold::set_fold(self, def)
105 }
106}
107
108impl ConvertOpDef {
109 pub fn without_log_width(self) -> ConvertOpType {
112 ConvertOpType {
113 def: self,
114 log_width: None,
115 }
116 }
117 pub fn with_log_width(self, log_width: u8) -> ConvertOpType {
120 ConvertOpType {
121 def: self,
122 log_width: Some(log_width),
123 }
124 }
125}
126#[derive(Debug, Clone, PartialEq)]
128pub struct ConvertOpType {
129 def: ConvertOpDef,
131 log_width: Option<u8>,
135}
136
137impl ConvertOpType {
138 pub fn def(&self) -> &ConvertOpDef {
140 &self.def
141 }
142
143 pub fn log_widths(&self) -> &[u8] {
145 self.log_width.as_slice()
146 }
147}
148
149impl NamedOp for ConvertOpType {
150 fn name(&self) -> OpName {
151 self.def.name()
152 }
153}
154
155impl MakeExtensionOp for ConvertOpType {
156 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
157 let def = ConvertOpDef::from_def(ext_op.def())?;
158 def.instantiate(ext_op.args())
159 }
160
161 fn type_args(&self) -> Vec<TypeArg> {
162 self.log_width.iter().map(|&n| (n as u64).into()).collect()
163 }
164}
165
166lazy_static! {
167 pub static ref EXTENSION: Arc<Extension> = {
169 Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
170 extension.add_requirements(
171 ExtensionSet::from_iter(vec![
172 super::int_types::EXTENSION_ID,
173 super::float_types::EXTENSION_ID,
174 ]));
175
176 ConvertOpDef::load_all_ops(extension, extension_ref).unwrap();
177 })
178 };
179}
180
181impl MakeRegisteredOp for ConvertOpType {
182 fn extension_id(&self) -> ExtensionId {
183 EXTENSION_ID.to_owned()
184 }
185
186 fn extension_ref(&self) -> Weak<Extension> {
187 Arc::downgrade(&EXTENSION)
188 }
189}
190
191impl HasConcrete for ConvertOpDef {
192 type Concrete = ConvertOpType;
193
194 fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
195 let log_width = match type_args {
196 [] => None,
197 [arg] => Some(get_log_width(arg).map_err(|_| SignatureError::InvalidTypeArgs)?),
198 _ => return Err(SignatureError::InvalidTypeArgs.into()),
199 };
200 Ok(ConvertOpType {
201 def: *self,
202 log_width,
203 })
204 }
205}
206
207impl HasDef for ConvertOpType {
208 type Def = ConvertOpDef;
209}
210
211#[cfg(test)]
212mod test {
213 use rstest::rstest;
214
215 use crate::extension::prelude::ConstUsize;
216 use crate::ops::Value;
217 use crate::std_extensions::arithmetic::int_types::ConstInt;
218 use crate::IncomingPort;
219
220 use super::*;
221
222 #[test]
223 fn test_conversions_extension() {
224 let r = &EXTENSION;
225 assert_eq!(r.name() as &str, "arithmetic.conversions");
226 assert_eq!(r.types().count(), 0);
227 }
228
229 #[test]
230 fn test_conversions() {
231 assert!(
233 ConvertOpDef::itobool
234 .with_log_width(1)
235 .to_extension_op()
236 .is_none(),
237 "type arguments invalid"
238 );
239
240 let o = ConvertOpDef::itobool.without_log_width();
242 let ext_op: ExtensionOp = o.clone().to_extension_op().unwrap();
243
244 assert_eq!(ConvertOpType::from_op(&ext_op).unwrap(), o);
245 assert_eq!(
246 ConvertOpDef::from_op(&ext_op).unwrap(),
247 ConvertOpDef::itobool
248 );
249 }
250
251 #[rstest]
252 #[case::itobool_false(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 0).unwrap().into()], &[Value::false_val()])]
253 #[case::itobool_true(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 1).unwrap().into()], &[Value::true_val()])]
254 #[case::ifrombool_false(ConvertOpDef::ifrombool.without_log_width(), &[Value::false_val()], &[ConstInt::new_u(0, 0).unwrap().into()])]
255 #[case::ifrombool_true(ConvertOpDef::ifrombool.without_log_width(), &[Value::true_val()], &[ConstInt::new_u(0, 1).unwrap().into()])]
256 #[case::itousize(ConvertOpDef::itousize.without_log_width(), &[ConstInt::new_u(6, 42).unwrap().into()], &[ConstUsize::new(42).into()])]
257 #[case::ifromusize(ConvertOpDef::ifromusize.without_log_width(), &[ConstUsize::new(42).into()], &[ConstInt::new_u(6, 42).unwrap().into()])]
258 fn convert_fold(
259 #[case] op: ConvertOpType,
260 #[case] inputs: &[Value],
261 #[case] outputs: &[Value],
262 ) {
263 use crate::ops::Value;
264
265 let consts: Vec<(IncomingPort, Value)> = inputs
266 .iter()
267 .enumerate()
268 .map(|(i, v)| (i.into(), v.clone()))
269 .collect();
270
271 let res = op
272 .to_extension_op()
273 .unwrap()
274 .constant_fold(&consts)
275 .unwrap();
276
277 for (i, expected) in outputs.iter().enumerate() {
278 let res_val: &Value = &res.get(i).unwrap().1;
279
280 assert_eq!(res_val, expected);
281 }
282 }
283}