1use std::sync::{Arc, Weak};
4
5use super::int_types::{LOG_WIDTH_TYPE_PARAM, get_log_width, int_tv};
6use crate::extension::prelude::{bool_t, sum_with_error};
7use crate::extension::simple_op::{
8 HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
9};
10use crate::extension::{CustomValidator, OpDef, SignatureFunc, ValidateJustArgs};
11use crate::ops::OpName;
12use crate::ops::custom::ExtensionOp;
13use crate::types::{FuncValueType, PolyFuncTypeRV, TypeRowRV};
14use crate::utils::collect_array;
15
16use crate::{
17 Extension,
18 extension::{ExtensionId, SignatureError},
19 types::{Type, type_param::TypeArg},
20};
21
22use lazy_static::lazy_static;
23use strum::{EnumIter, EnumString, IntoStaticStr};
24
25mod const_fold;
26
27pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int");
29pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
31
32struct IOValidator {
33 f_ge_s: bool,
35}
36
37impl ValidateJustArgs for IOValidator {
38 fn validate(&self, arg_values: &[TypeArg]) -> Result<(), SignatureError> {
39 let [arg0, arg1] = collect_array(arg_values);
40 let i: u8 = get_log_width(arg0)?;
41 let o: u8 = get_log_width(arg1)?;
42 let cmp = if self.f_ge_s { i >= o } else { i <= o };
43 if !cmp {
44 return Err(SignatureError::InvalidTypeArgs);
45 }
46 Ok(())
47 }
48}
49#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
51#[allow(missing_docs, non_camel_case_types)]
52#[non_exhaustive]
53pub enum IntOpDef {
54 iwiden_u,
55 iwiden_s,
56 inarrow_u,
57 inarrow_s,
58 ieq,
59 ine,
60 ilt_u,
61 ilt_s,
62 igt_u,
63 igt_s,
64 ile_u,
65 ile_s,
66 ige_u,
67 ige_s,
68 imax_u,
69 imax_s,
70 imin_u,
71 imin_s,
72 iadd,
73 isub,
74 ineg,
75 imul,
76 idivmod_checked_u,
77 idivmod_u,
78 idivmod_checked_s,
79 idivmod_s,
80 idiv_checked_u,
81 idiv_u,
82 imod_checked_u,
83 imod_u,
84 idiv_checked_s,
85 idiv_s,
86 imod_checked_s,
87 imod_s,
88 ipow,
89 iabs,
90 iand,
91 ior,
92 ixor,
93 inot,
94 ishl,
95 ishr,
96 irotl,
97 irotr,
98 iu_to_s,
99 is_to_u,
100}
101
102impl MakeOpDef for IntOpDef {
103 fn opdef_id(&self) -> OpName {
104 <&Self as Into<&'static str>>::into(self).into()
105 }
106 fn from_def(op_def: &OpDef) -> Result<Self, crate::extension::simple_op::OpLoadError> {
107 crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
108 }
109
110 fn extension(&self) -> ExtensionId {
111 EXTENSION_ID.clone()
112 }
113
114 fn extension_ref(&self) -> Weak<Extension> {
115 Arc::downgrade(&EXTENSION)
116 }
117
118 fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
119 use IntOpDef::*;
120 let tv0 = int_tv(0);
121 match self {
122 iwiden_s | iwiden_u => CustomValidator::new(
123 int_polytype(2, vec![tv0], vec![int_tv(1)]),
124 IOValidator { f_ge_s: false },
125 )
126 .into(),
127 inarrow_s | inarrow_u => CustomValidator::new(
128 int_polytype(2, tv0, sum_ty_with_err(int_tv(1))),
129 IOValidator { f_ge_s: true },
130 )
131 .into(),
132 ieq | ine | ilt_u | ilt_s | igt_u | igt_s | ile_u | ile_s | ige_u | ige_s => {
133 int_polytype(1, vec![tv0; 2], vec![bool_t()]).into()
134 }
135 imax_u | imax_s | imin_u | imin_s | iadd | isub | imul | iand | ior | ixor | ipow => {
136 ibinop_sig().into()
137 }
138 ineg | iabs | inot | iu_to_s | is_to_u => iunop_sig().into(),
139 idivmod_checked_u | idivmod_checked_s => {
140 let intpair: TypeRowRV = vec![tv0; 2].into();
141 int_polytype(
142 1,
143 intpair.clone(),
144 sum_ty_with_err(Type::new_tuple(intpair)),
145 )
146 }
147 .into(),
148 idivmod_u | idivmod_s => {
149 let intpair: TypeRowRV = vec![tv0; 2].into();
150 int_polytype(1, intpair.clone(), intpair.clone())
151 }
152 .into(),
153 idiv_u | idiv_s => int_polytype(1, vec![tv0.clone(); 2], vec![tv0]).into(),
154 idiv_checked_u | idiv_checked_s => {
155 int_polytype(1, vec![tv0.clone(); 2], sum_ty_with_err(tv0)).into()
156 }
157 imod_checked_u | imod_checked_s => {
158 int_polytype(1, vec![tv0.clone(); 2], sum_ty_with_err(tv0)).into()
159 }
160 imod_u | imod_s => int_polytype(1, vec![tv0.clone(); 2], vec![tv0]).into(),
161 ishl | ishr | irotl | irotr => int_polytype(1, vec![tv0.clone(); 2], vec![tv0]).into(),
162 }
163 }
164
165 fn description(&self) -> String {
166 use IntOpDef::*;
167
168 match self {
169 iwiden_u => "widen an unsigned integer to a wider one with the same value",
170 iwiden_s => "widen a signed integer to a wider one with the same value",
171 inarrow_u => "narrow an unsigned integer to a narrower one with the same value if possible",
172 inarrow_s => "narrow a signed integer to a narrower one with the same value if possible",
173 ieq => "equality test",
174 ine => "inequality test",
175 ilt_u => "\"less than\" as unsigned integers",
176 ilt_s => "\"less than\" as signed integers",
177 igt_u =>"\"greater than\" as unsigned integers",
178 igt_s => "\"greater than\" as signed integers",
179 ile_u => "\"less than or equal\" as unsigned integers",
180 ile_s => "\"less than or equal\" as signed integers",
181 ige_u => "\"greater than or equal\" as unsigned integers",
182 ige_s => "\"greater than or equal\" as signed integers",
183 imax_u => "maximum of unsigned integers",
184 imax_s => "maximum of signed integers",
185 imin_u => "minimum of unsigned integers",
186 imin_s => "minimum of signed integers",
187 iadd => "addition modulo 2^N (signed and unsigned versions are the same op)",
188 isub => "subtraction modulo 2^N (signed and unsigned versions are the same op)",
189 ineg => "negation modulo 2^N (signed and unsigned versions are the same op)",
190 imul => "multiplication modulo 2^N (signed and unsigned versions are the same op)",
191 idivmod_checked_u => "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^N, generates unsigned q, r where \
192 q*m+r=n, 0<=r<m (m=0 is an error)",
193 idivmod_u => "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^N, generates unsigned q, r where \
194 q*m+r=n, 0<=r<m (m=0 will call panic)",
195 idivmod_checked_s => "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^N, generates \
196 signed q and unsigned r where q*m+r=n, 0<=r<m (m=0 is an error)",
197 idivmod_s => "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^N, generates \
198 signed q and unsigned r where q*m+r=n, 0<=r<m (m=0 will call panic)",
199 idiv_checked_u => "as idivmod_checked_u but discarding the second output",
200 idiv_u => "as idivmod_u but discarding the second output",
201 imod_checked_u => "as idivmod_checked_u but discarding the first output",
202 imod_u => "as idivmod_u but discarding the first output",
203 idiv_checked_s => "as idivmod_checked_s but discarding the second output",
204 idiv_s => "as idivmod_s but discarding the second output",
205 imod_checked_s => "as idivmod_checked_s but discarding the first output",
206 imod_s => "as idivmod_s but discarding the first output",
207 ipow => "raise first input to the power of second input, the exponent is treated as an unsigned integer",
208 iabs => "convert signed to unsigned by taking absolute value",
209 iand => "bitwise AND",
210 ior => "bitwise OR",
211 ixor => "bitwise XOR",
212 inot => "bitwise NOT",
213 ishl => "shift first input left by k bits where k is unsigned interpretation of second input \
214 (leftmost bits dropped, rightmost bits set to zero",
215 ishr => "shift first input right by k bits where k is unsigned interpretation of second input \
216 (rightmost bits dropped, leftmost bits set to zero)",
217 irotl => "rotate first input left by k bits where k is unsigned interpretation of second input \
218 (leftmost bits replace rightmost bits)",
219 irotr => "rotate first input right by k bits where k is unsigned interpretation of second input \
220 (rightmost bits replace leftmost bits)",
221 is_to_u => "convert signed to unsigned by taking absolute value",
222 iu_to_s => "convert unsigned to signed by taking absolute value",
223 }.into()
224 }
225
226 fn post_opdef(&self, def: &mut OpDef) {
227 const_fold::set_fold(self, def);
228 }
229}
230
231pub(in crate::std_extensions::arithmetic) fn int_polytype(
233 n_vars: usize,
234 input: impl Into<TypeRowRV>,
235 output: impl Into<TypeRowRV>,
236) -> PolyFuncTypeRV {
237 PolyFuncTypeRV::new(
238 vec![LOG_WIDTH_TYPE_PARAM; n_vars],
239 FuncValueType::new(input, output),
240 )
241}
242
243fn ibinop_sig() -> PolyFuncTypeRV {
244 let int_type_var = int_tv(0);
245
246 int_polytype(1, vec![int_type_var.clone(); 2], vec![int_type_var])
247}
248
249fn iunop_sig() -> PolyFuncTypeRV {
250 let int_type_var = int_tv(0);
251 int_polytype(1, vec![int_type_var.clone()], vec![int_type_var])
252}
253
254lazy_static! {
255 pub static ref EXTENSION: Arc<Extension> = {
257 Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
258 IntOpDef::load_all_ops(extension, extension_ref).unwrap();
259 })
260 };
261}
262
263impl HasConcrete for IntOpDef {
264 type Concrete = ConcreteIntOp;
265
266 fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
267 let log_widths: Vec<u8> = type_args
268 .iter()
269 .map(|a| get_log_width(a).map_err(|_| SignatureError::InvalidTypeArgs))
270 .collect::<Result<_, _>>()?;
271 Ok(ConcreteIntOp {
272 def: *self,
273 log_widths,
274 })
275 }
276}
277
278impl HasDef for ConcreteIntOp {
279 type Def = IntOpDef;
280}
281
282#[derive(Debug, Clone, PartialEq)]
284#[non_exhaustive]
285pub struct ConcreteIntOp {
286 pub def: IntOpDef,
288 pub log_widths: Vec<u8>,
294}
295
296impl MakeExtensionOp for ConcreteIntOp {
297 fn op_id(&self) -> OpName {
298 self.def.opdef_id()
299 }
300
301 fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
302 let def = IntOpDef::from_def(ext_op.def())?;
303 def.instantiate(ext_op.args())
304 }
305
306 fn type_args(&self) -> Vec<TypeArg> {
307 self.log_widths
308 .iter()
309 .map(|&n| u64::from(n).into())
310 .collect()
311 }
312}
313
314impl MakeRegisteredOp for ConcreteIntOp {
315 fn extension_id(&self) -> ExtensionId {
316 EXTENSION_ID.clone()
317 }
318
319 fn extension_ref(&self) -> Weak<Extension> {
320 Arc::downgrade(&EXTENSION)
321 }
322}
323
324impl IntOpDef {
325 #[must_use]
328 pub fn without_log_width(self) -> ConcreteIntOp {
329 ConcreteIntOp {
330 def: self,
331 log_widths: vec![],
332 }
333 }
334 #[must_use]
337 pub fn with_log_width(self, log_width: u8) -> ConcreteIntOp {
338 ConcreteIntOp {
339 def: self,
340 log_widths: vec![log_width],
341 }
342 }
343 #[must_use]
346 pub fn with_two_log_widths(self, first_log_width: u8, second_log_width: u8) -> ConcreteIntOp {
347 ConcreteIntOp {
348 def: self,
349 log_widths: vec![first_log_width, second_log_width],
350 }
351 }
352}
353
354fn sum_ty_with_err(t: Type) -> Type {
355 sum_with_error(t).into()
356}
357
358#[cfg(test)]
359mod test {
360 use rstest::rstest;
361
362 use crate::{
363 ops::dataflow::DataflowOpTrait, std_extensions::arithmetic::int_types::int_type,
364 types::Signature,
365 };
366
367 use super::*;
368
369 #[test]
370 fn test_int_ops_extension() {
371 assert_eq!(EXTENSION.name() as &str, "arithmetic.int");
372 assert_eq!(EXTENSION.types().count(), 0);
373 for (name, _) in EXTENSION.operations() {
374 assert!(name.starts_with('i'));
375 }
376 }
377
378 #[test]
379 fn test_binary_signatures() {
380 assert_eq!(
381 IntOpDef::iwiden_s
382 .with_two_log_widths(3, 4)
383 .to_extension_op()
384 .unwrap()
385 .signature()
386 .as_ref(),
387 &Signature::new(int_type(3), int_type(4))
388 );
389 assert_eq!(
390 IntOpDef::iwiden_s
391 .with_two_log_widths(3, 3)
392 .to_extension_op()
393 .unwrap()
394 .signature()
395 .as_ref(),
396 &Signature::new_endo(int_type(3))
397 );
398 assert_eq!(
399 IntOpDef::inarrow_s
400 .with_two_log_widths(3, 3)
401 .to_extension_op()
402 .unwrap()
403 .signature()
404 .as_ref(),
405 &Signature::new(int_type(3), sum_ty_with_err(int_type(3)))
406 );
407 assert!(
408 IntOpDef::iwiden_u
409 .with_two_log_widths(4, 3)
410 .to_extension_op()
411 .is_none(),
412 "type arguments invalid"
413 );
414
415 assert_eq!(
416 IntOpDef::inarrow_s
417 .with_two_log_widths(2, 1)
418 .to_extension_op()
419 .unwrap()
420 .signature()
421 .as_ref(),
422 &Signature::new(int_type(2), sum_ty_with_err(int_type(1)))
423 );
424
425 assert!(
426 IntOpDef::inarrow_u
427 .with_two_log_widths(1, 2)
428 .to_extension_op()
429 .is_none()
430 );
431 }
432
433 #[rstest]
434 #[case::iadd(IntOpDef::iadd.with_log_width(5), &[1, 2], &[3], 5)]
435 #[case::isub(IntOpDef::isub.with_log_width(5), &[5, 2], &[3], 5)]
436 #[case::imul(IntOpDef::imul.with_log_width(5), &[2, 8], &[16], 5)]
437 #[case::idiv(IntOpDef::idiv_u.with_log_width(5), &[37, 8], &[4], 5)]
438 #[case::imod(IntOpDef::imod_u.with_log_width(5), &[43, 8], &[3], 5)]
439 #[case::ipow(IntOpDef::ipow.with_log_width(5), &[2, 8], &[256], 5)]
440 #[case::iu_to_s(IntOpDef::iu_to_s.with_log_width(5), &[42], &[42], 5)]
441 #[case::is_to_u(IntOpDef::is_to_u.with_log_width(5), &[42], &[42], 5)]
442 #[should_panic(expected = "too large to be converted to signed")]
443 #[case::iu_to_s_panic(IntOpDef::iu_to_s.with_log_width(5), &[u64::from(u32::MAX)], &[], 5)]
444 #[should_panic(expected = "Cannot convert negative integer")]
445 #[case::is_to_u_panic(IntOpDef::is_to_u.with_log_width(5), &[u64::from(0u32.wrapping_sub(42))], &[], 5)]
446 fn int_fold(
447 #[case] op: ConcreteIntOp,
448 #[case] inputs: &[u64],
449 #[case] outputs: &[u64],
450 #[case] log_width: u8,
451 ) {
452 use crate::ops::Value;
453 use crate::std_extensions::arithmetic::int_types::ConstInt;
454
455 let consts: Vec<_> = inputs
456 .iter()
457 .enumerate()
458 .map(|(i, &x)| {
459 (
460 i.into(),
461 Value::extension(ConstInt::new_u(log_width, x).unwrap()),
462 )
463 })
464 .collect();
465
466 let res = op
467 .to_extension_op()
468 .unwrap()
469 .constant_fold(&consts)
470 .unwrap();
471
472 for (i, &expected) in outputs.iter().enumerate() {
473 let res_val: u64 = res
474 .get(i)
475 .unwrap()
476 .1
477 .get_custom_value::<ConstInt>()
478 .expect("This function assumes all incoming constants are floats.")
479 .value_u();
480
481 assert_eq!(res_val, expected);
482 }
483 }
484}