1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4
5use crate::CubeDimResource;
6
7pub trait TileScope: Clone + Copy + Send + Sync + 'static {
9 fn default_resource() -> CubeDimResource;
11
12 const KIND: ScopeKind;
15}
16
17#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
18pub enum ScopeKind {
19 Unit,
20 Plane,
21 Cube,
22}
23
24#[derive(Clone, Copy)]
25pub struct Unit;
26#[derive(Clone, Copy)]
27pub struct Plane;
28#[derive(Clone, Copy)]
29pub struct Cube;
30
31impl TileScope for Unit {
32 fn default_resource() -> CubeDimResource {
33 CubeDimResource::Units(1)
34 }
35 const KIND: ScopeKind = ScopeKind::Unit;
36}
37impl TileScope for Plane {
38 fn default_resource() -> CubeDimResource {
39 CubeDimResource::Planes(1)
40 }
41 const KIND: ScopeKind = ScopeKind::Plane;
42}
43impl TileScope for Cube {
44 fn default_resource() -> CubeDimResource {
45 unimplemented!("Cube scope does not have a default cube-dim resource")
46 }
47 const KIND: ScopeKind = ScopeKind::Cube;
48}
49
50#[derive(CubeType, Clone, Copy)]
52pub struct ScopeMarker<Sc: TileScope> {
53 #[cube(comptime)]
54 _phantom: PhantomData<Sc>,
55}
56
57pub fn assert_plane_scope(kind: ScopeKind) {
59 match kind {
60 ScopeKind::Plane => {}
61 _ => panic!("This Tile variant is only valid in Plane scope"),
62 }
63}
64
65pub fn assert_unit_scope(kind: ScopeKind) {
67 match kind {
68 ScopeKind::Unit => {}
69 _ => panic!("This Tile variant is only valid in Unit scope"),
70 }
71}