1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use serde::{Deserialize, Serialize};
4
5#[derive(CubeType, Clone, Copy, Serialize, Deserialize, Hash, PartialEq, Eq, Debug)]
6pub enum CubeOption<T: CubeType> {
7 Some(T),
8 None,
9}
10
11#[cube]
12impl<T: CubeType> CubeOption<T> {
13 pub fn is_some(&self) -> bool {
14 match self {
15 CubeOption::Some(_) => true,
16 CubeOption::None => false,
17 }
18 }
19
20 pub fn unwrap(self) -> T {
21 match self {
22 CubeOption::Some(val) => val,
23 CubeOption::None => panic!("Unwrap on a None CubeOption"),
24 }
25 }
26
27 pub fn is_none(&self) -> bool {
28 !self.is_some()
29 }
30
31 pub fn unwrap_or(self, fallback: T) -> T {
32 match self {
33 CubeOption::Some(val) => val,
34 CubeOption::None => fallback,
35 }
36 }
37}
38
39impl<T: CubeType> CubeOptionExpand<T> {
40 pub fn is_some(&self) -> bool {
41 match self {
42 CubeOptionExpand::Some(_) => true,
43 CubeOptionExpand::None => false,
44 }
45 }
46
47 pub fn unwrap(self) -> T::ExpandType {
48 match self {
49 Self::Some(val) => val,
50 Self::None => panic!("Unwrap on a None CubeOption"),
51 }
52 }
53
54 pub fn is_none(&self) -> bool {
55 !self.is_some()
56 }
57
58 pub fn unwrap_or(self, fallback: T::ExpandType) -> T::ExpandType {
59 match self {
60 CubeOptionExpand::Some(val) => val,
61 CubeOptionExpand::None => fallback,
62 }
63 }
64}
65
66impl<T: CubeType + Into<T::ExpandType>> From<CubeOption<T>> for CubeOptionExpand<T> {
67 fn from(value: CubeOption<T>) -> Self {
68 match value {
69 CubeOption::Some(val) => CubeOptionExpand::Some(val.into()),
70 CubeOption::None => CubeOptionExpand::None,
71 }
72 }
73}
74
75pub enum CubeOptionArgs<'a, T: LaunchArg, R: Runtime> {
78 Some(<T as LaunchArg>::RuntimeArg<'a, R>),
79 None,
80}
81
82impl<'a, T: LaunchArg, R: Runtime> From<Option<<T as LaunchArg>::RuntimeArg<'a, R>>>
83 for CubeOptionArgs<'a, T, R>
84{
85 fn from(value: Option<<T as LaunchArg>::RuntimeArg<'a, R>>) -> Self {
86 match value {
87 Some(arg) => Self::Some(arg),
88 None => Self::None,
89 }
90 }
91}
92
93impl<T: LaunchArg, R: Runtime> ArgSettings<R> for CubeOptionArgs<'_, T, R> {
94 fn register(&self, launcher: &mut KernelLauncher<R>) {
95 match self {
96 CubeOptionArgs::Some(arg) => {
97 arg.register(launcher);
98 }
99 CubeOptionArgs::None => {}
100 }
101 }
102}
103impl<T: LaunchArg> LaunchArg for CubeOption<T> {
104 type RuntimeArg<'a, R: Runtime> = CubeOptionArgs<'a, T, R>;
105 type CompilationArg = CubeOptionCompilationArg<T>;
106
107 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
108 match runtime_arg {
109 CubeOptionArgs::Some(arg) => CubeOptionCompilationArg::Some(T::compilation_arg(arg)),
110 CubeOptionArgs::None => CubeOptionCompilationArg::None,
111 }
112 }
113
114 fn expand(
115 arg: &Self::CompilationArg,
116 builder: &mut KernelBuilder,
117 ) -> <Self as CubeType>::ExpandType {
118 match arg {
119 CubeOptionCompilationArg::Some(arg) => CubeOptionExpand::Some(T::expand(arg, builder)),
120 CubeOptionCompilationArg::None => CubeOptionExpand::None,
121 }
122 }
123
124 fn expand_output(
125 arg: &Self::CompilationArg,
126 builder: &mut KernelBuilder,
127 ) -> <Self as CubeType>::ExpandType {
128 match arg {
129 CubeOptionCompilationArg::Some(arg) => {
130 CubeOptionExpand::Some(T::expand_output(arg, builder))
131 }
132 CubeOptionCompilationArg::None => CubeOptionExpand::None,
133 }
134 }
135}
136
137pub enum CubeOptionCompilationArg<T: LaunchArg> {
138 Some(<T as LaunchArg>::CompilationArg),
139 None,
140}
141
142impl<T: LaunchArg> Clone for CubeOptionCompilationArg<T> {
143 fn clone(&self) -> Self {
144 match self {
145 CubeOptionCompilationArg::Some(arg) => CubeOptionCompilationArg::Some(arg.clone()),
146 CubeOptionCompilationArg::None => CubeOptionCompilationArg::None,
147 }
148 }
149}
150
151impl<T: LaunchArg> PartialEq for CubeOptionCompilationArg<T> {
152 fn eq(&self, other: &Self) -> bool {
153 match (self, other) {
154 (CubeOptionCompilationArg::Some(arg_0), CubeOptionCompilationArg::Some(arg_1)) => {
155 arg_0 == arg_1
156 }
157 (CubeOptionCompilationArg::None, CubeOptionCompilationArg::None) => true,
158 _ => false,
159 }
160 }
161}
162
163impl<T: LaunchArg> Eq for CubeOptionCompilationArg<T> {}
164
165impl<T: LaunchArg> core::hash::Hash for CubeOptionCompilationArg<T> {
166 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
167 match self {
168 CubeOptionCompilationArg::Some(arg) => {
169 arg.hash(state);
170 }
171 CubeOptionCompilationArg::None => {}
172 };
173 }
174}
175
176impl<T: LaunchArg> core::fmt::Debug for CubeOptionCompilationArg<T> {
177 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
178 match self {
179 CubeOptionCompilationArg::Some(arg) => f
180 .debug_tuple("CubeOptionCompilationArg :: Some")
181 .field(arg)
182 .finish(),
183 CubeOptionCompilationArg::None => write!(f, "CubeOptionCompilationArg :: None"),
184 }
185 }
186}
187
188impl<T: LaunchArg> CompilationArg for CubeOptionCompilationArg<T> {}