use alloc::vec::Vec;
use cubecl_zspace::Shape;
#[cfg(any(target_os = "windows", target_os = "linux", target_os = "macos"))]
use serde::{Deserialize, Serialize};
#[cfg_attr(
any(target_os = "windows", target_os = "linux", target_os = "macos"),
derive(Serialize, Deserialize)
)]
#[derive(Hash, PartialEq, Eq, Clone, Debug, new)]
pub struct TiledArgs {
pub tile_size: Shape,
}
#[derive(Hash, PartialEq, Eq, Clone, Debug, new)]
pub struct Im2colArgs {
pub pixel_box_lower_corner: Vec<i32>,
pub pixel_box_upper_corner: Vec<i32>,
pub channels_per_pixel: u32,
pub pixels_per_column: u32,
}
#[derive(Hash, PartialEq, Eq, Clone, Debug, new)]
pub struct Im2colWideArgs {
pub pixel_box_lower_corner_width: i32,
pub pixel_box_upper_corner_width: i32,
pub channels_per_pixel: u32,
pub pixels_per_column: u32,
}
#[derive(Hash, PartialEq, Eq, Clone, Debug)]
pub enum TensorMapFormat {
Tiled(TiledArgs),
Im2col(Im2colArgs),
Im2colWide(Im2colWideArgs),
}
#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
pub enum TensorMapInterleave {
#[default]
None,
B16,
B32,
}
#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
pub enum TensorMapSwizzle {
#[default]
None,
B32,
B64,
B128,
B128Atom32B,
B128Atom32BFlip8B,
B128Atom64B,
}
#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
pub enum TensorMapPrefetch {
#[default]
None,
B64,
B128,
B256,
}
#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
pub enum OobFill {
#[default]
Zero,
NaN,
}