re_types/
tensor_data.rs

1//! Internal helpers; not part of the public API.
2#![expect(missing_docs)]
3
4use half::f16;
5
6#[expect(unused_imports)] // Used for docstring links
7use crate::datatypes::TensorData;
8
9// ----------------------------------------------------------------------------
10
11/// Errors when trying to cast [`TensorData`] to an `ndarray`
12#[derive(thiserror::Error, Debug, PartialEq, Clone)]
13pub enum TensorCastError {
14    #[error("ndarray type mismatch with tensor storage")]
15    TypeMismatch,
16
17    #[error("tensor shape did not match storage length")]
18    BadTensorShape {
19        #[from]
20        source: ndarray::ShapeError,
21    },
22
23    #[error("ndarray Array is not contiguous and in standard order")]
24    NotContiguousStdOrder,
25}
26
27/// Errors when loading [`TensorData`] from the [`image`] crate.
28#[cfg(feature = "image")]
29#[derive(thiserror::Error, Clone, Debug)]
30pub enum TensorImageLoadError {
31    #[error(transparent)]
32    Image(std::sync::Arc<image::ImageError>),
33
34    #[error(
35        "Unsupported color type: {0:?}. We support 8-bit, 16-bit, and f32 images, and RGB, RGBA, Luminance, and Luminance-Alpha."
36    )]
37    UnsupportedImageColorType(image::ColorType),
38
39    #[error("Failed to load file: {0}")]
40    ReadError(std::sync::Arc<std::io::Error>),
41
42    #[error("The encoded tensor shape did not match its metadata {expected:?} != {found:?}")]
43    InvalidMetaData { expected: Vec<u64>, found: Vec<u64> },
44}
45
46#[cfg(feature = "image")]
47impl From<image::ImageError> for TensorImageLoadError {
48    #[inline]
49    fn from(err: image::ImageError) -> Self {
50        Self::Image(std::sync::Arc::new(err))
51    }
52}
53
54#[cfg(feature = "image")]
55impl From<std::io::Error> for TensorImageLoadError {
56    #[inline]
57    fn from(err: std::io::Error) -> Self {
58        Self::ReadError(std::sync::Arc::new(err))
59    }
60}
61
62// ----------------------------------------------------------------------------
63
64/// The data types supported by a [`crate::datatypes::TensorData`].
65#[derive(Clone, Copy, Debug, PartialEq, Eq)]
66pub enum TensorDataType {
67    /// Unsigned 8 bit integer.
68    ///
69    /// Commonly used for sRGB(A).
70    U8,
71
72    /// Unsigned 16 bit integer.
73    ///
74    /// Used by some depth images and some high-bitrate images.
75    U16,
76
77    /// Unsigned 32 bit integer.
78    U32,
79
80    /// Unsigned 64 bit integer.
81    U64,
82
83    /// Signed 8 bit integer.
84    I8,
85
86    /// Signed 16 bit integer.
87    I16,
88
89    /// Signed 32 bit integer.
90    I32,
91
92    /// Signed 64 bit integer.
93    I64,
94
95    /// 16-bit floating point number.
96    ///
97    /// Uses the standard IEEE 754-2008 binary16 format.
98    /// Set <https://en.wikipedia.org/wiki/Half-precision_floating-point_format>.
99    F16,
100
101    /// 32-bit floating point number.
102    F32,
103
104    /// 64-bit floating point number.
105    F64,
106}
107
108impl TensorDataType {
109    /// Number of bytes used by the type
110    #[inline]
111    pub fn size(&self) -> u64 {
112        match self {
113            Self::U8 => std::mem::size_of::<u8>() as _,
114            Self::U16 => std::mem::size_of::<u16>() as _,
115            Self::U32 => std::mem::size_of::<u32>() as _,
116            Self::U64 => std::mem::size_of::<u64>() as _,
117
118            Self::I8 => std::mem::size_of::<i8>() as _,
119            Self::I16 => std::mem::size_of::<i16>() as _,
120            Self::I32 => std::mem::size_of::<i32>() as _,
121            Self::I64 => std::mem::size_of::<i64>() as _,
122
123            Self::F16 => std::mem::size_of::<f16>() as _,
124            Self::F32 => std::mem::size_of::<f32>() as _,
125            Self::F64 => std::mem::size_of::<f64>() as _,
126        }
127    }
128
129    /// Is this datatype an integer?
130    #[inline]
131    pub fn is_integer(&self) -> bool {
132        !self.is_float()
133    }
134
135    /// Is this datatype a floating point number?
136    #[inline]
137    pub fn is_float(&self) -> bool {
138        match self {
139            Self::U8
140            | Self::U16
141            | Self::U32
142            | Self::U64
143            | Self::I8
144            | Self::I16
145            | Self::I32
146            | Self::I64 => false,
147            Self::F16 | Self::F32 | Self::F64 => true,
148        }
149    }
150
151    /// What is the minimum finite value representable by this datatype?
152    #[inline]
153    pub fn min_value(&self) -> f64 {
154        match self {
155            Self::U8 => u8::MIN as _,
156            Self::U16 => u16::MIN as _,
157            Self::U32 => u32::MIN as _,
158            Self::U64 => u64::MIN as _,
159
160            Self::I8 => i8::MIN as _,
161            Self::I16 => i16::MIN as _,
162            Self::I32 => i32::MIN as _,
163            Self::I64 => i64::MIN as _,
164
165            Self::F16 => f16::MIN.into(),
166            Self::F32 => f32::MIN as _,
167            Self::F64 => f64::MIN,
168        }
169    }
170
171    /// What is the maximum finite value representable by this datatype?
172    #[inline]
173    pub fn max_value(&self) -> f64 {
174        match self {
175            Self::U8 => u8::MAX as _,
176            Self::U16 => u16::MAX as _,
177            Self::U32 => u32::MAX as _,
178            Self::U64 => u64::MAX as _,
179
180            Self::I8 => i8::MAX as _,
181            Self::I16 => i16::MAX as _,
182            Self::I32 => i32::MAX as _,
183            Self::I64 => i64::MAX as _,
184
185            Self::F16 => f16::MAX.into(),
186            Self::F32 => f32::MAX as _,
187            Self::F64 => f64::MAX,
188        }
189    }
190}
191
192impl std::fmt::Display for TensorDataType {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        match self {
195            Self::U8 => "uint8".fmt(f),
196            Self::U16 => "uint16".fmt(f),
197            Self::U32 => "uint32".fmt(f),
198            Self::U64 => "uint64".fmt(f),
199
200            Self::I8 => "int8".fmt(f),
201            Self::I16 => "int16".fmt(f),
202            Self::I32 => "int32".fmt(f),
203            Self::I64 => "int64".fmt(f),
204
205            Self::F16 => "float16".fmt(f),
206            Self::F32 => "float32".fmt(f),
207            Self::F64 => "float64".fmt(f),
208        }
209    }
210}
211
212// ----------------------------------------------------------------------------
213
214pub trait TensorDataTypeTrait: Copy + Clone + Send + Sync {
215    const DTYPE: TensorDataType;
216}
217
218impl TensorDataTypeTrait for u8 {
219    const DTYPE: TensorDataType = TensorDataType::U8;
220}
221
222impl TensorDataTypeTrait for u16 {
223    const DTYPE: TensorDataType = TensorDataType::U16;
224}
225
226impl TensorDataTypeTrait for u32 {
227    const DTYPE: TensorDataType = TensorDataType::U32;
228}
229
230impl TensorDataTypeTrait for u64 {
231    const DTYPE: TensorDataType = TensorDataType::U64;
232}
233
234impl TensorDataTypeTrait for i8 {
235    const DTYPE: TensorDataType = TensorDataType::I8;
236}
237
238impl TensorDataTypeTrait for i16 {
239    const DTYPE: TensorDataType = TensorDataType::I16;
240}
241
242impl TensorDataTypeTrait for i32 {
243    const DTYPE: TensorDataType = TensorDataType::I32;
244}
245
246impl TensorDataTypeTrait for i64 {
247    const DTYPE: TensorDataType = TensorDataType::I64;
248}
249
250impl TensorDataTypeTrait for f16 {
251    const DTYPE: TensorDataType = TensorDataType::F16;
252}
253
254impl TensorDataTypeTrait for f32 {
255    const DTYPE: TensorDataType = TensorDataType::F32;
256}
257
258impl TensorDataTypeTrait for f64 {
259    const DTYPE: TensorDataType = TensorDataType::F64;
260}
261
262/// The data that can be stored in a [`crate::datatypes::TensorData`].
263#[derive(Clone, Copy, Debug, PartialEq)]
264pub enum TensorElement {
265    /// Unsigned 8 bit integer.
266    ///
267    /// Commonly used for sRGB(A).
268    U8(u8),
269
270    /// Unsigned 16 bit integer.
271    ///
272    /// Used by some depth images and some high-bitrate images.
273    U16(u16),
274
275    /// Unsigned 32 bit integer.
276    U32(u32),
277
278    /// Unsigned 64 bit integer.
279    U64(u64),
280
281    /// Signed 8 bit integer.
282    I8(i8),
283
284    /// Signed 16 bit integer.
285    I16(i16),
286
287    /// Signed 32 bit integer.
288    I32(i32),
289
290    /// Signed 64 bit integer.
291    I64(i64),
292
293    /// 16-bit floating point number.
294    ///
295    /// Uses the standard IEEE 754-2008 binary16 format.
296    /// Set <https://en.wikipedia.org/wiki/Half-precision_floating-point_format>.
297    F16(half::f16),
298
299    /// 32-bit floating point number.
300    F32(f32),
301
302    /// 64-bit floating point number.
303    F64(f64),
304}
305
306impl TensorElement {
307    /// Get the value as a 64-bit floating point number.
308    ///
309    /// Note that this may cause rounding for large 64-bit integers,
310    /// as `f64` can only represent integers up to 2^53 exactly.
311    #[inline]
312    pub fn as_f64(&self) -> f64 {
313        match self {
314            Self::U8(value) => *value as _,
315            Self::U16(value) => *value as _,
316            Self::U32(value) => *value as _,
317            Self::U64(value) => *value as _,
318
319            Self::I8(value) => *value as _,
320            Self::I16(value) => *value as _,
321            Self::I32(value) => *value as _,
322            Self::I64(value) => *value as _,
323
324            Self::F16(value) => value.to_f32() as _,
325            Self::F32(value) => *value as _,
326            Self::F64(value) => *value,
327        }
328    }
329
330    /// Convert the value to a `u16`, but only if it can be represented
331    /// exactly as a `u16`, without any rounding or clamping.
332    #[inline]
333    pub fn try_as_u16(&self) -> Option<u16> {
334        fn u16_from_f64(f: f64) -> Option<u16> {
335            let u16_value = f as u16;
336            let roundtrips = u16_value as f64 == f;
337            roundtrips.then_some(u16_value)
338        }
339
340        match self {
341            Self::U8(value) => Some(*value as u16),
342            Self::U16(value) => Some(*value),
343            Self::U32(value) => u16::try_from(*value).ok(),
344            Self::U64(value) => u16::try_from(*value).ok(),
345
346            Self::I8(value) => u16::try_from(*value).ok(),
347            Self::I16(value) => u16::try_from(*value).ok(),
348            Self::I32(value) => u16::try_from(*value).ok(),
349            Self::I64(value) => u16::try_from(*value).ok(),
350
351            Self::F16(value) => u16_from_f64(value.to_f32() as f64),
352            Self::F32(value) => u16_from_f64(*value as f64),
353            Self::F64(value) => u16_from_f64(*value),
354        }
355    }
356
357    /// Format the value with `re_format`
358    pub fn format(&self) -> String {
359        match self {
360            Self::U8(val) => re_format::format_uint(*val),
361            Self::U16(val) => re_format::format_uint(*val),
362            Self::U32(val) => re_format::format_uint(*val),
363            Self::U64(val) => re_format::format_uint(*val),
364            Self::I8(val) => re_format::format_int(*val),
365            Self::I16(val) => re_format::format_int(*val),
366            Self::I32(val) => re_format::format_int(*val),
367            Self::I64(val) => re_format::format_int(*val),
368            Self::F16(val) => re_format::format_f16(*val),
369            Self::F32(val) => re_format::format_f32(*val),
370            Self::F64(val) => re_format::format_f64(*val),
371        }
372    }
373
374    /// Get the minimum value representable by this element's type.
375    fn min_value(&self) -> Self {
376        match self {
377            Self::U8(_) => Self::U8(u8::MIN),
378            Self::U16(_) => Self::U16(u16::MIN),
379            Self::U32(_) => Self::U32(u32::MIN),
380            Self::U64(_) => Self::U64(u64::MIN),
381
382            Self::I8(_) => Self::I8(i8::MIN),
383            Self::I16(_) => Self::I16(i16::MIN),
384            Self::I32(_) => Self::I32(i32::MIN),
385            Self::I64(_) => Self::I64(i64::MIN),
386
387            Self::F16(_) => Self::F16(f16::MIN),
388            Self::F32(_) => Self::F32(f32::MIN),
389            Self::F64(_) => Self::F64(f64::MIN),
390        }
391    }
392
393    /// Get the maximum value representable by this element's type.
394    fn max_value(&self) -> Self {
395        match self {
396            Self::U8(_) => Self::U8(u8::MAX),
397            Self::U16(_) => Self::U16(u16::MAX),
398            Self::U32(_) => Self::U32(u32::MAX),
399            Self::U64(_) => Self::U64(u64::MAX),
400
401            Self::I8(_) => Self::I8(i8::MAX),
402            Self::I16(_) => Self::I16(i16::MAX),
403            Self::I32(_) => Self::I32(i32::MAX),
404            Self::I64(_) => Self::I64(i64::MAX),
405
406            Self::F16(_) => Self::F16(f16::MAX),
407            Self::F32(_) => Self::F32(f32::MAX),
408            Self::F64(_) => Self::F64(f64::MAX),
409        }
410    }
411
412    /// Formats the element as a string, padded to the width of the largest possible value.
413    pub fn format_padded(&self) -> String {
414        let max_len = match self {
415            Self::U8(_) | Self::U16(_) | Self::U32(_) | Self::U64(_) => {
416                self.max_value().format().chars().count()
417            }
418            Self::I8(_) | Self::I16(_) | Self::I32(_) | Self::I64(_) => {
419                self.min_value().format().chars().count()
420            }
421            // These were determined by checking the length of random formatted values
422            Self::F16(_) | Self::F32(_) => 12,
423            Self::F64(_) => 22,
424        };
425        let value_str = self.format();
426        format!("{value_str:>max_len$}")
427    }
428}
429
430impl std::fmt::Display for TensorElement {
431    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432        match self {
433            Self::U8(elem) => std::fmt::Display::fmt(elem, f),
434            Self::U16(elem) => std::fmt::Display::fmt(elem, f),
435            Self::U32(elem) => std::fmt::Display::fmt(elem, f),
436            Self::U64(elem) => std::fmt::Display::fmt(elem, f),
437            Self::I8(elem) => std::fmt::Display::fmt(elem, f),
438            Self::I16(elem) => std::fmt::Display::fmt(elem, f),
439            Self::I32(elem) => std::fmt::Display::fmt(elem, f),
440            Self::I64(elem) => std::fmt::Display::fmt(elem, f),
441            Self::F16(elem) => std::fmt::Display::fmt(elem, f),
442            Self::F32(elem) => std::fmt::Display::fmt(elem, f),
443            Self::F64(elem) => std::fmt::Display::fmt(elem, f),
444        }
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    #[test]
453    fn test_tensor_element_format() {
454        let elem = TensorElement::U8(42);
455        assert_eq!(elem.format(), "42");
456
457        let elem = TensorElement::F32(3.17);
458        assert_eq!(elem.format(), "3.17");
459
460        let elem = TensorElement::I64(-123456789);
461        assert_eq!(elem.format(), "−123\u{2009}456\u{2009}789");
462    }
463
464    #[test]
465    fn test_tensor_element_format_padded() {
466        macro_rules! test_padded_format {
467            ($type:ident, $random:expr) => {
468                let type_name = stringify!($type);
469                let left_padded = TensorElement::$type($random).format_padded();
470                for _ in 0..100 {
471                    let elem = TensorElement::$type($random);
472                    let right_padded = elem.format_padded();
473                    assert_eq!(
474                        left_padded.chars().count(),
475                        right_padded.chars().count(),
476                        "Padded format length mismatch for type {type_name} with value '{left_padded}' and value '{right_padded}'",
477                    );
478                }
479            };
480        }
481        test_padded_format!(U8, rand::random());
482        test_padded_format!(U16, rand::random());
483        test_padded_format!(U32, rand::random());
484        test_padded_format!(U64, rand::random());
485        test_padded_format!(I8, rand::random());
486        test_padded_format!(I16, rand::random());
487        test_padded_format!(I32, rand::random());
488        test_padded_format!(I64, rand::random());
489
490        test_padded_format!(F16, f16::from_bits(rand::random()));
491        test_padded_format!(F32, f32::from_bits(rand::random()));
492        test_padded_format!(F64, f64::from_bits(rand::random()));
493    }
494}