Skip to main content

cubecl_core/frontend/
option.rs

1use crate::{self as cubecl, ExpandType};
2use cubecl::prelude::*;
3
4#[doc(hidden)]
5#[derive(Default)]
6pub enum OptionExpand<T: CubeType> {
7    Some(<T as CubeType>::ExpandType),
8    #[default]
9    None,
10}
11
12impl<T: CubeType> CubeType for Option<T> {
13    type ExpandType = OptionExpand<T>;
14}
15
16impl<T: CubeType> IntoMut for OptionExpand<T> {
17    fn into_mut(self, scope: &mut cubecl::prelude::Scope) -> Self {
18        match self {
19            OptionExpand::Some(arg_0) => OptionExpand::Some(IntoMut::into_mut(arg_0, scope)),
20            OptionExpand::None => OptionExpand::None,
21        }
22    }
23}
24
25impl<T: CubeType> CubeDebug for Option<T> {}
26impl<T: CubeType> CubeDebug for OptionExpand<T> {}
27
28impl<T: CubeType> Clone for OptionExpand<T> {
29    fn clone(&self) -> Self {
30        match self {
31            OptionExpand::Some(arg_0) => OptionExpand::Some(arg_0.clone()),
32            OptionExpand::None => OptionExpand::None,
33        }
34    }
35}
36
37/// Extensions for [`Option`]
38#[allow(non_snake_case)]
39pub trait CubeOption<T: CubeType> {
40    /// Create a new [`Option::Some`] in a kernel
41    fn new_Some(_0: T) -> Option<T> {
42        Option::Some(_0)
43    }
44    /// Create a new [`Option::None`] in a kernel
45    fn new_None() -> Option<T> {
46        Option::None
47    }
48    #[doc(hidden)]
49    fn __expand_Some(_scope: &mut Scope, _0: ExpandType<T>) -> OptionExpand<T> {
50        OptionExpand::Some(_0)
51    }
52    #[doc(hidden)]
53    fn __expand_new_Some(_scope: &mut Scope, _0: ExpandType<T>) -> OptionExpand<T> {
54        OptionExpand::Some(_0)
55    }
56    #[doc(hidden)]
57    fn __expand_new_None(_scope: &mut Scope) -> OptionExpand<T> {
58        OptionExpand::None
59    }
60}
61
62impl<T: CubeType> CubeOption<T> for Option<T> {}
63
64impl<T: CubeType> OptionExpand<T> {
65    pub fn is_some(&self) -> bool {
66        match self {
67            OptionExpand::Some(_) => true,
68            OptionExpand::None => false,
69        }
70    }
71
72    pub fn unwrap(self) -> T::ExpandType {
73        match self {
74            Self::Some(val) => val,
75            Self::None => panic!("Unwrap on a None CubeOption"),
76        }
77    }
78
79    pub fn is_none(&self) -> bool {
80        !self.is_some()
81    }
82
83    pub fn unwrap_or(self, fallback: T::ExpandType) -> T::ExpandType {
84        match self {
85            OptionExpand::Some(val) => val,
86            OptionExpand::None => fallback,
87        }
88    }
89}
90
91pub enum OptionArgs<'a, T: LaunchArg, R: Runtime> {
92    Some(<T as LaunchArg>::RuntimeArg<'a, R>),
93    None,
94}
95
96impl<'a, T: LaunchArg, R: Runtime> From<Option<<T as LaunchArg>::RuntimeArg<'a, R>>>
97    for OptionArgs<'a, T, R>
98{
99    fn from(value: Option<<T as LaunchArg>::RuntimeArg<'a, R>>) -> Self {
100        match value {
101            Some(arg) => Self::Some(arg),
102            None => Self::None,
103        }
104    }
105}
106
107impl<T: LaunchArg, R: Runtime> ArgSettings<R> for OptionArgs<'_, T, R> {
108    fn register(&self, launcher: &mut KernelLauncher<R>) {
109        match self {
110            OptionArgs::Some(arg) => {
111                arg.register(launcher);
112            }
113            OptionArgs::None => {}
114        }
115    }
116}
117impl<T: LaunchArg> LaunchArg for Option<T> {
118    type RuntimeArg<'a, R: Runtime> = OptionArgs<'a, T, R>;
119    type CompilationArg = OptionCompilationArg<T>;
120
121    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
122        match runtime_arg {
123            OptionArgs::Some(arg) => OptionCompilationArg::Some(T::compilation_arg(arg)),
124            OptionArgs::None => OptionCompilationArg::None,
125        }
126    }
127
128    fn expand(
129        arg: &Self::CompilationArg,
130        builder: &mut KernelBuilder,
131    ) -> <Self as CubeType>::ExpandType {
132        match arg {
133            OptionCompilationArg::Some(arg) => OptionExpand::Some(T::expand(arg, builder)),
134            OptionCompilationArg::None => OptionExpand::None,
135        }
136    }
137
138    fn expand_output(
139        arg: &Self::CompilationArg,
140        builder: &mut KernelBuilder,
141    ) -> <Self as CubeType>::ExpandType {
142        match arg {
143            OptionCompilationArg::Some(arg) => OptionExpand::Some(T::expand_output(arg, builder)),
144            OptionCompilationArg::None => OptionExpand::None,
145        }
146    }
147}
148
149pub enum OptionCompilationArg<T: LaunchArg> {
150    Some(<T as LaunchArg>::CompilationArg),
151    None,
152}
153
154impl<T: LaunchArg> Clone for OptionCompilationArg<T> {
155    fn clone(&self) -> Self {
156        match self {
157            OptionCompilationArg::Some(arg) => OptionCompilationArg::Some(arg.clone()),
158            OptionCompilationArg::None => OptionCompilationArg::None,
159        }
160    }
161}
162
163impl<T: LaunchArg> PartialEq for OptionCompilationArg<T> {
164    fn eq(&self, other: &Self) -> bool {
165        match (self, other) {
166            (OptionCompilationArg::Some(arg_0), OptionCompilationArg::Some(arg_1)) => {
167                arg_0 == arg_1
168            }
169            (OptionCompilationArg::None, OptionCompilationArg::None) => true,
170            _ => false,
171        }
172    }
173}
174
175impl<T: LaunchArg> Eq for OptionCompilationArg<T> {}
176
177impl<T: LaunchArg> core::hash::Hash for OptionCompilationArg<T> {
178    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
179        match self {
180            OptionCompilationArg::Some(arg) => {
181                arg.hash(state);
182            }
183            OptionCompilationArg::None => {}
184        };
185    }
186}
187
188impl<T: LaunchArg> core::fmt::Debug for OptionCompilationArg<T> {
189    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
190        match self {
191            OptionCompilationArg::Some(arg) => f.debug_tuple("Some").field(arg).finish(),
192            OptionCompilationArg::None => write!(f, "None"),
193        }
194    }
195}
196
197impl<T: LaunchArg> CompilationArg for OptionCompilationArg<T> {}
198
199mod impls {
200    use core::ops::{Deref, DerefMut};
201
202    use super::*;
203    use OptionExpand::{None, Some};
204
205    #[doc(hidden)]
206    impl<T: CubeType> OptionExpand<T> {
207        pub fn __expand_is_some_method(&self, _scope: &mut Scope) -> bool {
208            matches!(*self, Some(_))
209        }
210
211        pub fn __expand_is_some_and_method(
212            self,
213            scope: &mut Scope,
214            f: impl FnOnce(&mut Scope, T::ExpandType) -> bool,
215        ) -> bool {
216            match self {
217                None => false,
218                Some(x) => f(scope, x),
219            }
220        }
221
222        pub fn __expand_is_none_method(&self, scope: &mut Scope) -> bool {
223            !self.__expand_is_some_method(scope)
224        }
225
226        pub fn __expand_is_none_or_method(
227            self,
228            scope: &mut Scope,
229            f: impl FnOnce(&mut Scope, T::ExpandType) -> bool,
230        ) -> bool {
231            match self {
232                None => true,
233                Some(x) => f(scope, x),
234            }
235        }
236
237        pub fn __expand_as_ref_method(self, _scope: &mut Scope) -> Self {
238            self
239        }
240
241        pub fn __expand_as_mut_method(self, _scope: &mut Scope) -> Self {
242            self
243        }
244
245        fn __expand_len_method(&self, _scope: &mut Scope) -> usize {
246            match self {
247                Some(_) => 1,
248                None => 0,
249            }
250        }
251
252        #[allow(clippy::unnecessary_literal_unwrap)]
253        pub fn __expand_expect_method(self, _scope: &mut Scope, msg: &str) -> T::ExpandType {
254            match self {
255                Some(val) => val,
256                None => Option::None.expect(msg),
257            }
258        }
259
260        #[allow(clippy::unnecessary_literal_unwrap)]
261        pub fn __expand_unwrap_method(self, _scope: &mut Scope) -> T::ExpandType {
262            match self {
263                Some(val) => val,
264                None => Option::None.unwrap(),
265            }
266        }
267
268        pub fn __expand_unwrap_or_method(
269            self,
270            _scope: &mut Scope,
271            default: T::ExpandType,
272        ) -> T::ExpandType {
273            match self {
274                Some(x) => x,
275                None => default,
276            }
277        }
278
279        pub fn __expand_unwrap_or_else_method<F>(self, scope: &mut Scope, f: F) -> T::ExpandType
280        where
281            F: FnOnce(&mut Scope) -> T::ExpandType,
282        {
283            match self {
284                Some(x) => x,
285                None => f(scope),
286            }
287        }
288
289        pub fn __expand_unwrap_or_default_method(self, scope: &mut Scope) -> T::ExpandType
290        where
291            T: Default + IntoRuntime,
292        {
293            match self {
294                Some(x) => x,
295                None => T::default().__expand_runtime_method(scope),
296            }
297        }
298
299        pub fn __expand_map_method<U, F>(self, scope: &mut Scope, f: F) -> OptionExpand<U>
300        where
301            U: CubeType,
302            F: FnOnce(&mut Scope, T::ExpandType) -> U::ExpandType,
303        {
304            match self {
305                Some(x) => Some(f(scope, x)),
306                None => None,
307            }
308        }
309
310        pub fn __expand_inspect_method<F>(self, scope: &mut Scope, f: F) -> Self
311        where
312            F: FnOnce(&mut Scope, T::ExpandType),
313        {
314            if let Some(x) = self.clone() {
315                f(scope, x);
316            }
317
318            self
319        }
320
321        pub fn __expand_map_or_method<U, F>(
322            self,
323            scope: &mut Scope,
324            default: U::ExpandType,
325            f: F,
326        ) -> U::ExpandType
327        where
328            F: FnOnce(&mut Scope, T::ExpandType) -> U::ExpandType,
329            U: CubeType,
330        {
331            match self {
332                Some(t) => f(scope, t),
333                None => default,
334            }
335        }
336
337        pub fn __expand_map_or_else_method<U, D, F>(
338            self,
339            scope: &mut Scope,
340            default: D,
341            f: F,
342        ) -> U::ExpandType
343        where
344            U: CubeType,
345            D: FnOnce(&mut Scope) -> U::ExpandType,
346            F: FnOnce(&mut Scope, T::ExpandType) -> U::ExpandType,
347        {
348            match self {
349                Some(t) => f(scope, t),
350                None => default(scope),
351            }
352        }
353
354        pub fn __expand_map_or_default_method<U, F>(self, scope: &mut Scope, f: F) -> U::ExpandType
355        where
356            U: CubeType + Default + Into<U::ExpandType>,
357            F: FnOnce(&mut Scope, T::ExpandType) -> U::ExpandType,
358        {
359            match self {
360                Some(t) => f(scope, t),
361                None => U::default().into(),
362            }
363        }
364
365        pub fn __expand_as_deref_method(self, scope: &mut Scope) -> OptionExpand<T::Target>
366        where
367            T: Deref<Target: CubeType + Sized>,
368            T::ExpandType: Deref<Target = <T::Target as CubeType>::ExpandType>,
369        {
370            self.__expand_map_method(scope, |_, it| (*it).clone())
371        }
372
373        pub fn __expand_as_deref_mut_method(self, scope: &mut Scope) -> OptionExpand<T::Target>
374        where
375            T: DerefMut<Target: CubeType + Sized>,
376            T::ExpandType: Deref<Target = <T::Target as CubeType>::ExpandType>,
377        {
378            self.__expand_map_method(scope, |_, it| (*it).clone())
379        }
380
381        pub fn __expand_and_method<U>(
382            self,
383            _scope: &mut Scope,
384            optb: OptionExpand<U>,
385        ) -> OptionExpand<U>
386        where
387            U: CubeType,
388        {
389            match self {
390                Some(_) => optb,
391                None => None,
392            }
393        }
394
395        pub fn __expand_and_then_method<U, F>(self, scope: &mut Scope, f: F) -> OptionExpand<U>
396        where
397            U: CubeType,
398            F: FnOnce(&mut Scope, T::ExpandType) -> OptionExpand<U>,
399        {
400            match self {
401                Some(x) => f(scope, x),
402                None => None,
403            }
404        }
405
406        pub fn __expand_filter_method<P>(self, scope: &mut Scope, predicate: P) -> Self
407        where
408            P: FnOnce(&mut Scope, T::ExpandType) -> bool,
409        {
410            if let Some(x) = self
411                && predicate(scope, x.clone())
412            {
413                Some(x)
414            } else {
415                None
416            }
417        }
418
419        pub fn __expand_or_method(
420            self,
421            _scope: &mut Scope,
422            optb: OptionExpand<T>,
423        ) -> OptionExpand<T> {
424            match self {
425                x @ Some(_) => x,
426                None => optb,
427            }
428        }
429
430        pub fn __expand_or_else_method<F>(self, scope: &mut Scope, f: F) -> OptionExpand<T>
431        where
432            F: FnOnce(&mut Scope) -> OptionExpand<T>,
433        {
434            match self {
435                x @ Some(_) => x,
436                None => f(scope),
437            }
438        }
439
440        pub fn __expand_xor_method(
441            self,
442            _scope: &mut Scope,
443            optb: OptionExpand<T>,
444        ) -> OptionExpand<T> {
445            match (self, optb) {
446                (a @ Some(_), None) => a,
447                (None, b @ Some(_)) => b,
448                _ => None,
449            }
450        }
451
452        // Entry methods that return &mut T excluded for now
453
454        pub fn __expand_take_method(&mut self, _scope: &mut Scope) -> OptionExpand<T> {
455            core::mem::take(self)
456        }
457
458        pub fn __expand_take_if_method<P>(
459            &mut self,
460            scope: &mut Scope,
461            predicate: P,
462        ) -> OptionExpand<T>
463        where
464            P: FnOnce(&mut Scope, T::ExpandType) -> bool,
465        {
466            match self {
467                Some(value) if predicate(scope, value.clone()) => self.__expand_take_method(scope),
468                _ => None,
469            }
470        }
471
472        pub fn __expand_replace_method(
473            &mut self,
474            _scope: &mut Scope,
475            value: T::ExpandType,
476        ) -> OptionExpand<T> {
477            core::mem::replace(self, Some(value))
478        }
479
480        pub fn __expand_zip_method<U>(
481            self,
482            _scope: &mut Scope,
483            other: OptionExpand<U>,
484        ) -> OptionExpand<(T, U)>
485        where
486            U: CubeType,
487        {
488            match (self, other) {
489                (Some(a), Some(b)) => Some((a, b)),
490                _ => None,
491            }
492        }
493
494        pub fn __expand_zip_with_method<U, F, R>(
495            self,
496            scope: &mut Scope,
497            other: OptionExpand<U>,
498            f: F,
499        ) -> OptionExpand<R>
500        where
501            F: FnOnce(&mut Scope, T::ExpandType, U::ExpandType) -> R::ExpandType,
502            R: CubeType,
503            U: CubeType,
504        {
505            match (self, other) {
506                (Some(a), Some(b)) => Some(f(scope, a, b)),
507                _ => None,
508            }
509        }
510
511        pub fn __expand_reduce_method<U, R, F>(
512            self,
513            scope: &mut Scope,
514            other: OptionExpand<U>,
515            f: F,
516        ) -> OptionExpand<R>
517        where
518            U: CubeType,
519            R: CubeType,
520            T::ExpandType: Into<R::ExpandType>,
521            U::ExpandType: Into<R::ExpandType>,
522            F: FnOnce(&mut Scope, T::ExpandType, U::ExpandType) -> R::ExpandType,
523        {
524            match (self, other) {
525                (Some(a), Some(b)) => Some(f(scope, a, b)),
526                (Some(a), _) => Some(a.into()),
527                (_, Some(b)) => Some(b.into()),
528                _ => None,
529            }
530        }
531    }
532
533    #[doc(hidden)]
534    impl<T: CubeType, U: CubeType> OptionExpand<(T, U)> {
535        pub fn __expand_unzip_method(
536            self,
537            _scope: &mut Scope,
538        ) -> (OptionExpand<T>, OptionExpand<U>) {
539            match self {
540                Some((a, b)) => (Some(a), Some(b)),
541                None => (None, None),
542            }
543        }
544    }
545}