midenc_dialect_arith/ops/
binary.rs

1use midenc_hir::{derive::operation, effects::*, traits::*, *};
2
3use crate::ArithDialect;
4
5// Implement `derive(InferTypeOpInterface)` with `#[infer]` helper attribute:
6//
7// * `#[infer]` on a result field indicates its type should be inferred from the type of the first
8//   operand field
9// * `#[infer(from = field)]` on a result field indicates its type should be inferred from
10//   the given field. The field is expected to implement `AsRef<Type>`
11// * `#[infer(type = I1)]` on a field indicates that the field should always be inferred to have the given type
12// * `#[infer(with = path::to::function)]` on a field indicates that the given function should be called to
13//   compute the inferred type for that field
14macro_rules! infer_return_ty_for_binary_op {
15    ($Op:ty) => {
16        impl InferTypeOpInterface for $Op {
17            fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> {
18                let lhs = self.lhs().ty().clone();
19                self.result_mut().set_type(lhs);
20                Ok(())
21            }
22        }
23
24    };
25
26
27    ($Op:ty as $manually_specified_ty:expr) => {
28        paste::paste! {
29            impl InferTypeOpInterface for $Op {
30                fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> {
31                    self.result_mut().set_type($manually_specified_ty);
32                    Ok(())
33                }
34            }
35        }
36    };
37
38    ($Op:ty, $($manually_specified_field_name:ident : $manually_specified_field_ty:expr),+) => {
39        paste::paste! {
40            impl InferTypeOpInterface for $Op {
41                fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> {
42                    let lhs = self.lhs().ty().clone();
43                    self.result_mut().set_type(lhs);
44                    $(
45                        self.[<$manually_specified_field_name _mut>]().set_type($manually_specified_field_ty);
46                    )*
47                    Ok(())
48                }
49            }
50        }
51    };
52}
53
54/// Two's complement sum
55#[operation(
56    dialect = ArithDialect,
57    traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType),
58    implements(InferTypeOpInterface, MemoryEffectOpInterface)
59)]
60pub struct Add {
61    #[operand]
62    lhs: AnyInteger,
63    #[operand]
64    rhs: AnyInteger,
65    #[result]
66    result: AnyInteger,
67    #[attr]
68    overflow: Overflow,
69}
70
71infer_return_ty_for_binary_op!(Add);
72has_no_effects!(Add);
73
74/// Two's complement sum with overflow bit
75#[operation(
76    dialect = ArithDialect,
77    traits(BinaryOp, Commutative, SameTypeOperands),
78    implements(InferTypeOpInterface, MemoryEffectOpInterface)
79)]
80pub struct AddOverflowing {
81    #[operand]
82    lhs: AnyInteger,
83    #[operand]
84    rhs: AnyInteger,
85    #[result]
86    overflowed: Bool,
87    #[result]
88    result: AnyInteger,
89}
90
91infer_return_ty_for_binary_op!(AddOverflowing, overflowed: Type::I1);
92has_no_effects!(AddOverflowing);
93
94/// Two's complement difference (subtraction)
95#[operation(
96    dialect = ArithDialect,
97    traits(BinaryOp, SameTypeOperands),
98    implements(InferTypeOpInterface, MemoryEffectOpInterface)
99)]
100pub struct Sub {
101    #[operand]
102    lhs: AnyInteger,
103    #[operand]
104    rhs: AnyInteger,
105    #[result]
106    result: AnyInteger,
107    #[attr]
108    overflow: Overflow,
109}
110
111infer_return_ty_for_binary_op!(Sub);
112has_no_effects!(Sub);
113
114/// Two's complement difference (subtraction) with underflow bit
115#[operation(
116    dialect = ArithDialect,
117    traits(BinaryOp, SameTypeOperands),
118    implements(InferTypeOpInterface, MemoryEffectOpInterface)
119)]
120pub struct SubOverflowing {
121    #[operand]
122    lhs: AnyInteger,
123    #[operand]
124    rhs: AnyInteger,
125    #[result]
126    overflowed: Bool,
127    #[result]
128    result: AnyInteger,
129}
130
131infer_return_ty_for_binary_op!(SubOverflowing, overflowed: Type::I1);
132has_no_effects!(SubOverflowing);
133
134/// Two's complement product
135#[operation(
136    dialect = ArithDialect,
137    traits(BinaryOp, Commutative, SameTypeOperands),
138    implements(InferTypeOpInterface, MemoryEffectOpInterface)
139)]
140pub struct Mul {
141    #[operand]
142    lhs: AnyInteger,
143    #[operand]
144    rhs: AnyInteger,
145    #[result]
146    result: AnyInteger,
147    #[attr]
148    overflow: Overflow,
149}
150
151infer_return_ty_for_binary_op!(Mul);
152has_no_effects!(Mul);
153
154/// Two's complement product with overflow bit
155#[operation(
156    dialect = ArithDialect,
157    traits(BinaryOp, Commutative, SameTypeOperands),
158    implements(InferTypeOpInterface, MemoryEffectOpInterface)
159)]
160pub struct MulOverflowing {
161    #[operand]
162    lhs: AnyInteger,
163    #[operand]
164    rhs: AnyInteger,
165    #[result]
166    overflowed: Bool,
167    #[result]
168    result: AnyInteger,
169}
170
171infer_return_ty_for_binary_op!(MulOverflowing, overflowed: Type::I1);
172has_no_effects!(MulOverflowing);
173
174/// Exponentiation for field elements
175#[operation(
176    dialect = ArithDialect,
177    traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType),
178    implements(InferTypeOpInterface, MemoryEffectOpInterface)
179)]
180pub struct Exp {
181    #[operand]
182    lhs: IntFelt,
183    #[operand]
184    rhs: IntFelt,
185    #[result]
186    result: IntFelt,
187}
188
189infer_return_ty_for_binary_op!(Exp);
190has_no_effects!(Exp);
191
192/// Unsigned integer division, traps on division by zero
193#[operation(
194    dialect = ArithDialect,
195    traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType),
196    implements(InferTypeOpInterface, MemoryEffectOpInterface)
197)]
198pub struct Div {
199    #[operand]
200    lhs: AnyInteger,
201    #[operand]
202    rhs: AnyInteger,
203    #[result]
204    result: AnyInteger,
205}
206
207infer_return_ty_for_binary_op!(Div);
208has_no_effects!(Div);
209
210/// Signed integer division, traps on division by zero or dividing the minimum signed value by -1
211#[operation(
212    dialect = ArithDialect,
213    traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType),
214    implements(InferTypeOpInterface, MemoryEffectOpInterface)
215)]
216pub struct Sdiv {
217    #[operand]
218    lhs: AnyInteger,
219    #[operand]
220    rhs: AnyInteger,
221    #[result]
222    result: AnyInteger,
223}
224
225infer_return_ty_for_binary_op!(Sdiv);
226has_no_effects!(Sdiv);
227
228/// Unsigned integer Euclidean modulo, traps on division by zero
229#[operation(
230    dialect = ArithDialect,
231    traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType),
232    implements(InferTypeOpInterface, MemoryEffectOpInterface)
233)]
234pub struct Mod {
235    #[operand]
236    lhs: AnyInteger,
237    #[operand]
238    rhs: AnyInteger,
239    #[result]
240    result: AnyInteger,
241}
242
243infer_return_ty_for_binary_op!(Mod);
244has_no_effects!(Mod);
245
246/// Signed integer Euclidean modulo, traps on division by zero
247///
248/// The result has the same sign as the dividend (lhs)
249#[operation(
250    dialect = ArithDialect,
251    traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType),
252    implements(InferTypeOpInterface, MemoryEffectOpInterface)
253)]
254pub struct Smod {
255    #[operand]
256    lhs: AnyInteger,
257    #[operand]
258    rhs: AnyInteger,
259    #[result]
260    result: AnyInteger,
261}
262
263infer_return_ty_for_binary_op!(Smod);
264has_no_effects!(Smod);
265
266/// Combined unsigned integer Euclidean division and remainder (modulo).
267///
268/// Traps on division by zero.
269#[operation(
270    dialect = ArithDialect,
271    traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType),
272    implements(InferTypeOpInterface, MemoryEffectOpInterface)
273)]
274pub struct Divmod {
275    #[operand]
276    lhs: AnyInteger,
277    #[operand]
278    rhs: AnyInteger,
279    #[result]
280    remainder: AnyInteger,
281    #[result]
282    quotient: AnyInteger,
283}
284
285has_no_effects!(Divmod);
286
287impl InferTypeOpInterface for Divmod {
288    fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> {
289        let lhs = self.lhs().ty().clone();
290        self.remainder_mut().set_type(lhs.clone());
291        self.quotient_mut().set_type(lhs);
292        Ok(())
293    }
294}
295
296/// Combined signed integer Euclidean division and remainder (modulo).
297///
298/// Traps on division by zero.
299///
300/// The remainder has the same sign as the dividend (lhs)
301#[operation(
302    dialect = ArithDialect,
303    traits(BinaryOp, SameTypeOperands, SameOperandsAndResultType),
304    implements(InferTypeOpInterface, MemoryEffectOpInterface)
305)]
306pub struct Sdivmod {
307    #[operand]
308    lhs: AnyInteger,
309    #[operand]
310    rhs: AnyInteger,
311    #[result]
312    remainder: AnyInteger,
313    #[result]
314    quotient: AnyInteger,
315}
316
317has_no_effects!(Sdivmod);
318
319impl InferTypeOpInterface for Sdivmod {
320    fn infer_return_types(&mut self, _context: &Context) -> Result<(), Report> {
321        let lhs = self.lhs().ty().clone();
322        self.remainder_mut().set_type(lhs.clone());
323        self.quotient_mut().set_type(lhs);
324        Ok(())
325    }
326}
327
328/// Logical AND
329///
330/// Operands must be boolean.
331#[operation(
332    dialect = ArithDialect,
333    traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType),
334    implements(InferTypeOpInterface, MemoryEffectOpInterface)
335)]
336pub struct And {
337    #[operand]
338    lhs: Bool,
339    #[operand]
340    rhs: Bool,
341    #[result]
342    result: Bool,
343}
344
345infer_return_ty_for_binary_op!(And);
346has_no_effects!(And);
347
348/// Logical OR
349///
350/// Operands must be boolean.
351#[operation(
352    dialect = ArithDialect,
353    traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType),
354    implements(InferTypeOpInterface, MemoryEffectOpInterface)
355)]
356pub struct Or {
357    #[operand]
358    lhs: Bool,
359    #[operand]
360    rhs: Bool,
361    #[result]
362    result: Bool,
363}
364
365infer_return_ty_for_binary_op!(Or);
366has_no_effects!(Or);
367
368/// Logical XOR
369///
370/// Operands must be boolean.
371#[operation(
372    dialect = ArithDialect,
373    traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType),
374    implements(InferTypeOpInterface, MemoryEffectOpInterface)
375)]
376pub struct Xor {
377    #[operand]
378    lhs: Bool,
379    #[operand]
380    rhs: Bool,
381    #[result]
382    result: Bool,
383}
384
385infer_return_ty_for_binary_op!(Xor);
386has_no_effects!(Xor);
387
388/// Bitwise AND
389#[operation(
390    dialect = ArithDialect,
391    traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType),
392    implements(InferTypeOpInterface, MemoryEffectOpInterface)
393)]
394pub struct Band {
395    #[operand]
396    lhs: AnyInteger,
397    #[operand]
398    rhs: AnyInteger,
399    #[result]
400    result: AnyInteger,
401}
402
403infer_return_ty_for_binary_op!(Band);
404has_no_effects!(Band);
405
406/// Bitwise OR
407#[operation(
408    dialect = ArithDialect,
409    traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType),
410    implements(InferTypeOpInterface, MemoryEffectOpInterface)
411)]
412pub struct Bor {
413    #[operand]
414    lhs: AnyInteger,
415    #[operand]
416    rhs: AnyInteger,
417    #[result]
418    result: AnyInteger,
419}
420
421infer_return_ty_for_binary_op!(Bor);
422has_no_effects!(Bor);
423
424/// Bitwise XOR
425///
426/// Operands must be boolean.
427#[operation(
428    dialect = ArithDialect,
429    traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType),
430    implements(InferTypeOpInterface, MemoryEffectOpInterface)
431)]
432pub struct Bxor {
433    #[operand]
434    lhs: AnyInteger,
435    #[operand]
436    rhs: AnyInteger,
437    #[result]
438    result: AnyInteger,
439}
440
441infer_return_ty_for_binary_op!(Bxor);
442has_no_effects!(Bxor);
443
444/// Bitwise shift-left
445///
446/// Shifts larger than the bitwidth of the value will be wrapped to zero.
447#[operation(
448    dialect = ArithDialect,
449    traits(BinaryOp),
450    implements(InferTypeOpInterface, MemoryEffectOpInterface)
451)]
452pub struct Shl {
453    #[operand]
454    lhs: AnyInteger,
455    #[operand]
456    shift: UInt32,
457    #[result]
458    result: AnyInteger,
459}
460
461infer_return_ty_for_binary_op!(Shl);
462has_no_effects!(Shl);
463
464/// Bitwise (logical) shift-right
465///
466/// Shifts larger than the bitwidth of the value will effectively truncate the value to zero.
467#[operation(
468    dialect = ArithDialect,
469    traits(BinaryOp),
470    implements(InferTypeOpInterface, MemoryEffectOpInterface)
471)]
472pub struct Shr {
473    #[operand]
474    lhs: AnyInteger,
475    #[operand]
476    shift: UInt32,
477    #[result]
478    result: AnyInteger,
479}
480
481infer_return_ty_for_binary_op!(Shr);
482has_no_effects!(Shr);
483
484/// Arithmetic (signed) shift-right
485///
486/// The result of shifts larger than the bitwidth of the value depend on the sign of the value;
487/// for positive values, it rounds to zero; for negative values, it rounds to MIN.
488#[operation(
489    dialect = ArithDialect,
490    traits(BinaryOp),
491    implements(InferTypeOpInterface, MemoryEffectOpInterface)
492)]
493pub struct Ashr {
494    #[operand]
495    lhs: AnyInteger,
496    #[operand]
497    shift: UInt32,
498    #[result]
499    result: AnyInteger,
500}
501
502infer_return_ty_for_binary_op!(Ashr);
503has_no_effects!(Ashr);
504
505/// Bitwise rotate-left
506///
507/// The rotation count must be < the bitwidth of the value type.
508#[operation(
509    dialect = ArithDialect,
510    traits(BinaryOp),
511    implements(InferTypeOpInterface, MemoryEffectOpInterface)
512)]
513pub struct Rotl {
514    #[operand]
515    lhs: AnyInteger,
516    #[operand]
517    shift: UInt32,
518    #[result]
519    result: AnyInteger,
520}
521
522infer_return_ty_for_binary_op!(Rotl);
523has_no_effects!(Rotl);
524
525/// Bitwise rotate-right
526///
527/// The rotation count must be < the bitwidth of the value type.
528#[operation(
529    dialect = ArithDialect,
530    traits(BinaryOp),
531    implements(InferTypeOpInterface, MemoryEffectOpInterface)
532)]
533pub struct Rotr {
534    #[operand]
535    lhs: AnyInteger,
536    #[operand]
537    shift: UInt32,
538    #[result]
539    result: AnyInteger,
540}
541
542infer_return_ty_for_binary_op!(Rotr);
543has_no_effects!(Rotr);
544
545/// Equality comparison
546#[operation(
547    dialect = ArithDialect,
548    traits(BinaryOp, Commutative, SameTypeOperands),
549    implements(InferTypeOpInterface, MemoryEffectOpInterface)
550)]
551pub struct Eq {
552    #[operand]
553    lhs: AnyInteger,
554    #[operand]
555    rhs: AnyInteger,
556    #[result]
557    result: Bool,
558}
559
560infer_return_ty_for_binary_op!(Eq as Type::I1);
561has_no_effects!(Eq);
562
563/// Inequality comparison
564#[operation(
565    dialect = ArithDialect,
566    traits(BinaryOp, Commutative, SameTypeOperands),
567    implements(InferTypeOpInterface, MemoryEffectOpInterface)
568)]
569pub struct Neq {
570    #[operand]
571    lhs: AnyInteger,
572    #[operand]
573    rhs: AnyInteger,
574    #[result]
575    result: Bool,
576}
577
578infer_return_ty_for_binary_op!(Neq as Type::I1);
579has_no_effects!(Neq);
580
581/// Greater-than comparison
582#[operation(
583    dialect = ArithDialect,
584    traits(BinaryOp, SameTypeOperands),
585    implements(InferTypeOpInterface, MemoryEffectOpInterface)
586)]
587pub struct Gt {
588    #[operand]
589    lhs: AnyInteger,
590    #[operand]
591    rhs: AnyInteger,
592    #[result]
593    result: Bool,
594}
595
596infer_return_ty_for_binary_op!(Gt as Type::I1);
597has_no_effects!(Gt);
598
599/// Greater-than-or-equal comparison
600#[operation(
601    dialect = ArithDialect,
602    traits(BinaryOp, SameTypeOperands),
603    implements(InferTypeOpInterface, MemoryEffectOpInterface)
604)]
605pub struct Gte {
606    #[operand]
607    lhs: AnyInteger,
608    #[operand]
609    rhs: AnyInteger,
610    #[result]
611    result: Bool,
612}
613
614infer_return_ty_for_binary_op!(Gte as Type::I1);
615has_no_effects!(Gte);
616
617/// Less-than comparison
618#[operation(
619    dialect = ArithDialect,
620    traits(BinaryOp, SameTypeOperands),
621    implements(InferTypeOpInterface, MemoryEffectOpInterface)
622)]
623pub struct Lt {
624    #[operand]
625    lhs: AnyInteger,
626    #[operand]
627    rhs: AnyInteger,
628    #[result]
629    result: Bool,
630}
631
632infer_return_ty_for_binary_op!(Lt as Type::I1);
633has_no_effects!(Lt);
634
635/// Less-than-or-equal comparison
636#[operation(
637    dialect = ArithDialect,
638    traits(BinaryOp, SameTypeOperands),
639    implements(InferTypeOpInterface, MemoryEffectOpInterface)
640)]
641pub struct Lte {
642    #[operand]
643    lhs: AnyInteger,
644    #[operand]
645    rhs: AnyInteger,
646    #[result]
647    result: Bool,
648}
649
650infer_return_ty_for_binary_op!(Lte as Type::I1);
651has_no_effects!(Lte);
652
653/// Select minimum value
654#[operation(
655    dialect = ArithDialect,
656    traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType),
657    implements(InferTypeOpInterface, MemoryEffectOpInterface)
658)]
659pub struct Min {
660    #[operand]
661    lhs: AnyInteger,
662    #[operand]
663    rhs: AnyInteger,
664    #[result]
665    result: AnyInteger,
666}
667
668infer_return_ty_for_binary_op!(Min);
669has_no_effects!(Min);
670
671/// Select maximum value
672#[operation(
673    dialect = ArithDialect,
674    traits(BinaryOp, Commutative, SameTypeOperands, SameOperandsAndResultType),
675    implements(InferTypeOpInterface, MemoryEffectOpInterface)
676)]
677pub struct Max {
678    #[operand]
679    lhs: AnyInteger,
680    #[operand]
681    rhs: AnyInteger,
682    #[result]
683    result: AnyInteger,
684}
685
686infer_return_ty_for_binary_op!(Max);
687has_no_effects!(Max);