1use std::sync::{Arc, LazyLock, Weak};
4
5use super::int_types::{LOG_WIDTH_TYPE_PARAM, get_log_width, int_tv};
6use crate::extension::prelude::{bool_t, sum_with_error};
7use crate::extension::simple_op::{
8 HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
9};
10use crate::extension::{CustomValidator, OpDef, SignatureFunc, ValidateJustArgs};
11use crate::ops::OpName;
12use crate::ops::custom::ExtensionOp;
13use crate::types::{FuncValueType, PolyFuncTypeRV, TypeRowRV};
14use crate::utils::collect_array;
15
16use crate::{
17 Extension,
18 extension::{ExtensionId, SignatureError},
19 types::{Type, type_param::TypeArg},
20};
21
22use strum::{EnumIter, EnumString, IntoStaticStr};
23
24mod const_fold;
25
26pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int");
28pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
30
31struct IOValidator {
32 f_ge_s: bool,
34}
35
36impl ValidateJustArgs for IOValidator {
37 fn validate(&self, arg_values: &[TypeArg]) -> Result<(), SignatureError> {
38 let [arg0, arg1] = collect_array(arg_values);
39 let i: u8 = get_log_width(arg0)?;
40 let o: u8 = get_log_width(arg1)?;
41 let cmp = if self.f_ge_s { i >= o } else { i <= o };
42 if !cmp {
43 return Err(SignatureError::InvalidTypeArgs);
44 }
45 Ok(())
46 }
47}
48#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
50#[allow(missing_docs, non_camel_case_types)]
51#[non_exhaustive]
52pub enum IntOpDef {
53 iwiden_u,
54 iwiden_s,
55 inarrow_u,
56 inarrow_s,
57 ieq,
58 ine,
59 ilt_u,
60 ilt_s,
61 igt_u,
62 igt_s,
63 ile_u,
64 ile_s,
65 ige_u,
66 ige_s,
67 imax_u,
68 imax_s,
69 imin_u,
70 imin_s,
71 iadd,
72 isub,
73 ineg,
74 imul,
75 idivmod_checked_u,
76 idivmod_u,
77 idivmod_checked_s,
78 idivmod_s,
79 idiv_checked_u,
80 idiv_u,
81 imod_checked_u,
82 imod_u,
83 idiv_checked_s,
84 idiv_s,
85 imod_checked_s,
86 imod_s,
87 ipow,
88 iabs,
89 iand,
90 ior,
91 ixor,
92 inot,
93 ishl,
94 ishr,
95 irotl,
96 irotr,
97 iu_to_s,
98 is_to_u,
99}
100
101impl MakeOpDef for IntOpDef {
102 fn opdef_id(&self) -> OpName {
103 <&Self as Into<&'static str>>::into(self).into()
104 }
105 fn from_def(op_def: &OpDef) -> Result<Self, crate::extension::simple_op::OpLoadError> {
106 crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
107 }
108
109 fn extension(&self) -> ExtensionId {
110 EXTENSION_ID.clone()
111 }
112
113 fn extension_ref(&self) -> Weak<Extension> {
114 Arc::downgrade(&EXTENSION)
115 }
116
117 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
118 use IntOpDef::*;
119 let tv0 = int_tv(0);
120 match self {
121 iwiden_s | iwiden_u => CustomValidator::new(
122 int_polytype(2, vec![tv0], vec![int_tv(1)]),
123 IOValidator { f_ge_s: false },
124 )
125 .into(),
126 inarrow_s | inarrow_u => CustomValidator::new(
127 int_polytype(2, tv0, sum_ty_with_err(int_tv(1))),
128 IOValidator { f_ge_s: true },
129 )
130 .into(),
131 ieq | ine | ilt_u | ilt_s | igt_u | igt_s | ile_u | ile_s | ige_u | ige_s => {
132 int_polytype(1, vec![tv0; 2], vec![bool_t()]).into()
133 }
134 imax_u | imax_s | imin_u | imin_s | iadd | isub | imul | iand | ior | ixor | ipow => {
135 ibinop_sig().into()
136 }
137 ineg | iabs | inot | iu_to_s | is_to_u => iunop_sig().into(),
138 idivmod_checked_u | idivmod_checked_s => {
139 let intpair: TypeRowRV = vec![tv0; 2].into();
140 int_polytype(
141 1,
142 intpair.clone(),
143 sum_ty_with_err(Type::new_tuple(intpair)),
144 )
145 }
146 .into(),
147 idivmod_u | idivmod_s => {
148 let intpair: TypeRowRV = vec![tv0; 2].into();
149 int_polytype(1, intpair.clone(), intpair.clone())
150 }
151 .into(),
152 idiv_u | idiv_s => int_polytype(1, vec![tv0.clone(); 2], vec![tv0]).into(),
153 idiv_checked_u | idiv_checked_s => {
154 int_polytype(1, vec![tv0.clone(); 2], sum_ty_with_err(tv0)).into()
155 }
156 imod_checked_u | imod_checked_s => {
157 int_polytype(1, vec![tv0.clone(); 2], sum_ty_with_err(tv0)).into()
158 }
159 imod_u | imod_s => int_polytype(1, vec![tv0.clone(); 2], vec![tv0]).into(),
160 ishl | ishr | irotl | irotr => int_polytype(1, vec![tv0.clone(); 2], vec![tv0]).into(),
161 }
162 }
163
164 fn description(&self) -> String {
165 use IntOpDef::*;
166
167 match self {
168 iwiden_u => "widen an unsigned integer to a wider one with the same value",
169 iwiden_s => "widen a signed integer to a wider one with the same value",
170 inarrow_u => "narrow an unsigned integer to a narrower one with the same value if possible",
171 inarrow_s => "narrow a signed integer to a narrower one with the same value if possible",
172 ieq => "equality test",
173 ine => "inequality test",
174 ilt_u => "\"less than\" as unsigned integers",
175 ilt_s => "\"less than\" as signed integers",
176 igt_u =>"\"greater than\" as unsigned integers",
177 igt_s => "\"greater than\" as signed integers",
178 ile_u => "\"less than or equal\" as unsigned integers",
179 ile_s => "\"less than or equal\" as signed integers",
180 ige_u => "\"greater than or equal\" as unsigned integers",
181 ige_s => "\"greater than or equal\" as signed integers",
182 imax_u => "maximum of unsigned integers",
183 imax_s => "maximum of signed integers",
184 imin_u => "minimum of unsigned integers",
185 imin_s => "minimum of signed integers",
186 iadd => "addition modulo 2^N (signed and unsigned versions are the same op)",
187 isub => "subtraction modulo 2^N (signed and unsigned versions are the same op)",
188 ineg => "negation modulo 2^N (signed and unsigned versions are the same op)",
189 imul => "multiplication modulo 2^N (signed and unsigned versions are the same op)",
190 idivmod_checked_u => "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^N, generates unsigned q, r where \
191 q*m+r=n, 0<=r<m (m=0 is an error)",
192 idivmod_u => "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^N, generates unsigned q, r where \
193 q*m+r=n, 0<=r<m (m=0 will call panic)",
194 idivmod_checked_s => "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^N, generates \
195 signed q and unsigned r where q*m+r=n, 0<=r<m (m=0 is an error)",
196 idivmod_s => "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^N, generates \
197 signed q and unsigned r where q*m+r=n, 0<=r<m (m=0 will call panic)",
198 idiv_checked_u => "as idivmod_checked_u but discarding the second output",
199 idiv_u => "as idivmod_u but discarding the second output",
200 imod_checked_u => "as idivmod_checked_u but discarding the first output",
201 imod_u => "as idivmod_u but discarding the first output",
202 idiv_checked_s => "as idivmod_checked_s but discarding the second output",
203 idiv_s => "as idivmod_s but discarding the second output",
204 imod_checked_s => "as idivmod_checked_s but discarding the first output",
205 imod_s => "as idivmod_s but discarding the first output",
206 ipow => "raise first input to the power of second input, the exponent is treated as an unsigned integer",
207 iabs => "convert signed to unsigned by taking absolute value",
208 iand => "bitwise AND",
209 ior => "bitwise OR",
210 ixor => "bitwise XOR",
211 inot => "bitwise NOT",
212 ishl => "shift first input left by k bits where k is unsigned interpretation of second input \
213 (leftmost bits dropped, rightmost bits set to zero",
214 ishr => "shift first input right by k bits where k is unsigned interpretation of second input \
215 (rightmost bits dropped, leftmost bits set to zero)",
216 irotl => "rotate first input left by k bits where k is unsigned interpretation of second input \
217 (leftmost bits replace rightmost bits)",
218 irotr => "rotate first input right by k bits where k is unsigned interpretation of second input \
219 (rightmost bits replace leftmost bits)",
220 is_to_u => "convert signed to unsigned by taking absolute value",
221 iu_to_s => "convert unsigned to signed by taking absolute value",
222 }.into()
223 }
224
225 fn post_opdef(&self, def: &mut OpDef) {
226 const_fold::set_fold(self, def);
227 }
228}
229
230pub(in crate::std_extensions::arithmetic) fn int_polytype(
232 n_vars: usize,
233 input: impl Into<TypeRowRV>,
234 output: impl Into<TypeRowRV>,
235) -> PolyFuncTypeRV {
236 PolyFuncTypeRV::new(
237 vec![LOG_WIDTH_TYPE_PARAM; n_vars],
238 FuncValueType::new(input, output),
239 )
240}
241
242fn ibinop_sig() -> PolyFuncTypeRV {
243 let int_type_var = int_tv(0);
244
245 int_polytype(1, vec![int_type_var.clone(); 2], vec![int_type_var])
246}
247
248fn iunop_sig() -> PolyFuncTypeRV {
249 let int_type_var = int_tv(0);
250 int_polytype(1, vec![int_type_var.clone()], vec![int_type_var])
251}
252
253pub static EXTENSION: LazyLock<Arc<Extension>> = LazyLock::new(|| {
255 Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
256 IntOpDef::load_all_ops(extension, extension_ref).unwrap();
257 })
258});
259
260impl HasConcrete for IntOpDef {
261 type Concrete = ConcreteIntOp;
262
263 fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
264 let log_widths: Vec<u8> = type_args
265 .iter()
266 .map(|a| get_log_width(a).map_err(|_| SignatureError::InvalidTypeArgs))
267 .collect::<Result<_, _>>()?;
268 Ok(ConcreteIntOp {
269 def: *self,
270 log_widths,
271 })
272 }
273}
274
275impl HasDef for ConcreteIntOp {
276 type Def = IntOpDef;
277}
278
279#[derive(Debug, Clone, PartialEq)]
281#[non_exhaustive]
282pub struct ConcreteIntOp {
283 pub def: IntOpDef,
285 pub log_widths: Vec<u8>,
291}
292
293impl MakeExtensionOp for ConcreteIntOp {
294 fn op_id(&self) -> OpName {
295 self.def.opdef_id()
296 }
297
298 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
299 let def = IntOpDef::from_def(ext_op.def())?;
300 def.instantiate(ext_op.args())
301 }
302
303 fn type_args(&self) -> Vec<TypeArg> {
304 self.log_widths
305 .iter()
306 .map(|&n| u64::from(n).into())
307 .collect()
308 }
309}
310
311impl MakeRegisteredOp for ConcreteIntOp {
312 fn extension_id(&self) -> ExtensionId {
313 EXTENSION_ID.clone()
314 }
315
316 fn extension_ref(&self) -> Weak<Extension> {
317 Arc::downgrade(&EXTENSION)
318 }
319}
320
321impl IntOpDef {
322 #[must_use]
325 pub fn without_log_width(self) -> ConcreteIntOp {
326 ConcreteIntOp {
327 def: self,
328 log_widths: vec![],
329 }
330 }
331 #[must_use]
334 pub fn with_log_width(self, log_width: u8) -> ConcreteIntOp {
335 ConcreteIntOp {
336 def: self,
337 log_widths: vec![log_width],
338 }
339 }
340 #[must_use]
343 pub fn with_two_log_widths(self, first_log_width: u8, second_log_width: u8) -> ConcreteIntOp {
344 ConcreteIntOp {
345 def: self,
346 log_widths: vec![first_log_width, second_log_width],
347 }
348 }
349}
350
351fn sum_ty_with_err(t: Type) -> Type {
352 sum_with_error(t).into()
353}
354
355#[cfg(test)]
356mod test {
357 use rstest::rstest;
358
359 use crate::{
360 ops::dataflow::DataflowOpTrait, std_extensions::arithmetic::int_types::int_type,
361 types::Signature,
362 };
363
364 use super::*;
365
366 #[test]
367 fn test_int_ops_extension() {
368 assert_eq!(EXTENSION.name() as &str, "arithmetic.int");
369 assert_eq!(EXTENSION.types().count(), 0);
370 for (name, _) in EXTENSION.operations() {
371 assert!(name.starts_with('i'));
372 }
373 }
374
375 #[test]
376 fn test_binary_signatures() {
377 assert_eq!(
378 IntOpDef::iwiden_s
379 .with_two_log_widths(3, 4)
380 .to_extension_op()
381 .unwrap()
382 .signature()
383 .as_ref(),
384 &Signature::new(int_type(3), int_type(4))
385 );
386 assert_eq!(
387 IntOpDef::iwiden_s
388 .with_two_log_widths(3, 3)
389 .to_extension_op()
390 .unwrap()
391 .signature()
392 .as_ref(),
393 &Signature::new_endo(int_type(3))
394 );
395 assert_eq!(
396 IntOpDef::inarrow_s
397 .with_two_log_widths(3, 3)
398 .to_extension_op()
399 .unwrap()
400 .signature()
401 .as_ref(),
402 &Signature::new(int_type(3), sum_ty_with_err(int_type(3)))
403 );
404 assert!(
405 IntOpDef::iwiden_u
406 .with_two_log_widths(4, 3)
407 .to_extension_op()
408 .is_none(),
409 "type arguments invalid"
410 );
411
412 assert_eq!(
413 IntOpDef::inarrow_s
414 .with_two_log_widths(2, 1)
415 .to_extension_op()
416 .unwrap()
417 .signature()
418 .as_ref(),
419 &Signature::new(int_type(2), sum_ty_with_err(int_type(1)))
420 );
421
422 assert!(
423 IntOpDef::inarrow_u
424 .with_two_log_widths(1, 2)
425 .to_extension_op()
426 .is_none()
427 );
428 }
429
430 #[rstest]
431 #[case::iadd(IntOpDef::iadd.with_log_width(5), &[1, 2], &[3], 5)]
432 #[case::isub(IntOpDef::isub.with_log_width(5), &[5, 2], &[3], 5)]
433 #[case::imul(IntOpDef::imul.with_log_width(5), &[2, 8], &[16], 5)]
434 #[case::idiv(IntOpDef::idiv_u.with_log_width(5), &[37, 8], &[4], 5)]
435 #[case::imod(IntOpDef::imod_u.with_log_width(5), &[43, 8], &[3], 5)]
436 #[case::ipow(IntOpDef::ipow.with_log_width(5), &[2, 8], &[256], 5)]
437 #[case::iu_to_s(IntOpDef::iu_to_s.with_log_width(5), &[42], &[42], 5)]
438 #[case::is_to_u(IntOpDef::is_to_u.with_log_width(5), &[42], &[42], 5)]
439 #[should_panic(expected = "too large to be converted to signed")]
440 #[case::iu_to_s_panic(IntOpDef::iu_to_s.with_log_width(5), &[u64::from(u32::MAX)], &[], 5)]
441 #[should_panic(expected = "Cannot convert negative integer")]
442 #[case::is_to_u_panic(IntOpDef::is_to_u.with_log_width(5), &[u64::from(0u32.wrapping_sub(42))], &[], 5)]
443 fn int_fold(
444 #[case] op: ConcreteIntOp,
445 #[case] inputs: &[u64],
446 #[case] outputs: &[u64],
447 #[case] log_width: u8,
448 ) {
449 use crate::ops::Value;
450 use crate::std_extensions::arithmetic::int_types::ConstInt;
451
452 let consts: Vec<_> = inputs
453 .iter()
454 .enumerate()
455 .map(|(i, &x)| {
456 (
457 i.into(),
458 Value::extension(ConstInt::new_u(log_width, x).unwrap()),
459 )
460 })
461 .collect();
462
463 let res = op
464 .to_extension_op()
465 .unwrap()
466 .constant_fold(&consts)
467 .unwrap();
468
469 for (i, &expected) in outputs.iter().enumerate() {
470 let res_val: u64 = res
471 .get(i)
472 .unwrap()
473 .1
474 .get_custom_value::<ConstInt>()
475 .expect("This function assumes all incoming constants are floats.")
476 .value_u();
477
478 assert_eq!(res_val, expected);
479 }
480 }
481}