1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4
5use crate::CubeDimResource;
6
7pub trait Scope: Clone + Copy + Send + Sync + 'static {
9 fn default_resource() -> CubeDimResource;
11
12 const KIND: ScopeKind;
15}
16
17#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
18pub enum ScopeKind {
19 Unit,
20 Plane,
21 Cube,
22}
23
24#[derive(Clone, Copy)]
25pub struct Unit;
26#[derive(Clone, Copy)]
27pub struct Plane;
28#[derive(Clone, Copy)]
29pub struct Cube;
30
31impl Scope for Unit {
32 fn default_resource() -> CubeDimResource {
33 CubeDimResource::Units(1)
34 }
35 const KIND: ScopeKind = ScopeKind::Unit;
36}
37impl Scope for Plane {
38 fn default_resource() -> CubeDimResource {
39 CubeDimResource::Planes(1)
40 }
41 const KIND: ScopeKind = ScopeKind::Plane;
42}
43impl Scope for Cube {
44 fn default_resource() -> CubeDimResource {
45 unimplemented!("Cube scope does not have a default cube-dim resource")
46 }
47 const KIND: ScopeKind = ScopeKind::Cube;
48}
49
50#[derive(CubeType, Clone, Copy)]
52pub struct ScopeMarker<Sc: Scope> {
53 #[cube(comptime)]
54 _phantom: PhantomData<Sc>,
55}
56
57pub fn assert_plane_scope(kind: ScopeKind) {
60 match kind {
61 ScopeKind::Plane => {}
62 _ => panic!("This Tile variant is only valid in Plane scope"),
63 }
64}
65
66pub fn assert_unit_scope(kind: ScopeKind) {
68 match kind {
69 ScopeKind::Unit => {}
70 _ => panic!("This Tile variant is only valid in Unit scope"),
71 }
72}