Skip to main content

cubek_std/tile/
scope.rs

1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4
5use crate::CubeDimResource;
6
7/// Identifies which compute primitive executes a tile matmul.
8pub trait TileScope: Clone + Copy + Send + Sync + 'static {
9    /// Compute resource a single instance of this scope occupies.
10    fn default_resource() -> CubeDimResource;
11
12    /// Comptime tag used at dispatch sites that need to assert a particular scope
13    /// (e.g. variants that only make sense on a plane).
14    const KIND: ScopeKind;
15}
16
17#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
18pub enum ScopeKind {
19    Unit,
20    Plane,
21    Cube,
22}
23
24#[derive(Clone, Copy)]
25pub struct Unit;
26#[derive(Clone, Copy)]
27pub struct Plane;
28#[derive(Clone, Copy)]
29pub struct Cube;
30
31impl TileScope for Unit {
32    fn default_resource() -> CubeDimResource {
33        CubeDimResource::Units(1)
34    }
35    const KIND: ScopeKind = ScopeKind::Unit;
36}
37impl TileScope for Plane {
38    fn default_resource() -> CubeDimResource {
39        CubeDimResource::Planes(1)
40    }
41    const KIND: ScopeKind = ScopeKind::Plane;
42}
43impl TileScope for Cube {
44    fn default_resource() -> CubeDimResource {
45        unimplemented!("Cube scope does not have a default cube-dim resource")
46    }
47    const KIND: ScopeKind = ScopeKind::Cube;
48}
49
50/// Zero-sized comptime marker used to carry a [Scope] generic through [Tile].
51#[derive(CubeType, Clone, Copy)]
52pub struct ScopeMarker<Sc: TileScope> {
53    #[cube(comptime)]
54    _phantom: PhantomData<Sc>,
55}
56
57/// Comptime assertion that a tile-scope generic resolves to `Plane`.
58pub fn assert_plane_scope(kind: ScopeKind) {
59    match kind {
60        ScopeKind::Plane => {}
61        _ => panic!("This Tile variant is only valid in Plane scope"),
62    }
63}
64
65/// Comptime assertion that a tile-scope generic resolves to `Unit`.
66pub fn assert_unit_scope(kind: ScopeKind) {
67    match kind {
68        ScopeKind::Unit => {}
69        _ => panic!("This Tile variant is only valid in Unit scope"),
70    }
71}