1use half::f16;
2use num_traits::{NumCast, ToPrimitive};
3
4#[derive(Debug, Clone, PartialEq)]
9pub enum Scalar {
10 Bool(bool),
11 U8(u8),
12 I8(i8),
13 U16(u16),
14 I16(i16),
15 F16(f16),
16 U32(u32),
17 I32(i32),
18 F32(f32),
19 U64(u64),
20 I64(i64),
21 F64(f64),
22 Unsupported,
23}
24
25impl Scalar {
26 pub fn to_bool(&self) -> Option<bool> {
27 match self {
28 Scalar::Bool(v) => Some(*v),
29 _ => None,
30 }
31 }
32
33 pub fn to_u8(&self) -> Option<u8> {
34 self.to()
35 }
36
37 pub fn to_i8(&self) -> Option<i8> {
38 self.to()
39 }
40
41 pub fn to_u16(&self) -> Option<u16> {
42 self.to()
43 }
44
45 pub fn to_i16(&self) -> Option<i16> {
46 self.to()
47 }
48
49 pub fn to_f16(&self) -> Option<f16> {
50 self.to()
51 }
52
53 pub fn to_u32(&self) -> Option<u32> {
54 self.to()
55 }
56
57 pub fn to_i32(&self) -> Option<i32> {
58 self.to()
59 }
60
61 pub fn to_f32(&self) -> Option<f32> {
62 self.to()
63 }
64
65 pub fn to_u64(&self) -> Option<u64> {
66 self.to()
67 }
68
69 pub fn to_i64(&self) -> Option<i64> {
70 self.to()
71 }
72
73 pub fn to_f64(&self) -> Option<f64> {
74 self.to()
75 }
76
77 fn to<T: NumCast>(&self) -> Option<T> {
78 match self {
79 Scalar::Bool(v) => NumCast::from(*v as u8),
81 Scalar::U8(v) => NumCast::from(*v),
82 Scalar::I8(v) => NumCast::from(*v),
83 Scalar::U16(v) => NumCast::from(*v),
84 Scalar::I16(v) => NumCast::from(*v),
85 Scalar::F16(v) => NumCast::from(*v),
86 Scalar::U32(v) => NumCast::from(*v),
87 Scalar::I32(v) => NumCast::from(*v),
88 Scalar::F32(v) => NumCast::from(*v),
89 Scalar::U64(v) => NumCast::from(*v),
90 Scalar::I64(v) => NumCast::from(*v),
91 Scalar::F64(v) => NumCast::from(*v),
92 Scalar::Unsupported => None,
93 }
94 }
95}
96
97use ndarray::{Array, IxDyn};
101
102#[derive(Debug, Clone, PartialEq)]
104pub enum NDArray {
105 Bool(Array<bool, IxDyn>),
106 U8(Array<u8, IxDyn>),
107 I8(Array<i8, IxDyn>),
108 U16(Array<u16, IxDyn>),
109 I16(Array<i16, IxDyn>),
110 F16(Array<f16, IxDyn>),
111 U32(Array<u32, IxDyn>),
112 I32(Array<i32, IxDyn>),
113 F32(Array<f32, IxDyn>),
114 U64(Array<u64, IxDyn>),
115 I64(Array<i64, IxDyn>),
116 F64(Array<f64, IxDyn>),
117 Unsupported,
118}
119
120impl NDArray {
121 pub fn into_bool_array(self) -> Option<Array<bool, IxDyn>> {
122 match self {
123 NDArray::Bool(arr) => Some(arr),
124 _ => None,
125 }
126 }
127
128 pub fn into_u8_array(self) -> Option<Array<u8, IxDyn>> {
129 match self {
130 NDArray::U8(arr) => Some(arr),
131 _ => self.convert_into::<u8>(),
132 }
133 }
134
135 pub fn into_i8_array(self) -> Option<Array<i8, IxDyn>> {
136 match self {
137 NDArray::I8(arr) => Some(arr),
138 _ => self.convert_into::<i8>(),
139 }
140 }
141
142 pub fn into_u16_array(self) -> Option<Array<u16, IxDyn>> {
143 match self {
144 NDArray::U16(arr) => Some(arr),
145 _ => self.convert_into::<u16>(),
146 }
147 }
148
149 pub fn into_i16_array(self) -> Option<Array<i16, IxDyn>> {
150 match self {
151 NDArray::I16(arr) => Some(arr),
152 _ => self.convert_into::<i16>(),
153 }
154 }
155
156 pub fn into_f16_array(self) -> Option<Array<f16, IxDyn>> {
157 match self {
158 NDArray::F16(arr) => Some(arr),
159 _ => self.convert_into::<f16>(),
160 }
161 }
162
163 pub fn into_u32_array(self) -> Option<Array<u32, IxDyn>> {
164 match self {
165 NDArray::U32(arr) => Some(arr),
166 _ => self.convert_into::<u32>(),
167 }
168 }
169
170 pub fn into_i32_array(self) -> Option<Array<i32, IxDyn>> {
171 match self {
172 NDArray::I32(arr) => Some(arr),
173 _ => self.convert_into::<i32>(),
174 }
175 }
176
177 pub fn into_f32_array(self) -> Option<Array<f32, IxDyn>> {
178 match self {
179 NDArray::F32(arr) => Some(arr),
180 _ => self.convert_into::<f32>(),
181 }
182 }
183
184 pub fn into_u64_array(self) -> Option<Array<u64, IxDyn>> {
185 match self {
186 NDArray::U64(arr) => Some(arr),
187 _ => self.convert_into::<u64>(),
188 }
189 }
190
191 pub fn into_i64_array(self) -> Option<Array<i64, IxDyn>> {
192 match self {
193 NDArray::I64(arr) => Some(arr),
194 _ => self.convert_into::<i64>(),
195 }
196 }
197
198 pub fn into_f64_array(self) -> Option<Array<f64, IxDyn>> {
199 match self {
200 NDArray::F64(arr) => Some(arr),
201 _ => self.convert_into::<f64>(),
202 }
203 }
204
205 fn convert_into<T: NumCast + Copy>(self) -> Option<Array<T, IxDyn>> {
206 match self {
207 NDArray::Bool(arr) => Self::convert_bool_array(arr),
208 NDArray::U8(arr) => Self::convert_array(arr),
209 NDArray::I8(arr) => Self::convert_array(arr),
210 NDArray::U16(arr) => Self::convert_array(arr),
211 NDArray::I16(arr) => Self::convert_array(arr),
212 NDArray::F16(arr) => Self::convert_array(arr),
213 NDArray::U32(arr) => Self::convert_array(arr),
214 NDArray::I32(arr) => Self::convert_array(arr),
215 NDArray::F32(arr) => Self::convert_array(arr),
216 NDArray::U64(arr) => Self::convert_array(arr),
217 NDArray::I64(arr) => Self::convert_array(arr),
218 NDArray::F64(arr) => Self::convert_array(arr),
219 NDArray::Unsupported => None,
220 }
221 }
222
223 fn convert_array<S: Copy + ToPrimitive, T: NumCast>(
224 arr: Array<S, IxDyn>,
225 ) -> Option<Array<T, IxDyn>> {
226 let raw_dim = arr.raw_dim();
227 arr.into_iter()
228 .map(|v| NumCast::from(v).ok_or(()))
229 .collect::<Result<Vec<_>, _>>()
230 .ok()
231 .map(|vec| Array::from_shape_vec(raw_dim, vec).unwrap())
232 }
233
234 fn convert_bool_array<T: NumCast>(arr: Array<bool, IxDyn>) -> Option<Array<T, IxDyn>> {
235 let raw_dim = arr.raw_dim();
236 arr.into_iter()
237 .map(|v| NumCast::from(v as u8).ok_or(()))
238 .collect::<Result<Vec<_>, _>>()
239 .ok()
240 .map(|vec| Array::from_shape_vec(raw_dim, vec).unwrap())
241 }
242}
243
244use ndarray::CowArray;
248
249#[derive(Debug, Clone, PartialEq)]
251pub enum CowNDArray<'a> {
252 Bool(CowArray<'a, bool, IxDyn>),
253 U8(CowArray<'a, u8, IxDyn>),
254 I8(CowArray<'a, i8, IxDyn>),
255 U16(CowArray<'a, u16, IxDyn>),
256 I16(CowArray<'a, i16, IxDyn>),
257 F16(CowArray<'a, f16, IxDyn>),
258 U32(CowArray<'a, u32, IxDyn>),
259 I32(CowArray<'a, i32, IxDyn>),
260 F32(CowArray<'a, f32, IxDyn>),
261 U64(CowArray<'a, u64, IxDyn>),
262 I64(CowArray<'a, i64, IxDyn>),
263 F64(CowArray<'a, f64, IxDyn>),
264 Unsupported,
265}
266
267impl<'a> CowNDArray<'a> {
268 pub fn into_bool_array(self) -> Option<CowArray<'a, bool, IxDyn>> {
269 match self {
270 CowNDArray::Bool(arr) => Some(arr),
271 _ => None,
272 }
273 }
274
275 pub fn into_u8_array(self) -> Option<CowArray<'a, u8, IxDyn>> {
276 match self {
277 CowNDArray::U8(arr) => Some(arr),
278 _ => self.convert_into::<u8>(),
279 }
280 }
281
282 pub fn into_i8_array(self) -> Option<CowArray<'a, i8, IxDyn>> {
283 match self {
284 CowNDArray::I8(arr) => Some(arr),
285 _ => self.convert_into::<i8>(),
286 }
287 }
288
289 pub fn into_u16_array(self) -> Option<CowArray<'a, u16, IxDyn>> {
290 match self {
291 CowNDArray::U16(arr) => Some(arr),
292 _ => self.convert_into::<u16>(),
293 }
294 }
295
296 pub fn into_i16_array(self) -> Option<CowArray<'a, i16, IxDyn>> {
297 match self {
298 CowNDArray::I16(arr) => Some(arr),
299 _ => self.convert_into::<i16>(),
300 }
301 }
302
303 pub fn into_f16_array(self) -> Option<CowArray<'a, f16, IxDyn>> {
304 match self {
305 CowNDArray::F16(arr) => Some(arr),
306 _ => self.convert_into::<f16>(),
308 }
309 }
310
311 pub fn into_u32_array(self) -> Option<CowArray<'a, u32, IxDyn>> {
312 match self {
313 CowNDArray::U32(arr) => Some(arr),
314 _ => self.convert_into::<u32>(),
315 }
316 }
317
318 pub fn into_i32_array(self) -> Option<CowArray<'a, i32, IxDyn>> {
319 match self {
320 CowNDArray::I32(arr) => Some(arr),
321 _ => self.convert_into::<i32>(),
322 }
323 }
324
325 pub fn into_f32_array(self) -> Option<CowArray<'a, f32, IxDyn>> {
326 match self {
327 CowNDArray::F32(arr) => Some(arr),
328 _ => self.convert_into::<f32>(),
329 }
330 }
331
332 pub fn into_u64_array(self) -> Option<CowArray<'a, u64, IxDyn>> {
333 match self {
334 CowNDArray::U64(arr) => Some(arr),
335 _ => self.convert_into::<u64>(),
336 }
337 }
338
339 pub fn into_i64_array(self) -> Option<CowArray<'a, i64, IxDyn>> {
340 match self {
341 CowNDArray::I64(arr) => Some(arr),
342 _ => self.convert_into::<i64>(),
343 }
344 }
345
346 pub fn into_f64_array(self) -> Option<CowArray<'a, f64, IxDyn>> {
347 match self {
348 CowNDArray::F64(arr) => Some(arr),
349 _ => self.convert_into::<f64>(),
350 }
351 }
352
353 fn convert_into<T: NumCast + Copy>(self) -> Option<CowArray<'a, T, IxDyn>> {
354 match self {
355 CowNDArray::Bool(arr) => Self::convert_bool_array(arr),
356 CowNDArray::U8(arr) => Self::convert_array(arr),
357 CowNDArray::I8(arr) => Self::convert_array(arr),
358 CowNDArray::U16(arr) => Self::convert_array(arr),
359 CowNDArray::I16(arr) => Self::convert_array(arr),
360 CowNDArray::F16(arr) => Self::convert_array(arr),
361 CowNDArray::U32(arr) => Self::convert_array(arr),
362 CowNDArray::I32(arr) => Self::convert_array(arr),
363 CowNDArray::F32(arr) => Self::convert_array(arr),
364 CowNDArray::U64(arr) => Self::convert_array(arr),
365 CowNDArray::I64(arr) => Self::convert_array(arr),
366 CowNDArray::F64(arr) => Self::convert_array(arr),
367 CowNDArray::Unsupported => None,
368 }
369 }
370
371 fn convert_array<S: Copy + ToPrimitive, T: NumCast>(
372 arr: CowArray<S, IxDyn>,
373 ) -> Option<CowArray<T, IxDyn>> {
374 let raw_dim = arr.raw_dim();
375 arr.into_iter()
376 .map(|v| NumCast::from(v).ok_or(()))
377 .collect::<Result<Vec<_>, _>>()
378 .ok()
379 .map(|vec| Array::from_shape_vec(raw_dim, vec).unwrap().into())
380 }
381
382 fn convert_bool_array<T: NumCast>(arr: CowArray<bool, IxDyn>) -> Option<CowArray<T, IxDyn>> {
383 let raw_dim = arr.raw_dim();
384 arr.into_iter()
385 .map(|v| NumCast::from(v as u8).ok_or(()))
386 .collect::<Result<Vec<_>, _>>()
387 .ok()
388 .map(|vec| Array::from_shape_vec(raw_dim, vec).unwrap().into())
389 }
390}