use cubecl::prelude::*;
use cubecl_core::{self as cubecl};
use variadics_please::all_tuples;
#[cube]
pub trait Coordinates: CubeType + Clone {
fn add(this: Self, other: Self) -> Self;
fn sub(this: Self, other: Self) -> Self;
fn min(this: Self, other: Self) -> Self;
fn max(this: Self, other: Self) -> Self;
fn is_in_bounds(pos: &Self, bounds: &Self) -> bool;
fn from_int(this: &Self, #[comptime] value: i64) -> Self;
}
pub type Coords1d = usize;
pub type Coords1i = i32;
pub type Coords2d = (u32, u32);
pub type Coords2i = (i32, i32);
pub type Coords3d = (u32, u32, u32);
pub type Coords3i = (i32, i32, i32);
pub type Coords4d = (u32, u32, u32, u32);
pub type Coords4i = (i32, i32, i32, i32);
pub type Coords5d = (u32, u32, u32, u32, u32);
pub type Coords5i = (i32, i32, i32, i32, i32);
pub type CoordsDyn = Sequence<u32>;
macro_rules! impl_coordinates_tuple {
($(($T:ident, $t:ident, $o: ident)),*) => {
#[cube(no_debug_symbols)]
impl<$($T: Coordinates),*> Coordinates for ($($T),*) {
fn add(this: Self, other: Self) -> Self {
let ($($t),*) = this;
let ($($o),*) = other;
($($T::add($t, $o)),*)
}
fn sub(this: Self, other: Self) -> Self {
let ($($t),*) = this;
let ($($o),*) = other;
($($T::sub($t, $o)),*)
}
fn min(this: Self, other: Self) -> Self {
let ($($t),*) = this;
let ($($o),*) = other;
($($T::min($t, $o)),*)
}
fn max(this: Self, other: Self) -> Self {
let ($($t),*) = this;
let ($($o),*) = other;
($($T::max($t, $o)),*)
}
fn is_in_bounds(this: &Self, other: &Self) -> bool {
let ($($t),*) = this;
let ($($o),*) = other;
true $(&& $T::is_in_bounds($t, $o))*
}
fn from_int(this: &Self, #[comptime] value: i64) -> Self {
let ($($t),*) = this;
($($T::from_int($t, value)),*)
}
}
};
}
macro_rules! impl_coordinates_primitive {
($ty: ty) => {
#[cube]
impl Coordinates for $ty {
fn add(this: Self, other: Self) -> Self {
this + other
}
fn sub(this: Self, other: Self) -> Self {
this - other
}
fn min(this: Self, other: Self) -> Self {
this.min(other)
}
fn max(this: Self, other: Self) -> Self {
this.max(other)
}
fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
pos < bounds
}
fn from_int(_this: &Self, #[comptime] value: i64) -> Self {
<$ty as Numeric>::from_int(value)
}
}
};
($($ty: ty),*) => {
$(impl_coordinates_primitive!($ty);)*
}
}
impl_coordinates_primitive!(u8, u16, u32, u64, usize, i8, i16, i32, i64);
all_tuples!(impl_coordinates_tuple, 2, 12, T, t, o);
#[cube]
impl<T: Coordinates + Copy> Coordinates for Sequence<T> {
fn add(this: Self, other: Self) -> Self {
let rank = this.len();
let mut out = Sequence::new();
#[unroll]
for i in 0..rank {
out.push(T::add(this[i], other[i]));
}
out
}
fn sub(this: Self, other: Self) -> Self {
let rank = this.len();
let mut out = Sequence::new();
#[unroll]
for i in 0..rank {
out.push(T::sub(this[i], other[i]));
}
out
}
fn min(this: Self, other: Self) -> Self {
let rank = this.len();
let mut out = Sequence::new();
#[unroll]
for i in 0..rank {
out.push(T::min(this[i], other[i]));
}
out
}
fn max(this: Self, other: Self) -> Self {
let rank = this.len();
let mut out = Sequence::new();
#[unroll]
for i in 0..rank {
out.push(T::max(this[i], other[i]));
}
out
}
fn is_in_bounds(pos: &Self, bounds: &Self) -> bool {
let rank = pos.len();
let mut out = true;
#[unroll]
for i in 0..rank {
out &= T::is_in_bounds(&pos[i], &bounds[i]);
}
out
}
fn from_int(this: &Self, #[comptime] value: i64) -> Self {
let rank = this.len();
let mut origin = Sequence::new();
#[unroll]
for i in 0..rank {
origin.push(T::from_int(&this[i], value));
}
origin
}
}