use cubecl::{
Runtime, VectorizationError,
client::ComputeClient,
ir::VectorSize,
tensor_vector_size_parallel,
zspace::{Shape, Strides},
};
use cubek_std::MatrixLayout;
use std::fmt::Debug;
use crate::definition::error::MatmulSetupError;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
pub struct MatmulVectorSizes {
pub lhs: VectorSize,
pub rhs: VectorSize,
pub out: VectorSize,
}
#[derive(Clone, Debug)]
pub struct AvailableVectorSizes {
pub lhs: Vec<VectorSize>,
pub rhs: Vec<VectorSize>,
pub out: Vec<VectorSize>,
}
impl AvailableVectorSizes {
pub fn from_type_size_tma<R: Runtime>(client: &ComputeClient<R>, elem_out: usize) -> Self {
AvailableVectorSizes {
lhs: vec![1],
rhs: vec![1],
out: client.io_optimized_vector_sizes(elem_out).collect(),
}
}
pub fn from_type_sizes<R: Runtime>(
client: &ComputeClient<R>,
elem_lhs: usize,
elem_rhs: usize,
elem_out: usize,
) -> Self {
AvailableVectorSizes {
lhs: client.io_optimized_vector_sizes(elem_lhs).collect(),
rhs: client.io_optimized_vector_sizes(elem_rhs).collect(),
out: client.io_optimized_vector_sizes(elem_out).collect(),
}
}
pub fn filter_lhs_with_tensor(
self,
strides: &Strides,
shape: &Shape,
layout: MatrixLayout,
) -> Self {
let rank = strides.len();
let target = tensor_vector_size_parallel(
self.lhs.iter().copied(),
shape,
strides,
match layout {
MatrixLayout::RowMajor => rank - 1,
MatrixLayout::ColMajor => rank - 2,
},
);
self.filter_lhs(move |x| *x == target)
}
pub fn filter_rhs_with_tensor(
self,
strides: &Strides,
shape: &Shape,
layout: MatrixLayout,
) -> Self {
let rank = strides.len();
let target = tensor_vector_size_parallel(
self.rhs.iter().copied(),
shape,
strides,
match layout {
MatrixLayout::RowMajor => rank - 1,
MatrixLayout::ColMajor => rank - 2,
},
);
self.filter_rhs(move |x| *x == target)
}
pub fn filter_out_with_tensor(self, strides: &Strides, shape: &Shape) -> Self {
let rank = strides.len();
let target =
tensor_vector_size_parallel(self.out.iter().copied(), shape, strides, rank - 1);
self.filter_out(move |x| *x == target)
}
pub fn filter_lhs<F>(self, pred: F) -> Self
where
F: FnMut(&usize) -> bool,
{
Self {
lhs: self.lhs.iter().copied().filter(pred).collect(),
rhs: self.rhs,
out: self.out,
}
}
pub fn filter_rhs<F>(self, pred: F) -> Self
where
F: FnMut(&usize) -> bool,
{
Self {
lhs: self.lhs,
rhs: self.rhs.iter().copied().filter(pred).collect(),
out: self.out,
}
}
pub fn filter_out<F>(self, pred: F) -> Self
where
F: FnMut(&usize) -> bool,
{
Self {
lhs: self.lhs,
rhs: self.rhs,
out: self.out.iter().copied().filter(pred).collect(),
}
}
pub fn pick_max(self) -> Result<MatmulVectorSizes, MatmulSetupError> {
let pick = |v: Vec<usize>| {
v.into_iter().max().ok_or(MatmulSetupError::Vectorization(
VectorizationError::NoValidVectorization,
))
};
Ok(MatmulVectorSizes {
lhs: pick(self.lhs)?,
rhs: pick(self.rhs)?,
out: pick(self.out)?,
})
}
}