use cubecl::{
prelude::*,
std::tensor::layout::{Coords2d, Layout, LayoutExpand},
};
use cubek_matmul::launch::BatchedCoords;
#[derive(CubeType, CubeLaunch)]
pub struct TmaOutGradLayout {}
#[cube]
impl Layout for TmaOutGradLayout {
type Coordinates = BatchedCoords;
type SourceCoordinates = Coords2d;
fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
let (_, row, col) = pos;
(row, col)
}
fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
true.runtime()
}
fn shape(&self) -> Self::Coordinates {
(u32::MAX as usize, u32::MAX, u32::MAX).runtime()
}
fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
(self.to_source_pos(pos), self.is_in_bounds(pos))
}
}