cubecl_std/
option.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4#[derive(CubeType, Clone, Copy)]
5pub enum CubeOption<T: CubeType> {
6    Some(T),
7    None,
8}
9
10#[cube]
11impl<T: CubeType> CubeOption<T> {
12    pub fn is_some(&self) -> bool {
13        match self {
14            CubeOption::Some(_) => true,
15            CubeOption::None => false,
16        }
17    }
18
19    pub fn unwrap(self) -> T {
20        match self {
21            CubeOption::Some(val) => val,
22            CubeOption::None => panic!("Unwrap on a None CubeOption"),
23        }
24    }
25
26    pub fn is_none(&self) -> bool {
27        !self.is_some()
28    }
29
30    pub fn unwrap_or(self, fallback: T) -> T {
31        match self {
32            CubeOption::Some(val) => val,
33            CubeOption::None => fallback,
34        }
35    }
36}
37
38impl<T: CubeType> CubeOptionExpand<T> {
39    pub fn is_some(&self) -> bool {
40        match self {
41            CubeOptionExpand::Some(_) => true,
42            CubeOptionExpand::None => false,
43        }
44    }
45
46    pub fn unwrap(self) -> T::ExpandType {
47        match self {
48            Self::Some(val) => val,
49            Self::None => panic!("Unwrap on a None CubeOption"),
50        }
51    }
52
53    pub fn is_none(&self) -> bool {
54        !self.is_some()
55    }
56
57    pub fn unwrap_or(self, fallback: T::ExpandType) -> T::ExpandType {
58        match self {
59            CubeOptionExpand::Some(val) => val,
60            CubeOptionExpand::None => fallback,
61        }
62    }
63}
64
65// Manually implement CubeLaunch as the macro is currently not permissive enough.
66
67pub enum CubeOptionArgs<'a, T: CubeLaunch, R: Runtime> {
68    Some(<T as LaunchArg>::RuntimeArg<'a, R>),
69    None,
70}
71
72impl<'a, T: CubeLaunch, R: Runtime> From<Option<<T as LaunchArg>::RuntimeArg<'a, R>>>
73    for CubeOptionArgs<'a, T, R>
74{
75    fn from(value: Option<<T as LaunchArg>::RuntimeArg<'a, R>>) -> Self {
76        match value {
77            Some(arg) => Self::Some(arg),
78            None => Self::None,
79        }
80    }
81}
82
83impl<T: CubeLaunch, R: Runtime> ArgSettings<R> for CubeOptionArgs<'_, T, R> {
84    fn register(&self, launcher: &mut KernelLauncher<R>) {
85        match self {
86            CubeOptionArgs::Some(arg) => {
87                arg.register(launcher);
88            }
89            CubeOptionArgs::None => {}
90        }
91    }
92}
93impl<T: CubeLaunch> LaunchArg for CubeOption<T> {
94    type RuntimeArg<'a, R: Runtime> = CubeOptionArgs<'a, T, R>;
95
96    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
97        match runtime_arg {
98            CubeOptionArgs::Some(arg) => {
99                CubeOptionCompilationArg::Some(T::compilation_arg::<R>(arg))
100            }
101            CubeOptionArgs::None => CubeOptionCompilationArg::None,
102        }
103    }
104}
105
106#[derive(serde::Serialize, serde::Deserialize)]
107#[serde(bound(serialize = "", deserialize = ""))]
108pub enum CubeOptionCompilationArg<T: CubeLaunch> {
109    Some(<T as LaunchArgExpand>::CompilationArg),
110    None,
111}
112
113impl<T: CubeLaunch> Clone for CubeOptionCompilationArg<T> {
114    fn clone(&self) -> Self {
115        match self {
116            CubeOptionCompilationArg::Some(arg) => CubeOptionCompilationArg::Some(arg.clone()),
117            CubeOptionCompilationArg::None => CubeOptionCompilationArg::None,
118        }
119    }
120}
121
122impl<T: CubeLaunch> PartialEq for CubeOptionCompilationArg<T> {
123    fn eq(&self, other: &Self) -> bool {
124        match (self, other) {
125            (CubeOptionCompilationArg::Some(arg_0), CubeOptionCompilationArg::Some(arg_1)) => {
126                arg_0 == arg_1
127            }
128            (CubeOptionCompilationArg::None, CubeOptionCompilationArg::None) => true,
129            _ => false,
130        }
131    }
132}
133
134impl<T: CubeLaunch> Eq for CubeOptionCompilationArg<T> {}
135
136impl<T: CubeLaunch> core::hash::Hash for CubeOptionCompilationArg<T> {
137    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
138        match self {
139            CubeOptionCompilationArg::Some(arg) => {
140                arg.hash(state);
141            }
142            CubeOptionCompilationArg::None => {}
143        };
144    }
145}
146
147impl<T: CubeLaunch> core::fmt::Debug for CubeOptionCompilationArg<T> {
148    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
149        match self {
150            CubeOptionCompilationArg::Some(arg) => f
151                .debug_tuple("CubeOptionCompilationArg :: Some")
152                .field(arg)
153                .finish(),
154            CubeOptionCompilationArg::None => write!(f, "CubeOptionCompilationArg :: None"),
155        }
156    }
157}
158
159impl<T: CubeLaunch> CompilationArg for CubeOptionCompilationArg<T> {}
160
161impl<T: CubeLaunch> LaunchArgExpand for CubeOption<T> {
162    type CompilationArg = CubeOptionCompilationArg<T>;
163
164    fn expand(
165        arg: &Self::CompilationArg,
166        builder: &mut KernelBuilder,
167    ) -> <Self as CubeType>::ExpandType {
168        match arg {
169            CubeOptionCompilationArg::Some(arg) => CubeOptionExpand::Some(T::expand(arg, builder)),
170            CubeOptionCompilationArg::None => CubeOptionExpand::None,
171        }
172    }
173
174    fn expand_output(
175        arg: &Self::CompilationArg,
176        builder: &mut KernelBuilder,
177    ) -> <Self as CubeType>::ExpandType {
178        match arg {
179            CubeOptionCompilationArg::Some(arg) => {
180                CubeOptionExpand::Some(T::expand_output(arg, builder))
181            }
182            CubeOptionCompilationArg::None => CubeOptionExpand::None,
183        }
184    }
185}