use std::marker::PhantomData;
use cubecl_core::ir::Elem;
use cubecl_core::{
Runtime,
client::ComputeClient,
prelude::{Numeric, TensorHandleRef},
};
use cubecl_matmul::components::MatmulSelection;
use crate::{
base::ConvolutionProblem, homogeneous::simple::SimpleConvolutionFamily,
selection::convolution_matmul_selection,
};
use cubecl_matmul::components::stage::NumStages;
use cubecl_matmul::components::{
InputIdent,
global::args::TensorArgs,
stage::{FullReaderFamily, PlaneMatmulFamily},
tile::TileMatmulFamily,
};
use cubecl_std::tensor::{TensorHandle, into_contiguous};
use super::Algorithm;
pub struct SimpleConvAlgorithm<TMM: TileMatmulFamily> {
_tmm: PhantomData<TMM>,
}
impl<TMM: TileMatmulFamily> Algorithm for SimpleConvAlgorithm<TMM> {
type TileMatmul = TMM;
type StageMatmul = PlaneMatmulFamily<Self::TileMatmul, FullReaderFamily, FullReaderFamily>;
type GlobalConvolution = SimpleConvolutionFamily<Self::StageMatmul>;
type Args = TensorArgs;
fn into_tensor_handle<R: Runtime, E: Numeric>(
client: &ComputeClient<R::Server, R::Channel>,
handle: &TensorHandleRef<'_, R>,
ident: cubecl_matmul::components::InputIdent,
) -> TensorHandle<R, E> {
let rank = handle.shape.len();
let dim_c = rank - 1;
let mut handle = if has_valid_layout(handle, ident) {
TensorHandle::from_ref(handle)
} else {
into_contiguous(client, handle)
};
match ident {
InputIdent::Lhs => handle,
InputIdent::Rhs => {
handle.shape = vec![handle.shape[1..].iter().product(), handle.shape[0]];
handle.strides = vec![handle.strides[dim_c], handle.strides[0]];
handle
}
}
}
fn num_stages() -> NumStages {
(1, 1).into()
}
fn selection<R: Runtime>(
client: &ComputeClient<R::Server, R::Channel>,
problem: &ConvolutionProblem,
plane_dim: u32,
elem_stage: Elem,
elem_acc: Elem,
) -> MatmulSelection {
convolution_matmul_selection::<TMM, R>(client, problem, plane_dim, elem_stage, elem_acc)
}
}
fn has_valid_layout<R: Runtime>(handle: &TensorHandleRef<'_, R>, ident: InputIdent) -> bool {
let rank = handle.shape.len();
let dim_c = rank - 1;
match ident {
InputIdent::Lhs => handle.strides[dim_c] == 1,
InputIdent::Rhs => {
let mut strides = handle.strides.to_vec();
strides.sort();
let ordered = handle.strides == strides;
let mut contiguous_k = true;
for i in 1..dim_c {
contiguous_k &= strides[i] == strides[i + 1] * handle.shape[i + 1];
}
ordered && contiguous_k
}
}
}