Skip to main content

tract_data/
datum.rs

1//! `Tensor` is the main data container for tract
2use crate::TVec;
3use crate::dim::TDim;
4use crate::internal::*;
5use crate::tensor::Tensor;
6use half::f16;
7#[cfg(feature = "complex")]
8use num_complex::Complex;
9use scan_fmt::scan_fmt;
10use std::fmt;
11use std::hash::Hash;
12
13use num_traits::AsPrimitive;
14
15#[derive(Copy, Clone, PartialEq)]
16pub enum QParams {
17    MinMax { min: f32, max: f32 },
18    ZpScale { zero_point: i32, scale: f32 },
19}
20
21impl Eq for QParams {}
22
23impl Ord for QParams {
24    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
25        use QParams::*;
26        match (self, other) {
27            (MinMax { .. }, ZpScale { .. }) => std::cmp::Ordering::Less,
28            (ZpScale { .. }, MinMax { .. }) => std::cmp::Ordering::Greater,
29            (MinMax { min: min1, max: max1 }, MinMax { min: min2, max: max2 }) => {
30                min1.total_cmp(min2).then_with(|| max1.total_cmp(max2))
31            }
32            (
33                Self::ZpScale { zero_point: zp1, scale: s1 },
34                Self::ZpScale { zero_point: zp2, scale: s2 },
35            ) => zp1.cmp(zp2).then_with(|| s1.total_cmp(s2)),
36        }
37    }
38}
39
40impl PartialOrd for QParams {
41    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
42        Some(self.cmp(other))
43    }
44}
45
46impl Default for QParams {
47    fn default() -> Self {
48        QParams::ZpScale { zero_point: 0, scale: 1. }
49    }
50}
51
52#[allow(clippy::derived_hash_with_manual_eq)]
53impl Hash for QParams {
54    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
55        match self {
56            QParams::MinMax { min, max } => {
57                0.hash(state);
58                min.to_bits().hash(state);
59                max.to_bits().hash(state);
60            }
61            QParams::ZpScale { zero_point, scale } => {
62                1.hash(state);
63                zero_point.hash(state);
64                scale.to_bits().hash(state);
65            }
66        }
67    }
68}
69
70impl QParams {
71    pub fn zp_scale(&self) -> (i32, f32) {
72        match self {
73            QParams::MinMax { min, max } => {
74                let scale = (max - min) / 255.;
75                ((-(min + max) / 2. / scale) as i32, scale)
76            }
77            QParams::ZpScale { zero_point, scale } => (*zero_point, *scale),
78        }
79    }
80
81    pub fn q(&self, f: f32) -> i32 {
82        let (zp, scale) = self.zp_scale();
83        (f / scale) as i32 + zp
84    }
85
86    pub fn dq(&self, i: i32) -> f32 {
87        let (zp, scale) = self.zp_scale();
88        (i - zp) as f32 * scale
89    }
90}
91
92impl std::fmt::Debug for QParams {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        let (zp, scale) = self.zp_scale();
95        write!(f, "Z:{zp} S:{scale}")
96    }
97}
98
99#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
100pub enum DatumType {
101    Bool,
102    U8,
103    U16,
104    U32,
105    U64,
106    I8,
107    I16,
108    I32,
109    I64,
110    F16,
111    F32,
112    F64,
113    TDim,
114    Blob,
115    String,
116    QI8(QParams),
117    QU8(QParams),
118    QI32(QParams),
119    #[cfg(feature = "complex")]
120    ComplexI16,
121    #[cfg(feature = "complex")]
122    ComplexI32,
123    #[cfg(feature = "complex")]
124    ComplexI64,
125    #[cfg(feature = "complex")]
126    ComplexF16,
127    #[cfg(feature = "complex")]
128    ComplexF32,
129    #[cfg(feature = "complex")]
130    ComplexF64,
131}
132
133impl DatumType {
134    pub fn super_types(&self) -> TVec<DatumType> {
135        use DatumType::*;
136        if *self == String || *self == TDim || *self == Blob || *self == Bool || self.is_quantized()
137        {
138            return tvec!(*self);
139        }
140        #[cfg(feature = "complex")]
141        if self.is_complex_float() {
142            return [ComplexF16, ComplexF32, ComplexF64]
143                .iter()
144                .filter(|s| s.size_of() >= self.size_of())
145                .copied()
146                .collect();
147        } else if self.is_complex_signed() {
148            return [ComplexI16, ComplexI32, ComplexI64]
149                .iter()
150                .filter(|s| s.size_of() >= self.size_of())
151                .copied()
152                .collect();
153        }
154        if self.is_float() {
155            [F16, F32, F64].iter().filter(|s| s.size_of() >= self.size_of()).copied().collect()
156        } else if self.is_signed() {
157            [I8, I16, I32, I64, TDim]
158                .iter()
159                .filter(|s| s.size_of() >= self.size_of())
160                .copied()
161                .collect()
162        } else {
163            [U8, U16, U32, U64].iter().filter(|s| s.size_of() >= self.size_of()).copied().collect()
164        }
165    }
166
167    pub fn super_type_for(
168        i: impl IntoIterator<Item = impl std::borrow::Borrow<DatumType>>,
169    ) -> Option<DatumType> {
170        let mut iter = i.into_iter();
171        let mut current = match iter.next() {
172            None => return None,
173            Some(it) => *it.borrow(),
174        };
175        for n in iter {
176            match current.common_super_type(*n.borrow()) {
177                None => return None,
178                Some(it) => current = it,
179            }
180        }
181        Some(current)
182    }
183
184    pub fn common_super_type(&self, rhs: DatumType) -> Option<DatumType> {
185        for mine in self.super_types() {
186            for theirs in rhs.super_types() {
187                if mine == theirs {
188                    return Some(mine);
189                }
190            }
191        }
192        None
193    }
194
195    pub fn is_unsigned(&self) -> bool {
196        matches!(
197            self.unquantized(),
198            DatumType::U8 | DatumType::U16 | DatumType::U32 | DatumType::U64
199        )
200    }
201
202    pub fn is_signed(&self) -> bool {
203        matches!(
204            self.unquantized(),
205            DatumType::I8 | DatumType::I16 | DatumType::I32 | DatumType::I64
206        )
207    }
208
209    pub fn is_float(&self) -> bool {
210        matches!(self, DatumType::F16 | DatumType::F32 | DatumType::F64)
211    }
212
213    pub fn is_number(&self) -> bool {
214        self.is_signed() | self.is_unsigned() | self.is_float() | self.is_quantized()
215    }
216
217    pub fn is_tdim(&self) -> bool {
218        *self == DatumType::TDim
219    }
220
221    #[cfg(feature = "complex")]
222    pub fn is_complex(&self) -> bool {
223        self.is_complex_float() || self.is_complex_signed()
224    }
225
226    #[cfg(feature = "complex")]
227    pub fn is_complex_float(&self) -> bool {
228        matches!(self, DatumType::ComplexF16 | DatumType::ComplexF32 | DatumType::ComplexF64)
229    }
230
231    #[cfg(feature = "complex")]
232    pub fn is_complex_signed(&self) -> bool {
233        matches!(self, DatumType::ComplexI16 | DatumType::ComplexI32 | DatumType::ComplexI64)
234    }
235
236    #[cfg(feature = "complex")]
237    pub fn complexify(&self) -> TractResult<DatumType> {
238        match *self {
239            DatumType::I16 => Ok(DatumType::ComplexI16),
240            DatumType::I32 => Ok(DatumType::ComplexI32),
241            DatumType::I64 => Ok(DatumType::ComplexI64),
242            DatumType::F16 => Ok(DatumType::ComplexF16),
243            DatumType::F32 => Ok(DatumType::ComplexF32),
244            DatumType::F64 => Ok(DatumType::ComplexF64),
245            _ => bail!("No complex datum type formed on {:?}", self),
246        }
247    }
248
249    #[cfg(feature = "complex")]
250    pub fn decomplexify(&self) -> TractResult<DatumType> {
251        match *self {
252            DatumType::ComplexI16 => Ok(DatumType::I16),
253            DatumType::ComplexI32 => Ok(DatumType::I32),
254            DatumType::ComplexI64 => Ok(DatumType::I64),
255            DatumType::ComplexF16 => Ok(DatumType::F16),
256            DatumType::ComplexF32 => Ok(DatumType::F32),
257            DatumType::ComplexF64 => Ok(DatumType::F64),
258            _ => bail!("{:?} is not a complex type", self),
259        }
260    }
261
262    pub fn is_copy(&self) -> bool {
263        #[cfg(feature = "complex")]
264        if self.is_complex() {
265            return true;
266        }
267        *self == DatumType::Bool || self.is_unsigned() || self.is_signed() || self.is_float()
268    }
269
270    pub fn is_quantized(&self) -> bool {
271        self.qparams().is_some()
272    }
273
274    pub fn qparams(&self) -> Option<QParams> {
275        match self {
276            DatumType::QI8(qparams) | DatumType::QU8(qparams) | DatumType::QI32(qparams) => {
277                Some(*qparams)
278            }
279            _ => None,
280        }
281    }
282
283    pub fn with_qparams(&self, qparams: QParams) -> DatumType {
284        match self {
285            DatumType::QI8(_) => DatumType::QI8(qparams),
286            DatumType::QU8(_) => DatumType::QI8(qparams),
287            DatumType::QI32(_) => DatumType::QI32(qparams),
288            _ => *self,
289        }
290    }
291
292    pub fn quantize(&self, qparams: QParams) -> DatumType {
293        match self {
294            DatumType::I8 => DatumType::QI8(qparams),
295            DatumType::U8 => DatumType::QU8(qparams),
296            DatumType::I32 => DatumType::QI32(qparams),
297            DatumType::QI8(_) => DatumType::QI8(qparams),
298            DatumType::QU8(_) => DatumType::QU8(qparams),
299            DatumType::QI32(_) => DatumType::QI32(qparams),
300            _ => panic!("Can't quantize {self:?}"),
301        }
302    }
303
304    #[inline(always)]
305    pub fn zp_scale(&self) -> (i32, f32) {
306        self.qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.))
307    }
308
309    #[inline(always)]
310    pub fn with_zp_scale(&self, zero_point: i32, scale: f32) -> DatumType {
311        self.quantize(QParams::ZpScale { zero_point, scale })
312    }
313
314    pub fn unquantized(&self) -> DatumType {
315        match self {
316            DatumType::QI8(_) => DatumType::I8,
317            DatumType::QU8(_) => DatumType::U8,
318            DatumType::QI32(_) => DatumType::I32,
319            _ => *self,
320        }
321    }
322
323    pub fn integer(signed: bool, size: usize) -> Self {
324        use DatumType::*;
325        match (signed, size) {
326            (false, 8) => U8,
327            (false, 16) => U16,
328            (false, 32) => U32,
329            (false, 64) => U64,
330            (true, 8) => U8,
331            (true, 16) => U16,
332            (true, 32) => U32,
333            (true, 64) => U64,
334            _ => panic!("No integer for signed:{signed} size:{size}"),
335        }
336    }
337
338    pub fn is_integer(&self) -> bool {
339        self.is_signed() || self.is_unsigned()
340    }
341
342    #[inline]
343    pub fn size_of(&self) -> usize {
344        dispatch_datum!(std::mem::size_of(self)())
345    }
346
347    pub fn min_value(&self) -> Tensor {
348        match self {
349            DatumType::QU8(_)
350            | DatumType::U8
351            | DatumType::U16
352            | DatumType::U32
353            | DatumType::U64 => Tensor::zero_dt(*self, &[1]).unwrap(),
354            DatumType::I8 | DatumType::QI8(_) => tensor0(i8::MIN),
355            DatumType::QI32(_) => tensor0(i32::MIN),
356            DatumType::I16 => tensor0(i16::MIN),
357            DatumType::I32 => tensor0(i32::MIN),
358            DatumType::I64 => tensor0(i64::MIN),
359            DatumType::F16 => tensor0(f16::MIN),
360            DatumType::F32 => tensor0(f32::MIN),
361            DatumType::F64 => tensor0(f64::MIN),
362            _ => panic!("No min value for datum type {self:?}"),
363        }
364    }
365    pub fn max_value(&self) -> Tensor {
366        match self {
367            DatumType::U8 | DatumType::QU8(_) => tensor0(u8::MAX),
368            DatumType::U16 => tensor0(u16::MAX),
369            DatumType::U32 => tensor0(u32::MAX),
370            DatumType::U64 => tensor0(u64::MAX),
371            DatumType::I8 | DatumType::QI8(_) => tensor0(i8::MAX),
372            DatumType::I16 => tensor0(i16::MAX),
373            DatumType::I32 => tensor0(i32::MAX),
374            DatumType::I64 => tensor0(i64::MAX),
375            DatumType::QI32(_) => tensor0(i32::MAX),
376            DatumType::F16 => tensor0(f16::MAX),
377            DatumType::F32 => tensor0(f32::MAX),
378            DatumType::F64 => tensor0(f64::MAX),
379            _ => panic!("No max value for datum type {self:?}"),
380        }
381    }
382
383    pub fn is<D: Datum>(&self) -> bool {
384        *self == D::datum_type()
385    }
386}
387
388impl std::str::FromStr for DatumType {
389    type Err = TractError;
390
391    fn from_str(s: &str) -> Result<Self, Self::Err> {
392        if let Ok((z, s)) = scan_fmt!(s, "QU8(Z:{d} S:{f})", i32, f32) {
393            Ok(DatumType::QU8(QParams::ZpScale { zero_point: z, scale: s }))
394        } else if let Ok((z, s)) = scan_fmt!(s, "QI8(Z:{d} S:{f})", i32, f32) {
395            Ok(DatumType::QI8(QParams::ZpScale { zero_point: z, scale: s }))
396        } else if let Ok((z, s)) = scan_fmt!(s, "QI32(Z:{d} S:{f})", i32, f32) {
397            Ok(DatumType::QI32(QParams::ZpScale { zero_point: z, scale: s }))
398        } else {
399            match s {
400                "I8" | "i8" => Ok(DatumType::I8),
401                "I16" | "i16" => Ok(DatumType::I16),
402                "I32" | "i32" => Ok(DatumType::I32),
403                "I64" | "i64" => Ok(DatumType::I64),
404                "U8" | "u8" => Ok(DatumType::U8),
405                "U16" | "u16" => Ok(DatumType::U16),
406                "U32" | "u32" => Ok(DatumType::U32),
407                "U64" | "u64" => Ok(DatumType::U64),
408                "F16" | "f16" => Ok(DatumType::F16),
409                "F32" | "f32" => Ok(DatumType::F32),
410                "F64" | "f64" => Ok(DatumType::F64),
411                "Bool" | "bool" => Ok(DatumType::Bool),
412                "Blob" | "blob" => Ok(DatumType::Blob),
413                "String" | "string" => Ok(DatumType::String),
414                "TDim" | "tdim" => Ok(DatumType::TDim),
415                #[cfg(feature = "complex")]
416                "ComplexI16" | "complexi16" => Ok(DatumType::ComplexI16),
417                #[cfg(feature = "complex")]
418                "ComplexI32" | "complexi32" => Ok(DatumType::ComplexI32),
419                #[cfg(feature = "complex")]
420                "ComplexI64" | "complexi64" => Ok(DatumType::ComplexI64),
421                #[cfg(feature = "complex")]
422                "ComplexF16" | "complexf16" => Ok(DatumType::ComplexF16),
423                #[cfg(feature = "complex")]
424                "ComplexF32" | "complexf32" => Ok(DatumType::ComplexF32),
425                #[cfg(feature = "complex")]
426                "ComplexF64" | "complexf64" => Ok(DatumType::ComplexF64),
427                _ => bail!("Unknown type {}", s),
428            }
429        }
430    }
431}
432
433const TOINT: f32 = 1.0f32 / f32::EPSILON;
434
435pub fn round_ties_to_even(x: f32) -> f32 {
436    let u = x.to_bits();
437    let e = (u >> 23) & 0xff;
438    if e >= 0x7f + 23 {
439        return x;
440    }
441    let s = u >> 31;
442    let y = if s == 1 { x - TOINT + TOINT } else { x + TOINT - TOINT };
443    if y == 0.0 { if s == 1 { -0f32 } else { 0f32 } } else { y }
444}
445
446#[inline]
447pub fn scale_by<T: Datum + AsPrimitive<f32>>(b: T, a: f32) -> T
448where
449    f32: AsPrimitive<T>,
450{
451    let b = b.as_();
452    (round_ties_to_even(b.abs() * a) * b.signum()).as_()
453}
454
455pub trait ClampCast: PartialOrd + Copy + 'static {
456    #[inline(always)]
457    fn clamp_cast<O>(self) -> O
458    where
459        Self: AsPrimitive<O> + Datum,
460        O: AsPrimitive<Self> + num_traits::Bounded + Datum,
461    {
462        // this fails if we're upcasting, in which case clamping is useless
463        if O::min_value().as_() < O::max_value().as_() {
464            num_traits::clamp(self, O::min_value().as_(), O::max_value().as_()).as_()
465        } else {
466            self.as_()
467        }
468    }
469}
470impl<T: PartialOrd + Copy + 'static> ClampCast for T {}
471
472pub trait Datum:
473    Clone + Send + Sync + fmt::Debug + fmt::Display + Default + 'static + PartialEq
474{
475    fn name() -> &'static str;
476    fn datum_type() -> DatumType;
477    fn is<D: Datum>() -> bool;
478}
479
480macro_rules! datum {
481    ($t:ty, $v:ident) => {
482        impl From<$t> for Tensor {
483            fn from(it: $t) -> Tensor {
484                tensor0(it)
485            }
486        }
487
488        impl Datum for $t {
489            fn name() -> &'static str {
490                stringify!($t)
491            }
492
493            fn datum_type() -> DatumType {
494                DatumType::$v
495            }
496
497            fn is<D: Datum>() -> bool {
498                Self::datum_type() == D::datum_type()
499            }
500        }
501    };
502}
503
504datum!(bool, Bool);
505datum!(f16, F16);
506datum!(f32, F32);
507datum!(f64, F64);
508datum!(i8, I8);
509datum!(i16, I16);
510datum!(i32, I32);
511datum!(i64, I64);
512datum!(u8, U8);
513datum!(u16, U16);
514datum!(u32, U32);
515datum!(u64, U64);
516datum!(TDim, TDim);
517datum!(String, String);
518datum!(crate::blob::Blob, Blob);
519#[cfg(feature = "complex")]
520datum!(Complex<i16>, ComplexI16);
521#[cfg(feature = "complex")]
522datum!(Complex<i32>, ComplexI32);
523#[cfg(feature = "complex")]
524datum!(Complex<i64>, ComplexI64);
525#[cfg(feature = "complex")]
526datum!(Complex<f16>, ComplexF16);
527#[cfg(feature = "complex")]
528datum!(Complex<f32>, ComplexF32);
529#[cfg(feature = "complex")]
530datum!(Complex<f64>, ComplexF64);
531
532#[cfg(test)]
533mod tests {
534    use crate::internal::*;
535    use ndarray::arr1;
536
537    #[test]
538    fn test_array_to_tensor_to_array() {
539        let array = arr1(&[12i32, 42]);
540        let tensor = Tensor::from(array.clone());
541        let view = tensor.to_plain_array_view::<i32>().unwrap();
542        assert_eq!(array, view.into_dimensionality().unwrap());
543    }
544
545    #[test]
546    fn test_cast_dim_to_dim() {
547        let t_dim: Tensor = tensor1(&[12isize.to_dim(), 42isize.to_dim()]);
548        let t_i32 = t_dim.cast_to::<i32>().unwrap();
549        let t_dim_2 = t_i32.cast_to::<TDim>().unwrap().into_owned();
550        assert_eq!(t_dim, t_dim_2);
551    }
552
553    #[test]
554    fn test_cast_i32_to_dim() {
555        let t_i32: Tensor = tensor1(&[0i32, 12]);
556        t_i32.cast_to::<TDim>().unwrap();
557    }
558
559    #[test]
560    fn test_cast_i64_to_bool() {
561        let t_i64: Tensor = tensor1(&[0i64]);
562        t_i64.cast_to::<bool>().unwrap();
563    }
564
565    #[test]
566    fn test_parse_qu8() {
567        assert_eq!(
568            "QU8(Z:128 S:0.01)".parse::<DatumType>().unwrap(),
569            DatumType::QU8(QParams::ZpScale { zero_point: 128, scale: 0.01 })
570        );
571    }
572}