Skip to main content

cubek_std/tile/
scope.rs

1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4
5use crate::CubeDimResource;
6
7/// Identifies which compute primitive executes a tile matmul.
8pub trait Scope: Clone + Copy + Send + Sync + 'static {
9    /// Compute resource a single instance of this scope occupies.
10    fn default_resource() -> CubeDimResource;
11
12    /// Comptime tag used at dispatch sites that need to assert a particular scope
13    /// (e.g. variants that only make sense on a plane).
14    const KIND: ScopeKind;
15}
16
17#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
18pub enum ScopeKind {
19    Unit,
20    Plane,
21    Cube,
22}
23
24#[derive(Clone, Copy)]
25pub struct Unit;
26#[derive(Clone, Copy)]
27pub struct Plane;
28#[derive(Clone, Copy)]
29pub struct Cube;
30
31impl Scope for Unit {
32    fn default_resource() -> CubeDimResource {
33        CubeDimResource::Units(1)
34    }
35    const KIND: ScopeKind = ScopeKind::Unit;
36}
37impl Scope for Plane {
38    fn default_resource() -> CubeDimResource {
39        CubeDimResource::Planes(1)
40    }
41    const KIND: ScopeKind = ScopeKind::Plane;
42}
43impl Scope for Cube {
44    fn default_resource() -> CubeDimResource {
45        unimplemented!("Cube scope does not have a default cube-dim resource")
46    }
47    const KIND: ScopeKind = ScopeKind::Cube;
48}
49
50/// Zero-sized comptime marker used to carry a [Scope] generic through [Tile].
51#[derive(CubeType, Clone, Copy)]
52pub struct ScopeMarker<Sc: Scope> {
53    #[cube(comptime)]
54    _phantom: PhantomData<Sc>,
55}
56
57/// Comptime assertion that a tile-scope generic resolves to `Plane`.
58/// Use at construction sites of plane-only variants (`Tile::Local`, `Tile::Bounce`).
59pub fn assert_plane_scope(kind: ScopeKind) {
60    match kind {
61        ScopeKind::Plane => {}
62        _ => panic!("This Tile variant is only valid in Plane scope"),
63    }
64}
65
66/// Comptime assertion that a tile-scope generic resolves to `Unit`.
67pub fn assert_unit_scope(kind: ScopeKind) {
68    match kind {
69        ScopeKind::Unit => {}
70        _ => panic!("This Tile variant is only valid in Unit scope"),
71    }
72}