cubecl_std/tensor/layout/
coordinates.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use variadics_please::all_tuples;
4
5#[cube]
7pub trait Coordinates: CubeType + Clone {
8 fn add(this: Self, other: Self) -> Self;
10 fn sub(this: Self, other: Self) -> Self;
12 fn min(this: Self, other: Self) -> Self;
14 fn max(this: Self, other: Self) -> Self;
16 fn is_in_bounds(pos: &Self, bounds: &Self) -> bool;
18 fn from_int(this: &Self, #[comptime] value: i64) -> Self;
21}
22
23pub type Coords1d = u32;
25pub type Coords1i = i32;
26pub type Coords2d = (u32, u32);
27pub type Coords2i = (i32, i32);
28pub type Coords3d = (u32, u32, u32);
29pub type Coords3i = (i32, i32, i32);
30pub type Coords4d = (u32, u32, u32, u32);
31pub type Coords4i = (i32, i32, i32, i32);
32pub type Coords5d = (u32, u32, u32, u32, u32);
33pub type Coords5i = (i32, i32, i32, i32, i32);
34pub type CoordsDyn = Sequence<u32>;
35
36macro_rules! impl_coordinates_tuple {
37 ($(($T:ident, $t:ident, $o: ident)),*) => {
38 #[cube]
39 impl<$($T: Coordinates),*> Coordinates for ($($T),*) {
40 fn add(this: Self, other: Self) -> Self {
41 let ($($t),*) = this;
42 let ($($o),*) = other;
43 ($($T::add($t, $o)),*)
44 }
45 fn sub(this: Self, other: Self) -> Self {
46 let ($($t),*) = this;
47 let ($($o),*) = other;
48 ($($T::sub($t, $o)),*)
49 }
50 fn min(this: Self, other: Self) -> Self {
51 let ($($t),*) = this;
52 let ($($o),*) = other;
53 ($($T::min($t, $o)),*)
54 }
55 fn max(this: Self, other: Self) -> Self {
56 let ($($t),*) = this;
57 let ($($o),*) = other;
58 ($($T::max($t, $o)),*)
59 }
60 fn is_in_bounds(this: &Self, other: &Self) -> bool {
61 let ($($t),*) = this;
62 let ($($o),*) = other;
63 true $(&& $T::is_in_bounds($t, $o))*
64 }
65 fn from_int(this: &Self, #[comptime] value: i64) -> Self {
66 let ($($t),*) = this;
67 ($($T::from_int($t, value)),*)
68 }
69 }
70 };
71}
72
73macro_rules! impl_coordinates_primitive {
75 ($ty: ty) => {
76 #[cube]
77 impl Coordinates for $ty {
78 fn add(this: Self, other: Self) -> Self {
79 this + other
80 }
81 fn sub(this: Self, other: Self) -> Self {
82 this - other
83 }
84 fn min(this: Self, other: Self) -> Self {
85 Min::min(this, other)
86 }
87 fn max(this: Self, other: Self) -> Self {
88 Max::max(this, other)
89 }
90 fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
91 pos < bounds
92 }
93 fn from_int(_this: &Self, #[comptime] value: i64) -> Self {
94 <$ty as Numeric>::from_int(value)
95 }
96 }
97 };
98 ($($ty: ty),*) => {
99 $(impl_coordinates_primitive!($ty);)*
100 }
101}
102
103impl_coordinates_primitive!(u8, u16, u32, u64, i8, i16, i32, i64);
104all_tuples!(impl_coordinates_tuple, 2, 12, T, t, o);
105
106#[cube]
107impl<T: Coordinates + Copy> Coordinates for Sequence<T> {
108 fn add(this: Self, other: Self) -> Self {
109 let rank = comptime![this.len()];
110 let mut out = Sequence::new();
111
112 #[unroll]
113 for i in 0..rank {
114 out.push(T::add(*this.index(i), *other.index(i)));
115 }
116
117 out
118 }
119
120 fn sub(this: Self, other: Self) -> Self {
121 let rank = comptime![this.len()];
122 let mut out = Sequence::new();
123
124 #[unroll]
125 for i in 0..rank {
126 out.push(T::sub(*this.index(i), *other.index(i)));
127 }
128
129 out
130 }
131
132 fn min(this: Self, other: Self) -> Self {
133 let rank = comptime![this.len()];
134 let mut out = Sequence::new();
135
136 #[unroll]
137 for i in 0..rank {
138 out.push(T::min(*this.index(i), *other.index(i)));
139 }
140
141 out
142 }
143
144 fn max(this: Self, other: Self) -> Self {
145 let rank = comptime![this.len()];
146 let mut out = Sequence::new();
147
148 #[unroll]
149 for i in 0..rank {
150 out.push(T::max(*this.index(i), *other.index(i)));
151 }
152
153 out
154 }
155
156 fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
157 let rank = comptime![pos.len()];
158 let mut out = true;
159
160 #[unroll]
161 for i in 0..rank {
162 out &= T::is_in_bounds(pos.index(i), bounds.index(i));
163 }
164
165 out
166 }
167
168 fn from_int(this: &Self, #[comptime] value: i64) -> Self {
169 let rank = comptime![this.len()];
170 let mut origin = Sequence::new();
171
172 #[unroll]
173 for i in 0..rank {
174 origin.push(T::from_int(this.index(i), value));
175 }
176
177 origin
178 }
179}