cubecl_std/
option.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use serde::{Deserialize, Serialize};
4
5#[derive(CubeType, Clone, Copy, Serialize, Deserialize, Hash, PartialEq, Eq, Debug)]
6pub enum CubeOption<T: CubeType> {
7    Some(T),
8    None,
9}
10
11#[cube]
12impl<T: CubeType> CubeOption<T> {
13    pub fn is_some(&self) -> bool {
14        match self {
15            CubeOption::Some(_) => true,
16            CubeOption::None => false,
17        }
18        .runtime()
19    }
20
21    pub fn unwrap(self) -> T {
22        match self {
23            CubeOption::Some(val) => val,
24            CubeOption::None => panic!("Unwrap on a None CubeOption"),
25        }
26    }
27
28    pub fn is_none(&self) -> bool {
29        !self.is_some()
30    }
31
32    pub fn unwrap_or(self, fallback: T) -> T {
33        match self {
34            CubeOption::Some(val) => val,
35            CubeOption::None => fallback,
36        }
37    }
38}
39
40impl<T: CubeType> CubeOptionExpand<T> {
41    pub fn is_some(&self) -> bool {
42        match self {
43            CubeOptionExpand::Some(_) => true,
44            CubeOptionExpand::None => false,
45        }
46    }
47
48    pub fn unwrap(self) -> T::ExpandType {
49        match self {
50            Self::Some(val) => val,
51            Self::None => panic!("Unwrap on a None CubeOption"),
52        }
53    }
54
55    pub fn is_none(&self) -> bool {
56        !self.is_some()
57    }
58
59    pub fn unwrap_or(self, fallback: T::ExpandType) -> T::ExpandType {
60        match self {
61            CubeOptionExpand::Some(val) => val,
62            CubeOptionExpand::None => fallback,
63        }
64    }
65}
66
67impl<T: CubeType + Into<T::ExpandType>> From<CubeOption<T>> for CubeOptionExpand<T> {
68    fn from(value: CubeOption<T>) -> Self {
69        match value {
70            CubeOption::Some(val) => CubeOptionExpand::Some(val.into()),
71            CubeOption::None => CubeOptionExpand::None,
72        }
73    }
74}
75
76// Manually implement LaunchArg as the macro is currently not permissive enough.
77
78pub enum CubeOptionArgs<'a, T: LaunchArg, R: Runtime> {
79    Some(<T as LaunchArg>::RuntimeArg<'a, R>),
80    None,
81}
82
83impl<'a, T: LaunchArg, R: Runtime> From<Option<<T as LaunchArg>::RuntimeArg<'a, R>>>
84    for CubeOptionArgs<'a, T, R>
85{
86    fn from(value: Option<<T as LaunchArg>::RuntimeArg<'a, R>>) -> Self {
87        match value {
88            Some(arg) => Self::Some(arg),
89            None => Self::None,
90        }
91    }
92}
93
94impl<T: LaunchArg, R: Runtime> ArgSettings<R> for CubeOptionArgs<'_, T, R> {
95    fn register(&self, launcher: &mut KernelLauncher<R>) {
96        match self {
97            CubeOptionArgs::Some(arg) => {
98                arg.register(launcher);
99            }
100            CubeOptionArgs::None => {}
101        }
102    }
103}
104impl<T: LaunchArg> LaunchArg for CubeOption<T> {
105    type RuntimeArg<'a, R: Runtime> = CubeOptionArgs<'a, T, R>;
106    type CompilationArg = CubeOptionCompilationArg<T>;
107
108    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
109        match runtime_arg {
110            CubeOptionArgs::Some(arg) => CubeOptionCompilationArg::Some(T::compilation_arg(arg)),
111            CubeOptionArgs::None => CubeOptionCompilationArg::None,
112        }
113    }
114
115    fn expand(
116        arg: &Self::CompilationArg,
117        builder: &mut KernelBuilder,
118    ) -> <Self as CubeType>::ExpandType {
119        match arg {
120            CubeOptionCompilationArg::Some(arg) => CubeOptionExpand::Some(T::expand(arg, builder)),
121            CubeOptionCompilationArg::None => CubeOptionExpand::None,
122        }
123    }
124
125    fn expand_output(
126        arg: &Self::CompilationArg,
127        builder: &mut KernelBuilder,
128    ) -> <Self as CubeType>::ExpandType {
129        match arg {
130            CubeOptionCompilationArg::Some(arg) => {
131                CubeOptionExpand::Some(T::expand_output(arg, builder))
132            }
133            CubeOptionCompilationArg::None => CubeOptionExpand::None,
134        }
135    }
136}
137
138pub enum CubeOptionCompilationArg<T: LaunchArg> {
139    Some(<T as LaunchArg>::CompilationArg),
140    None,
141}
142
143impl<T: LaunchArg> Clone for CubeOptionCompilationArg<T> {
144    fn clone(&self) -> Self {
145        match self {
146            CubeOptionCompilationArg::Some(arg) => CubeOptionCompilationArg::Some(arg.clone()),
147            CubeOptionCompilationArg::None => CubeOptionCompilationArg::None,
148        }
149    }
150}
151
152impl<T: LaunchArg> PartialEq for CubeOptionCompilationArg<T> {
153    fn eq(&self, other: &Self) -> bool {
154        match (self, other) {
155            (CubeOptionCompilationArg::Some(arg_0), CubeOptionCompilationArg::Some(arg_1)) => {
156                arg_0 == arg_1
157            }
158            (CubeOptionCompilationArg::None, CubeOptionCompilationArg::None) => true,
159            _ => false,
160        }
161    }
162}
163
164impl<T: LaunchArg> Eq for CubeOptionCompilationArg<T> {}
165
166impl<T: LaunchArg> core::hash::Hash for CubeOptionCompilationArg<T> {
167    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
168        match self {
169            CubeOptionCompilationArg::Some(arg) => {
170                arg.hash(state);
171            }
172            CubeOptionCompilationArg::None => {}
173        };
174    }
175}
176
177impl<T: LaunchArg> core::fmt::Debug for CubeOptionCompilationArg<T> {
178    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
179        match self {
180            CubeOptionCompilationArg::Some(arg) => f.debug_tuple("Some").field(arg).finish(),
181            CubeOptionCompilationArg::None => write!(f, "None"),
182        }
183    }
184}
185
186impl<T: LaunchArg> CompilationArg for CubeOptionCompilationArg<T> {}