hugr_core/std_extensions/arithmetic/
conversions.rs1use std::sync::{Arc, Weak};
4
5use strum::{EnumIter, EnumString, IntoStaticStr};
6
7use crate::extension::prelude::sum_with_error;
8use crate::extension::prelude::{bool_t, string_type, usize_t};
9use crate::extension::simple_op::{HasConcrete, HasDef};
10use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError};
11use crate::extension::{ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc};
12use crate::ops::OpName;
13use crate::ops::{custom::ExtensionOp, NamedOp};
14use crate::std_extensions::arithmetic::int_ops::int_polytype;
15use crate::std_extensions::arithmetic::int_types::int_type;
16use crate::types::{TypeArg, TypeRV};
17use crate::Extension;
18
19use super::float_types::float64_type;
20use super::int_types::{get_log_width, int_tv};
21use lazy_static::lazy_static;
22mod const_fold;
23pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");
25pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
27
28#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
30#[allow(missing_docs, non_camel_case_types)]
31#[non_exhaustive]
32pub enum ConvertOpDef {
33 trunc_u,
34 trunc_s,
35 convert_u,
36 convert_s,
37 itobool,
38 ifrombool,
39 itostring_u,
40 itostring_s,
41 itousize,
42 ifromusize,
43}
44
45impl MakeOpDef for ConvertOpDef {
46 fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
47 crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
48 }
49
50 fn extension(&self) -> ExtensionId {
51 EXTENSION_ID.to_owned()
52 }
53
54 fn extension_ref(&self) -> Weak<Extension> {
55 Arc::downgrade(&EXTENSION)
56 }
57
58 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
59 use ConvertOpDef::*;
60 match self {
61 trunc_s | trunc_u => int_polytype(
62 1,
63 vec![float64_type()],
64 TypeRV::from(sum_with_error(int_tv(0))),
65 ),
66 convert_s | convert_u => int_polytype(1, vec![int_tv(0)], vec![float64_type()]),
67 itobool => int_polytype(0, vec![int_type(0)], vec![bool_t()]),
68 ifrombool => int_polytype(0, vec![bool_t()], vec![int_type(0)]),
69 itostring_u | itostring_s => int_polytype(1, vec![int_tv(0)], vec![string_type()]),
70 itousize => int_polytype(0, vec![int_type(6)], vec![usize_t()]),
71 ifromusize => int_polytype(0, vec![usize_t()], vec![int_type(6)]),
72 }
73 .into()
74 }
75
76 fn description(&self) -> String {
77 use ConvertOpDef::*;
78 match self {
79 trunc_u => "float to unsigned int",
80 trunc_s => "float to signed int",
81 convert_u => "unsigned int to float",
82 convert_s => "signed int to float",
83 itobool => "convert a 1-bit integer to bool (1 is true, 0 is false)",
84 ifrombool => "convert from bool into a 1-bit integer (1 is true, 0 is false)",
85 itostring_s => "convert a signed integer to its string representation",
86 itostring_u => "convert an unsigned integer to its string representation",
87 itousize => "convert a 64b unsigned integer to its usize representation",
88 ifromusize => "convert a usize to a 64b unsigned integer",
89 }
90 .to_string()
91 }
92
93 fn post_opdef(&self, def: &mut OpDef) {
94 const_fold::set_fold(self, def)
95 }
96}
97
98impl ConvertOpDef {
99 pub fn without_log_width(self) -> ConvertOpType {
102 ConvertOpType {
103 def: self,
104 log_width: None,
105 }
106 }
107 pub fn with_log_width(self, log_width: u8) -> ConvertOpType {
110 ConvertOpType {
111 def: self,
112 log_width: Some(log_width),
113 }
114 }
115}
116#[derive(Debug, Clone, PartialEq)]
118pub struct ConvertOpType {
119 def: ConvertOpDef,
121 log_width: Option<u8>,
125}
126
127impl ConvertOpType {
128 pub fn def(&self) -> &ConvertOpDef {
130 &self.def
131 }
132
133 pub fn log_widths(&self) -> &[u8] {
135 self.log_width.as_slice()
136 }
137}
138
139impl NamedOp for ConvertOpType {
140 fn name(&self) -> OpName {
141 self.def.name()
142 }
143}
144
145impl MakeExtensionOp for ConvertOpType {
146 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
147 let def = ConvertOpDef::from_def(ext_op.def())?;
148 def.instantiate(ext_op.args())
149 }
150
151 fn type_args(&self) -> Vec<TypeArg> {
152 self.log_width.iter().map(|&n| (n as u64).into()).collect()
153 }
154}
155
156lazy_static! {
157 pub static ref EXTENSION: Arc<Extension> = {
159 Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
160 extension.add_requirements(
161 ExtensionSet::from_iter(vec![
162 super::int_types::EXTENSION_ID,
163 super::float_types::EXTENSION_ID,
164 ]));
165
166 ConvertOpDef::load_all_ops(extension, extension_ref).unwrap();
167 })
168 };
169}
170
171impl MakeRegisteredOp for ConvertOpType {
172 fn extension_id(&self) -> ExtensionId {
173 EXTENSION_ID.to_owned()
174 }
175
176 fn extension_ref(&self) -> Weak<Extension> {
177 Arc::downgrade(&EXTENSION)
178 }
179}
180
181impl HasConcrete for ConvertOpDef {
182 type Concrete = ConvertOpType;
183
184 fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
185 let log_width = match type_args {
186 [] => None,
187 [arg] => Some(get_log_width(arg).map_err(|_| SignatureError::InvalidTypeArgs)?),
188 _ => return Err(SignatureError::InvalidTypeArgs.into()),
189 };
190 Ok(ConvertOpType {
191 def: *self,
192 log_width,
193 })
194 }
195}
196
197impl HasDef for ConvertOpType {
198 type Def = ConvertOpDef;
199}
200
201#[cfg(test)]
202mod test {
203 use rstest::rstest;
204
205 use crate::extension::prelude::ConstUsize;
206 use crate::ops::Value;
207 use crate::std_extensions::arithmetic::int_types::ConstInt;
208 use crate::IncomingPort;
209
210 use super::*;
211
212 #[test]
213 fn test_conversions_extension() {
214 let r = &EXTENSION;
215 assert_eq!(r.name() as &str, "arithmetic.conversions");
216 assert_eq!(r.types().count(), 0);
217 }
218
219 #[test]
220 fn test_conversions() {
221 assert!(
223 ConvertOpDef::itobool
224 .with_log_width(1)
225 .to_extension_op()
226 .is_none(),
227 "type arguments invalid"
228 );
229
230 let o = ConvertOpDef::itobool.without_log_width();
232 let ext_op: ExtensionOp = o.clone().to_extension_op().unwrap();
233
234 assert_eq!(ConvertOpType::from_op(&ext_op).unwrap(), o);
235 assert_eq!(
236 ConvertOpDef::from_op(&ext_op).unwrap(),
237 ConvertOpDef::itobool
238 );
239 }
240
241 #[rstest]
242 #[case::itobool_false(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 0).unwrap().into()], &[Value::false_val()])]
243 #[case::itobool_true(ConvertOpDef::itobool.without_log_width(), &[ConstInt::new_u(0, 1).unwrap().into()], &[Value::true_val()])]
244 #[case::ifrombool_false(ConvertOpDef::ifrombool.without_log_width(), &[Value::false_val()], &[ConstInt::new_u(0, 0).unwrap().into()])]
245 #[case::ifrombool_true(ConvertOpDef::ifrombool.without_log_width(), &[Value::true_val()], &[ConstInt::new_u(0, 1).unwrap().into()])]
246 #[case::itousize(ConvertOpDef::itousize.without_log_width(), &[ConstInt::new_u(6, 42).unwrap().into()], &[ConstUsize::new(42).into()])]
247 #[case::ifromusize(ConvertOpDef::ifromusize.without_log_width(), &[ConstUsize::new(42).into()], &[ConstInt::new_u(6, 42).unwrap().into()])]
248 fn convert_fold(
249 #[case] op: ConvertOpType,
250 #[case] inputs: &[Value],
251 #[case] outputs: &[Value],
252 ) {
253 use crate::ops::Value;
254
255 let consts: Vec<(IncomingPort, Value)> = inputs
256 .iter()
257 .enumerate()
258 .map(|(i, v)| (i.into(), v.clone()))
259 .collect();
260
261 let res = op
262 .to_extension_op()
263 .unwrap()
264 .constant_fold(&consts)
265 .unwrap();
266
267 for (i, expected) in outputs.iter().enumerate() {
268 let res_val: &Value = &res.get(i).unwrap().1;
269
270 assert_eq!(res_val, expected);
271 }
272 }
273}