1use cubecl_macros::derive_expand;
2
3use crate as cubecl;
4use crate::prelude::*;
5
6#[derive_expand(CubeType, CubeTypeMut, IntoRuntime)]
7#[cube(runtime_variants, no_constructors)]
8pub enum Option<T: CubeType> {
9 None,
11 Some(T),
13}
14
15fn discriminant(variant_name: &'static str) -> i32 {
16 OptionExpand::<u32>::discriminant_of(variant_name)
17}
18
19pub enum OptionArgs<T: LaunchArg, R: Runtime> {
20 Some(<T as LaunchArg>::RuntimeArg<R>),
21 None,
22}
23
24impl<T: LaunchArg, R: Runtime> From<Option<<T as LaunchArg>::RuntimeArg<R>>> for OptionArgs<T, R> {
25 fn from(value: Option<<T as LaunchArg>::RuntimeArg<R>>) -> Self {
26 match value {
27 Some(arg) => Self::Some(arg),
28 None => Self::None,
29 }
30 }
31}
32
33impl<T: LaunchArg + Default + IntoRuntime> LaunchArg for Option<T> {
34 type RuntimeArg<R: Runtime> = OptionArgs<T, R>;
35 type CompilationArg = OptionCompilationArg<T>;
36
37 fn register<R: Runtime>(
38 arg: Self::RuntimeArg<R>,
39 launcher: &mut KernelLauncher<R>,
40 ) -> Self::CompilationArg {
41 match arg {
42 OptionArgs::Some(arg) => OptionCompilationArg::Some(T::register(arg, launcher)),
43 OptionArgs::None => OptionCompilationArg::None,
44 }
45 }
46
47 fn expand(
48 arg: &Self::CompilationArg,
49 builder: &mut KernelBuilder,
50 ) -> <Self as CubeType>::ExpandType {
51 match arg {
52 OptionCompilationArg::Some(value) => {
53 let value = T::expand(value, builder);
54 OptionExpand {
55 discriminant: discriminant("Some").into(),
56 value,
57 }
58 }
59 OptionCompilationArg::None => OptionExpand {
60 discriminant: discriminant("None").into(),
61 value: T::default().__expand_runtime_method(&mut builder.scope),
62 },
63 }
64 }
65
66 fn expand_output(
67 arg: &Self::CompilationArg,
68 builder: &mut KernelBuilder,
69 ) -> <Self as CubeType>::ExpandType {
70 match arg {
71 OptionCompilationArg::Some(value) => {
72 let value = T::expand_output(value, builder);
73 OptionExpand {
74 discriminant: discriminant("Some").into(),
75 value,
76 }
77 }
78 OptionCompilationArg::None => OptionExpand {
79 discriminant: discriminant("None").into(),
80 value: T::default().__expand_runtime_method(&mut builder.scope),
81 },
82 }
83 }
84}
85
86pub enum OptionCompilationArg<T: LaunchArg> {
87 Some(T::CompilationArg),
88 None,
89}
90
91impl<T: LaunchArg> Clone for OptionCompilationArg<T> {
92 fn clone(&self) -> Self {
93 match self {
94 OptionCompilationArg::Some(value) => OptionCompilationArg::Some(value.clone()),
95 OptionCompilationArg::None => OptionCompilationArg::None,
96 }
97 }
98}
99
100impl<T: LaunchArg> PartialEq for OptionCompilationArg<T> {
101 fn eq(&self, other: &Self) -> bool {
102 match (self, other) {
103 (Self::Some(l0), Self::Some(r0)) => l0 == r0,
104 _ => core::mem::discriminant(self) == core::mem::discriminant(other),
105 }
106 }
107}
108
109impl<T: LaunchArg> Eq for OptionCompilationArg<T> {}
110
111impl<T: LaunchArg> core::hash::Hash for OptionCompilationArg<T> {
112 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
113 core::mem::discriminant(self).hash(state);
114 match self {
115 OptionCompilationArg::Some(value) => value.hash(state),
116 OptionCompilationArg::None => {}
117 }
118 }
119}
120
121impl<T: LaunchArg> core::fmt::Debug for OptionCompilationArg<T> {
122 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
123 match self {
124 Self::Some(arg0) => f.debug_tuple("Some").field(arg0).finish(),
125 Self::None => write!(f, "None"),
126 }
127 }
128}
129
130#[allow(non_snake_case)]
132pub trait CubeOption<T: CubeType> {
133 fn new_Some(_0: T) -> Option<T> {
135 Option::Some(_0)
136 }
137 fn none_with_default(_0: T) -> Option<T> {
138 Option::None
139 }
140
141 #[doc(hidden)]
142 fn __expand_Some(scope: &mut Scope, value: T::ExpandType) -> OptionExpand<T> {
143 Self::__expand_new_Some(scope, value)
144 }
145 #[doc(hidden)]
146 fn __expand_new_Some(_scope: &mut Scope, value: T::ExpandType) -> OptionExpand<T> {
147 OptionExpand::<T> {
148 discriminant: discriminant("Some").into(),
149 value,
150 }
151 }
152 fn __expand_none_with_default(_scope: &mut Scope, value: T::ExpandType) -> OptionExpand<T> {
153 OptionExpand {
154 discriminant: discriminant("None").into(),
155 value,
156 }
157 }
158}
159
160#[allow(non_snake_case)]
162pub trait CubeOptionDefault<T: CubeType + Default + IntoRuntime>: CubeOption<T> {
163 fn new_None() -> Option<T> {
165 Option::None
166 }
167
168 #[doc(hidden)]
169 fn __expand_new_None(scope: &mut Scope) -> OptionExpand<T> {
170 let value = T::default().__expand_runtime_method(scope);
171 Self::__expand_none_with_default(scope, value)
172 }
173}
174
175impl<T: CubeType> CubeOption<T> for Option<T> {}
176impl<T: CubeType + Default + IntoRuntime> CubeOptionDefault<T> for Option<T> {}
177
178mod impls {
179 use core::ops::{Deref, DerefMut};
180
181 use super::*;
182 use crate as cubecl;
183
184 #[doc(hidden)]
189 impl<T: CubeType> OptionExpand<T> {
190 pub fn __expand_is_some_and_method(
191 self,
192 scope: &mut Scope,
193 f: impl FnOnce(&mut Scope, T::ExpandType) -> NativeExpand<bool>,
194 ) -> NativeExpand<bool> {
195 match_expand_expr(scope, self, discriminant("None"), |_, _| false.into())
196 .case(scope, discriminant("Some"), |scope, value| f(scope, value))
197 .finish(scope)
198 }
199
200 pub fn __expand_is_none_or_method(
201 self,
202 scope: &mut Scope,
203 f: impl FnOnce(&mut Scope, T::ExpandType) -> NativeExpand<bool>,
204 ) -> NativeExpand<bool> {
205 match_expand_expr(scope, self, discriminant("None"), |_, _| true.into())
206 .case(scope, discriminant("Some"), |scope, value| f(scope, value))
207 .finish(scope)
208 }
209
210 pub fn __expand_as_ref_method(&self, _scope: &mut Scope) -> OptionExpand<T> {
211 self.clone()
212 }
213
214 pub fn __expand_as_mut_method(&mut self, _scope: &mut Scope) -> OptionExpand<T> {
215 self.clone()
216 }
217
218 pub fn __expand_expect_method(self, scope: &mut Scope, msg: &str) -> T::ExpandType
219 where
220 T::ExpandType: Assign,
221 {
222 match_expand_expr(scope, self, discriminant("Some"), |_, value| value)
224 .case(scope, discriminant("None"), |scope, value| {
225 printf_expand(scope, msg, alloc::vec![]);
226 terminate!();
227 value
228 })
229 .finish(scope)
230 }
231
232 pub fn __expand_unwrap_or_else_method<F>(self, scope: &mut Scope, f: F) -> T::ExpandType
233 where
234 F: FnOnce(&mut Scope) -> T::ExpandType,
235 T::ExpandType: Assign,
236 {
237 match_expand_expr(scope, self, discriminant("Some"), |_, value| value)
238 .case(scope, discriminant("None"), |scope, _| f(scope))
239 .finish(scope)
240 }
241
242 pub fn __expand_map_method<U, F>(self, scope: &mut Scope, f: F) -> OptionExpand<U>
243 where
244 F: FnOnce(&mut Scope, T::ExpandType) -> U::ExpandType,
245 U: CubeType + IntoRuntime + Default,
246 OptionExpand<U>: Assign,
247 {
248 match_expand_expr(scope, self, discriminant("Some"), |scope, value| {
249 let value = f(scope, value);
250 Option::__expand_new_Some(scope, value)
251 })
252 .case(scope, discriminant("None"), |scope, _| {
253 Option::__expand_new_None(scope)
254 })
255 .finish(scope)
256 }
257
258 pub fn __expand_inspect_method<F>(self, scope: &mut Scope, f: F) -> Self
259 where
260 F: FnOnce(&mut Scope, &T::ExpandType),
261 {
262 match_expand(scope, self.clone(), discriminant("Some"), |scope, value| {
263 f(scope, &value)
264 })
265 .case(scope, discriminant("None"), |_, _| {})
266 .finish(scope);
267 self
268 }
269
270 pub fn __expand_map_or_method<U, F>(
271 self,
272 scope: &mut Scope,
273 default: U::ExpandType,
274 f: F,
275 ) -> U::ExpandType
276 where
277 F: FnOnce(&mut Scope, T::ExpandType) -> U::ExpandType,
278 U: CubeType + Default + IntoRuntime,
279 U::ExpandType: Assign,
280 {
281 match_expand_expr(scope, self, discriminant("Some"), f)
282 .case(scope, discriminant("None"), |_, _| default)
283 .finish(scope)
284 }
285
286 pub fn __expand_map_or_else_method<U, D, F>(
287 self,
288 scope: &mut Scope,
289 default: D,
290 f: F,
291 ) -> U::ExpandType
292 where
293 D: FnOnce(&mut Scope) -> U::ExpandType,
294 F: FnOnce(&mut Scope, T::ExpandType) -> U::ExpandType,
295 U: CubeType + Default + IntoRuntime,
296 U::ExpandType: Assign,
297 {
298 match_expand_expr(scope, self, discriminant("Some"), f)
299 .case(scope, discriminant("None"), |scope, _| default(scope))
300 .finish(scope)
301 }
302
303 pub fn __expand_map_or_default_method<U, F>(self, scope: &mut Scope, f: F) -> U::ExpandType
304 where
305 U: CubeType + IntoRuntime + Default,
306 F: FnOnce(&mut Scope, T::ExpandType) -> U::ExpandType,
307 U::ExpandType: Assign,
308 {
309 match_expand_expr(scope, self, discriminant("Some"), f)
310 .case(scope, discriminant("None"), |scope, _| {
311 U::default().__expand_runtime_method(scope)
312 })
313 .finish(scope)
314 }
315
316 pub fn __expand_as_deref_method(self, scope: &mut Scope) -> OptionExpand<T::Target>
317 where
318 T: Deref<Target: CubeType + Default + IntoRuntime>,
319 T::ExpandType: Deref<Target = <T::Target as CubeType>::ExpandType>,
320 <T::Target as CubeType>::ExpandType: Assign,
321 {
322 self.__expand_map_method(scope, |_, value| (*value).clone())
323 }
324
325 pub fn __expand_as_deref_mut_method(self, scope: &mut Scope) -> OptionExpand<T::Target>
326 where
327 T: DerefMut<Target: CubeType + Default + IntoRuntime>,
328 T::ExpandType: Deref<Target = <T::Target as CubeType>::ExpandType>,
329 <T::Target as CubeType>::ExpandType: Assign,
330 {
331 self.__expand_map_method(scope, |_, value| (*value).clone())
332 }
333
334 pub fn __expand_and_then_method<U, F>(self, scope: &mut Scope, f: F) -> OptionExpand<U>
335 where
336 F: FnOnce(&mut Scope, T::ExpandType) -> OptionExpand<U>,
337 U: CubeType + IntoRuntime + Default,
338 U::ExpandType: Assign,
339 {
340 match_expand_expr(scope, self, discriminant("Some"), f)
341 .case(scope, discriminant("None"), |scope, _| {
342 Option::__expand_new_None(scope)
343 })
344 .finish(scope)
345 }
346
347 pub fn __expand_filter_method<P>(self, scope: &mut Scope, predicate: P) -> Self
348 where
349 P: FnOnce(&mut Scope, T::ExpandType) -> NativeExpand<bool>,
350 T: Default + IntoRuntime,
351 Self: Assign,
352 {
353 match_expand_expr(scope, self, discriminant("Some"), |scope, value| {
354 let cond = predicate(scope, value.clone());
355 if_else_expr_expand(scope, cond, |scope| Option::__expand_new_Some(scope, value))
356 .or_else(scope, |scope| Option::__expand_new_None(scope))
357 })
358 .case(scope, discriminant("None"), |scope, _| {
359 Option::__expand_new_None(scope)
360 })
361 .finish(scope)
362 }
363
364 pub fn __expand_or_else_method<F>(self, scope: &mut Scope, f: F) -> OptionExpand<T>
365 where
366 F: FnOnce(&mut Scope) -> OptionExpand<T>,
367 OptionExpand<T>: Assign,
368 {
369 let is_some = self.clone().__expand_is_some_method(scope);
370 if_else_expr_expand(scope, is_some, |_| self).or_else(scope, |scope| f(scope))
371 }
372
373 pub fn __expand_zip_with_method<U, F, R>(
374 self,
375 scope: &mut Scope,
376 other: OptionExpand<U>,
377 f: F,
378 ) -> OptionExpand<R>
379 where
380 F: FnOnce(&mut Scope, T::ExpandType, U::ExpandType) -> R::ExpandType,
381 U: CubeType,
382 R: CubeType + IntoRuntime + Default,
383 OptionExpand<R>: Assign,
384 {
385 match_expand_expr(scope, self, discriminant("Some"), |scope, value| {
386 match_expand_expr(scope, other, discriminant("Some"), |scope, other| {
387 let value = f(scope, value, other);
388 Option::__expand_new_Some(scope, value)
389 })
390 .case(scope, discriminant("None"), |scope, _| {
391 Option::__expand_new_None(scope)
392 })
393 .finish(scope)
394 })
395 .case(scope, discriminant("None"), |scope, _| {
396 Option::__expand_new_None(scope)
397 })
398 .finish(scope)
399 }
400
401 pub fn __expand_reduce_method<U, R, F>(
402 self,
403 scope: &mut Scope,
404 other: OptionExpand<U>,
405 f: F,
406 ) -> OptionExpand<R>
407 where
408 T::ExpandType: Into<R::ExpandType>,
409 U::ExpandType: Into<R::ExpandType>,
410 F: FnOnce(&mut Scope, T::ExpandType, U::ExpandType) -> R::ExpandType,
411 U: CubeType + IntoRuntime + Default,
412 R: CubeType + IntoRuntime + Default,
413 OptionExpand<R>: Assign,
414 {
415 match_expand_expr(scope, self, discriminant("Some"), {
416 let other = other.clone();
417 |scope, value| {
418 match_expand_expr(scope, other, discriminant("Some"), {
419 let value = value.clone();
420 |scope, other| {
421 let value = f(scope, value, other);
422 Option::__expand_new_Some(scope, value)
423 }
424 })
425 .case(scope, discriminant("None"), |scope, _| {
426 Option::__expand_new_Some(scope, value.into())
427 })
428 .finish(scope)
429 }
430 })
431 .case(scope, discriminant("None"), |scope, _| {
432 match_expand_expr(scope, other, discriminant("Some"), |scope, other| {
433 Option::__expand_new_Some(scope, other.into())
434 })
435 .case(scope, discriminant("None"), |scope, _| {
436 Option::__expand_new_None(scope)
437 })
438 .finish(scope)
439 })
440 .finish(scope)
441 }
442
443 #[allow(clippy::missing_safety_doc)]
444 pub unsafe fn __expand_unwrap_unchecked_method(self, scope: &mut Scope) -> T::ExpandType
445 where
446 T::ExpandType: Assign,
447 {
448 match_expand_expr(scope, self, discriminant("Some"), |_, value| value).finish(scope)
449 }
450 }
451
452 #[cube(expand_only)]
453 impl<T: CubeType> Option<T> {
454 pub fn is_some(&self) -> bool {
470 match self {
471 Option::Some(_) => true.runtime(),
472 Option::None => false.runtime(),
473 }
474 }
475
476 #[must_use = "if you intended to assert that this doesn't have a value, consider \
488 wrapping this in an `assert!()` instead"]
489 pub fn is_none(&self) -> bool {
490 !self.is_some()
491 }
492
493 pub fn unwrap(self) -> T
530 where
531 T::ExpandType: Assign,
532 {
533 self.expect("called `Option::unwrap()` on a `None` value")
534 }
535
536 pub fn unwrap_or(self, default: T) -> T
551 where
552 T::ExpandType: Assign,
553 {
554 match self {
555 Some(x) => x,
556 None => default,
557 }
558 }
559
560 pub fn unwrap_or_default(self) -> T
580 where
581 T: Default + IntoRuntime,
582 T::ExpandType: Assign,
583 {
584 match self {
585 Some(x) => x,
586 None => comptime![T::default()].runtime(),
587 }
588 }
589
590 pub fn and<U>(self, optb: Option<U>) -> Option<U>
626 where
627 U: CubeType + IntoRuntime + Default,
628 U::ExpandType: Assign,
629 {
630 match self {
631 Option::Some(_) => optb,
632 Option::None => Option::new_None(),
633 }
634 }
635
636 pub fn or(self, optb: Option<T>) -> Option<T>
664 where
665 T::ExpandType: Assign,
666 {
667 if self.is_some() { self } else { optb }
668 }
669
670 pub fn xor(self, optb: Option<T>) -> Option<T>
692 where
693 T: Default + IntoRuntime,
694 T::ExpandType: Assign,
695 {
696 if self.is_some() && optb.is_none() {
697 self
698 } else if self.is_none() && optb.is_some() {
699 optb
700 } else {
701 Option::new_None()
702 }
703 }
704
705 pub fn zip<U>(self, other: Option<U>) -> Option<(T, U)>
727 where
728 U: CubeType,
729 (T, U): Default + IntoRuntime,
730 (T::ExpandType, U::ExpandType): Into<<(T, U) as CubeType>::ExpandType>,
731 OptionExpand<(T, U)>: Assign,
732 {
733 match self {
734 Some(a) => match other {
735 Some(b) => Option::Some((a, b)),
736 None => Option::new_None(),
737 },
738 None => Option::new_None(),
739 }
740 }
741 }
742
743 #[cube(expand_only)]
744 impl<
745 T: CubeType<ExpandType: Assign> + IntoRuntime + Default,
746 U: CubeType<ExpandType: Assign> + IntoRuntime + Default,
747 > Option<(T, U)>
748 {
749 #[inline]
764 pub fn unzip(self) -> (Option<T>, Option<U>) {
765 match self {
766 Option::Some(value) => (Option::Some(value.0), Option::Some(value.1)),
767 Option::None => (Option::new_None(), Option::new_None()),
768 }
769 }
770 }
771}