cubecl_std/
option.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use serde::{Deserialize, Serialize};
4
5#[derive(CubeType, Clone, Copy, Serialize, Deserialize, Hash, PartialEq, Eq, Debug)]
6pub enum CubeOption<T: CubeType> {
7    Some(T),
8    None,
9}
10
11#[cube]
12impl<T: CubeType> CubeOption<T> {
13    pub fn is_some(&self) -> bool {
14        match self {
15            CubeOption::Some(_) => true,
16            CubeOption::None => false,
17        }
18    }
19
20    pub fn unwrap(self) -> T {
21        match self {
22            CubeOption::Some(val) => val,
23            CubeOption::None => panic!("Unwrap on a None CubeOption"),
24        }
25    }
26
27    pub fn is_none(&self) -> bool {
28        !self.is_some()
29    }
30
31    pub fn unwrap_or(self, fallback: T) -> T {
32        match self {
33            CubeOption::Some(val) => val,
34            CubeOption::None => fallback,
35        }
36    }
37}
38
39impl<T: CubeType> CubeOptionExpand<T> {
40    pub fn is_some(&self) -> bool {
41        match self {
42            CubeOptionExpand::Some(_) => true,
43            CubeOptionExpand::None => false,
44        }
45    }
46
47    pub fn unwrap(self) -> T::ExpandType {
48        match self {
49            Self::Some(val) => val,
50            Self::None => panic!("Unwrap on a None CubeOption"),
51        }
52    }
53
54    pub fn is_none(&self) -> bool {
55        !self.is_some()
56    }
57
58    pub fn unwrap_or(self, fallback: T::ExpandType) -> T::ExpandType {
59        match self {
60            CubeOptionExpand::Some(val) => val,
61            CubeOptionExpand::None => fallback,
62        }
63    }
64}
65
66impl<T: CubeType + Into<T::ExpandType>> From<CubeOption<T>> for CubeOptionExpand<T> {
67    fn from(value: CubeOption<T>) -> Self {
68        match value {
69            CubeOption::Some(val) => CubeOptionExpand::Some(val.into()),
70            CubeOption::None => CubeOptionExpand::None,
71        }
72    }
73}
74
75// Manually implement LaunchArg as the macro is currently not permissive enough.
76
77pub enum CubeOptionArgs<'a, T: LaunchArg, R: Runtime> {
78    Some(<T as LaunchArg>::RuntimeArg<'a, R>),
79    None,
80}
81
82impl<'a, T: LaunchArg, R: Runtime> From<Option<<T as LaunchArg>::RuntimeArg<'a, R>>>
83    for CubeOptionArgs<'a, T, R>
84{
85    fn from(value: Option<<T as LaunchArg>::RuntimeArg<'a, R>>) -> Self {
86        match value {
87            Some(arg) => Self::Some(arg),
88            None => Self::None,
89        }
90    }
91}
92
93impl<T: LaunchArg, R: Runtime> ArgSettings<R> for CubeOptionArgs<'_, T, R> {
94    fn register(&self, launcher: &mut KernelLauncher<R>) {
95        match self {
96            CubeOptionArgs::Some(arg) => {
97                arg.register(launcher);
98            }
99            CubeOptionArgs::None => {}
100        }
101    }
102}
103impl<T: LaunchArg> LaunchArg for CubeOption<T> {
104    type RuntimeArg<'a, R: Runtime> = CubeOptionArgs<'a, T, R>;
105    type CompilationArg = CubeOptionCompilationArg<T>;
106
107    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
108        match runtime_arg {
109            CubeOptionArgs::Some(arg) => CubeOptionCompilationArg::Some(T::compilation_arg(arg)),
110            CubeOptionArgs::None => CubeOptionCompilationArg::None,
111        }
112    }
113
114    fn expand(
115        arg: &Self::CompilationArg,
116        builder: &mut KernelBuilder,
117    ) -> <Self as CubeType>::ExpandType {
118        match arg {
119            CubeOptionCompilationArg::Some(arg) => CubeOptionExpand::Some(T::expand(arg, builder)),
120            CubeOptionCompilationArg::None => CubeOptionExpand::None,
121        }
122    }
123
124    fn expand_output(
125        arg: &Self::CompilationArg,
126        builder: &mut KernelBuilder,
127    ) -> <Self as CubeType>::ExpandType {
128        match arg {
129            CubeOptionCompilationArg::Some(arg) => {
130                CubeOptionExpand::Some(T::expand_output(arg, builder))
131            }
132            CubeOptionCompilationArg::None => CubeOptionExpand::None,
133        }
134    }
135}
136
137pub enum CubeOptionCompilationArg<T: LaunchArg> {
138    Some(<T as LaunchArg>::CompilationArg),
139    None,
140}
141
142impl<T: LaunchArg> Clone for CubeOptionCompilationArg<T> {
143    fn clone(&self) -> Self {
144        match self {
145            CubeOptionCompilationArg::Some(arg) => CubeOptionCompilationArg::Some(arg.clone()),
146            CubeOptionCompilationArg::None => CubeOptionCompilationArg::None,
147        }
148    }
149}
150
151impl<T: LaunchArg> PartialEq for CubeOptionCompilationArg<T> {
152    fn eq(&self, other: &Self) -> bool {
153        match (self, other) {
154            (CubeOptionCompilationArg::Some(arg_0), CubeOptionCompilationArg::Some(arg_1)) => {
155                arg_0 == arg_1
156            }
157            (CubeOptionCompilationArg::None, CubeOptionCompilationArg::None) => true,
158            _ => false,
159        }
160    }
161}
162
163impl<T: LaunchArg> Eq for CubeOptionCompilationArg<T> {}
164
165impl<T: LaunchArg> core::hash::Hash for CubeOptionCompilationArg<T> {
166    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
167        match self {
168            CubeOptionCompilationArg::Some(arg) => {
169                arg.hash(state);
170            }
171            CubeOptionCompilationArg::None => {}
172        };
173    }
174}
175
176impl<T: LaunchArg> core::fmt::Debug for CubeOptionCompilationArg<T> {
177    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
178        match self {
179            CubeOptionCompilationArg::Some(arg) => f
180                .debug_tuple("CubeOptionCompilationArg :: Some")
181                .field(arg)
182                .finish(),
183            CubeOptionCompilationArg::None => write!(f, "CubeOptionCompilationArg :: None"),
184        }
185    }
186}
187
188impl<T: LaunchArg> CompilationArg for CubeOptionCompilationArg<T> {}