cubecl_matmul/components/global/memory/
iterator.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::{View, layout::Coords2d};
4
5#[derive(Clone, CubeType)]
6/// An iterator over global memory, advancing along k.
7pub struct GlobalIterator<EI: CubePrimitive> {
8    global_view: View<EI, Coords2d>,
9    offset: RuntimeCell<u32>,
10    /// The amount to advance by on each iteration
11    step: u32,
12    view_size: Coords2d,
13    #[cube(comptime)]
14    view_direction: ViewDirection,
15    #[cube(comptime)]
16    checked: bool,
17}
18
19unsafe impl<EG: CubePrimitive> Sync for GlobalIterator<EG> {}
20unsafe impl<EG: CubePrimitive> Send for GlobalIterator<EG> {}
21
22#[derive(CubeType, Clone, Copy)]
23pub enum ViewDirection {
24    Row,
25    Col,
26    /// Cannot advance if direction is none
27    None,
28}
29
30#[cube]
31impl<EG: CubePrimitive> GlobalIterator<EG> {
32    /// Instantiate a read iterator over the given global view, which should be sliced to the size
33    /// of one `m`/`n` stage and the full range of `k` handled by this matmul instance.
34    ///
35    /// `step` is the amount advanced in `view_direction` each iteration.
36    /// `checked` determines whether the slices should be created as checked or unchecked.
37    pub fn new(
38        global_view: View<EG, Coords2d>,
39        step: u32,
40        #[comptime] view_direction: ViewDirection,
41        #[comptime] checked: bool,
42    ) -> Self {
43        let (size_row, size_col) = global_view.shape();
44        let view_size = match view_direction {
45            ViewDirection::Row => (step, size_col),
46            ViewDirection::Col => (size_row, step),
47            ViewDirection::None => (size_row, size_col),
48        };
49
50        GlobalIterator::<EG> {
51            global_view,
52            offset: RuntimeCell::new(0),
53            step,
54            view_size,
55            view_direction,
56            checked,
57        }
58    }
59
60    /// Advance the view along the k dimension by a specified offset, `k_offset`.
61    pub fn advance(&self) {
62        self.offset.store(self.offset.read() + self.step);
63    }
64
65    /// Returns the current view slice of the iterator
66    pub fn view(&self) -> View<EG, Coords2d> {
67        let offset = match comptime![self.view_direction] {
68            ViewDirection::Row => (self.offset.read(), 0u32),
69            ViewDirection::Col => (0u32, self.offset.read()),
70            ViewDirection::None => (0u32, 0u32).runtime(),
71        };
72        if comptime![self.checked] {
73            self.global_view.slice(offset, self.view_size)
74        } else {
75            self.global_view.slice_unchecked(offset, self.view_size)
76        }
77    }
78
79    /// Returns the line size of the global view
80    pub fn line_size(&self) -> comptime_type!(u32) {
81        self.global_view.line_size()
82    }
83}