cubecl_matmul/components/global/read/strategy/base.rs
1use crate::components::global::memory::GlobalIterator;
2use crate::components::global::{CopyMechanism, GlobalConfig};
3use crate::components::stage::{StridedStage, TilingLayout};
4use crate::components::{InvalidConfigError, MatmulIdent, MatrixPrecision};
5use cubecl_core as cubecl;
6use cubecl_core::prelude::*;
7
8#[cube]
9/// A loading job represents a sequence of loading tasks.
10/// Each task is the smallest unit of loading work:
11/// one unit at one iteration, operating at a specific point within a read view.
12/// The job holds shared information reused across read views and iterations.
13/// By calling execute_task at strategic moments, one can hope to speed up the matmul.
14pub trait LoadingJob<IP: MatrixPrecision, TL: TilingLayout>: CubeType + Copy + Clone {
15 /// Execute the `task_id`th loading task
16 fn execute_task<G: GlobalConfig>(
17 this: &mut Self,
18 #[comptime] task_id: u32,
19 tensor_reader: &GlobalIterator<Line<IP::Global>>,
20 stage_memory: &mut StridedStage<IP::Stage, TL>,
21 #[comptime] config: G,
22 );
23
24 /// Get the number of tasks
25 fn task_count(this: &Self) -> comptime_type!(u32);
26}
27
28#[cube]
29/// A loading job represents a sequence of loading tasks.
30/// Each task is the smallest unit of loading work:
31/// one unit at one iteration, operating at a specific point within a read view.
32/// The job holds shared information reused across read views and iterations.
33/// By calling execute_task at strategic moments, one can hope to speed up the matmul.
34pub trait AsyncLoadingJob<IP: MatrixPrecision, TL: TilingLayout>: CubeType + Copy + Clone {
35 /// Execute the `task_id`th loading task
36 fn execute_task<CM: CopyMechanism, G: GlobalConfig>(
37 this: &mut Self,
38 task_id: u32,
39 tensor_reader: &GlobalIterator<Line<IP::Global>>,
40 stage_memory: &mut StridedStage<IP::Stage, TL>,
41 mechanism: &CM,
42 #[comptime] config: G,
43 );
44
45 /// Get the number of tasks
46 fn task_count(this: &Self) -> comptime_type!(u32);
47}
48
49/// Allows to verify configs are valid for a reader
50pub trait LoadingValidation {
51 /// Verify that configs are valid for a reader, otherwise return an error stating why
52 fn check<C: GlobalConfig>(config: &C, ident: MatmulIdent) -> Result<(), InvalidConfigError>;
53}
54
55/// Dummy trait implementation
56pub struct NoLoadingValidation {}
57impl LoadingValidation for NoLoadingValidation {
58 fn check<C: GlobalConfig>(_config: &C, _ident: MatmulIdent) -> Result<(), InvalidConfigError> {
59 Ok(())
60 }
61}
62
63#[derive(Default, Copy, Clone, Debug, Hash, PartialEq, Eq)]
64/// Controls bounds checking for reader operations.
65///
66/// This **does not** disable tensor read bounds checks.
67/// It only affects checks for whether the reader loads more data than allowed
68/// at each global matmul iteration.
69pub enum ReaderMode {
70 /// Enforces compile-time validation of balanced workloads across units.
71 /// Restricts valid combinations of tile shape, count, and line size.
72 Strict,
73 /// Inserts runtime checks only when an out-of-bounds access will occur.
74 /// May reduce performance if workloads are imbalanced.
75 #[default]
76 Relaxed,
77}