cubecl_core/frontend/operation/
cmp.rs

1use crate::frontend::ExpandElementTyped;
2use crate::frontend::operation::base::cmp_expand;
3use crate::ir::{Comparison, Scope};
4use crate::prelude::CubePrimitive;
5
6// NOTE: Unary comparison tests are in the unary module
7
8pub mod ne {
9    use super::*;
10
11    pub fn expand<C: CubePrimitive>(
12        scope: &mut Scope,
13        lhs: ExpandElementTyped<C>,
14        rhs: ExpandElementTyped<C>,
15    ) -> ExpandElementTyped<bool> {
16        cmp_expand(scope, lhs.into(), rhs.into(), Comparison::NotEqual).into()
17    }
18}
19
20pub mod gt {
21    use super::*;
22
23    pub fn expand<C: CubePrimitive>(
24        scope: &mut Scope,
25        lhs: ExpandElementTyped<C>,
26        rhs: ExpandElementTyped<C>,
27    ) -> ExpandElementTyped<bool> {
28        cmp_expand(scope, lhs.into(), rhs.into(), Comparison::Greater).into()
29    }
30}
31
32pub mod lt {
33    use super::*;
34
35    pub fn expand<C: CubePrimitive>(
36        scope: &mut Scope,
37        lhs: ExpandElementTyped<C>,
38        rhs: ExpandElementTyped<C>,
39    ) -> ExpandElementTyped<bool> {
40        cmp_expand(scope, lhs.into(), rhs.into(), Comparison::Lower).into()
41    }
42}
43
44pub mod ge {
45    use super::*;
46
47    pub fn expand<C: CubePrimitive>(
48        scope: &mut Scope,
49        lhs: ExpandElementTyped<C>,
50        rhs: ExpandElementTyped<C>,
51    ) -> ExpandElementTyped<bool> {
52        cmp_expand(scope, lhs.into(), rhs.into(), Comparison::GreaterEqual).into()
53    }
54}
55
56pub mod le {
57    use super::*;
58
59    pub fn expand<C: CubePrimitive>(
60        scope: &mut Scope,
61        lhs: ExpandElementTyped<C>,
62        rhs: ExpandElementTyped<C>,
63    ) -> ExpandElementTyped<bool> {
64        cmp_expand(scope, lhs.into(), rhs.into(), Comparison::LowerEqual).into()
65    }
66}
67
68pub mod eq {
69
70    use super::*;
71
72    pub fn expand<C: CubePrimitive>(
73        scope: &mut Scope,
74        lhs: ExpandElementTyped<C>,
75        rhs: ExpandElementTyped<C>,
76    ) -> ExpandElementTyped<bool> {
77        cmp_expand(scope, lhs.into(), rhs.into(), Comparison::Equal).into()
78    }
79}