hugr_core/std_extensions/arithmetic/
int_ops.rs

1//! Basic integer operations.
2
3use std::sync::{Arc, Weak};
4
5use super::int_types::{LOG_WIDTH_TYPE_PARAM, get_log_width, int_tv};
6use crate::extension::prelude::{bool_t, sum_with_error};
7use crate::extension::simple_op::{
8    HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
9};
10use crate::extension::{CustomValidator, OpDef, SignatureFunc, ValidateJustArgs};
11use crate::ops::OpName;
12use crate::ops::custom::ExtensionOp;
13use crate::types::{FuncValueType, PolyFuncTypeRV, TypeRowRV};
14use crate::utils::collect_array;
15
16use crate::{
17    Extension,
18    extension::{ExtensionId, SignatureError},
19    types::{Type, type_param::TypeArg},
20};
21
22use lazy_static::lazy_static;
23use strum::{EnumIter, EnumString, IntoStaticStr};
24
25mod const_fold;
26
27/// The extension identifier.
28pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int");
29/// Extension version.
30pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
31
32struct IOValidator {
33    // whether the first type argument should be greater than or equal to the second
34    f_ge_s: bool,
35}
36
37impl ValidateJustArgs for IOValidator {
38    fn validate(&self, arg_values: &[TypeArg]) -> Result<(), SignatureError> {
39        let [arg0, arg1] = collect_array(arg_values);
40        let i: u8 = get_log_width(arg0)?;
41        let o: u8 = get_log_width(arg1)?;
42        let cmp = if self.f_ge_s { i >= o } else { i <= o };
43        if !cmp {
44            return Err(SignatureError::InvalidTypeArgs);
45        }
46        Ok(())
47    }
48}
49/// Integer extension operation definitions.
50#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
51#[allow(missing_docs, non_camel_case_types)]
52#[non_exhaustive]
53pub enum IntOpDef {
54    iwiden_u,
55    iwiden_s,
56    inarrow_u,
57    inarrow_s,
58    ieq,
59    ine,
60    ilt_u,
61    ilt_s,
62    igt_u,
63    igt_s,
64    ile_u,
65    ile_s,
66    ige_u,
67    ige_s,
68    imax_u,
69    imax_s,
70    imin_u,
71    imin_s,
72    iadd,
73    isub,
74    ineg,
75    imul,
76    idivmod_checked_u,
77    idivmod_u,
78    idivmod_checked_s,
79    idivmod_s,
80    idiv_checked_u,
81    idiv_u,
82    imod_checked_u,
83    imod_u,
84    idiv_checked_s,
85    idiv_s,
86    imod_checked_s,
87    imod_s,
88    ipow,
89    iabs,
90    iand,
91    ior,
92    ixor,
93    inot,
94    ishl,
95    ishr,
96    irotl,
97    irotr,
98    iu_to_s,
99    is_to_u,
100}
101
102impl MakeOpDef for IntOpDef {
103    fn opdef_id(&self) -> OpName {
104        <&Self as Into<&'static str>>::into(self).into()
105    }
106    fn from_def(op_def: &OpDef) -> Result<Self, crate::extension::simple_op::OpLoadError> {
107        crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
108    }
109
110    fn extension(&self) -> ExtensionId {
111        EXTENSION_ID.clone()
112    }
113
114    fn extension_ref(&self) -> Weak<Extension> {
115        Arc::downgrade(&EXTENSION)
116    }
117
118    fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
119        use IntOpDef::*;
120        let tv0 = int_tv(0);
121        match self {
122            iwiden_s | iwiden_u => CustomValidator::new(
123                int_polytype(2, vec![tv0], vec![int_tv(1)]),
124                IOValidator { f_ge_s: false },
125            )
126            .into(),
127            inarrow_s | inarrow_u => CustomValidator::new(
128                int_polytype(2, tv0, sum_ty_with_err(int_tv(1))),
129                IOValidator { f_ge_s: true },
130            )
131            .into(),
132            ieq | ine | ilt_u | ilt_s | igt_u | igt_s | ile_u | ile_s | ige_u | ige_s => {
133                int_polytype(1, vec![tv0; 2], vec![bool_t()]).into()
134            }
135            imax_u | imax_s | imin_u | imin_s | iadd | isub | imul | iand | ior | ixor | ipow => {
136                ibinop_sig().into()
137            }
138            ineg | iabs | inot | iu_to_s | is_to_u => iunop_sig().into(),
139            idivmod_checked_u | idivmod_checked_s => {
140                let intpair: TypeRowRV = vec![tv0; 2].into();
141                int_polytype(
142                    1,
143                    intpair.clone(),
144                    sum_ty_with_err(Type::new_tuple(intpair)),
145                )
146            }
147            .into(),
148            idivmod_u | idivmod_s => {
149                let intpair: TypeRowRV = vec![tv0; 2].into();
150                int_polytype(1, intpair.clone(), intpair.clone())
151            }
152            .into(),
153            idiv_u | idiv_s => int_polytype(1, vec![tv0.clone(); 2], vec![tv0]).into(),
154            idiv_checked_u | idiv_checked_s => {
155                int_polytype(1, vec![tv0.clone(); 2], sum_ty_with_err(tv0)).into()
156            }
157            imod_checked_u | imod_checked_s => {
158                int_polytype(1, vec![tv0.clone(); 2], sum_ty_with_err(tv0)).into()
159            }
160            imod_u | imod_s => int_polytype(1, vec![tv0.clone(); 2], vec![tv0]).into(),
161            ishl | ishr | irotl | irotr => int_polytype(1, vec![tv0.clone(); 2], vec![tv0]).into(),
162        }
163    }
164
165    fn description(&self) -> String {
166        use IntOpDef::*;
167
168        match self {
169            iwiden_u => "widen an unsigned integer to a wider one with the same value",
170            iwiden_s => "widen a signed integer to a wider one with the same value",
171            inarrow_u => "narrow an unsigned integer to a narrower one with the same value if possible",
172            inarrow_s => "narrow a signed integer to a narrower one with the same value if possible",
173            ieq => "equality test",
174            ine => "inequality test",
175            ilt_u => "\"less than\" as unsigned integers",
176            ilt_s => "\"less than\" as signed integers",
177            igt_u =>"\"greater than\" as unsigned integers",
178            igt_s => "\"greater than\" as signed integers",
179            ile_u => "\"less than or equal\" as unsigned integers",
180            ile_s => "\"less than or equal\" as signed integers",
181            ige_u => "\"greater than or equal\" as unsigned integers",
182            ige_s => "\"greater than or equal\" as signed integers",
183            imax_u => "maximum of unsigned integers",
184            imax_s => "maximum of signed integers",
185            imin_u => "minimum of unsigned integers",
186            imin_s => "minimum of signed integers",
187            iadd => "addition modulo 2^N (signed and unsigned versions are the same op)",
188            isub => "subtraction modulo 2^N (signed and unsigned versions are the same op)",
189            ineg => "negation modulo 2^N (signed and unsigned versions are the same op)",
190            imul => "multiplication modulo 2^N (signed and unsigned versions are the same op)",
191            idivmod_checked_u => "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^N, generates unsigned q, r where \
192            q*m+r=n, 0<=r<m (m=0 is an error)",
193            idivmod_u => "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^N, generates unsigned q, r where \
194            q*m+r=n, 0<=r<m (m=0 will call panic)",
195            idivmod_checked_s => "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^N, generates \
196            signed q and unsigned r where q*m+r=n, 0<=r<m (m=0 is an error)",
197            idivmod_s => "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^N, generates \
198            signed q and unsigned r where q*m+r=n, 0<=r<m (m=0 will call panic)",
199            idiv_checked_u => "as idivmod_checked_u but discarding the second output",
200            idiv_u => "as idivmod_u but discarding the second output",
201            imod_checked_u => "as idivmod_checked_u but discarding the first output",
202            imod_u => "as idivmod_u but discarding the first output",
203            idiv_checked_s => "as idivmod_checked_s but discarding the second output",
204            idiv_s => "as idivmod_s but discarding the second output",
205            imod_checked_s => "as idivmod_checked_s but discarding the first output",
206            imod_s => "as idivmod_s but discarding the first output",
207            ipow => "raise first input to the power of second input, the exponent is treated as an unsigned integer",
208            iabs => "convert signed to unsigned by taking absolute value",
209            iand => "bitwise AND",
210            ior => "bitwise OR",
211            ixor => "bitwise XOR",
212            inot => "bitwise NOT",
213            ishl => "shift first input left by k bits where k is unsigned interpretation of second input \
214            (leftmost bits dropped, rightmost bits set to zero",
215            ishr => "shift first input right by k bits where k is unsigned interpretation of second input \
216            (rightmost bits dropped, leftmost bits set to zero)",
217            irotl => "rotate first input left by k bits where k is unsigned interpretation of second input \
218            (leftmost bits replace rightmost bits)",
219            irotr => "rotate first input right by k bits where k is unsigned interpretation of second input \
220            (rightmost bits replace leftmost bits)",
221            is_to_u => "convert signed to unsigned by taking absolute value",
222            iu_to_s => "convert unsigned to signed by taking absolute value",
223        }.into()
224    }
225
226    fn post_opdef(&self, def: &mut OpDef) {
227        const_fold::set_fold(self, def);
228    }
229}
230
231/// Returns a polytype composed by a function type, and a number of integer width type parameters.
232pub(in crate::std_extensions::arithmetic) fn int_polytype(
233    n_vars: usize,
234    input: impl Into<TypeRowRV>,
235    output: impl Into<TypeRowRV>,
236) -> PolyFuncTypeRV {
237    PolyFuncTypeRV::new(
238        vec![LOG_WIDTH_TYPE_PARAM; n_vars],
239        FuncValueType::new(input, output),
240    )
241}
242
243fn ibinop_sig() -> PolyFuncTypeRV {
244    let int_type_var = int_tv(0);
245
246    int_polytype(1, vec![int_type_var.clone(); 2], vec![int_type_var])
247}
248
249fn iunop_sig() -> PolyFuncTypeRV {
250    let int_type_var = int_tv(0);
251    int_polytype(1, vec![int_type_var.clone()], vec![int_type_var])
252}
253
254lazy_static! {
255    /// Extension for basic integer operations.
256    pub static ref EXTENSION: Arc<Extension> = {
257        Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
258            IntOpDef::load_all_ops(extension, extension_ref).unwrap();
259        })
260    };
261}
262
263impl HasConcrete for IntOpDef {
264    type Concrete = ConcreteIntOp;
265
266    fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
267        let log_widths: Vec<u8> = type_args
268            .iter()
269            .map(|a| get_log_width(a).map_err(|_| SignatureError::InvalidTypeArgs))
270            .collect::<Result<_, _>>()?;
271        Ok(ConcreteIntOp {
272            def: *self,
273            log_widths,
274        })
275    }
276}
277
278impl HasDef for ConcreteIntOp {
279    type Def = IntOpDef;
280}
281
282/// Concrete integer operation with integer widths set.
283#[derive(Debug, Clone, PartialEq)]
284#[non_exhaustive]
285pub struct ConcreteIntOp {
286    /// The kind of int op.
287    pub def: IntOpDef,
288    /// The width parameters of the int op. These are interpreted differently,
289    /// depending on `def`. The types of inputs and outputs of the op will have
290    /// [`int_type`]s of these widths.
291    ///
292    /// [`int_type`]: crate::std_extensions::arithmetic::int_types::int_type
293    pub log_widths: Vec<u8>,
294}
295
296impl MakeExtensionOp for ConcreteIntOp {
297    fn op_id(&self) -> OpName {
298        self.def.opdef_id()
299    }
300
301    fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
302        let def = IntOpDef::from_def(ext_op.def())?;
303        def.instantiate(ext_op.args())
304    }
305
306    fn type_args(&self) -> Vec<TypeArg> {
307        self.log_widths
308            .iter()
309            .map(|&n| u64::from(n).into())
310            .collect()
311    }
312}
313
314impl MakeRegisteredOp for ConcreteIntOp {
315    fn extension_id(&self) -> ExtensionId {
316        EXTENSION_ID.clone()
317    }
318
319    fn extension_ref(&self) -> Weak<Extension> {
320        Arc::downgrade(&EXTENSION)
321    }
322}
323
324impl IntOpDef {
325    /// Initialize a [`ConcreteIntOp`] from a [`IntOpDef`] which requires no
326    /// integer widths set.
327    #[must_use]
328    pub fn without_log_width(self) -> ConcreteIntOp {
329        ConcreteIntOp {
330            def: self,
331            log_widths: vec![],
332        }
333    }
334    /// Initialize a [`ConcreteIntOp`] from a [`IntOpDef`] which requires one
335    /// integer width set.
336    #[must_use]
337    pub fn with_log_width(self, log_width: u8) -> ConcreteIntOp {
338        ConcreteIntOp {
339            def: self,
340            log_widths: vec![log_width],
341        }
342    }
343    /// Initialize a [`ConcreteIntOp`] from a [`IntOpDef`] which requires two
344    /// integer widths set.
345    #[must_use]
346    pub fn with_two_log_widths(self, first_log_width: u8, second_log_width: u8) -> ConcreteIntOp {
347        ConcreteIntOp {
348            def: self,
349            log_widths: vec![first_log_width, second_log_width],
350        }
351    }
352}
353
354fn sum_ty_with_err(t: Type) -> Type {
355    sum_with_error(t).into()
356}
357
358#[cfg(test)]
359mod test {
360    use rstest::rstest;
361
362    use crate::{
363        ops::dataflow::DataflowOpTrait, std_extensions::arithmetic::int_types::int_type,
364        types::Signature,
365    };
366
367    use super::*;
368
369    #[test]
370    fn test_int_ops_extension() {
371        assert_eq!(EXTENSION.name() as &str, "arithmetic.int");
372        assert_eq!(EXTENSION.types().count(), 0);
373        for (name, _) in EXTENSION.operations() {
374            assert!(name.starts_with('i'));
375        }
376    }
377
378    #[test]
379    fn test_binary_signatures() {
380        assert_eq!(
381            IntOpDef::iwiden_s
382                .with_two_log_widths(3, 4)
383                .to_extension_op()
384                .unwrap()
385                .signature()
386                .as_ref(),
387            &Signature::new(int_type(3), int_type(4))
388        );
389        assert_eq!(
390            IntOpDef::iwiden_s
391                .with_two_log_widths(3, 3)
392                .to_extension_op()
393                .unwrap()
394                .signature()
395                .as_ref(),
396            &Signature::new_endo(int_type(3))
397        );
398        assert_eq!(
399            IntOpDef::inarrow_s
400                .with_two_log_widths(3, 3)
401                .to_extension_op()
402                .unwrap()
403                .signature()
404                .as_ref(),
405            &Signature::new(int_type(3), sum_ty_with_err(int_type(3)))
406        );
407        assert!(
408            IntOpDef::iwiden_u
409                .with_two_log_widths(4, 3)
410                .to_extension_op()
411                .is_none(),
412            "type arguments invalid"
413        );
414
415        assert_eq!(
416            IntOpDef::inarrow_s
417                .with_two_log_widths(2, 1)
418                .to_extension_op()
419                .unwrap()
420                .signature()
421                .as_ref(),
422            &Signature::new(int_type(2), sum_ty_with_err(int_type(1)))
423        );
424
425        assert!(
426            IntOpDef::inarrow_u
427                .with_two_log_widths(1, 2)
428                .to_extension_op()
429                .is_none()
430        );
431    }
432
433    #[rstest]
434    #[case::iadd(IntOpDef::iadd.with_log_width(5), &[1, 2], &[3], 5)]
435    #[case::isub(IntOpDef::isub.with_log_width(5), &[5, 2], &[3], 5)]
436    #[case::imul(IntOpDef::imul.with_log_width(5), &[2, 8], &[16], 5)]
437    #[case::idiv(IntOpDef::idiv_u.with_log_width(5), &[37, 8], &[4], 5)]
438    #[case::imod(IntOpDef::imod_u.with_log_width(5), &[43, 8], &[3], 5)]
439    #[case::ipow(IntOpDef::ipow.with_log_width(5), &[2, 8], &[256], 5)]
440    #[case::iu_to_s(IntOpDef::iu_to_s.with_log_width(5), &[42], &[42], 5)]
441    #[case::is_to_u(IntOpDef::is_to_u.with_log_width(5), &[42], &[42], 5)]
442    #[should_panic(expected = "too large to be converted to signed")]
443    #[case::iu_to_s_panic(IntOpDef::iu_to_s.with_log_width(5), &[u64::from(u32::MAX)], &[], 5)]
444    #[should_panic(expected = "Cannot convert negative integer")]
445    #[case::is_to_u_panic(IntOpDef::is_to_u.with_log_width(5), &[u64::from(0u32.wrapping_sub(42))], &[], 5)]
446    fn int_fold(
447        #[case] op: ConcreteIntOp,
448        #[case] inputs: &[u64],
449        #[case] outputs: &[u64],
450        #[case] log_width: u8,
451    ) {
452        use crate::ops::Value;
453        use crate::std_extensions::arithmetic::int_types::ConstInt;
454
455        let consts: Vec<_> = inputs
456            .iter()
457            .enumerate()
458            .map(|(i, &x)| {
459                (
460                    i.into(),
461                    Value::extension(ConstInt::new_u(log_width, x).unwrap()),
462                )
463            })
464            .collect();
465
466        let res = op
467            .to_extension_op()
468            .unwrap()
469            .constant_fold(&consts)
470            .unwrap();
471
472        for (i, &expected) in outputs.iter().enumerate() {
473            let res_val: u64 = res
474                .get(i)
475                .unwrap()
476                .1
477                .get_custom_value::<ConstInt>()
478                .expect("This function assumes all incoming constants are floats.")
479                .value_u();
480
481            assert_eq!(res_val, expected);
482        }
483    }
484}