use cubecl_core as cubecl;
use cubecl_core::prelude::*;
#[derive(CubeType)]
pub struct Im2colTmaReader<E: Numeric> {
pub tensor: TensorMap<Line<E>>,
pub n_offset: u32,
pub spatial_offsets: Sequence<u32>,
pub k_offset: u32,
}
#[cube]
impl<E: Numeric> Im2colTmaReader<E> {
#[allow(clippy::too_many_arguments)]
pub fn new(
tensor: TensorMap<Line<E>>,
n_offset: u32,
spatial_offsets: Sequence<u32>,
k_offset: u32,
) -> Im2colTmaReader<E> {
Im2colTmaReader::<E> {
tensor,
n_offset,
spatial_offsets,
k_offset,
}
}
pub fn update_view(&mut self, k_offset: u32) {
self.k_offset += k_offset;
}
}
unsafe impl<E: Numeric> Sync for Im2colTmaReader<E> {}
unsafe impl<E: Numeric> Send for Im2colTmaReader<E> {}