cubecl_matmul/components/global/read/strategy/
base.rs

1use crate::components::global::memory::GlobalIterator;
2use crate::components::global::{CopyMechanism, GlobalConfig};
3use crate::components::stage::{StridedStage, TilingLayout};
4use crate::components::{InvalidConfigError, MatmulIdent, MatrixPrecision};
5use cubecl_core as cubecl;
6use cubecl_core::prelude::*;
7
8#[cube]
9/// A loading job represents a sequence of loading tasks.
10/// Each task is the smallest unit of loading work:
11/// one unit at one iteration, operating at a specific point within a read view.
12/// The job holds shared information reused across read views and iterations.
13/// By calling execute_task at strategic moments, one can hope to speed up the matmul.
14pub trait LoadingJob<IP: MatrixPrecision, TL: TilingLayout>: CubeType + Copy + Clone {
15    /// Execute the `task_id`th loading task
16    fn execute_task<G: GlobalConfig>(
17        this: &mut Self,
18        #[comptime] task_id: u32,
19        tensor_reader: &GlobalIterator<Line<IP::Global>>,
20        stage_memory: &mut StridedStage<IP::Stage, TL>,
21        #[comptime] config: G,
22    );
23
24    /// Get the number of tasks
25    fn task_count(this: &Self) -> comptime_type!(u32);
26}
27
28#[cube]
29/// A loading job represents a sequence of loading tasks.
30/// Each task is the smallest unit of loading work:
31/// one unit at one iteration, operating at a specific point within a read view.
32/// The job holds shared information reused across read views and iterations.
33/// By calling execute_task at strategic moments, one can hope to speed up the matmul.
34pub trait AsyncLoadingJob<IP: MatrixPrecision, TL: TilingLayout>: CubeType + Copy + Clone {
35    /// Execute the `task_id`th loading task
36    fn execute_task<CM: CopyMechanism, G: GlobalConfig>(
37        this: &mut Self,
38        task_id: u32,
39        tensor_reader: &GlobalIterator<Line<IP::Global>>,
40        stage_memory: &mut StridedStage<IP::Stage, TL>,
41        mechanism: &CM,
42        #[comptime] config: G,
43    );
44
45    /// Get the number of tasks
46    fn task_count(this: &Self) -> comptime_type!(u32);
47}
48
49/// Allows to verify configs are valid for a reader
50pub trait LoadingValidation {
51    /// Verify that configs are valid for a reader, otherwise return an error stating why
52    fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError>;
53}
54
55/// Dummy trait implementation
56pub struct NoLoadingValidation {}
57impl LoadingValidation for NoLoadingValidation {
58    fn check<C: GlobalConfig>(_config: &C, _ident: MatmulIdent) -> Result<(), InvalidConfigError> {
59        Ok(())
60    }
61}
62
63#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
64/// Controls bounds checking for reader operations.
65///
66/// This **does not** disable tensor read bounds checks.
67/// It only affects checks for whether the reader loads more data than allowed
68/// at each global matmul iteration.
69pub enum ReaderMode {
70    /// Enforces compile-time validation of balanced workloads across units.
71    /// Restricts valid combinations of tile shape, count, and line size.
72    Strict,
73    /// Inserts runtime checks only when an out-of-bounds access will occur.
74    /// May reduce performance if workloads are imbalanced.
75    #[default]
76    Relaxed,
77}