1#![expect(missing_docs)]
3
4use half::f16;
5
6#[expect(unused_imports)] use crate::datatypes::TensorData;
8
9#[derive(thiserror::Error, Debug, PartialEq, Clone)]
13pub enum TensorCastError {
14 #[error("ndarray type mismatch with tensor storage")]
15 TypeMismatch,
16
17 #[error("tensor shape did not match storage length")]
18 BadTensorShape {
19 #[from]
20 source: ndarray::ShapeError,
21 },
22
23 #[error("ndarray Array is not contiguous and in standard order")]
24 NotContiguousStdOrder,
25}
26
27#[cfg(feature = "image")]
29#[derive(thiserror::Error, Clone, Debug)]
30pub enum TensorImageLoadError {
31 #[error(transparent)]
32 Image(std::sync::Arc<image::ImageError>),
33
34 #[error(
35 "Unsupported color type: {0:?}. We support 8-bit, 16-bit, and f32 images, and RGB, RGBA, Luminance, and Luminance-Alpha."
36 )]
37 UnsupportedImageColorType(image::ColorType),
38
39 #[error("Failed to load file: {0}")]
40 ReadError(std::sync::Arc<std::io::Error>),
41
42 #[error("The encoded tensor shape did not match its metadata {expected:?} != {found:?}")]
43 InvalidMetaData { expected: Vec<u64>, found: Vec<u64> },
44}
45
46#[cfg(feature = "image")]
47impl From<image::ImageError> for TensorImageLoadError {
48 #[inline]
49 fn from(err: image::ImageError) -> Self {
50 Self::Image(std::sync::Arc::new(err))
51 }
52}
53
54#[cfg(feature = "image")]
55impl From<std::io::Error> for TensorImageLoadError {
56 #[inline]
57 fn from(err: std::io::Error) -> Self {
58 Self::ReadError(std::sync::Arc::new(err))
59 }
60}
61
62#[derive(Clone, Copy, Debug, PartialEq, Eq)]
66pub enum TensorDataType {
67 U8,
71
72 U16,
76
77 U32,
79
80 U64,
82
83 I8,
85
86 I16,
88
89 I32,
91
92 I64,
94
95 F16,
100
101 F32,
103
104 F64,
106}
107
108impl TensorDataType {
109 #[inline]
111 pub fn size(&self) -> u64 {
112 match self {
113 Self::U8 => std::mem::size_of::<u8>() as _,
114 Self::U16 => std::mem::size_of::<u16>() as _,
115 Self::U32 => std::mem::size_of::<u32>() as _,
116 Self::U64 => std::mem::size_of::<u64>() as _,
117
118 Self::I8 => std::mem::size_of::<i8>() as _,
119 Self::I16 => std::mem::size_of::<i16>() as _,
120 Self::I32 => std::mem::size_of::<i32>() as _,
121 Self::I64 => std::mem::size_of::<i64>() as _,
122
123 Self::F16 => std::mem::size_of::<f16>() as _,
124 Self::F32 => std::mem::size_of::<f32>() as _,
125 Self::F64 => std::mem::size_of::<f64>() as _,
126 }
127 }
128
129 #[inline]
131 pub fn is_integer(&self) -> bool {
132 !self.is_float()
133 }
134
135 #[inline]
137 pub fn is_float(&self) -> bool {
138 match self {
139 Self::U8
140 | Self::U16
141 | Self::U32
142 | Self::U64
143 | Self::I8
144 | Self::I16
145 | Self::I32
146 | Self::I64 => false,
147 Self::F16 | Self::F32 | Self::F64 => true,
148 }
149 }
150
151 #[inline]
153 pub fn min_value(&self) -> f64 {
154 match self {
155 Self::U8 => u8::MIN as _,
156 Self::U16 => u16::MIN as _,
157 Self::U32 => u32::MIN as _,
158 Self::U64 => u64::MIN as _,
159
160 Self::I8 => i8::MIN as _,
161 Self::I16 => i16::MIN as _,
162 Self::I32 => i32::MIN as _,
163 Self::I64 => i64::MIN as _,
164
165 Self::F16 => f16::MIN.into(),
166 Self::F32 => f32::MIN as _,
167 Self::F64 => f64::MIN,
168 }
169 }
170
171 #[inline]
173 pub fn max_value(&self) -> f64 {
174 match self {
175 Self::U8 => u8::MAX as _,
176 Self::U16 => u16::MAX as _,
177 Self::U32 => u32::MAX as _,
178 Self::U64 => u64::MAX as _,
179
180 Self::I8 => i8::MAX as _,
181 Self::I16 => i16::MAX as _,
182 Self::I32 => i32::MAX as _,
183 Self::I64 => i64::MAX as _,
184
185 Self::F16 => f16::MAX.into(),
186 Self::F32 => f32::MAX as _,
187 Self::F64 => f64::MAX,
188 }
189 }
190}
191
192impl std::fmt::Display for TensorDataType {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 match self {
195 Self::U8 => "uint8".fmt(f),
196 Self::U16 => "uint16".fmt(f),
197 Self::U32 => "uint32".fmt(f),
198 Self::U64 => "uint64".fmt(f),
199
200 Self::I8 => "int8".fmt(f),
201 Self::I16 => "int16".fmt(f),
202 Self::I32 => "int32".fmt(f),
203 Self::I64 => "int64".fmt(f),
204
205 Self::F16 => "float16".fmt(f),
206 Self::F32 => "float32".fmt(f),
207 Self::F64 => "float64".fmt(f),
208 }
209 }
210}
211
212pub trait TensorDataTypeTrait: Copy + Clone + Send + Sync {
215 const DTYPE: TensorDataType;
216}
217
218impl TensorDataTypeTrait for u8 {
219 const DTYPE: TensorDataType = TensorDataType::U8;
220}
221
222impl TensorDataTypeTrait for u16 {
223 const DTYPE: TensorDataType = TensorDataType::U16;
224}
225
226impl TensorDataTypeTrait for u32 {
227 const DTYPE: TensorDataType = TensorDataType::U32;
228}
229
230impl TensorDataTypeTrait for u64 {
231 const DTYPE: TensorDataType = TensorDataType::U64;
232}
233
234impl TensorDataTypeTrait for i8 {
235 const DTYPE: TensorDataType = TensorDataType::I8;
236}
237
238impl TensorDataTypeTrait for i16 {
239 const DTYPE: TensorDataType = TensorDataType::I16;
240}
241
242impl TensorDataTypeTrait for i32 {
243 const DTYPE: TensorDataType = TensorDataType::I32;
244}
245
246impl TensorDataTypeTrait for i64 {
247 const DTYPE: TensorDataType = TensorDataType::I64;
248}
249
250impl TensorDataTypeTrait for f16 {
251 const DTYPE: TensorDataType = TensorDataType::F16;
252}
253
254impl TensorDataTypeTrait for f32 {
255 const DTYPE: TensorDataType = TensorDataType::F32;
256}
257
258impl TensorDataTypeTrait for f64 {
259 const DTYPE: TensorDataType = TensorDataType::F64;
260}
261
262#[derive(Clone, Copy, Debug, PartialEq)]
264pub enum TensorElement {
265 U8(u8),
269
270 U16(u16),
274
275 U32(u32),
277
278 U64(u64),
280
281 I8(i8),
283
284 I16(i16),
286
287 I32(i32),
289
290 I64(i64),
292
293 F16(half::f16),
298
299 F32(f32),
301
302 F64(f64),
304}
305
306impl TensorElement {
307 #[inline]
312 pub fn as_f64(&self) -> f64 {
313 match self {
314 Self::U8(value) => *value as _,
315 Self::U16(value) => *value as _,
316 Self::U32(value) => *value as _,
317 Self::U64(value) => *value as _,
318
319 Self::I8(value) => *value as _,
320 Self::I16(value) => *value as _,
321 Self::I32(value) => *value as _,
322 Self::I64(value) => *value as _,
323
324 Self::F16(value) => value.to_f32() as _,
325 Self::F32(value) => *value as _,
326 Self::F64(value) => *value,
327 }
328 }
329
330 #[inline]
333 pub fn try_as_u16(&self) -> Option<u16> {
334 fn u16_from_f64(f: f64) -> Option<u16> {
335 let u16_value = f as u16;
336 let roundtrips = u16_value as f64 == f;
337 roundtrips.then_some(u16_value)
338 }
339
340 match self {
341 Self::U8(value) => Some(*value as u16),
342 Self::U16(value) => Some(*value),
343 Self::U32(value) => u16::try_from(*value).ok(),
344 Self::U64(value) => u16::try_from(*value).ok(),
345
346 Self::I8(value) => u16::try_from(*value).ok(),
347 Self::I16(value) => u16::try_from(*value).ok(),
348 Self::I32(value) => u16::try_from(*value).ok(),
349 Self::I64(value) => u16::try_from(*value).ok(),
350
351 Self::F16(value) => u16_from_f64(value.to_f32() as f64),
352 Self::F32(value) => u16_from_f64(*value as f64),
353 Self::F64(value) => u16_from_f64(*value),
354 }
355 }
356
357 pub fn format(&self) -> String {
359 match self {
360 Self::U8(val) => re_format::format_uint(*val),
361 Self::U16(val) => re_format::format_uint(*val),
362 Self::U32(val) => re_format::format_uint(*val),
363 Self::U64(val) => re_format::format_uint(*val),
364 Self::I8(val) => re_format::format_int(*val),
365 Self::I16(val) => re_format::format_int(*val),
366 Self::I32(val) => re_format::format_int(*val),
367 Self::I64(val) => re_format::format_int(*val),
368 Self::F16(val) => re_format::format_f16(*val),
369 Self::F32(val) => re_format::format_f32(*val),
370 Self::F64(val) => re_format::format_f64(*val),
371 }
372 }
373
374 fn min_value(&self) -> Self {
376 match self {
377 Self::U8(_) => Self::U8(u8::MIN),
378 Self::U16(_) => Self::U16(u16::MIN),
379 Self::U32(_) => Self::U32(u32::MIN),
380 Self::U64(_) => Self::U64(u64::MIN),
381
382 Self::I8(_) => Self::I8(i8::MIN),
383 Self::I16(_) => Self::I16(i16::MIN),
384 Self::I32(_) => Self::I32(i32::MIN),
385 Self::I64(_) => Self::I64(i64::MIN),
386
387 Self::F16(_) => Self::F16(f16::MIN),
388 Self::F32(_) => Self::F32(f32::MIN),
389 Self::F64(_) => Self::F64(f64::MIN),
390 }
391 }
392
393 fn max_value(&self) -> Self {
395 match self {
396 Self::U8(_) => Self::U8(u8::MAX),
397 Self::U16(_) => Self::U16(u16::MAX),
398 Self::U32(_) => Self::U32(u32::MAX),
399 Self::U64(_) => Self::U64(u64::MAX),
400
401 Self::I8(_) => Self::I8(i8::MAX),
402 Self::I16(_) => Self::I16(i16::MAX),
403 Self::I32(_) => Self::I32(i32::MAX),
404 Self::I64(_) => Self::I64(i64::MAX),
405
406 Self::F16(_) => Self::F16(f16::MAX),
407 Self::F32(_) => Self::F32(f32::MAX),
408 Self::F64(_) => Self::F64(f64::MAX),
409 }
410 }
411
412 pub fn format_padded(&self) -> String {
414 let max_len = match self {
415 Self::U8(_) | Self::U16(_) | Self::U32(_) | Self::U64(_) => {
416 self.max_value().format().chars().count()
417 }
418 Self::I8(_) | Self::I16(_) | Self::I32(_) | Self::I64(_) => {
419 self.min_value().format().chars().count()
420 }
421 Self::F16(_) | Self::F32(_) => 12,
423 Self::F64(_) => 22,
424 };
425 let value_str = self.format();
426 format!("{value_str:>max_len$}")
427 }
428}
429
430impl std::fmt::Display for TensorElement {
431 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432 match self {
433 Self::U8(elem) => std::fmt::Display::fmt(elem, f),
434 Self::U16(elem) => std::fmt::Display::fmt(elem, f),
435 Self::U32(elem) => std::fmt::Display::fmt(elem, f),
436 Self::U64(elem) => std::fmt::Display::fmt(elem, f),
437 Self::I8(elem) => std::fmt::Display::fmt(elem, f),
438 Self::I16(elem) => std::fmt::Display::fmt(elem, f),
439 Self::I32(elem) => std::fmt::Display::fmt(elem, f),
440 Self::I64(elem) => std::fmt::Display::fmt(elem, f),
441 Self::F16(elem) => std::fmt::Display::fmt(elem, f),
442 Self::F32(elem) => std::fmt::Display::fmt(elem, f),
443 Self::F64(elem) => std::fmt::Display::fmt(elem, f),
444 }
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451
452 #[test]
453 fn test_tensor_element_format() {
454 let elem = TensorElement::U8(42);
455 assert_eq!(elem.format(), "42");
456
457 let elem = TensorElement::F32(3.17);
458 assert_eq!(elem.format(), "3.17");
459
460 let elem = TensorElement::I64(-123456789);
461 assert_eq!(elem.format(), "−123\u{2009}456\u{2009}789");
462 }
463
464 #[test]
465 fn test_tensor_element_format_padded() {
466 macro_rules! test_padded_format {
467 ($type:ident, $random:expr) => {
468 let type_name = stringify!($type);
469 let left_padded = TensorElement::$type($random).format_padded();
470 for _ in 0..100 {
471 let elem = TensorElement::$type($random);
472 let right_padded = elem.format_padded();
473 assert_eq!(
474 left_padded.chars().count(),
475 right_padded.chars().count(),
476 "Padded format length mismatch for type {type_name} with value '{left_padded}' and value '{right_padded}'",
477 );
478 }
479 };
480 }
481 test_padded_format!(U8, rand::random());
482 test_padded_format!(U16, rand::random());
483 test_padded_format!(U32, rand::random());
484 test_padded_format!(U64, rand::random());
485 test_padded_format!(I8, rand::random());
486 test_padded_format!(I16, rand::random());
487 test_padded_format!(I32, rand::random());
488 test_padded_format!(I64, rand::random());
489
490 test_padded_format!(F16, f16::from_bits(rand::random()));
491 test_padded_format!(F32, f32::from_bits(rand::random()));
492 test_padded_format!(F64, f64::from_bits(rand::random()));
493 }
494}