cubecl_matmul/components/global/memory/
iterator.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::{View, layout::Coords2d};
4
5#[derive(Clone, CubeType)]
6pub struct GlobalIterator<EI: CubePrimitive> {
8 global_view: View<EI, Coords2d>,
9 offset: RuntimeCell<u32>,
10 step: u32,
12 view_size: Coords2d,
13 #[cube(comptime)]
14 view_direction: ViewDirection,
15 #[cube(comptime)]
16 checked: bool,
17}
18
19unsafe impl<EG: CubePrimitive> Sync for GlobalIterator<EG> {}
20unsafe impl<EG: CubePrimitive> Send for GlobalIterator<EG> {}
21
22#[derive(CubeType, Clone, Copy)]
23pub enum ViewDirection {
24 Row,
25 Col,
26 None,
28}
29
30#[cube]
31impl<EG: CubePrimitive> GlobalIterator<EG> {
32 pub fn new(
38 global_view: View<EG, Coords2d>,
39 step: u32,
40 #[comptime] view_direction: ViewDirection,
41 #[comptime] checked: bool,
42 ) -> Self {
43 let (size_row, size_col) = global_view.shape();
44 let view_size = match view_direction {
45 ViewDirection::Row => (step, size_col),
46 ViewDirection::Col => (size_row, step),
47 ViewDirection::None => (size_row, size_col),
48 };
49
50 GlobalIterator::<EG> {
51 global_view,
52 offset: RuntimeCell::new(0),
53 step,
54 view_size,
55 view_direction,
56 checked,
57 }
58 }
59
60 pub fn advance(&self) {
62 self.offset.store(self.offset.read() + self.step);
63 }
64
65 pub fn view(&self) -> View<EG, Coords2d> {
67 let offset = match comptime![self.view_direction] {
68 ViewDirection::Row => (self.offset.read(), 0u32),
69 ViewDirection::Col => (0u32, self.offset.read()),
70 ViewDirection::None => (0u32, 0u32).runtime(),
71 };
72 if comptime![self.checked] {
73 self.global_view.slice(offset, self.view_size)
74 } else {
75 self.global_view.slice_unchecked(offset, self.view_size)
76 }
77 }
78
79 pub fn line_size(&self) -> comptime_type!(u32) {
81 self.global_view.line_size()
82 }
83}