1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use serde::{Deserialize, Serialize};
4
5#[derive(CubeType, Clone, Copy, Serialize, Deserialize, Hash, PartialEq, Eq, Debug)]
6pub enum CubeOption<T: CubeType> {
7 Some(T),
8 None,
9}
10
11#[cube]
12impl<T: CubeType> CubeOption<T> {
13 pub fn is_some(&self) -> bool {
14 match self {
15 CubeOption::Some(_) => true,
16 CubeOption::None => false,
17 }
18 .runtime()
19 }
20
21 pub fn unwrap(self) -> T {
22 match self {
23 CubeOption::Some(val) => val,
24 CubeOption::None => panic!("Unwrap on a None CubeOption"),
25 }
26 }
27
28 pub fn is_none(&self) -> bool {
29 !self.is_some()
30 }
31
32 pub fn unwrap_or(self, fallback: T) -> T {
33 match self {
34 CubeOption::Some(val) => val,
35 CubeOption::None => fallback,
36 }
37 }
38}
39
40impl<T: CubeType> CubeOptionExpand<T> {
41 pub fn is_some(&self) -> bool {
42 match self {
43 CubeOptionExpand::Some(_) => true,
44 CubeOptionExpand::None => false,
45 }
46 }
47
48 pub fn unwrap(self) -> T::ExpandType {
49 match self {
50 Self::Some(val) => val,
51 Self::None => panic!("Unwrap on a None CubeOption"),
52 }
53 }
54
55 pub fn is_none(&self) -> bool {
56 !self.is_some()
57 }
58
59 pub fn unwrap_or(self, fallback: T::ExpandType) -> T::ExpandType {
60 match self {
61 CubeOptionExpand::Some(val) => val,
62 CubeOptionExpand::None => fallback,
63 }
64 }
65}
66
67impl<T: CubeType + Into<T::ExpandType>> From<CubeOption<T>> for CubeOptionExpand<T> {
68 fn from(value: CubeOption<T>) -> Self {
69 match value {
70 CubeOption::Some(val) => CubeOptionExpand::Some(val.into()),
71 CubeOption::None => CubeOptionExpand::None,
72 }
73 }
74}
75
76pub enum CubeOptionArgs<'a, T: LaunchArg, R: Runtime> {
79 Some(<T as LaunchArg>::RuntimeArg<'a, R>),
80 None,
81}
82
83impl<'a, T: LaunchArg, R: Runtime> From<Option<<T as LaunchArg>::RuntimeArg<'a, R>>>
84 for CubeOptionArgs<'a, T, R>
85{
86 fn from(value: Option<<T as LaunchArg>::RuntimeArg<'a, R>>) -> Self {
87 match value {
88 Some(arg) => Self::Some(arg),
89 None => Self::None,
90 }
91 }
92}
93
94impl<T: LaunchArg, R: Runtime> ArgSettings<R> for CubeOptionArgs<'_, T, R> {
95 fn register(&self, launcher: &mut KernelLauncher<R>) {
96 match self {
97 CubeOptionArgs::Some(arg) => {
98 arg.register(launcher);
99 }
100 CubeOptionArgs::None => {}
101 }
102 }
103}
104impl<T: LaunchArg> LaunchArg for CubeOption<T> {
105 type RuntimeArg<'a, R: Runtime> = CubeOptionArgs<'a, T, R>;
106 type CompilationArg = CubeOptionCompilationArg<T>;
107
108 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
109 match runtime_arg {
110 CubeOptionArgs::Some(arg) => CubeOptionCompilationArg::Some(T::compilation_arg(arg)),
111 CubeOptionArgs::None => CubeOptionCompilationArg::None,
112 }
113 }
114
115 fn expand(
116 arg: &Self::CompilationArg,
117 builder: &mut KernelBuilder,
118 ) -> <Self as CubeType>::ExpandType {
119 match arg {
120 CubeOptionCompilationArg::Some(arg) => CubeOptionExpand::Some(T::expand(arg, builder)),
121 CubeOptionCompilationArg::None => CubeOptionExpand::None,
122 }
123 }
124
125 fn expand_output(
126 arg: &Self::CompilationArg,
127 builder: &mut KernelBuilder,
128 ) -> <Self as CubeType>::ExpandType {
129 match arg {
130 CubeOptionCompilationArg::Some(arg) => {
131 CubeOptionExpand::Some(T::expand_output(arg, builder))
132 }
133 CubeOptionCompilationArg::None => CubeOptionExpand::None,
134 }
135 }
136}
137
138pub enum CubeOptionCompilationArg<T: LaunchArg> {
139 Some(<T as LaunchArg>::CompilationArg),
140 None,
141}
142
143impl<T: LaunchArg> Clone for CubeOptionCompilationArg<T> {
144 fn clone(&self) -> Self {
145 match self {
146 CubeOptionCompilationArg::Some(arg) => CubeOptionCompilationArg::Some(arg.clone()),
147 CubeOptionCompilationArg::None => CubeOptionCompilationArg::None,
148 }
149 }
150}
151
152impl<T: LaunchArg> PartialEq for CubeOptionCompilationArg<T> {
153 fn eq(&self, other: &Self) -> bool {
154 match (self, other) {
155 (CubeOptionCompilationArg::Some(arg_0), CubeOptionCompilationArg::Some(arg_1)) => {
156 arg_0 == arg_1
157 }
158 (CubeOptionCompilationArg::None, CubeOptionCompilationArg::None) => true,
159 _ => false,
160 }
161 }
162}
163
164impl<T: LaunchArg> Eq for CubeOptionCompilationArg<T> {}
165
166impl<T: LaunchArg> core::hash::Hash for CubeOptionCompilationArg<T> {
167 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
168 match self {
169 CubeOptionCompilationArg::Some(arg) => {
170 arg.hash(state);
171 }
172 CubeOptionCompilationArg::None => {}
173 };
174 }
175}
176
177impl<T: LaunchArg> core::fmt::Debug for CubeOptionCompilationArg<T> {
178 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
179 match self {
180 CubeOptionCompilationArg::Some(arg) => f.debug_tuple("Some").field(arg).finish(),
181 CubeOptionCompilationArg::None => write!(f, "None"),
182 }
183 }
184}
185
186impl<T: LaunchArg> CompilationArg for CubeOptionCompilationArg<T> {}