hugr_core/std_extensions/arithmetic/
int_ops.rs

1//! Basic integer operations.
2
3use std::sync::{Arc, LazyLock, Weak};
4
5use super::int_types::{LOG_WIDTH_TYPE_PARAM, get_log_width, int_tv};
6use crate::extension::prelude::{bool_t, sum_with_error};
7use crate::extension::simple_op::{
8    HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
9};
10use crate::extension::{CustomValidator, OpDef, SignatureFunc, ValidateJustArgs};
11use crate::ops::OpName;
12use crate::ops::custom::ExtensionOp;
13use crate::types::{FuncValueType, PolyFuncTypeRV, TypeRowRV};
14use crate::utils::collect_array;
15
16use crate::{
17    Extension,
18    extension::{ExtensionId, SignatureError},
19    types::{Type, type_param::TypeArg},
20};
21
22use strum::{EnumIter, EnumString, IntoStaticStr};
23
24mod const_fold;
25
26/// The extension identifier.
27pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int");
28/// Extension version.
29pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
30
31struct IOValidator {
32    // whether the first type argument should be greater than or equal to the second
33    f_ge_s: bool,
34}
35
36impl ValidateJustArgs for IOValidator {
37    fn validate(&self, arg_values: &[TypeArg]) -> Result<(), SignatureError> {
38        let [arg0, arg1] = collect_array(arg_values);
39        let i: u8 = get_log_width(arg0)?;
40        let o: u8 = get_log_width(arg1)?;
41        let cmp = if self.f_ge_s { i >= o } else { i <= o };
42        if !cmp {
43            return Err(SignatureError::InvalidTypeArgs);
44        }
45        Ok(())
46    }
47}
48/// Integer extension operation definitions.
49#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
50#[allow(missing_docs, non_camel_case_types)]
51#[non_exhaustive]
52pub enum IntOpDef {
53    iwiden_u,
54    iwiden_s,
55    inarrow_u,
56    inarrow_s,
57    ieq,
58    ine,
59    ilt_u,
60    ilt_s,
61    igt_u,
62    igt_s,
63    ile_u,
64    ile_s,
65    ige_u,
66    ige_s,
67    imax_u,
68    imax_s,
69    imin_u,
70    imin_s,
71    iadd,
72    isub,
73    ineg,
74    imul,
75    idivmod_checked_u,
76    idivmod_u,
77    idivmod_checked_s,
78    idivmod_s,
79    idiv_checked_u,
80    idiv_u,
81    imod_checked_u,
82    imod_u,
83    idiv_checked_s,
84    idiv_s,
85    imod_checked_s,
86    imod_s,
87    ipow,
88    iabs,
89    iand,
90    ior,
91    ixor,
92    inot,
93    ishl,
94    ishr,
95    irotl,
96    irotr,
97    iu_to_s,
98    is_to_u,
99}
100
101impl MakeOpDef for IntOpDef {
102    fn opdef_id(&self) -> OpName {
103        <&Self as Into<&'static str>>::into(self).into()
104    }
105    fn from_def(op_def: &OpDef) -> Result<Self, crate::extension::simple_op::OpLoadError> {
106        crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
107    }
108
109    fn extension(&self) -> ExtensionId {
110        EXTENSION_ID.clone()
111    }
112
113    fn extension_ref(&self) -> Weak<Extension> {
114        Arc::downgrade(&EXTENSION)
115    }
116
117    fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
118        use IntOpDef::*;
119        let tv0 = int_tv(0);
120        match self {
121            iwiden_s | iwiden_u => CustomValidator::new(
122                int_polytype(2, vec![tv0], vec![int_tv(1)]),
123                IOValidator { f_ge_s: false },
124            )
125            .into(),
126            inarrow_s | inarrow_u => CustomValidator::new(
127                int_polytype(2, tv0, sum_ty_with_err(int_tv(1))),
128                IOValidator { f_ge_s: true },
129            )
130            .into(),
131            ieq | ine | ilt_u | ilt_s | igt_u | igt_s | ile_u | ile_s | ige_u | ige_s => {
132                int_polytype(1, vec![tv0; 2], vec![bool_t()]).into()
133            }
134            imax_u | imax_s | imin_u | imin_s | iadd | isub | imul | iand | ior | ixor | ipow => {
135                ibinop_sig().into()
136            }
137            ineg | iabs | inot | iu_to_s | is_to_u => iunop_sig().into(),
138            idivmod_checked_u | idivmod_checked_s => {
139                let intpair: TypeRowRV = vec![tv0; 2].into();
140                int_polytype(
141                    1,
142                    intpair.clone(),
143                    sum_ty_with_err(Type::new_tuple(intpair)),
144                )
145            }
146            .into(),
147            idivmod_u | idivmod_s => {
148                let intpair: TypeRowRV = vec![tv0; 2].into();
149                int_polytype(1, intpair.clone(), intpair.clone())
150            }
151            .into(),
152            idiv_u | idiv_s => int_polytype(1, vec![tv0.clone(); 2], vec![tv0]).into(),
153            idiv_checked_u | idiv_checked_s => {
154                int_polytype(1, vec![tv0.clone(); 2], sum_ty_with_err(tv0)).into()
155            }
156            imod_checked_u | imod_checked_s => {
157                int_polytype(1, vec![tv0.clone(); 2], sum_ty_with_err(tv0)).into()
158            }
159            imod_u | imod_s => int_polytype(1, vec![tv0.clone(); 2], vec![tv0]).into(),
160            ishl | ishr | irotl | irotr => int_polytype(1, vec![tv0.clone(); 2], vec![tv0]).into(),
161        }
162    }
163
164    fn description(&self) -> String {
165        use IntOpDef::*;
166
167        match self {
168            iwiden_u => "widen an unsigned integer to a wider one with the same value",
169            iwiden_s => "widen a signed integer to a wider one with the same value",
170            inarrow_u => "narrow an unsigned integer to a narrower one with the same value if possible",
171            inarrow_s => "narrow a signed integer to a narrower one with the same value if possible",
172            ieq => "equality test",
173            ine => "inequality test",
174            ilt_u => "\"less than\" as unsigned integers",
175            ilt_s => "\"less than\" as signed integers",
176            igt_u =>"\"greater than\" as unsigned integers",
177            igt_s => "\"greater than\" as signed integers",
178            ile_u => "\"less than or equal\" as unsigned integers",
179            ile_s => "\"less than or equal\" as signed integers",
180            ige_u => "\"greater than or equal\" as unsigned integers",
181            ige_s => "\"greater than or equal\" as signed integers",
182            imax_u => "maximum of unsigned integers",
183            imax_s => "maximum of signed integers",
184            imin_u => "minimum of unsigned integers",
185            imin_s => "minimum of signed integers",
186            iadd => "addition modulo 2^N (signed and unsigned versions are the same op)",
187            isub => "subtraction modulo 2^N (signed and unsigned versions are the same op)",
188            ineg => "negation modulo 2^N (signed and unsigned versions are the same op)",
189            imul => "multiplication modulo 2^N (signed and unsigned versions are the same op)",
190            idivmod_checked_u => "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^N, generates unsigned q, r where \
191            q*m+r=n, 0<=r<m (m=0 is an error)",
192            idivmod_u => "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^N, generates unsigned q, r where \
193            q*m+r=n, 0<=r<m (m=0 will call panic)",
194            idivmod_checked_s => "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^N, generates \
195            signed q and unsigned r where q*m+r=n, 0<=r<m (m=0 is an error)",
196            idivmod_s => "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^N, generates \
197            signed q and unsigned r where q*m+r=n, 0<=r<m (m=0 will call panic)",
198            idiv_checked_u => "as idivmod_checked_u but discarding the second output",
199            idiv_u => "as idivmod_u but discarding the second output",
200            imod_checked_u => "as idivmod_checked_u but discarding the first output",
201            imod_u => "as idivmod_u but discarding the first output",
202            idiv_checked_s => "as idivmod_checked_s but discarding the second output",
203            idiv_s => "as idivmod_s but discarding the second output",
204            imod_checked_s => "as idivmod_checked_s but discarding the first output",
205            imod_s => "as idivmod_s but discarding the first output",
206            ipow => "raise first input to the power of second input, the exponent is treated as an unsigned integer",
207            iabs => "convert signed to unsigned by taking absolute value",
208            iand => "bitwise AND",
209            ior => "bitwise OR",
210            ixor => "bitwise XOR",
211            inot => "bitwise NOT",
212            ishl => "shift first input left by k bits where k is unsigned interpretation of second input \
213            (leftmost bits dropped, rightmost bits set to zero",
214            ishr => "shift first input right by k bits where k is unsigned interpretation of second input \
215            (rightmost bits dropped, leftmost bits set to zero)",
216            irotl => "rotate first input left by k bits where k is unsigned interpretation of second input \
217            (leftmost bits replace rightmost bits)",
218            irotr => "rotate first input right by k bits where k is unsigned interpretation of second input \
219            (rightmost bits replace leftmost bits)",
220            is_to_u => "convert signed to unsigned by taking absolute value",
221            iu_to_s => "convert unsigned to signed by taking absolute value",
222        }.into()
223    }
224
225    fn post_opdef(&self, def: &mut OpDef) {
226        const_fold::set_fold(self, def);
227    }
228}
229
230/// Returns a polytype composed by a function type, and a number of integer width type parameters.
231pub(in crate::std_extensions::arithmetic) fn int_polytype(
232    n_vars: usize,
233    input: impl Into<TypeRowRV>,
234    output: impl Into<TypeRowRV>,
235) -> PolyFuncTypeRV {
236    PolyFuncTypeRV::new(
237        vec![LOG_WIDTH_TYPE_PARAM; n_vars],
238        FuncValueType::new(input, output),
239    )
240}
241
242fn ibinop_sig() -> PolyFuncTypeRV {
243    let int_type_var = int_tv(0);
244
245    int_polytype(1, vec![int_type_var.clone(); 2], vec![int_type_var])
246}
247
248fn iunop_sig() -> PolyFuncTypeRV {
249    let int_type_var = int_tv(0);
250    int_polytype(1, vec![int_type_var.clone()], vec![int_type_var])
251}
252
253/// Extension for basic integer operations.
254pub static EXTENSION: LazyLock<Arc<Extension>> = LazyLock::new(|| {
255    Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
256        IntOpDef::load_all_ops(extension, extension_ref).unwrap();
257    })
258});
259
260impl HasConcrete for IntOpDef {
261    type Concrete = ConcreteIntOp;
262
263    fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
264        let log_widths: Vec<u8> = type_args
265            .iter()
266            .map(|a| get_log_width(a).map_err(|_| SignatureError::InvalidTypeArgs))
267            .collect::<Result<_, _>>()?;
268        Ok(ConcreteIntOp {
269            def: *self,
270            log_widths,
271        })
272    }
273}
274
275impl HasDef for ConcreteIntOp {
276    type Def = IntOpDef;
277}
278
279/// Concrete integer operation with integer widths set.
280#[derive(Debug, Clone, PartialEq)]
281#[non_exhaustive]
282pub struct ConcreteIntOp {
283    /// The kind of int op.
284    pub def: IntOpDef,
285    /// The width parameters of the int op. These are interpreted differently,
286    /// depending on `def`. The types of inputs and outputs of the op will have
287    /// [`int_type`]s of these widths.
288    ///
289    /// [`int_type`]: crate::std_extensions::arithmetic::int_types::int_type
290    pub log_widths: Vec<u8>,
291}
292
293impl MakeExtensionOp for ConcreteIntOp {
294    fn op_id(&self) -> OpName {
295        self.def.opdef_id()
296    }
297
298    fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
299        let def = IntOpDef::from_def(ext_op.def())?;
300        def.instantiate(ext_op.args())
301    }
302
303    fn type_args(&self) -> Vec<TypeArg> {
304        self.log_widths
305            .iter()
306            .map(|&n| u64::from(n).into())
307            .collect()
308    }
309}
310
311impl MakeRegisteredOp for ConcreteIntOp {
312    fn extension_id(&self) -> ExtensionId {
313        EXTENSION_ID.clone()
314    }
315
316    fn extension_ref(&self) -> Weak<Extension> {
317        Arc::downgrade(&EXTENSION)
318    }
319}
320
321impl IntOpDef {
322    /// Initialize a [`ConcreteIntOp`] from a [`IntOpDef`] which requires no
323    /// integer widths set.
324    #[must_use]
325    pub fn without_log_width(self) -> ConcreteIntOp {
326        ConcreteIntOp {
327            def: self,
328            log_widths: vec![],
329        }
330    }
331    /// Initialize a [`ConcreteIntOp`] from a [`IntOpDef`] which requires one
332    /// integer width set.
333    #[must_use]
334    pub fn with_log_width(self, log_width: u8) -> ConcreteIntOp {
335        ConcreteIntOp {
336            def: self,
337            log_widths: vec![log_width],
338        }
339    }
340    /// Initialize a [`ConcreteIntOp`] from a [`IntOpDef`] which requires two
341    /// integer widths set.
342    #[must_use]
343    pub fn with_two_log_widths(self, first_log_width: u8, second_log_width: u8) -> ConcreteIntOp {
344        ConcreteIntOp {
345            def: self,
346            log_widths: vec![first_log_width, second_log_width],
347        }
348    }
349}
350
351fn sum_ty_with_err(t: Type) -> Type {
352    sum_with_error(t).into()
353}
354
355#[cfg(test)]
356mod test {
357    use rstest::rstest;
358
359    use crate::{
360        ops::dataflow::DataflowOpTrait, std_extensions::arithmetic::int_types::int_type,
361        types::Signature,
362    };
363
364    use super::*;
365
366    #[test]
367    fn test_int_ops_extension() {
368        assert_eq!(EXTENSION.name() as &str, "arithmetic.int");
369        assert_eq!(EXTENSION.types().count(), 0);
370        for (name, _) in EXTENSION.operations() {
371            assert!(name.starts_with('i'));
372        }
373    }
374
375    #[test]
376    fn test_binary_signatures() {
377        assert_eq!(
378            IntOpDef::iwiden_s
379                .with_two_log_widths(3, 4)
380                .to_extension_op()
381                .unwrap()
382                .signature()
383                .as_ref(),
384            &Signature::new(int_type(3), int_type(4))
385        );
386        assert_eq!(
387            IntOpDef::iwiden_s
388                .with_two_log_widths(3, 3)
389                .to_extension_op()
390                .unwrap()
391                .signature()
392                .as_ref(),
393            &Signature::new_endo(int_type(3))
394        );
395        assert_eq!(
396            IntOpDef::inarrow_s
397                .with_two_log_widths(3, 3)
398                .to_extension_op()
399                .unwrap()
400                .signature()
401                .as_ref(),
402            &Signature::new(int_type(3), sum_ty_with_err(int_type(3)))
403        );
404        assert!(
405            IntOpDef::iwiden_u
406                .with_two_log_widths(4, 3)
407                .to_extension_op()
408                .is_none(),
409            "type arguments invalid"
410        );
411
412        assert_eq!(
413            IntOpDef::inarrow_s
414                .with_two_log_widths(2, 1)
415                .to_extension_op()
416                .unwrap()
417                .signature()
418                .as_ref(),
419            &Signature::new(int_type(2), sum_ty_with_err(int_type(1)))
420        );
421
422        assert!(
423            IntOpDef::inarrow_u
424                .with_two_log_widths(1, 2)
425                .to_extension_op()
426                .is_none()
427        );
428    }
429
430    #[rstest]
431    #[case::iadd(IntOpDef::iadd.with_log_width(5), &[1, 2], &[3], 5)]
432    #[case::isub(IntOpDef::isub.with_log_width(5), &[5, 2], &[3], 5)]
433    #[case::imul(IntOpDef::imul.with_log_width(5), &[2, 8], &[16], 5)]
434    #[case::idiv(IntOpDef::idiv_u.with_log_width(5), &[37, 8], &[4], 5)]
435    #[case::imod(IntOpDef::imod_u.with_log_width(5), &[43, 8], &[3], 5)]
436    #[case::ipow(IntOpDef::ipow.with_log_width(5), &[2, 8], &[256], 5)]
437    #[case::iu_to_s(IntOpDef::iu_to_s.with_log_width(5), &[42], &[42], 5)]
438    #[case::is_to_u(IntOpDef::is_to_u.with_log_width(5), &[42], &[42], 5)]
439    #[should_panic(expected = "too large to be converted to signed")]
440    #[case::iu_to_s_panic(IntOpDef::iu_to_s.with_log_width(5), &[u64::from(u32::MAX)], &[], 5)]
441    #[should_panic(expected = "Cannot convert negative integer")]
442    #[case::is_to_u_panic(IntOpDef::is_to_u.with_log_width(5), &[u64::from(0u32.wrapping_sub(42))], &[], 5)]
443    fn int_fold(
444        #[case] op: ConcreteIntOp,
445        #[case] inputs: &[u64],
446        #[case] outputs: &[u64],
447        #[case] log_width: u8,
448    ) {
449        use crate::ops::Value;
450        use crate::std_extensions::arithmetic::int_types::ConstInt;
451
452        let consts: Vec<_> = inputs
453            .iter()
454            .enumerate()
455            .map(|(i, &x)| {
456                (
457                    i.into(),
458                    Value::extension(ConstInt::new_u(log_width, x).unwrap()),
459                )
460            })
461            .collect();
462
463        let res = op
464            .to_extension_op()
465            .unwrap()
466            .constant_fold(&consts)
467            .unwrap();
468
469        for (i, &expected) in outputs.iter().enumerate() {
470            let res_val: u64 = res
471                .get(i)
472                .unwrap()
473                .1
474                .get_custom_value::<ConstInt>()
475                .expect("This function assumes all incoming constants are floats.")
476                .value_u();
477
478            assert_eq!(res_val, expected);
479        }
480    }
481}