1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4#[derive(CubeType, Clone, Copy)]
5pub enum CubeOption<T: CubeType> {
6 Some(T),
7 None,
8}
9
10#[cube]
11impl<T: CubeType> CubeOption<T> {
12 pub fn is_some(&self) -> bool {
13 match self {
14 CubeOption::Some(_) => true,
15 CubeOption::None => false,
16 }
17 }
18
19 pub fn unwrap(self) -> T {
20 match self {
21 CubeOption::Some(val) => val,
22 CubeOption::None => panic!("Unwrap on a None CubeOption"),
23 }
24 }
25
26 pub fn is_none(&self) -> bool {
27 !self.is_some()
28 }
29
30 pub fn unwrap_or(self, fallback: T) -> T {
31 match self {
32 CubeOption::Some(val) => val,
33 CubeOption::None => fallback,
34 }
35 }
36}
37
38impl<T: CubeType> CubeOptionExpand<T> {
39 pub fn is_some(&self) -> bool {
40 match self {
41 CubeOptionExpand::Some(_) => true,
42 CubeOptionExpand::None => false,
43 }
44 }
45
46 pub fn unwrap(self) -> T::ExpandType {
47 match self {
48 Self::Some(val) => val,
49 Self::None => panic!("Unwrap on a None CubeOption"),
50 }
51 }
52
53 pub fn is_none(&self) -> bool {
54 !self.is_some()
55 }
56
57 pub fn unwrap_or(self, fallback: T::ExpandType) -> T::ExpandType {
58 match self {
59 CubeOptionExpand::Some(val) => val,
60 CubeOptionExpand::None => fallback,
61 }
62 }
63}
64
65pub enum CubeOptionArgs<'a, T: CubeLaunch, R: Runtime> {
68 Some(<T as LaunchArg>::RuntimeArg<'a, R>),
69 None,
70}
71
72impl<'a, T: CubeLaunch, R: Runtime> From<Option<<T as LaunchArg>::RuntimeArg<'a, R>>>
73 for CubeOptionArgs<'a, T, R>
74{
75 fn from(value: Option<<T as LaunchArg>::RuntimeArg<'a, R>>) -> Self {
76 match value {
77 Some(arg) => Self::Some(arg),
78 None => Self::None,
79 }
80 }
81}
82
83impl<T: CubeLaunch, R: Runtime> ArgSettings<R> for CubeOptionArgs<'_, T, R> {
84 fn register(&self, launcher: &mut KernelLauncher<R>) {
85 match self {
86 CubeOptionArgs::Some(arg) => {
87 arg.register(launcher);
88 }
89 CubeOptionArgs::None => {}
90 }
91 }
92}
93impl<T: CubeLaunch> LaunchArg for CubeOption<T> {
94 type RuntimeArg<'a, R: Runtime> = CubeOptionArgs<'a, T, R>;
95
96 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
97 match runtime_arg {
98 CubeOptionArgs::Some(arg) => {
99 CubeOptionCompilationArg::Some(T::compilation_arg::<R>(arg))
100 }
101 CubeOptionArgs::None => CubeOptionCompilationArg::None,
102 }
103 }
104}
105
106#[derive(serde::Serialize, serde::Deserialize)]
107#[serde(bound(serialize = "", deserialize = ""))]
108pub enum CubeOptionCompilationArg<T: CubeLaunch> {
109 Some(<T as LaunchArgExpand>::CompilationArg),
110 None,
111}
112
113impl<T: CubeLaunch> Clone for CubeOptionCompilationArg<T> {
114 fn clone(&self) -> Self {
115 match self {
116 CubeOptionCompilationArg::Some(arg) => CubeOptionCompilationArg::Some(arg.clone()),
117 CubeOptionCompilationArg::None => CubeOptionCompilationArg::None,
118 }
119 }
120}
121
122impl<T: CubeLaunch> PartialEq for CubeOptionCompilationArg<T> {
123 fn eq(&self, other: &Self) -> bool {
124 match (self, other) {
125 (CubeOptionCompilationArg::Some(arg_0), CubeOptionCompilationArg::Some(arg_1)) => {
126 arg_0 == arg_1
127 }
128 (CubeOptionCompilationArg::None, CubeOptionCompilationArg::None) => true,
129 _ => false,
130 }
131 }
132}
133
134impl<T: CubeLaunch> Eq for CubeOptionCompilationArg<T> {}
135
136impl<T: CubeLaunch> core::hash::Hash for CubeOptionCompilationArg<T> {
137 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
138 match self {
139 CubeOptionCompilationArg::Some(arg) => {
140 arg.hash(state);
141 }
142 CubeOptionCompilationArg::None => {}
143 };
144 }
145}
146
147impl<T: CubeLaunch> core::fmt::Debug for CubeOptionCompilationArg<T> {
148 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
149 match self {
150 CubeOptionCompilationArg::Some(arg) => f
151 .debug_tuple("CubeOptionCompilationArg :: Some")
152 .field(arg)
153 .finish(),
154 CubeOptionCompilationArg::None => write!(f, "CubeOptionCompilationArg :: None"),
155 }
156 }
157}
158
159impl<T: CubeLaunch> CompilationArg for CubeOptionCompilationArg<T> {}
160
161impl<T: CubeLaunch> LaunchArgExpand for CubeOption<T> {
162 type CompilationArg = CubeOptionCompilationArg<T>;
163
164 fn expand(
165 arg: &Self::CompilationArg,
166 builder: &mut KernelBuilder,
167 ) -> <Self as CubeType>::ExpandType {
168 match arg {
169 CubeOptionCompilationArg::Some(arg) => CubeOptionExpand::Some(T::expand(arg, builder)),
170 CubeOptionCompilationArg::None => CubeOptionExpand::None,
171 }
172 }
173
174 fn expand_output(
175 arg: &Self::CompilationArg,
176 builder: &mut KernelBuilder,
177 ) -> <Self as CubeType>::ExpandType {
178 match arg {
179 CubeOptionCompilationArg::Some(arg) => {
180 CubeOptionExpand::Some(T::expand_output(arg, builder))
181 }
182 CubeOptionCompilationArg::None => CubeOptionExpand::None,
183 }
184 }
185}