cubecl_core/frontend/operation/
cmp.rs

1use crate::frontend::ExpandElementTyped;
2use crate::frontend::operation::base::cmp_expand;
3use crate::ir::{Comparison, Scope};
4use crate::prelude::CubePrimitive;
5
6pub mod ne {
7    use super::*;
8
9    pub fn expand<C: CubePrimitive>(
10        scope: &mut Scope,
11        lhs: ExpandElementTyped<C>,
12        rhs: ExpandElementTyped<C>,
13    ) -> ExpandElementTyped<bool> {
14        cmp_expand(scope, lhs.into(), rhs.into(), Comparison::NotEqual).into()
15    }
16}
17
18pub mod gt {
19    use super::*;
20
21    pub fn expand<C: CubePrimitive>(
22        scope: &mut Scope,
23        lhs: ExpandElementTyped<C>,
24        rhs: ExpandElementTyped<C>,
25    ) -> ExpandElementTyped<bool> {
26        cmp_expand(scope, lhs.into(), rhs.into(), Comparison::Greater).into()
27    }
28}
29
30pub mod lt {
31    use super::*;
32
33    pub fn expand<C: CubePrimitive>(
34        scope: &mut Scope,
35        lhs: ExpandElementTyped<C>,
36        rhs: ExpandElementTyped<C>,
37    ) -> ExpandElementTyped<bool> {
38        cmp_expand(scope, lhs.into(), rhs.into(), Comparison::Lower).into()
39    }
40}
41
42pub mod ge {
43    use super::*;
44
45    pub fn expand<C: CubePrimitive>(
46        scope: &mut Scope,
47        lhs: ExpandElementTyped<C>,
48        rhs: ExpandElementTyped<C>,
49    ) -> ExpandElementTyped<bool> {
50        cmp_expand(scope, lhs.into(), rhs.into(), Comparison::GreaterEqual).into()
51    }
52}
53
54pub mod le {
55    use super::*;
56
57    pub fn expand<C: CubePrimitive>(
58        scope: &mut Scope,
59        lhs: ExpandElementTyped<C>,
60        rhs: ExpandElementTyped<C>,
61    ) -> ExpandElementTyped<bool> {
62        cmp_expand(scope, lhs.into(), rhs.into(), Comparison::LowerEqual).into()
63    }
64}
65
66pub mod eq {
67
68    use super::*;
69
70    pub fn expand<C: CubePrimitive>(
71        scope: &mut Scope,
72        lhs: ExpandElementTyped<C>,
73        rhs: ExpandElementTyped<C>,
74    ) -> ExpandElementTyped<bool> {
75        cmp_expand(scope, lhs.into(), rhs.into(), Comparison::Equal).into()
76    }
77}