1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use serde::{Deserialize, Serialize};
4
5#[derive(CubeType, Clone, Copy, Serialize, Deserialize, Hash, PartialEq, Eq, Debug)]
6pub enum CubeOption<T: CubeType> {
7 Some(T),
8 None,
9}
10
11#[cube]
12impl<T: CubeType> CubeOption<T> {
13 pub fn is_some(&self) -> bool {
14 match self {
15 CubeOption::Some(_) => true,
16 CubeOption::None => false,
17 }
18 }
19
20 pub fn unwrap(self) -> T {
21 match self {
22 CubeOption::Some(val) => val,
23 CubeOption::None => panic!("Unwrap on a None CubeOption"),
24 }
25 }
26
27 pub fn is_none(&self) -> bool {
28 !self.is_some()
29 }
30
31 pub fn unwrap_or(self, fallback: T) -> T {
32 match self {
33 CubeOption::Some(val) => val,
34 CubeOption::None => fallback,
35 }
36 }
37}
38
39impl<T: CubeType> CubeOptionExpand<T> {
40 pub fn is_some(&self) -> bool {
41 match self {
42 CubeOptionExpand::Some(_) => true,
43 CubeOptionExpand::None => false,
44 }
45 }
46
47 pub fn unwrap(self) -> T::ExpandType {
48 match self {
49 Self::Some(val) => val,
50 Self::None => panic!("Unwrap on a None CubeOption"),
51 }
52 }
53
54 pub fn is_none(&self) -> bool {
55 !self.is_some()
56 }
57
58 pub fn unwrap_or(self, fallback: T::ExpandType) -> T::ExpandType {
59 match self {
60 CubeOptionExpand::Some(val) => val,
61 CubeOptionExpand::None => fallback,
62 }
63 }
64}
65
66impl<T: CubeType + Into<T::ExpandType>> From<CubeOption<T>> for CubeOptionExpand<T> {
67 fn from(value: CubeOption<T>) -> Self {
68 match value {
69 CubeOption::Some(val) => CubeOptionExpand::Some(val.into()),
70 CubeOption::None => CubeOptionExpand::None,
71 }
72 }
73}
74
75pub enum CubeOptionArgs<'a, T: LaunchArg, R: Runtime> {
78 Some(<T as LaunchArg>::RuntimeArg<'a, R>),
79 None,
80}
81
82impl<'a, T: LaunchArg, R: Runtime> From<Option<<T as LaunchArg>::RuntimeArg<'a, R>>>
83 for CubeOptionArgs<'a, T, R>
84{
85 fn from(value: Option<<T as LaunchArg>::RuntimeArg<'a, R>>) -> Self {
86 match value {
87 Some(arg) => Self::Some(arg),
88 None => Self::None,
89 }
90 }
91}
92
93impl<T: LaunchArg, R: Runtime> ArgSettings<R> for CubeOptionArgs<'_, T, R> {
94 fn register(&self, launcher: &mut KernelLauncher<R>) {
95 match self {
96 CubeOptionArgs::Some(arg) => {
97 arg.register(launcher);
98 }
99 CubeOptionArgs::None => {}
100 }
101 }
102}
103impl<T: LaunchArg> LaunchArg for CubeOption<T> {
104 type RuntimeArg<'a, R: Runtime> = CubeOptionArgs<'a, T, R>;
105 type CompilationArg = CubeOptionCompilationArg<T>;
106
107 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
108 match runtime_arg {
109 CubeOptionArgs::Some(arg) => {
110 CubeOptionCompilationArg::Some(T::compilation_arg::<R>(arg))
111 }
112 CubeOptionArgs::None => CubeOptionCompilationArg::None,
113 }
114 }
115
116 fn expand(
117 arg: &Self::CompilationArg,
118 builder: &mut KernelBuilder,
119 ) -> <Self as CubeType>::ExpandType {
120 match arg {
121 CubeOptionCompilationArg::Some(arg) => CubeOptionExpand::Some(T::expand(arg, builder)),
122 CubeOptionCompilationArg::None => CubeOptionExpand::None,
123 }
124 }
125
126 fn expand_output(
127 arg: &Self::CompilationArg,
128 builder: &mut KernelBuilder,
129 ) -> <Self as CubeType>::ExpandType {
130 match arg {
131 CubeOptionCompilationArg::Some(arg) => {
132 CubeOptionExpand::Some(T::expand_output(arg, builder))
133 }
134 CubeOptionCompilationArg::None => CubeOptionExpand::None,
135 }
136 }
137}
138
139pub enum CubeOptionCompilationArg<T: LaunchArg> {
140 Some(<T as LaunchArg>::CompilationArg),
141 None,
142}
143
144impl<T: LaunchArg> Clone for CubeOptionCompilationArg<T> {
145 fn clone(&self) -> Self {
146 match self {
147 CubeOptionCompilationArg::Some(arg) => CubeOptionCompilationArg::Some(arg.clone()),
148 CubeOptionCompilationArg::None => CubeOptionCompilationArg::None,
149 }
150 }
151}
152
153impl<T: LaunchArg> PartialEq for CubeOptionCompilationArg<T> {
154 fn eq(&self, other: &Self) -> bool {
155 match (self, other) {
156 (CubeOptionCompilationArg::Some(arg_0), CubeOptionCompilationArg::Some(arg_1)) => {
157 arg_0 == arg_1
158 }
159 (CubeOptionCompilationArg::None, CubeOptionCompilationArg::None) => true,
160 _ => false,
161 }
162 }
163}
164
165impl<T: LaunchArg> Eq for CubeOptionCompilationArg<T> {}
166
167impl<T: LaunchArg> core::hash::Hash for CubeOptionCompilationArg<T> {
168 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
169 match self {
170 CubeOptionCompilationArg::Some(arg) => {
171 arg.hash(state);
172 }
173 CubeOptionCompilationArg::None => {}
174 };
175 }
176}
177
178impl<T: LaunchArg> core::fmt::Debug for CubeOptionCompilationArg<T> {
179 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
180 match self {
181 CubeOptionCompilationArg::Some(arg) => f
182 .debug_tuple("CubeOptionCompilationArg :: Some")
183 .field(arg)
184 .finish(),
185 CubeOptionCompilationArg::None => write!(f, "CubeOptionCompilationArg :: None"),
186 }
187 }
188}
189
190impl<T: LaunchArg> CompilationArg for CubeOptionCompilationArg<T> {}