1use crate::{self as cubecl, ExpandType};
2use cubecl::prelude::*;
3
4#[doc(hidden)]
5#[derive(Default)]
6pub enum OptionExpand<T: CubeType> {
7 Some(<T as CubeType>::ExpandType),
8 #[default]
9 None,
10}
11
12impl<T: CubeType> CubeType for Option<T> {
13 type ExpandType = OptionExpand<T>;
14}
15
16impl<T: CubeType> IntoMut for OptionExpand<T> {
17 fn into_mut(self, scope: &mut cubecl::prelude::Scope) -> Self {
18 match self {
19 OptionExpand::Some(arg_0) => OptionExpand::Some(IntoMut::into_mut(arg_0, scope)),
20 OptionExpand::None => OptionExpand::None,
21 }
22 }
23}
24
25impl<T: CubeType> CubeDebug for Option<T> {}
26impl<T: CubeType> CubeDebug for OptionExpand<T> {}
27
28impl<T: CubeType> Clone for OptionExpand<T> {
29 fn clone(&self) -> Self {
30 match self {
31 OptionExpand::Some(arg_0) => OptionExpand::Some(arg_0.clone()),
32 OptionExpand::None => OptionExpand::None,
33 }
34 }
35}
36
37#[allow(non_snake_case)]
39pub trait CubeOption<T: CubeType> {
40 fn new_Some(_0: T) -> Option<T> {
42 Option::Some(_0)
43 }
44 fn new_None() -> Option<T> {
46 Option::None
47 }
48 #[doc(hidden)]
49 fn __expand_Some(_scope: &mut Scope, _0: ExpandType<T>) -> OptionExpand<T> {
50 OptionExpand::Some(_0)
51 }
52 #[doc(hidden)]
53 fn __expand_new_Some(_scope: &mut Scope, _0: ExpandType<T>) -> OptionExpand<T> {
54 OptionExpand::Some(_0)
55 }
56 #[doc(hidden)]
57 fn __expand_new_None(_scope: &mut Scope) -> OptionExpand<T> {
58 OptionExpand::None
59 }
60}
61
62impl<T: CubeType> CubeOption<T> for Option<T> {}
63
64impl<T: CubeType> OptionExpand<T> {
65 pub fn is_some(&self) -> bool {
66 match self {
67 OptionExpand::Some(_) => true,
68 OptionExpand::None => false,
69 }
70 }
71
72 pub fn unwrap(self) -> T::ExpandType {
73 match self {
74 Self::Some(val) => val,
75 Self::None => panic!("Unwrap on a None CubeOption"),
76 }
77 }
78
79 pub fn is_none(&self) -> bool {
80 !self.is_some()
81 }
82
83 pub fn unwrap_or(self, fallback: T::ExpandType) -> T::ExpandType {
84 match self {
85 OptionExpand::Some(val) => val,
86 OptionExpand::None => fallback,
87 }
88 }
89}
90
91pub enum OptionArgs<'a, T: LaunchArg, R: Runtime> {
92 Some(<T as LaunchArg>::RuntimeArg<'a, R>),
93 None,
94}
95
96impl<'a, T: LaunchArg, R: Runtime> From<Option<<T as LaunchArg>::RuntimeArg<'a, R>>>
97 for OptionArgs<'a, T, R>
98{
99 fn from(value: Option<<T as LaunchArg>::RuntimeArg<'a, R>>) -> Self {
100 match value {
101 Some(arg) => Self::Some(arg),
102 None => Self::None,
103 }
104 }
105}
106
107impl<T: LaunchArg, R: Runtime> ArgSettings<R> for OptionArgs<'_, T, R> {
108 fn register(&self, launcher: &mut KernelLauncher<R>) {
109 match self {
110 OptionArgs::Some(arg) => {
111 arg.register(launcher);
112 }
113 OptionArgs::None => {}
114 }
115 }
116}
117impl<T: LaunchArg> LaunchArg for Option<T> {
118 type RuntimeArg<'a, R: Runtime> = OptionArgs<'a, T, R>;
119 type CompilationArg = OptionCompilationArg<T>;
120
121 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
122 match runtime_arg {
123 OptionArgs::Some(arg) => OptionCompilationArg::Some(T::compilation_arg(arg)),
124 OptionArgs::None => OptionCompilationArg::None,
125 }
126 }
127
128 fn expand(
129 arg: &Self::CompilationArg,
130 builder: &mut KernelBuilder,
131 ) -> <Self as CubeType>::ExpandType {
132 match arg {
133 OptionCompilationArg::Some(arg) => OptionExpand::Some(T::expand(arg, builder)),
134 OptionCompilationArg::None => OptionExpand::None,
135 }
136 }
137
138 fn expand_output(
139 arg: &Self::CompilationArg,
140 builder: &mut KernelBuilder,
141 ) -> <Self as CubeType>::ExpandType {
142 match arg {
143 OptionCompilationArg::Some(arg) => OptionExpand::Some(T::expand_output(arg, builder)),
144 OptionCompilationArg::None => OptionExpand::None,
145 }
146 }
147}
148
149pub enum OptionCompilationArg<T: LaunchArg> {
150 Some(<T as LaunchArg>::CompilationArg),
151 None,
152}
153
154impl<T: LaunchArg> Clone for OptionCompilationArg<T> {
155 fn clone(&self) -> Self {
156 match self {
157 OptionCompilationArg::Some(arg) => OptionCompilationArg::Some(arg.clone()),
158 OptionCompilationArg::None => OptionCompilationArg::None,
159 }
160 }
161}
162
163impl<T: LaunchArg> PartialEq for OptionCompilationArg<T> {
164 fn eq(&self, other: &Self) -> bool {
165 match (self, other) {
166 (OptionCompilationArg::Some(arg_0), OptionCompilationArg::Some(arg_1)) => {
167 arg_0 == arg_1
168 }
169 (OptionCompilationArg::None, OptionCompilationArg::None) => true,
170 _ => false,
171 }
172 }
173}
174
175impl<T: LaunchArg> Eq for OptionCompilationArg<T> {}
176
177impl<T: LaunchArg> core::hash::Hash for OptionCompilationArg<T> {
178 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
179 match self {
180 OptionCompilationArg::Some(arg) => {
181 arg.hash(state);
182 }
183 OptionCompilationArg::None => {}
184 };
185 }
186}
187
188impl<T: LaunchArg> core::fmt::Debug for OptionCompilationArg<T> {
189 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
190 match self {
191 OptionCompilationArg::Some(arg) => f.debug_tuple("Some").field(arg).finish(),
192 OptionCompilationArg::None => write!(f, "None"),
193 }
194 }
195}
196
197impl<T: LaunchArg> CompilationArg for OptionCompilationArg<T> {}
198
199mod impls {
200 use core::ops::{Deref, DerefMut};
201
202 use super::*;
203 use OptionExpand::{None, Some};
204
205 #[doc(hidden)]
206 impl<T: CubeType> OptionExpand<T> {
207 pub fn __expand_is_some_method(&self, _scope: &mut Scope) -> bool {
208 matches!(*self, Some(_))
209 }
210
211 pub fn __expand_is_some_and_method(
212 self,
213 scope: &mut Scope,
214 f: impl FnOnce(&mut Scope, T::ExpandType) -> bool,
215 ) -> bool {
216 match self {
217 None => false,
218 Some(x) => f(scope, x),
219 }
220 }
221
222 pub fn __expand_is_none_method(&self, scope: &mut Scope) -> bool {
223 !self.__expand_is_some_method(scope)
224 }
225
226 pub fn __expand_is_none_or_method(
227 self,
228 scope: &mut Scope,
229 f: impl FnOnce(&mut Scope, T::ExpandType) -> bool,
230 ) -> bool {
231 match self {
232 None => true,
233 Some(x) => f(scope, x),
234 }
235 }
236
237 pub fn __expand_as_ref_method(self, _scope: &mut Scope) -> Self {
238 self
239 }
240
241 pub fn __expand_as_mut_method(self, _scope: &mut Scope) -> Self {
242 self
243 }
244
245 fn __expand_len_method(&self, _scope: &mut Scope) -> usize {
246 match self {
247 Some(_) => 1,
248 None => 0,
249 }
250 }
251
252 #[allow(clippy::unnecessary_literal_unwrap)]
253 pub fn __expand_expect_method(self, _scope: &mut Scope, msg: &str) -> T::ExpandType {
254 match self {
255 Some(val) => val,
256 None => Option::None.expect(msg),
257 }
258 }
259
260 #[allow(clippy::unnecessary_literal_unwrap)]
261 pub fn __expand_unwrap_method(self, _scope: &mut Scope) -> T::ExpandType {
262 match self {
263 Some(val) => val,
264 None => Option::None.unwrap(),
265 }
266 }
267
268 pub fn __expand_unwrap_or_method(
269 self,
270 _scope: &mut Scope,
271 default: T::ExpandType,
272 ) -> T::ExpandType {
273 match self {
274 Some(x) => x,
275 None => default,
276 }
277 }
278
279 pub fn __expand_unwrap_or_else_method<F>(self, scope: &mut Scope, f: F) -> T::ExpandType
280 where
281 F: FnOnce(&mut Scope) -> T::ExpandType,
282 {
283 match self {
284 Some(x) => x,
285 None => f(scope),
286 }
287 }
288
289 pub fn __expand_unwrap_or_default_method(self, scope: &mut Scope) -> T::ExpandType
290 where
291 T: Default + IntoRuntime,
292 {
293 match self {
294 Some(x) => x,
295 None => T::default().__expand_runtime_method(scope),
296 }
297 }
298
299 pub fn __expand_map_method<U, F>(self, scope: &mut Scope, f: F) -> OptionExpand<U>
300 where
301 U: CubeType,
302 F: FnOnce(&mut Scope, T::ExpandType) -> U::ExpandType,
303 {
304 match self {
305 Some(x) => Some(f(scope, x)),
306 None => None,
307 }
308 }
309
310 pub fn __expand_inspect_method<F>(self, scope: &mut Scope, f: F) -> Self
311 where
312 F: FnOnce(&mut Scope, T::ExpandType),
313 {
314 if let Some(x) = self.clone() {
315 f(scope, x);
316 }
317
318 self
319 }
320
321 pub fn __expand_map_or_method<U, F>(
322 self,
323 scope: &mut Scope,
324 default: U::ExpandType,
325 f: F,
326 ) -> U::ExpandType
327 where
328 F: FnOnce(&mut Scope, T::ExpandType) -> U::ExpandType,
329 U: CubeType,
330 {
331 match self {
332 Some(t) => f(scope, t),
333 None => default,
334 }
335 }
336
337 pub fn __expand_map_or_else_method<U, D, F>(
338 self,
339 scope: &mut Scope,
340 default: D,
341 f: F,
342 ) -> U::ExpandType
343 where
344 U: CubeType,
345 D: FnOnce(&mut Scope) -> U::ExpandType,
346 F: FnOnce(&mut Scope, T::ExpandType) -> U::ExpandType,
347 {
348 match self {
349 Some(t) => f(scope, t),
350 None => default(scope),
351 }
352 }
353
354 pub fn __expand_map_or_default_method<U, F>(self, scope: &mut Scope, f: F) -> U::ExpandType
355 where
356 U: CubeType + Default + Into<U::ExpandType>,
357 F: FnOnce(&mut Scope, T::ExpandType) -> U::ExpandType,
358 {
359 match self {
360 Some(t) => f(scope, t),
361 None => U::default().into(),
362 }
363 }
364
365 pub fn __expand_as_deref_method(self, scope: &mut Scope) -> OptionExpand<T::Target>
366 where
367 T: Deref<Target: CubeType + Sized>,
368 T::ExpandType: Deref<Target = <T::Target as CubeType>::ExpandType>,
369 {
370 self.__expand_map_method(scope, |_, it| (*it).clone())
371 }
372
373 pub fn __expand_as_deref_mut_method(self, scope: &mut Scope) -> OptionExpand<T::Target>
374 where
375 T: DerefMut<Target: CubeType + Sized>,
376 T::ExpandType: Deref<Target = <T::Target as CubeType>::ExpandType>,
377 {
378 self.__expand_map_method(scope, |_, it| (*it).clone())
379 }
380
381 pub fn __expand_and_method<U>(
382 self,
383 _scope: &mut Scope,
384 optb: OptionExpand<U>,
385 ) -> OptionExpand<U>
386 where
387 U: CubeType,
388 {
389 match self {
390 Some(_) => optb,
391 None => None,
392 }
393 }
394
395 pub fn __expand_and_then_method<U, F>(self, scope: &mut Scope, f: F) -> OptionExpand<U>
396 where
397 U: CubeType,
398 F: FnOnce(&mut Scope, T::ExpandType) -> OptionExpand<U>,
399 {
400 match self {
401 Some(x) => f(scope, x),
402 None => None,
403 }
404 }
405
406 pub fn __expand_filter_method<P>(self, scope: &mut Scope, predicate: P) -> Self
407 where
408 P: FnOnce(&mut Scope, T::ExpandType) -> bool,
409 {
410 if let Some(x) = self
411 && predicate(scope, x.clone())
412 {
413 Some(x)
414 } else {
415 None
416 }
417 }
418
419 pub fn __expand_or_method(
420 self,
421 _scope: &mut Scope,
422 optb: OptionExpand<T>,
423 ) -> OptionExpand<T> {
424 match self {
425 x @ Some(_) => x,
426 None => optb,
427 }
428 }
429
430 pub fn __expand_or_else_method<F>(self, scope: &mut Scope, f: F) -> OptionExpand<T>
431 where
432 F: FnOnce(&mut Scope) -> OptionExpand<T>,
433 {
434 match self {
435 x @ Some(_) => x,
436 None => f(scope),
437 }
438 }
439
440 pub fn __expand_xor_method(
441 self,
442 _scope: &mut Scope,
443 optb: OptionExpand<T>,
444 ) -> OptionExpand<T> {
445 match (self, optb) {
446 (a @ Some(_), None) => a,
447 (None, b @ Some(_)) => b,
448 _ => None,
449 }
450 }
451
452 pub fn __expand_take_method(&mut self, _scope: &mut Scope) -> OptionExpand<T> {
455 core::mem::take(self)
456 }
457
458 pub fn __expand_take_if_method<P>(
459 &mut self,
460 scope: &mut Scope,
461 predicate: P,
462 ) -> OptionExpand<T>
463 where
464 P: FnOnce(&mut Scope, T::ExpandType) -> bool,
465 {
466 match self {
467 Some(value) if predicate(scope, value.clone()) => self.__expand_take_method(scope),
468 _ => None,
469 }
470 }
471
472 pub fn __expand_replace_method(
473 &mut self,
474 _scope: &mut Scope,
475 value: T::ExpandType,
476 ) -> OptionExpand<T> {
477 core::mem::replace(self, Some(value))
478 }
479
480 pub fn __expand_zip_method<U>(
481 self,
482 _scope: &mut Scope,
483 other: OptionExpand<U>,
484 ) -> OptionExpand<(T, U)>
485 where
486 U: CubeType,
487 {
488 match (self, other) {
489 (Some(a), Some(b)) => Some((a, b)),
490 _ => None,
491 }
492 }
493
494 pub fn __expand_zip_with_method<U, F, R>(
495 self,
496 scope: &mut Scope,
497 other: OptionExpand<U>,
498 f: F,
499 ) -> OptionExpand<R>
500 where
501 F: FnOnce(&mut Scope, T::ExpandType, U::ExpandType) -> R::ExpandType,
502 R: CubeType,
503 U: CubeType,
504 {
505 match (self, other) {
506 (Some(a), Some(b)) => Some(f(scope, a, b)),
507 _ => None,
508 }
509 }
510
511 pub fn __expand_reduce_method<U, R, F>(
512 self,
513 scope: &mut Scope,
514 other: OptionExpand<U>,
515 f: F,
516 ) -> OptionExpand<R>
517 where
518 U: CubeType,
519 R: CubeType,
520 T::ExpandType: Into<R::ExpandType>,
521 U::ExpandType: Into<R::ExpandType>,
522 F: FnOnce(&mut Scope, T::ExpandType, U::ExpandType) -> R::ExpandType,
523 {
524 match (self, other) {
525 (Some(a), Some(b)) => Some(f(scope, a, b)),
526 (Some(a), _) => Some(a.into()),
527 (_, Some(b)) => Some(b.into()),
528 _ => None,
529 }
530 }
531 }
532
533 #[doc(hidden)]
534 impl<T: CubeType, U: CubeType> OptionExpand<(T, U)> {
535 pub fn __expand_unzip_method(
536 self,
537 _scope: &mut Scope,
538 ) -> (OptionExpand<T>, OptionExpand<U>) {
539 match self {
540 Some((a, b)) => (Some(a), Some(b)),
541 None => (None, None),
542 }
543 }
544 }
545}