cubecl_std/tensor/layout/
coordinates.rs

1use cubecl::prelude::*;
2use cubecl_core::{self as cubecl};
3use variadics_please::all_tuples;
4
5/// A set of coordinates used in layouts. Contains some utilities for comptime inspection.
6#[cube]
7pub trait Coordinates: CubeType + Clone {
8    /// Add two coordinates together and return the result.
9    fn add(this: Self, other: Self) -> Self;
10    /// Subtract two coordinates from each other and return the result.
11    fn sub(this: Self, other: Self) -> Self;
12    /// Apply an elementwise minimum to the coordinates and return the result.
13    fn min(this: Self, other: Self) -> Self;
14    /// Apply an elementwise maximum to the coordinates and return the result.
15    fn max(this: Self, other: Self) -> Self;
16    /// Check whether `pos` is fully contained within `bounds`.
17    fn is_in_bounds(pos: &Self, bounds: &Self) -> bool;
18    /// Create a new coordinates object where all values are `value`.
19    /// `this` may be used as a reference coordinate for dynamically sized layouts.
20    fn from_int(this: &Self, #[comptime] value: i64) -> Self;
21}
22
23// Aliases for convenience and semantic clarity
24pub type Coords1d = u32;
25pub type Coords1i = i32;
26pub type Coords2d = (u32, u32);
27pub type Coords2i = (i32, i32);
28pub type Coords3d = (u32, u32, u32);
29pub type Coords3i = (i32, i32, i32);
30pub type Coords4d = (u32, u32, u32, u32);
31pub type Coords4i = (i32, i32, i32, i32);
32pub type Coords5d = (u32, u32, u32, u32, u32);
33pub type Coords5i = (i32, i32, i32, i32, i32);
34pub type CoordsDyn = Sequence<u32>;
35
36macro_rules! impl_coordinates_tuple {
37    ($(($T:ident, $t:ident, $o: ident)),*) => {
38        #[cube]
39        impl<$($T: Coordinates),*> Coordinates for ($($T),*) {
40            fn add(this: Self, other: Self) -> Self {
41                let ($($t),*) = this;
42                let ($($o),*) = other;
43                ($($T::add($t, $o)),*)
44            }
45            fn sub(this: Self, other: Self) -> Self {
46                let ($($t),*) = this;
47                let ($($o),*) = other;
48                ($($T::sub($t, $o)),*)
49            }
50            fn min(this: Self, other: Self) -> Self {
51                let ($($t),*) = this;
52                let ($($o),*) = other;
53                ($($T::min($t, $o)),*)
54            }
55            fn max(this: Self, other: Self) -> Self {
56                let ($($t),*) = this;
57                let ($($o),*) = other;
58                ($($T::max($t, $o)),*)
59            }
60            fn is_in_bounds(this: &Self, other: &Self) -> bool {
61                let ($($t),*) = this;
62                let ($($o),*) = other;
63                true $(&& $T::is_in_bounds($t, $o))*
64            }
65            fn from_int(this: &Self, #[comptime] value: i64) -> Self {
66                let ($($t),*) = this;
67                ($($T::from_int($t, value)),*)
68            }
69        }
70    };
71}
72
73// Can't blanket implement because of trait rules
74macro_rules! impl_coordinates_primitive {
75    ($ty: ty) => {
76        #[cube]
77        impl Coordinates for $ty {
78            fn add(this: Self, other: Self) -> Self {
79                this + other
80            }
81            fn sub(this: Self, other: Self) -> Self {
82                this - other
83            }
84            fn min(this: Self, other: Self) -> Self {
85                Min::min(this, other)
86            }
87            fn max(this: Self, other: Self) -> Self {
88                Max::max(this, other)
89            }
90            fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
91                pos < bounds
92            }
93            fn from_int(_this: &Self, #[comptime] value: i64) -> Self {
94                <$ty as Numeric>::from_int(value)
95            }
96        }
97    };
98    ($($ty: ty),*) => {
99        $(impl_coordinates_primitive!($ty);)*
100    }
101}
102
103impl_coordinates_primitive!(u8, u16, u32, u64, i8, i16, i32, i64);
104all_tuples!(impl_coordinates_tuple, 2, 12, T, t, o);
105
106#[cube]
107impl<T: Coordinates + Copy> Coordinates for Sequence<T> {
108    fn add(this: Self, other: Self) -> Self {
109        let rank = comptime![this.len()];
110        let mut out = Sequence::new();
111
112        #[unroll]
113        for i in 0..rank {
114            out.push(T::add(*this.index(i), *other.index(i)));
115        }
116
117        out
118    }
119
120    fn sub(this: Self, other: Self) -> Self {
121        let rank = comptime![this.len()];
122        let mut out = Sequence::new();
123
124        #[unroll]
125        for i in 0..rank {
126            out.push(T::sub(*this.index(i), *other.index(i)));
127        }
128
129        out
130    }
131
132    fn min(this: Self, other: Self) -> Self {
133        let rank = comptime![this.len()];
134        let mut out = Sequence::new();
135
136        #[unroll]
137        for i in 0..rank {
138            out.push(T::min(*this.index(i), *other.index(i)));
139        }
140
141        out
142    }
143
144    fn max(this: Self, other: Self) -> Self {
145        let rank = comptime![this.len()];
146        let mut out = Sequence::new();
147
148        #[unroll]
149        for i in 0..rank {
150            out.push(T::max(*this.index(i), *other.index(i)));
151        }
152
153        out
154    }
155
156    fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
157        let rank = comptime![pos.len()];
158        let mut out = true;
159
160        #[unroll]
161        for i in 0..rank {
162            out &= T::is_in_bounds(pos.index(i), bounds.index(i));
163        }
164
165        out
166    }
167
168    fn from_int(this: &Self, #[comptime] value: i64) -> Self {
169        let rank = comptime![this.len()];
170        let mut origin = Sequence::new();
171
172        #[unroll]
173        for i in 0..rank {
174            origin.push(T::from_int(this.index(i), value));
175        }
176
177        origin
178    }
179}