cubecl-std 0.7.0

CubeCL Standard Library.
Documentation
use core::marker::PhantomData;

use cubecl::prelude::*;
use cubecl_core::{self as cubecl, unexpanded};

/// This struct allows to take a slice of `Line<S>` and reinterpret it
/// as a slice of `T`. Semantically, this is equivalent to reinterpreting the slice of `Line<S>`
/// to a slice of `T`. When indexing, the index is valid in the casted list.
///
/// # Warning
///
/// Currently, this only work with `cube(launch_unchecked)` and is not supported on wgpu.
#[derive(CubeType)]
pub struct ReinterpretSlice<S: CubePrimitive, T: CubePrimitive> {
    slice: Slice<Line<S>>,

    #[cube(comptime)]
    line_size: u32,

    #[cube(comptime)]
    load_many: Option<u32>,

    #[cube(comptime)]
    _phantom: PhantomData<T>,
}

#[cube]
impl<S: CubePrimitive, T: CubePrimitive> ReinterpretSlice<S, T> {
    pub fn new(slice: Slice<Line<S>>, #[comptime] line_size: u32) -> ReinterpretSlice<S, T> {
        let source_size = size_of::<S>();
        let target_size = size_of::<T>();
        let (optimized_line_size, load_many) =
            comptime!(optimize_line_size(source_size, line_size, target_size));
        match comptime!(optimized_line_size) {
            Some(line_size) => ReinterpretSlice::<S, T> {
                slice: slice.with_line_size(line_size),
                line_size,
                load_many,
                _phantom: PhantomData,
            },
            None => ReinterpretSlice::<S, T> {
                slice,
                line_size,
                load_many,
                _phantom: PhantomData,
            },
        }
    }

    pub fn read(&self, index: u32) -> T {
        match comptime!(self.load_many) {
            Some(amount) => {
                let first = index * amount;
                let mut line = Line::<S>::empty(comptime!(amount * self.line_size));
                #[unroll]
                for k in 0..amount {
                    let elem = self.slice[first + k];
                    #[unroll]
                    for j in 0..self.line_size {
                        line[k * self.line_size + j] = elem[j];
                    }
                }
                T::reinterpret(line)
            }
            None => T::reinterpret(self.slice[index]),
        }
    }
}

/// This struct allows to take a mutable slice of `Line<S>` and reinterpret it
/// as a mutable slice of `T`. Semantically, this is equivalent to reinterpreting the slice of `Line<S>`
/// to a mutable slice of `T`. When indexing, the index is valid in the casted list.
///
/// # Warning
///
/// Currently, this only work with `cube(launch_unchecked)` and is not supported on wgpu.
#[derive(CubeType)]
pub struct ReinterpretSliceMut<S: CubePrimitive, T: CubePrimitive> {
    slice: SliceMut<Line<S>>,

    #[cube(comptime)]
    line_size: u32,

    #[cube(comptime)]
    load_many: Option<u32>,

    #[cube(comptime)]
    _phantom: PhantomData<T>,
}

#[cube]
impl<S: CubePrimitive, T: CubePrimitive> ReinterpretSliceMut<S, T> {
    pub fn new(slice: SliceMut<Line<S>>, #[comptime] line_size: u32) -> ReinterpretSliceMut<S, T> {
        let source_size = size_of::<S>();
        let target_size = size_of::<T>();
        let (optimized_line_size, load_many) =
            comptime!(optimize_line_size(source_size, line_size, target_size));
        match comptime!(optimized_line_size) {
            Some(line_size) => ReinterpretSliceMut::<S, T> {
                slice: slice.with_line_size(line_size),
                line_size,
                load_many,
                _phantom: PhantomData,
            },
            None => ReinterpretSliceMut::<S, T> {
                slice,
                line_size,
                load_many,
                _phantom: PhantomData,
            },
        }
    }

    pub fn read(&self, index: u32) -> T {
        match comptime!(self.load_many) {
            Some(amount) => {
                let first = index * amount;
                let mut line = Line::<S>::empty(comptime!(amount * self.line_size));
                #[unroll]
                for k in 0..amount {
                    let elem = self.slice[first + k];
                    #[unroll]
                    for j in 0..self.line_size {
                        line[k * self.line_size + j] = elem[j];
                    }
                }
                T::reinterpret(line)
            }
            None => T::reinterpret(self.slice[index]),
        }
    }

    pub fn write(&mut self, index: u32, value: T) {
        let reinterpreted = Line::<S>::reinterpret(value);
        match comptime!(self.load_many) {
            Some(amount) => {
                let first = index * amount;
                let line_size = comptime!(reinterpreted.size() / amount);

                #[unroll]
                for k in 0..amount {
                    let mut line = Line::empty(line_size);
                    #[unroll]
                    for j in 0..line_size {
                        line[j] = reinterpreted[k * line_size + j];
                    }
                    self.slice[first + k] = line;
                }
            }
            None => self.slice[index] = reinterpreted,
        }
    }
}

fn optimize_line_size(
    source_size: u32,
    line_size: u32,
    target_size: u32,
) -> (Option<u32>, Option<u32>) {
    let line_source_size = source_size * line_size;
    match line_source_size.cmp(&target_size) {
        core::cmp::Ordering::Less => {
            if !target_size.is_multiple_of(line_source_size) {
                panic!("incompatible number of bytes");
            }

            let ratio = target_size / line_source_size;

            (None, Some(ratio))
        }
        core::cmp::Ordering::Greater => {
            if !line_source_size.is_multiple_of(target_size) {
                panic!("incompatible number of bytes");
            }
            let ratio = line_source_size / target_size;

            (Some(line_size / ratio), None)
        }
        core::cmp::Ordering::Equal => (None, None),
    }
}

pub fn size_of<S: CubePrimitive>() -> u32 {
    unexpanded!()
}

pub mod size_of {
    use super::*;
    #[allow(unused, clippy::all)]
    pub fn expand<S: CubePrimitive>(context: &mut cubecl::prelude::Scope) -> u32 {
        S::as_type(context).size() as u32
    }
}