cubecl_ir/
comparison.rs

1use core::fmt::Display;
2
3use crate::{TypeHash, UnaryOperator};
4
5use crate::{BinaryOperator, OperationReflect};
6
7/// Comparison operations
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
10#[operation(opcode_name = ComparisonOpCode, pure)]
11pub enum Comparison {
12    Lower(BinaryOperator),
13    LowerEqual(BinaryOperator),
14    #[operation(commutative)]
15    Equal(BinaryOperator),
16    #[operation(commutative)]
17    NotEqual(BinaryOperator),
18    GreaterEqual(BinaryOperator),
19    Greater(BinaryOperator),
20    IsNan(UnaryOperator),
21    IsInf(UnaryOperator),
22}
23
24impl Display for Comparison {
25    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
26        match self {
27            Comparison::Equal(op) => write!(f, "{} == {}", op.lhs, op.rhs),
28            Comparison::NotEqual(op) => write!(f, "{} != {}", op.lhs, op.rhs),
29            Comparison::Lower(op) => write!(f, "{} < {}", op.lhs, op.rhs),
30            Comparison::Greater(op) => write!(f, "{} > {}", op.lhs, op.rhs),
31            Comparison::LowerEqual(op) => write!(f, "{} <= {}", op.lhs, op.rhs),
32            Comparison::GreaterEqual(op) => write!(f, "{} >= {}", op.lhs, op.rhs),
33            Comparison::IsNan(op) => write!(f, "{}.isnan()", op.input),
34            Comparison::IsInf(op) => write!(f, "{}.isinf()", op.input),
35        }
36    }
37}