tract_data/
datum.rs

1//! `Tensor` is the main data container for tract
2use crate::dim::TDim;
3use crate::internal::*;
4use crate::tensor::Tensor;
5use crate::TVec;
6use half::f16;
7#[cfg(feature = "complex")]
8use num_complex::Complex;
9use scan_fmt::scan_fmt;
10use std::fmt;
11use std::hash::Hash;
12
13use num_traits::AsPrimitive;
14
15#[derive(Copy, Clone, PartialEq)]
16pub enum QParams {
17    MinMax { min: f32, max: f32 },
18    ZpScale { zero_point: i32, scale: f32 },
19}
20
21impl Eq for QParams {}
22
23impl Ord for QParams {
24    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
25        use QParams::*;
26        match (self, other) {
27            (MinMax { .. }, ZpScale { .. }) => std::cmp::Ordering::Less,
28            (ZpScale { .. }, MinMax { .. }) => std::cmp::Ordering::Greater,
29            (MinMax { min: min1, max: max1 }, MinMax { min: min2, max: max2 }) => {
30                min1.total_cmp(min2).then_with(|| max1.total_cmp(max2))
31            }
32            (
33                Self::ZpScale { zero_point: zp1, scale: s1 },
34                Self::ZpScale { zero_point: zp2, scale: s2 },
35            ) => zp1.cmp(zp2).then_with(|| s1.total_cmp(s2)),
36        }
37    }
38}
39
40impl PartialOrd for QParams {
41    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
42        Some(self.cmp(other))
43    }
44}
45
46impl Default for QParams {
47    fn default() -> Self {
48        QParams::ZpScale { zero_point: 0, scale: 1. }
49    }
50}
51
52#[allow(clippy::derived_hash_with_manual_eq)]
53impl Hash for QParams {
54    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
55        match self {
56            QParams::MinMax { min, max } => {
57                0.hash(state);
58                min.to_bits().hash(state);
59                max.to_bits().hash(state);
60            }
61            QParams::ZpScale { zero_point, scale } => {
62                1.hash(state);
63                zero_point.hash(state);
64                scale.to_bits().hash(state);
65            }
66        }
67    }
68}
69
70impl QParams {
71    pub fn zp_scale(&self) -> (i32, f32) {
72        match self {
73            QParams::MinMax { min, max } => {
74                let scale = (max - min) / 255.;
75                ((-(min + max) / 2. / scale) as i32, scale)
76            }
77            QParams::ZpScale { zero_point, scale } => (*zero_point, *scale),
78        }
79    }
80
81    pub fn q(&self, f: f32) -> i32 {
82        let (zp, scale) = self.zp_scale();
83        (f / scale) as i32 + zp
84    }
85
86    pub fn dq(&self, i: i32) -> f32 {
87        let (zp, scale) = self.zp_scale();
88        (i - zp) as f32 * scale
89    }
90}
91
92impl std::fmt::Debug for QParams {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        let (zp, scale) = self.zp_scale();
95        write!(f, "Z:{zp} S:{scale}")
96    }
97}
98
99#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
100pub enum DatumType {
101    Bool,
102    U8,
103    U16,
104    U32,
105    U64,
106    I8,
107    I16,
108    I32,
109    I64,
110    F16,
111    F32,
112    F64,
113    TDim,
114    Blob,
115    String,
116    QI8(QParams),
117    QU8(QParams),
118    QI32(QParams),
119    #[cfg(feature = "complex")]
120    ComplexI16,
121    #[cfg(feature = "complex")]
122    ComplexI32,
123    #[cfg(feature = "complex")]
124    ComplexI64,
125    #[cfg(feature = "complex")]
126    ComplexF16,
127    #[cfg(feature = "complex")]
128    ComplexF32,
129    #[cfg(feature = "complex")]
130    ComplexF64,
131    Opaque,
132}
133
134impl DatumType {
135    pub fn super_types(&self) -> TVec<DatumType> {
136        use DatumType::*;
137        if *self == String || *self == TDim || *self == Blob || *self == Bool || self.is_quantized()
138        {
139            return tvec!(*self);
140        }
141        #[cfg(feature = "complex")]
142        if self.is_complex_float() {
143            return [ComplexF16, ComplexF32, ComplexF64]
144                .iter()
145                .filter(|s| s.size_of() >= self.size_of())
146                .copied()
147                .collect();
148        } else if self.is_complex_signed() {
149            return [ComplexI16, ComplexI32, ComplexI64]
150                .iter()
151                .filter(|s| s.size_of() >= self.size_of())
152                .copied()
153                .collect();
154        }
155        if self.is_float() {
156            [F16, F32, F64].iter().filter(|s| s.size_of() >= self.size_of()).copied().collect()
157        } else if self.is_signed() {
158            [I8, I16, I32, I64, TDim]
159                .iter()
160                .filter(|s| s.size_of() >= self.size_of())
161                .copied()
162                .collect()
163        } else {
164            [U8, U16, U32, U64].iter().filter(|s| s.size_of() >= self.size_of()).copied().collect()
165        }
166    }
167
168    pub fn super_type_for(
169        i: impl IntoIterator<Item = impl std::borrow::Borrow<DatumType>>,
170    ) -> Option<DatumType> {
171        let mut iter = i.into_iter();
172        let mut current = match iter.next() {
173            None => return None,
174            Some(it) => *it.borrow(),
175        };
176        for n in iter {
177            match current.common_super_type(*n.borrow()) {
178                None => return None,
179                Some(it) => current = it,
180            }
181        }
182        Some(current)
183    }
184
185    pub fn common_super_type(&self, rhs: DatumType) -> Option<DatumType> {
186        for mine in self.super_types() {
187            for theirs in rhs.super_types() {
188                if mine == theirs {
189                    return Some(mine);
190                }
191            }
192        }
193        None
194    }
195
196    pub fn is_unsigned(&self) -> bool {
197        matches!(
198            self.unquantized(),
199            DatumType::U8 | DatumType::U16 | DatumType::U32 | DatumType::U64
200        )
201    }
202
203    pub fn is_signed(&self) -> bool {
204        matches!(
205            self.unquantized(),
206            DatumType::I8 | DatumType::I16 | DatumType::I32 | DatumType::I64
207        )
208    }
209
210    pub fn is_float(&self) -> bool {
211        matches!(self, DatumType::F16 | DatumType::F32 | DatumType::F64)
212    }
213
214    pub fn is_number(&self) -> bool {
215        self.is_signed() | self.is_unsigned() | self.is_float() | self.is_quantized()
216    }
217
218    pub fn is_tdim(&self) -> bool {
219        *self == DatumType::TDim
220    }
221
222    pub fn is_opaque(&self) -> bool {
223        *self == DatumType::Opaque
224    }
225
226    #[cfg(feature = "complex")]
227    pub fn is_complex(&self) -> bool {
228        self.is_complex_float() || self.is_complex_signed()
229    }
230
231    #[cfg(feature = "complex")]
232    pub fn is_complex_float(&self) -> bool {
233        matches!(self, DatumType::ComplexF16 | DatumType::ComplexF32 | DatumType::ComplexF64)
234    }
235
236    #[cfg(feature = "complex")]
237    pub fn is_complex_signed(&self) -> bool {
238        matches!(self, DatumType::ComplexI16 | DatumType::ComplexI32 | DatumType::ComplexI64)
239    }
240
241    #[cfg(feature = "complex")]
242    pub fn complexify(&self) -> TractResult<DatumType> {
243        match *self {
244            DatumType::I16 => Ok(DatumType::ComplexI16),
245            DatumType::I32 => Ok(DatumType::ComplexI32),
246            DatumType::I64 => Ok(DatumType::ComplexI64),
247            DatumType::F16 => Ok(DatumType::ComplexF16),
248            DatumType::F32 => Ok(DatumType::ComplexF32),
249            DatumType::F64 => Ok(DatumType::ComplexF64),
250            _ => bail!("No complex datum type formed on {:?}", self),
251        }
252    }
253
254    #[cfg(feature = "complex")]
255    pub fn decomplexify(&self) -> TractResult<DatumType> {
256        match *self {
257            DatumType::ComplexI16 => Ok(DatumType::I16),
258            DatumType::ComplexI32 => Ok(DatumType::I32),
259            DatumType::ComplexI64 => Ok(DatumType::I64),
260            DatumType::ComplexF16 => Ok(DatumType::F16),
261            DatumType::ComplexF32 => Ok(DatumType::F32),
262            DatumType::ComplexF64 => Ok(DatumType::F64),
263            _ => bail!("{:?} is not a complex type", self),
264        }
265    }
266
267    pub fn is_copy(&self) -> bool {
268        #[cfg(feature = "complex")]
269        if self.is_complex() {
270            return true;
271        }
272        *self == DatumType::Bool || self.is_unsigned() || self.is_signed() || self.is_float()
273    }
274
275    pub fn is_quantized(&self) -> bool {
276        self.qparams().is_some()
277    }
278
279    pub fn qparams(&self) -> Option<QParams> {
280        match self {
281            DatumType::QI8(qparams) | DatumType::QU8(qparams) | DatumType::QI32(qparams) => {
282                Some(*qparams)
283            }
284            _ => None,
285        }
286    }
287
288    pub fn with_qparams(&self, qparams: QParams) -> DatumType {
289        match self {
290            DatumType::QI8(_) => DatumType::QI8(qparams),
291            DatumType::QU8(_) => DatumType::QI8(qparams),
292            DatumType::QI32(_) => DatumType::QI32(qparams),
293            _ => *self,
294        }
295    }
296
297    pub fn quantize(&self, qparams: QParams) -> DatumType {
298        match self {
299            DatumType::I8 => DatumType::QI8(qparams),
300            DatumType::U8 => DatumType::QU8(qparams),
301            DatumType::I32 => DatumType::QI32(qparams),
302            DatumType::QI8(_) => DatumType::QI8(qparams),
303            DatumType::QU8(_) => DatumType::QU8(qparams),
304            DatumType::QI32(_) => DatumType::QI32(qparams),
305            _ => panic!("Can't quantize {self:?}"),
306        }
307    }
308
309    #[inline(always)]
310    pub fn zp_scale(&self) -> (i32, f32) {
311        self.qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.))
312    }
313
314    #[inline(always)]
315    pub fn with_zp_scale(&self, zero_point: i32, scale: f32) -> DatumType {
316        self.quantize(QParams::ZpScale { zero_point, scale })
317    }
318
319    pub fn unquantized(&self) -> DatumType {
320        match self {
321            DatumType::QI8(_) => DatumType::I8,
322            DatumType::QU8(_) => DatumType::U8,
323            DatumType::QI32(_) => DatumType::I32,
324            _ => *self,
325        }
326    }
327
328    pub fn integer(signed: bool, size: usize) -> Self {
329        use DatumType::*;
330        match (signed, size) {
331            (false, 8) => U8,
332            (false, 16) => U16,
333            (false, 32) => U32,
334            (false, 64) => U64,
335            (true, 8) => U8,
336            (true, 16) => U16,
337            (true, 32) => U32,
338            (true, 64) => U64,
339            _ => panic!("No integer for signed:{signed} size:{size}"),
340        }
341    }
342
343    pub fn is_integer(&self) -> bool {
344        self.is_signed() || self.is_unsigned()
345    }
346
347    #[inline]
348    pub fn size_of(&self) -> usize {
349        dispatch_datum!(std::mem::size_of(self)())
350    }
351
352    pub fn min_value(&self) -> Tensor {
353        match self {
354            DatumType::QU8(_)
355            | DatumType::U8
356            | DatumType::U16
357            | DatumType::U32
358            | DatumType::U64 => Tensor::zero_dt(*self, &[1]).unwrap(),
359            DatumType::I8 | DatumType::QI8(_) => tensor0(i8::MIN),
360            DatumType::QI32(_) => tensor0(i32::MIN),
361            DatumType::I16 => tensor0(i16::MIN),
362            DatumType::I32 => tensor0(i32::MIN),
363            DatumType::I64 => tensor0(i64::MIN),
364            DatumType::F16 => tensor0(f16::MIN),
365            DatumType::F32 => tensor0(f32::MIN),
366            DatumType::F64 => tensor0(f64::MIN),
367            _ => panic!("No min value for datum type {self:?}"),
368        }
369    }
370    pub fn max_value(&self) -> Tensor {
371        match self {
372            DatumType::U8 | DatumType::QU8(_) => tensor0(u8::MAX),
373            DatumType::U16 => tensor0(u16::MAX),
374            DatumType::U32 => tensor0(u32::MAX),
375            DatumType::U64 => tensor0(u64::MAX),
376            DatumType::I8 | DatumType::QI8(_) => tensor0(i8::MAX),
377            DatumType::I16 => tensor0(i16::MAX),
378            DatumType::I32 => tensor0(i32::MAX),
379            DatumType::I64 => tensor0(i64::MAX),
380            DatumType::QI32(_) => tensor0(i32::MAX),
381            DatumType::F16 => tensor0(f16::MAX),
382            DatumType::F32 => tensor0(f32::MAX),
383            DatumType::F64 => tensor0(f64::MAX),
384            _ => panic!("No max value for datum type {self:?}"),
385        }
386    }
387
388    pub fn is<D: Datum>(&self) -> bool {
389        *self == D::datum_type()
390    }
391}
392
393impl std::str::FromStr for DatumType {
394    type Err = TractError;
395
396    fn from_str(s: &str) -> Result<Self, Self::Err> {
397        if let Ok((z, s)) = scan_fmt!(s, "QU8(Z:{d} S:{f})", i32, f32) {
398            Ok(DatumType::QU8(QParams::ZpScale { zero_point: z, scale: s }))
399        } else if let Ok((z, s)) = scan_fmt!(s, "QI8(Z:{d} S:{f})", i32, f32) {
400            Ok(DatumType::QI8(QParams::ZpScale { zero_point: z, scale: s }))
401        } else if let Ok((z, s)) = scan_fmt!(s, "QI32(Z:{d} S:{f})", i32, f32) {
402            Ok(DatumType::QI32(QParams::ZpScale { zero_point: z, scale: s }))
403        } else {
404            match s {
405                "I8" | "i8" => Ok(DatumType::I8),
406                "I16" | "i16" => Ok(DatumType::I16),
407                "I32" | "i32" => Ok(DatumType::I32),
408                "I64" | "i64" => Ok(DatumType::I64),
409                "U8" | "u8" => Ok(DatumType::U8),
410                "U16" | "u16" => Ok(DatumType::U16),
411                "U32" | "u32" => Ok(DatumType::U32),
412                "U64" | "u64" => Ok(DatumType::U64),
413                "F16" | "f16" => Ok(DatumType::F16),
414                "F32" | "f32" => Ok(DatumType::F32),
415                "F64" | "f64" => Ok(DatumType::F64),
416                "Bool" | "bool" => Ok(DatumType::Bool),
417                "Blob" | "blob" => Ok(DatumType::Blob),
418                "String" | "string" => Ok(DatumType::String),
419                "TDim" | "tdim" => Ok(DatumType::TDim),
420                #[cfg(feature = "complex")]
421                "ComplexI16" | "complexi16" => Ok(DatumType::ComplexI16),
422                #[cfg(feature = "complex")]
423                "ComplexI32" | "complexi32" => Ok(DatumType::ComplexI32),
424                #[cfg(feature = "complex")]
425                "ComplexI64" | "complexi64" => Ok(DatumType::ComplexI64),
426                #[cfg(feature = "complex")]
427                "ComplexF16" | "complexf16" => Ok(DatumType::ComplexF16),
428                #[cfg(feature = "complex")]
429                "ComplexF32" | "complexf32" => Ok(DatumType::ComplexF32),
430                #[cfg(feature = "complex")]
431                "ComplexF64" | "complexf64" => Ok(DatumType::ComplexF64),
432                _ => bail!("Unknown type {}", s),
433            }
434        }
435    }
436}
437
438const TOINT: f32 = 1.0f32 / f32::EPSILON;
439
440pub fn round_ties_to_even(x: f32) -> f32 {
441    let u = x.to_bits();
442    let e = (u >> 23) & 0xff;
443    if e >= 0x7f + 23 {
444        return x;
445    }
446    let s = u >> 31;
447    let y = if s == 1 { x - TOINT + TOINT } else { x + TOINT - TOINT };
448    if y == 0.0 {
449        if s == 1 {
450            -0f32
451        } else {
452            0f32
453        }
454    } else {
455        y
456    }
457}
458
459#[inline]
460pub fn scale_by<T: Datum + AsPrimitive<f32>>(b: T, a: f32) -> T
461where
462    f32: AsPrimitive<T>,
463{
464    let b = b.as_();
465    (round_ties_to_even(b.abs() * a) * b.signum()).as_()
466}
467
468pub trait ClampCast: PartialOrd + Copy + 'static {
469    #[inline(always)]
470    fn clamp_cast<O>(self) -> O
471    where
472        Self: AsPrimitive<O> + Datum,
473        O: AsPrimitive<Self> + num_traits::Bounded + Datum,
474    {
475        // this fails if we're upcasting, in which case clamping is useless
476        if O::min_value().as_() < O::max_value().as_() {
477            num_traits::clamp(self, O::min_value().as_(), O::max_value().as_()).as_()
478        } else {
479            self.as_()
480        }
481    }
482}
483impl<T: PartialOrd + Copy + 'static> ClampCast for T {}
484
485pub trait Datum:
486    Clone + Send + Sync + fmt::Debug + fmt::Display + Default + 'static + PartialEq
487{
488    fn name() -> &'static str;
489    fn datum_type() -> DatumType;
490    fn is<D: Datum>() -> bool;
491}
492
493macro_rules! datum {
494    ($t:ty, $v:ident) => {
495        impl From<$t> for Tensor {
496            fn from(it: $t) -> Tensor {
497                tensor0(it)
498            }
499        }
500
501        impl Datum for $t {
502            fn name() -> &'static str {
503                stringify!($t)
504            }
505
506            fn datum_type() -> DatumType {
507                DatumType::$v
508            }
509
510            fn is<D: Datum>() -> bool {
511                Self::datum_type() == D::datum_type()
512            }
513        }
514    };
515}
516
517datum!(bool, Bool);
518datum!(f16, F16);
519datum!(f32, F32);
520datum!(f64, F64);
521datum!(i8, I8);
522datum!(i16, I16);
523datum!(i32, I32);
524datum!(i64, I64);
525datum!(u8, U8);
526datum!(u16, U16);
527datum!(u32, U32);
528datum!(u64, U64);
529datum!(TDim, TDim);
530datum!(String, String);
531datum!(crate::blob::Blob, Blob);
532datum!(crate::opaque::Opaque, Opaque);
533#[cfg(feature = "complex")]
534datum!(Complex<i16>, ComplexI16);
535#[cfg(feature = "complex")]
536datum!(Complex<i32>, ComplexI32);
537#[cfg(feature = "complex")]
538datum!(Complex<i64>, ComplexI64);
539#[cfg(feature = "complex")]
540datum!(Complex<f16>, ComplexF16);
541#[cfg(feature = "complex")]
542datum!(Complex<f32>, ComplexF32);
543#[cfg(feature = "complex")]
544datum!(Complex<f64>, ComplexF64);
545
546#[cfg(test)]
547mod tests {
548    use crate::internal::*;
549    use ndarray::arr1;
550
551    #[test]
552    fn test_array_to_tensor_to_array() {
553        let array = arr1(&[12i32, 42]);
554        let tensor = Tensor::from(array.clone());
555        let view = tensor.to_array_view::<i32>().unwrap();
556        assert_eq!(array, view.into_dimensionality().unwrap());
557    }
558
559    #[test]
560    fn test_cast_dim_to_dim() {
561        let t_dim: Tensor = tensor1(&[12isize.to_dim(), 42isize.to_dim()]);
562        let t_i32 = t_dim.cast_to::<i32>().unwrap();
563        let t_dim_2 = t_i32.cast_to::<TDim>().unwrap().into_owned();
564        assert_eq!(t_dim, t_dim_2);
565    }
566
567    #[test]
568    fn test_cast_i32_to_dim() {
569        let t_i32: Tensor = tensor1(&[0i32, 12]);
570        t_i32.cast_to::<TDim>().unwrap();
571    }
572
573    #[test]
574    fn test_cast_i64_to_bool() {
575        let t_i64: Tensor = tensor1(&[0i64]);
576        t_i64.cast_to::<bool>().unwrap();
577    }
578
579    #[test]
580    fn test_parse_qu8() {
581        assert_eq!(
582            "QU8(Z:128 S:0.01)".parse::<DatumType>().unwrap(),
583            DatumType::QU8(QParams::ZpScale { zero_point: 128, scale: 0.01 })
584        );
585    }
586}