Skip to main content

cubek_std/tile/compute/
copy.rs

1use cubecl;
2use cubecl::prelude::*;
3
4use crate::{
5    StageIdent,
6    tile::{
7        MmaFragment, MmaFragmentExpand, Tile, TileExpand, TileScope,
8        compute::matmul::{
9            cmma::{cmma_load_from_shared, cmma_load_zeros, cmma_write_to_shared},
10            interleaved::{
11                interleaved_load_from_shared, interleaved_load_zeros, interleaved_write_to_shared,
12            },
13            mma::{
14                mma_load_acc_from_shared, mma_load_acc_zeros, mma_load_lhs_from_shared,
15                mma_load_rhs_from_shared, mma_write_to_shared,
16            },
17            plane_vec::{planevec_load_from_shared, planevec_load_zeros, planevec_write_to_shared},
18            register::{register_load_from_shared, register_load_zeros, register_write_to_shared},
19        },
20        data::BounceTile,
21    },
22};
23
24/// Internal `copy_from` between the `cmma` and `fragment` parts of a
25/// [`BounceTile`]: cmma -> smem -> fragment. Used by the high-level
26/// `softmax` / `scale_mul` / `scale_div` methods to make the fragment view
27/// current.
28#[cube]
29pub(crate) fn cmma_to_whitebox_fragment<E: Float>(b: &mut BounceTile<E>) {
30    let stride = comptime!(b.cmma.tile_size.n());
31    cubecl::cmma::store(
32        &mut b.smem,
33        &b.cmma.matrix,
34        stride,
35        cubecl::cmma::MatrixLayout::RowMajor,
36    );
37    sync_cube();
38    b.fragment.load_from_slice(&b.smem.to_slice());
39    sync_cube();
40}
41
42/// Internal `copy_from` between the `fragment` and `cmma` parts of a
43/// [`BounceTile`]: fragment -> smem -> cmma. Reverses
44/// [`cmma_to_whitebox_fragment`].
45#[cube]
46pub(crate) fn whitebox_fragment_to_cmma<E: Float>(b: &mut BounceTile<E>) {
47    let stride = comptime!(b.cmma.tile_size.n());
48    b.fragment.store_to(&mut b.smem);
49    sync_cube();
50    cubecl::cmma::load_with_layout(
51        &b.cmma.matrix,
52        &b.smem.to_slice(),
53        stride,
54        cubecl::cmma::MatrixLayout::RowMajor,
55    );
56}
57
58#[cube]
59impl<N: Numeric, Sc: TileScope> Tile<N, Sc, ReadWrite> {
60    /// Copies data from `source` into `self`.
61    ///
62    /// `SS` is the vector size of the shared memory tile involved in the copy
63    /// (whether that's the source on a load, or the destination on a write).
64    /// `L`/`R`/`A` are the matrix-level numeric types needed by the MMA
65    /// readers/writers — they are unused on non-MMA paths.
66    pub fn copy_from<
67        SE: Numeric,
68        SS: Size,
69        L: Numeric,
70        R: Numeric,
71        A: Numeric,
72        SIO: SliceVisibility,
73    >(
74        &mut self,
75        source: &Tile<SE, Sc, SIO>,
76        #[comptime] ident: StageIdent,
77    ) {
78        match (source, self) {
79            // --- Cmma loads ---
80            (Tile::SharedMemory(shared), Tile::Cmma(t)) => {
81                let shared = shared.view::<SS>();
82                cmma_load_from_shared::<SE, SS, N, SIO>(
83                    &shared,
84                    &mut t.matrix,
85                    ident,
86                    t.matrix_layout,
87                );
88            }
89            (Tile::None, Tile::Cmma(t)) => {
90                cmma_load_zeros::<N>(&mut t.matrix);
91            }
92
93            // --- Bounce loads (delegate to inner cmma) ---
94            (Tile::SharedMemory(shared), Tile::Bounce(b)) => {
95                let shared = shared.view::<SS>();
96                cmma_load_from_shared::<SE, SS, N, SIO>(
97                    &shared,
98                    &mut b.cmma.matrix,
99                    ident,
100                    b.cmma.matrix_layout,
101                );
102            }
103            (Tile::None, Tile::Bounce(b)) => {
104                cmma_load_zeros::<N>(&mut b.cmma.matrix);
105            }
106
107            // --- Mma loads ---
108            (Tile::SharedMemory(shared), Tile::Mma(t)) => {
109                let shared = shared.view::<SS>();
110                match &mut t.fragment {
111                    MmaFragment::Lhs(f) => mma_load_lhs_from_shared::<SE, SS, N, R, A, SIO>(
112                        &shared,
113                        f,
114                        t.matrix_layout,
115                        t.config,
116                    ),
117                    MmaFragment::Rhs(f) => mma_load_rhs_from_shared::<SE, SS, N, L, A, SIO>(
118                        &shared,
119                        f,
120                        t.matrix_layout,
121                        t.config,
122                    ),
123                    MmaFragment::Acc(f) => mma_load_acc_from_shared::<SE, SS, N, L, R, SIO>(
124                        &shared,
125                        f,
126                        t.matrix_layout,
127                        t.config,
128                    ),
129                }
130            }
131            (Tile::None, Tile::Mma(t)) => match &mut t.fragment {
132                MmaFragment::Acc(f) => {
133                    mma_load_acc_zeros::<SE, SS, N, L, R>(f, t.matrix_layout, t.config);
134                }
135                MmaFragment::Lhs(_) | MmaFragment::Rhs(_) => {
136                    panic!("Mma zero-load only supported for Acc role")
137                }
138            },
139
140            // --- Register loads ---
141            (Tile::SharedMemory(shared), Tile::Register(t)) => {
142                let shared = shared.view::<SS>();
143                register_load_from_shared::<SE, SS, N, SIO>(
144                    &shared,
145                    &mut t.data,
146                    t.matrix_layout,
147                    t.config,
148                    ident,
149                );
150            }
151            (Tile::None, Tile::Register(t)) => {
152                register_load_zeros::<N>(&mut t.data, t.config, ident);
153            }
154
155            // --- PlaneVec loads ---
156            (Tile::SharedMemory(shared), Tile::PlaneVec(t)) => {
157                let shared = shared.view::<SS>();
158                planevec_load_from_shared::<SE, SS, N, SIO>(&shared, &mut t.data, t.config, ident);
159            }
160            (Tile::None, Tile::PlaneVec(t)) => {
161                planevec_load_zeros::<N>(&mut t.data, t.config);
162            }
163
164            // --- Interleaved loads ---
165            (Tile::SharedMemory(shared), Tile::Interleaved(t)) => {
166                let shared = shared.view::<SS>();
167                interleaved_load_from_shared::<SE, SS, N, SIO>(
168                    &shared,
169                    &mut t.data,
170                    t.config,
171                    ident,
172                );
173            }
174            (Tile::None, Tile::Interleaved(t)) => {
175                interleaved_load_zeros::<N>(&mut t.data, t.config);
176            }
177
178            // --- Writes: shared memory copies from a compute container ---
179            (Tile::Cmma(t), Tile::SharedMemory(shared)) => {
180                let mut shared = shared.view::<SS>();
181                cmma_write_to_shared::<N, SS, SE>(&mut shared, &t.matrix);
182            }
183            (Tile::Bounce(b), Tile::SharedMemory(shared)) => {
184                let mut shared = shared.view::<SS>();
185                cmma_write_to_shared::<N, SS, SE>(&mut shared, &b.cmma.matrix);
186            }
187            (Tile::Mma(t), Tile::SharedMemory(shared)) => {
188                let mut shared = shared.view::<SS>();
189                match &t.fragment {
190                    MmaFragment::Acc(f) => {
191                        mma_write_to_shared::<N, SS, SE, L, R>(&mut shared, f, t.config);
192                    }
193                    MmaFragment::Lhs(_) | MmaFragment::Rhs(_) => {
194                        panic!("Mma write_to_shared only supported for Acc role")
195                    }
196                }
197            }
198            (Tile::Register(t), Tile::SharedMemory(shared)) => {
199                let mut shared = shared.view::<SS>();
200                register_write_to_shared::<N, SS, SE>(&mut shared, &t.data, t.config);
201            }
202            (Tile::PlaneVec(t), Tile::SharedMemory(shared)) => {
203                let mut shared = shared.view::<SS>();
204                planevec_write_to_shared::<SE, N, SS>(&mut shared, &t.data, t.config);
205            }
206            (Tile::Interleaved(t), Tile::SharedMemory(shared)) => {
207                let mut shared = shared.view::<SS>();
208                interleaved_write_to_shared::<N, SS, SE>(&mut shared, &t.data, t.config);
209            }
210
211            _ => panic!("Unsupported storage pair for copy_from"),
212        }
213    }
214}