cubecl_std/
option.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use serde::{Deserialize, Serialize};
4
5#[derive(CubeType, Clone, Copy, Serialize, Deserialize, Hash, PartialEq, Eq, Debug)]
6pub enum CubeOption<T: CubeType> {
7    Some(T),
8    None,
9}
10
11#[cube]
12impl<T: CubeType> CubeOption<T> {
13    pub fn is_some(&self) -> bool {
14        match self {
15            CubeOption::Some(_) => true,
16            CubeOption::None => false,
17        }
18    }
19
20    pub fn unwrap(self) -> T {
21        match self {
22            CubeOption::Some(val) => val,
23            CubeOption::None => panic!("Unwrap on a None CubeOption"),
24        }
25    }
26
27    pub fn is_none(&self) -> bool {
28        !self.is_some()
29    }
30
31    pub fn unwrap_or(self, fallback: T) -> T {
32        match self {
33            CubeOption::Some(val) => val,
34            CubeOption::None => fallback,
35        }
36    }
37}
38
39impl<T: CubeType> CubeOptionExpand<T> {
40    pub fn is_some(&self) -> bool {
41        match self {
42            CubeOptionExpand::Some(_) => true,
43            CubeOptionExpand::None => false,
44        }
45    }
46
47    pub fn unwrap(self) -> T::ExpandType {
48        match self {
49            Self::Some(val) => val,
50            Self::None => panic!("Unwrap on a None CubeOption"),
51        }
52    }
53
54    pub fn is_none(&self) -> bool {
55        !self.is_some()
56    }
57
58    pub fn unwrap_or(self, fallback: T::ExpandType) -> T::ExpandType {
59        match self {
60            CubeOptionExpand::Some(val) => val,
61            CubeOptionExpand::None => fallback,
62        }
63    }
64}
65
66impl<T: CubeType + Into<T::ExpandType>> From<CubeOption<T>> for CubeOptionExpand<T> {
67    fn from(value: CubeOption<T>) -> Self {
68        match value {
69            CubeOption::Some(val) => CubeOptionExpand::Some(val.into()),
70            CubeOption::None => CubeOptionExpand::None,
71        }
72    }
73}
74
75// Manually implement LaunchArg as the macro is currently not permissive enough.
76
77pub enum CubeOptionArgs<'a, T: LaunchArg, R: Runtime> {
78    Some(<T as LaunchArg>::RuntimeArg<'a, R>),
79    None,
80}
81
82impl<'a, T: LaunchArg, R: Runtime> From<Option<<T as LaunchArg>::RuntimeArg<'a, R>>>
83    for CubeOptionArgs<'a, T, R>
84{
85    fn from(value: Option<<T as LaunchArg>::RuntimeArg<'a, R>>) -> Self {
86        match value {
87            Some(arg) => Self::Some(arg),
88            None => Self::None,
89        }
90    }
91}
92
93impl<T: LaunchArg, R: Runtime> ArgSettings<R> for CubeOptionArgs<'_, T, R> {
94    fn register(&self, launcher: &mut KernelLauncher<R>) {
95        match self {
96            CubeOptionArgs::Some(arg) => {
97                arg.register(launcher);
98            }
99            CubeOptionArgs::None => {}
100        }
101    }
102}
103impl<T: LaunchArg> LaunchArg for CubeOption<T> {
104    type RuntimeArg<'a, R: Runtime> = CubeOptionArgs<'a, T, R>;
105    type CompilationArg = CubeOptionCompilationArg<T>;
106
107    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
108        match runtime_arg {
109            CubeOptionArgs::Some(arg) => {
110                CubeOptionCompilationArg::Some(T::compilation_arg::<R>(arg))
111            }
112            CubeOptionArgs::None => CubeOptionCompilationArg::None,
113        }
114    }
115
116    fn expand(
117        arg: &Self::CompilationArg,
118        builder: &mut KernelBuilder,
119    ) -> <Self as CubeType>::ExpandType {
120        match arg {
121            CubeOptionCompilationArg::Some(arg) => CubeOptionExpand::Some(T::expand(arg, builder)),
122            CubeOptionCompilationArg::None => CubeOptionExpand::None,
123        }
124    }
125
126    fn expand_output(
127        arg: &Self::CompilationArg,
128        builder: &mut KernelBuilder,
129    ) -> <Self as CubeType>::ExpandType {
130        match arg {
131            CubeOptionCompilationArg::Some(arg) => {
132                CubeOptionExpand::Some(T::expand_output(arg, builder))
133            }
134            CubeOptionCompilationArg::None => CubeOptionExpand::None,
135        }
136    }
137}
138
139pub enum CubeOptionCompilationArg<T: LaunchArg> {
140    Some(<T as LaunchArg>::CompilationArg),
141    None,
142}
143
144impl<T: LaunchArg> Clone for CubeOptionCompilationArg<T> {
145    fn clone(&self) -> Self {
146        match self {
147            CubeOptionCompilationArg::Some(arg) => CubeOptionCompilationArg::Some(arg.clone()),
148            CubeOptionCompilationArg::None => CubeOptionCompilationArg::None,
149        }
150    }
151}
152
153impl<T: LaunchArg> PartialEq for CubeOptionCompilationArg<T> {
154    fn eq(&self, other: &Self) -> bool {
155        match (self, other) {
156            (CubeOptionCompilationArg::Some(arg_0), CubeOptionCompilationArg::Some(arg_1)) => {
157                arg_0 == arg_1
158            }
159            (CubeOptionCompilationArg::None, CubeOptionCompilationArg::None) => true,
160            _ => false,
161        }
162    }
163}
164
165impl<T: LaunchArg> Eq for CubeOptionCompilationArg<T> {}
166
167impl<T: LaunchArg> core::hash::Hash for CubeOptionCompilationArg<T> {
168    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
169        match self {
170            CubeOptionCompilationArg::Some(arg) => {
171                arg.hash(state);
172            }
173            CubeOptionCompilationArg::None => {}
174        };
175    }
176}
177
178impl<T: LaunchArg> core::fmt::Debug for CubeOptionCompilationArg<T> {
179    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
180        match self {
181            CubeOptionCompilationArg::Some(arg) => f
182                .debug_tuple("CubeOptionCompilationArg :: Some")
183                .field(arg)
184                .finish(),
185            CubeOptionCompilationArg::None => write!(f, "CubeOptionCompilationArg :: None"),
186        }
187    }
188}
189
190impl<T: LaunchArg> CompilationArg for CubeOptionCompilationArg<T> {}