cubecl_std/tensor/layout/
coordinates.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use variadics_please::all_tuples;
4
5/// A set of coordinates used in layouts. Contains some utilities for comptime inspection.
6#[cube]
7pub trait Coordinates: CubeType + Clone {
8    /// Add two coordinates together and return the result.
9    fn add(this: Self, other: Self) -> Self;
10    /// Subtract two coordinates from each other and return the result.
11    fn sub(this: Self, other: Self) -> Self;
12    /// Apply an elementwise minimum to the coordinates and return the result.
13    fn min(this: Self, other: Self) -> Self;
14    /// Apply an elementwise maximum to the coordinates and return the result.
15    fn max(this: Self, other: Self) -> Self;
16    /// Check whether `pos` is fully contained within `bounds`.
17    fn is_in_bounds(pos: &Self, bounds: &Self) -> bool;
18    /// Create a new coordinates object where all values are `value`.
19    /// `this` may be used as a reference coordinate for dynamically sized layouts.
20    fn from_int(this: &Self, #[comptime] value: i64) -> Self;
21}
22
23// Aliases for convenience and semantic clarity
24pub type Coords1d = u32;
25pub type Coords1i = i32;
26pub type Coords2d = (u32, u32);
27pub type Coords2i = (i32, i32);
28pub type Coords3d = (u32, u32, u32);
29pub type Coords3i = (i32, i32, i32);
30pub type Coords4d = (u32, u32, u32, u32);
31pub type Coords4i = (i32, i32, i32, i32);
32pub type Coords5d = (u32, u32, u32, u32, u32);
33pub type Coords5i = (i32, i32, i32, i32, i32);
34pub type CoordsDyn = Sequence<u32>;
35
36macro_rules! impl_coordinates_tuple {
37    ($(($T:ident, $t:ident, $o: ident)),*) => {
38        // Need to force off debug symbols because of macro hygiene weirdness.
39        #[cube(no_debug_symbols)]
40        impl<$($T: Coordinates),*> Coordinates for ($($T),*) {
41            fn add(this: Self, other: Self) -> Self {
42                let ($($t),*) = this;
43                let ($($o),*) = other;
44                ($($T::add($t, $o)),*)
45            }
46            fn sub(this: Self, other: Self) -> Self {
47                let ($($t),*) = this;
48                let ($($o),*) = other;
49                ($($T::sub($t, $o)),*)
50            }
51            fn min(this: Self, other: Self) -> Self {
52                let ($($t),*) = this;
53                let ($($o),*) = other;
54                ($($T::min($t, $o)),*)
55            }
56            fn max(this: Self, other: Self) -> Self {
57                let ($($t),*) = this;
58                let ($($o),*) = other;
59                ($($T::max($t, $o)),*)
60            }
61            fn is_in_bounds(this: &Self, other: &Self) -> bool {
62                let ($($t),*) = this;
63                let ($($o),*) = other;
64                true $(&& $T::is_in_bounds($t, $o))*
65            }
66            fn from_int(this: &Self, #[comptime] value: i64) -> Self {
67                let ($($t),*) = this;
68                ($($T::from_int($t, value)),*)
69            }
70        }
71    };
72}
73
74// Can't blanket implement because of trait rules
75macro_rules! impl_coordinates_primitive {
76    ($ty: ty) => {
77        #[cube]
78        impl Coordinates for $ty {
79            fn add(this: Self, other: Self) -> Self {
80                this + other
81            }
82            fn sub(this: Self, other: Self) -> Self {
83                this - other
84            }
85            fn min(this: Self, other: Self) -> Self {
86                Min::min(this, other)
87            }
88            fn max(this: Self, other: Self) -> Self {
89                Max::max(this, other)
90            }
91            fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
92                pos < bounds
93            }
94            fn from_int(_this: &Self, #[comptime] value: i64) -> Self {
95                <$ty as Numeric>::from_int(value)
96            }
97        }
98    };
99    ($($ty: ty),*) => {
100        $(impl_coordinates_primitive!($ty);)*
101    }
102}
103
104impl_coordinates_primitive!(u8, u16, u32, u64, i8, i16, i32, i64);
105all_tuples!(impl_coordinates_tuple, 2, 12, T, t, o);
106
107#[cube]
108impl<T: Coordinates + Copy> Coordinates for Sequence<T> {
109    fn add(this: Self, other: Self) -> Self {
110        let rank = comptime![this.len()];
111        let mut out = Sequence::new();
112
113        #[unroll]
114        for i in 0..rank {
115            out.push(T::add(*this.index(i), *other.index(i)));
116        }
117
118        out
119    }
120
121    fn sub(this: Self, other: Self) -> Self {
122        let rank = comptime![this.len()];
123        let mut out = Sequence::new();
124
125        #[unroll]
126        for i in 0..rank {
127            out.push(T::sub(*this.index(i), *other.index(i)));
128        }
129
130        out
131    }
132
133    fn min(this: Self, other: Self) -> Self {
134        let rank = comptime![this.len()];
135        let mut out = Sequence::new();
136
137        #[unroll]
138        for i in 0..rank {
139            out.push(T::min(*this.index(i), *other.index(i)));
140        }
141
142        out
143    }
144
145    fn max(this: Self, other: Self) -> Self {
146        let rank = comptime![this.len()];
147        let mut out = Sequence::new();
148
149        #[unroll]
150        for i in 0..rank {
151            out.push(T::max(*this.index(i), *other.index(i)));
152        }
153
154        out
155    }
156
157    fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
158        let rank = comptime![pos.len()];
159        let mut out = true;
160
161        #[unroll]
162        for i in 0..rank {
163            out &= T::is_in_bounds(pos.index(i), bounds.index(i));
164        }
165
166        out
167    }
168
169    fn from_int(this: &Self, #[comptime] value: i64) -> Self {
170        let rank = comptime![this.len()];
171        let mut origin = Sequence::new();
172
173        #[unroll]
174        for i in 0..rank {
175            origin.push(T::from_int(this.index(i), value));
176        }
177
178        origin
179    }
180}