use cubecl::prelude::*;
use cubecl_core::{self as cubecl};
use crate::tensor::{
launch::{BufferArg, ViewLayoutLaunchArg},
layout::{Coords1d, Layout, LayoutExpand},
};
#[derive(CubeType, Clone)]
pub struct PlainLayout {
len: usize,
}
#[cube]
impl PlainLayout {
pub fn new(len: usize) -> Self {
PlainLayout { len }
}
}
impl ViewLayoutLaunchArg for PlainLayout {
type RuntimeArg<R: Runtime> = ();
type CompilationArg = ();
fn register<R: Runtime, B: BufferArg>(
_: Self::RuntimeArg<R>,
buffer: &B,
ty: Type,
launcher: &mut KernelLauncher<R>,
) {
<usize as LaunchArg>::register(buffer.len() / ty.vector_size(), launcher);
}
fn expand(
_: &Self::CompilationArg,
_: Type,
builder: &mut KernelBuilder,
) -> <Self as CubeType>::ExpandType {
let len = <usize as LaunchArg>::expand(&(), builder);
PlainLayout::__expand_new(&mut builder.scope, len)
}
}
#[cube]
impl Layout for PlainLayout {
type Coordinates = Coords1d;
type SourceCoordinates = Coords1d;
fn to_source_pos(&self, pos: Self::Coordinates) -> usize {
pos
}
fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (usize, bool) {
(self.to_source_pos(pos), self.is_in_bounds(pos))
}
fn shape(&self) -> Self::Coordinates {
self.len
}
fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
pos < self.len
}
}