cubecl_std/tensor/layout/
coordinates.rs1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use variadics_please::all_tuples;
4
5#[cube]
7pub trait Coordinates: CubeType + Clone {
8 fn add(this: Self, other: Self) -> Self;
10 fn sub(this: Self, other: Self) -> Self;
12 fn min(this: Self, other: Self) -> Self;
14 fn max(this: Self, other: Self) -> Self;
16 fn is_in_bounds(pos: &Self, bounds: &Self) -> bool;
18 fn from_int(this: &Self, #[comptime] value: i64) -> Self;
21}
22
23pub type Coords1d = u32;
25pub type Coords1i = i32;
26pub type Coords2d = (u32, u32);
27pub type Coords2i = (i32, i32);
28pub type Coords3d = (u32, u32, u32);
29pub type Coords3i = (i32, i32, i32);
30pub type Coords4d = (u32, u32, u32, u32);
31pub type Coords4i = (i32, i32, i32, i32);
32pub type Coords5d = (u32, u32, u32, u32, u32);
33pub type Coords5i = (i32, i32, i32, i32, i32);
34pub type CoordsDyn = Sequence<u32>;
35
36macro_rules! impl_coordinates_tuple {
37 ($(($T:ident, $t:ident, $o: ident)),*) => {
38 #[cube(no_debug_symbols)]
40 impl<$($T: Coordinates),*> Coordinates for ($($T),*) {
41 fn add(this: Self, other: Self) -> Self {
42 let ($($t),*) = this;
43 let ($($o),*) = other;
44 ($($T::add($t, $o)),*)
45 }
46 fn sub(this: Self, other: Self) -> Self {
47 let ($($t),*) = this;
48 let ($($o),*) = other;
49 ($($T::sub($t, $o)),*)
50 }
51 fn min(this: Self, other: Self) -> Self {
52 let ($($t),*) = this;
53 let ($($o),*) = other;
54 ($($T::min($t, $o)),*)
55 }
56 fn max(this: Self, other: Self) -> Self {
57 let ($($t),*) = this;
58 let ($($o),*) = other;
59 ($($T::max($t, $o)),*)
60 }
61 fn is_in_bounds(this: &Self, other: &Self) -> bool {
62 let ($($t),*) = this;
63 let ($($o),*) = other;
64 true $(&& $T::is_in_bounds($t, $o))*
65 }
66 fn from_int(this: &Self, #[comptime] value: i64) -> Self {
67 let ($($t),*) = this;
68 ($($T::from_int($t, value)),*)
69 }
70 }
71 };
72}
73
74macro_rules! impl_coordinates_primitive {
76 ($ty: ty) => {
77 #[cube]
78 impl Coordinates for $ty {
79 fn add(this: Self, other: Self) -> Self {
80 this + other
81 }
82 fn sub(this: Self, other: Self) -> Self {
83 this - other
84 }
85 fn min(this: Self, other: Self) -> Self {
86 Min::min(this, other)
87 }
88 fn max(this: Self, other: Self) -> Self {
89 Max::max(this, other)
90 }
91 fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
92 pos < bounds
93 }
94 fn from_int(_this: &Self, #[comptime] value: i64) -> Self {
95 <$ty as Numeric>::from_int(value)
96 }
97 }
98 };
99 ($($ty: ty),*) => {
100 $(impl_coordinates_primitive!($ty);)*
101 }
102}
103
104impl_coordinates_primitive!(u8, u16, u32, u64, i8, i16, i32, i64);
105all_tuples!(impl_coordinates_tuple, 2, 12, T, t, o);
106
107#[cube]
108impl<T: Coordinates + Copy> Coordinates for Sequence<T> {
109 fn add(this: Self, other: Self) -> Self {
110 let rank = comptime![this.len()];
111 let mut out = Sequence::new();
112
113 #[unroll]
114 for i in 0..rank {
115 out.push(T::add(*this.index(i), *other.index(i)));
116 }
117
118 out
119 }
120
121 fn sub(this: Self, other: Self) -> Self {
122 let rank = comptime![this.len()];
123 let mut out = Sequence::new();
124
125 #[unroll]
126 for i in 0..rank {
127 out.push(T::sub(*this.index(i), *other.index(i)));
128 }
129
130 out
131 }
132
133 fn min(this: Self, other: Self) -> Self {
134 let rank = comptime![this.len()];
135 let mut out = Sequence::new();
136
137 #[unroll]
138 for i in 0..rank {
139 out.push(T::min(*this.index(i), *other.index(i)));
140 }
141
142 out
143 }
144
145 fn max(this: Self, other: Self) -> Self {
146 let rank = comptime![this.len()];
147 let mut out = Sequence::new();
148
149 #[unroll]
150 for i in 0..rank {
151 out.push(T::max(*this.index(i), *other.index(i)));
152 }
153
154 out
155 }
156
157 fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
158 let rank = comptime![pos.len()];
159 let mut out = true;
160
161 #[unroll]
162 for i in 0..rank {
163 out &= T::is_in_bounds(pos.index(i), bounds.index(i));
164 }
165
166 out
167 }
168
169 fn from_int(this: &Self, #[comptime] value: i64) -> Self {
170 let rank = comptime![this.len()];
171 let mut origin = Sequence::new();
172
173 #[unroll]
174 for i in 0..rank {
175 origin.push(T::from_int(this.index(i), value));
176 }
177
178 origin
179 }
180}