1use crate::TVec;
3use crate::dim::TDim;
4use crate::internal::*;
5use crate::tensor::Tensor;
6use half::f16;
7#[cfg(feature = "complex")]
8use num_complex::Complex;
9use scan_fmt::scan_fmt;
10use std::fmt;
11use std::hash::Hash;
12
13use num_traits::AsPrimitive;
14
15#[derive(Copy, Clone, PartialEq)]
16pub enum QParams {
17 MinMax { min: f32, max: f32 },
18 ZpScale { zero_point: i32, scale: f32 },
19}
20
21impl Eq for QParams {}
22
23impl Ord for QParams {
24 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
25 use QParams::*;
26 match (self, other) {
27 (MinMax { .. }, ZpScale { .. }) => std::cmp::Ordering::Less,
28 (ZpScale { .. }, MinMax { .. }) => std::cmp::Ordering::Greater,
29 (MinMax { min: min1, max: max1 }, MinMax { min: min2, max: max2 }) => {
30 min1.total_cmp(min2).then_with(|| max1.total_cmp(max2))
31 }
32 (
33 Self::ZpScale { zero_point: zp1, scale: s1 },
34 Self::ZpScale { zero_point: zp2, scale: s2 },
35 ) => zp1.cmp(zp2).then_with(|| s1.total_cmp(s2)),
36 }
37 }
38}
39
40impl PartialOrd for QParams {
41 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
42 Some(self.cmp(other))
43 }
44}
45
46impl Default for QParams {
47 fn default() -> Self {
48 QParams::ZpScale { zero_point: 0, scale: 1. }
49 }
50}
51
52#[allow(clippy::derived_hash_with_manual_eq)]
53impl Hash for QParams {
54 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
55 match self {
56 QParams::MinMax { min, max } => {
57 0.hash(state);
58 min.to_bits().hash(state);
59 max.to_bits().hash(state);
60 }
61 QParams::ZpScale { zero_point, scale } => {
62 1.hash(state);
63 zero_point.hash(state);
64 scale.to_bits().hash(state);
65 }
66 }
67 }
68}
69
70impl QParams {
71 pub fn zp_scale(&self) -> (i32, f32) {
72 match self {
73 QParams::MinMax { min, max } => {
74 let scale = (max - min) / 255.;
75 ((-(min + max) / 2. / scale) as i32, scale)
76 }
77 QParams::ZpScale { zero_point, scale } => (*zero_point, *scale),
78 }
79 }
80
81 pub fn q(&self, f: f32) -> i32 {
82 let (zp, scale) = self.zp_scale();
83 (f / scale) as i32 + zp
84 }
85
86 pub fn dq(&self, i: i32) -> f32 {
87 let (zp, scale) = self.zp_scale();
88 (i - zp) as f32 * scale
89 }
90}
91
92impl std::fmt::Debug for QParams {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 let (zp, scale) = self.zp_scale();
95 write!(f, "Z:{zp} S:{scale}")
96 }
97}
98
99#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
100pub enum DatumType {
101 Bool,
102 U8,
103 U16,
104 U32,
105 U64,
106 I8,
107 I16,
108 I32,
109 I64,
110 F16,
111 F32,
112 F64,
113 TDim,
114 Blob,
115 String,
116 QI8(QParams),
117 QU8(QParams),
118 QI32(QParams),
119 #[cfg(feature = "complex")]
120 ComplexI16,
121 #[cfg(feature = "complex")]
122 ComplexI32,
123 #[cfg(feature = "complex")]
124 ComplexI64,
125 #[cfg(feature = "complex")]
126 ComplexF16,
127 #[cfg(feature = "complex")]
128 ComplexF32,
129 #[cfg(feature = "complex")]
130 ComplexF64,
131}
132
133impl DatumType {
134 pub fn super_types(&self) -> TVec<DatumType> {
135 use DatumType::*;
136 if *self == String || *self == TDim || *self == Blob || *self == Bool || self.is_quantized()
137 {
138 return tvec!(*self);
139 }
140 #[cfg(feature = "complex")]
141 if self.is_complex_float() {
142 return [ComplexF16, ComplexF32, ComplexF64]
143 .iter()
144 .filter(|s| s.size_of() >= self.size_of())
145 .copied()
146 .collect();
147 } else if self.is_complex_signed() {
148 return [ComplexI16, ComplexI32, ComplexI64]
149 .iter()
150 .filter(|s| s.size_of() >= self.size_of())
151 .copied()
152 .collect();
153 }
154 if self.is_float() {
155 [F16, F32, F64].iter().filter(|s| s.size_of() >= self.size_of()).copied().collect()
156 } else if self.is_signed() {
157 [I8, I16, I32, I64, TDim]
158 .iter()
159 .filter(|s| s.size_of() >= self.size_of())
160 .copied()
161 .collect()
162 } else {
163 [U8, U16, U32, U64].iter().filter(|s| s.size_of() >= self.size_of()).copied().collect()
164 }
165 }
166
167 pub fn super_type_for(
168 i: impl IntoIterator<Item = impl std::borrow::Borrow<DatumType>>,
169 ) -> Option<DatumType> {
170 let mut iter = i.into_iter();
171 let mut current = match iter.next() {
172 None => return None,
173 Some(it) => *it.borrow(),
174 };
175 for n in iter {
176 match current.common_super_type(*n.borrow()) {
177 None => return None,
178 Some(it) => current = it,
179 }
180 }
181 Some(current)
182 }
183
184 pub fn common_super_type(&self, rhs: DatumType) -> Option<DatumType> {
185 for mine in self.super_types() {
186 for theirs in rhs.super_types() {
187 if mine == theirs {
188 return Some(mine);
189 }
190 }
191 }
192 None
193 }
194
195 pub fn is_unsigned(&self) -> bool {
196 matches!(
197 self.unquantized(),
198 DatumType::U8 | DatumType::U16 | DatumType::U32 | DatumType::U64
199 )
200 }
201
202 pub fn is_signed(&self) -> bool {
203 matches!(
204 self.unquantized(),
205 DatumType::I8 | DatumType::I16 | DatumType::I32 | DatumType::I64
206 )
207 }
208
209 pub fn is_float(&self) -> bool {
210 matches!(self, DatumType::F16 | DatumType::F32 | DatumType::F64)
211 }
212
213 pub fn is_number(&self) -> bool {
214 self.is_signed() | self.is_unsigned() | self.is_float() | self.is_quantized()
215 }
216
217 pub fn is_tdim(&self) -> bool {
218 *self == DatumType::TDim
219 }
220
221 #[cfg(feature = "complex")]
222 pub fn is_complex(&self) -> bool {
223 self.is_complex_float() || self.is_complex_signed()
224 }
225
226 #[cfg(feature = "complex")]
227 pub fn is_complex_float(&self) -> bool {
228 matches!(self, DatumType::ComplexF16 | DatumType::ComplexF32 | DatumType::ComplexF64)
229 }
230
231 #[cfg(feature = "complex")]
232 pub fn is_complex_signed(&self) -> bool {
233 matches!(self, DatumType::ComplexI16 | DatumType::ComplexI32 | DatumType::ComplexI64)
234 }
235
236 #[cfg(feature = "complex")]
237 pub fn complexify(&self) -> TractResult<DatumType> {
238 match *self {
239 DatumType::I16 => Ok(DatumType::ComplexI16),
240 DatumType::I32 => Ok(DatumType::ComplexI32),
241 DatumType::I64 => Ok(DatumType::ComplexI64),
242 DatumType::F16 => Ok(DatumType::ComplexF16),
243 DatumType::F32 => Ok(DatumType::ComplexF32),
244 DatumType::F64 => Ok(DatumType::ComplexF64),
245 _ => bail!("No complex datum type formed on {:?}", self),
246 }
247 }
248
249 #[cfg(feature = "complex")]
250 pub fn decomplexify(&self) -> TractResult<DatumType> {
251 match *self {
252 DatumType::ComplexI16 => Ok(DatumType::I16),
253 DatumType::ComplexI32 => Ok(DatumType::I32),
254 DatumType::ComplexI64 => Ok(DatumType::I64),
255 DatumType::ComplexF16 => Ok(DatumType::F16),
256 DatumType::ComplexF32 => Ok(DatumType::F32),
257 DatumType::ComplexF64 => Ok(DatumType::F64),
258 _ => bail!("{:?} is not a complex type", self),
259 }
260 }
261
262 pub fn is_copy(&self) -> bool {
263 #[cfg(feature = "complex")]
264 if self.is_complex() {
265 return true;
266 }
267 *self == DatumType::Bool || self.is_unsigned() || self.is_signed() || self.is_float()
268 }
269
270 pub fn is_quantized(&self) -> bool {
271 self.qparams().is_some()
272 }
273
274 pub fn qparams(&self) -> Option<QParams> {
275 match self {
276 DatumType::QI8(qparams) | DatumType::QU8(qparams) | DatumType::QI32(qparams) => {
277 Some(*qparams)
278 }
279 _ => None,
280 }
281 }
282
283 pub fn with_qparams(&self, qparams: QParams) -> DatumType {
284 match self {
285 DatumType::QI8(_) => DatumType::QI8(qparams),
286 DatumType::QU8(_) => DatumType::QI8(qparams),
287 DatumType::QI32(_) => DatumType::QI32(qparams),
288 _ => *self,
289 }
290 }
291
292 pub fn quantize(&self, qparams: QParams) -> DatumType {
293 match self {
294 DatumType::I8 => DatumType::QI8(qparams),
295 DatumType::U8 => DatumType::QU8(qparams),
296 DatumType::I32 => DatumType::QI32(qparams),
297 DatumType::QI8(_) => DatumType::QI8(qparams),
298 DatumType::QU8(_) => DatumType::QU8(qparams),
299 DatumType::QI32(_) => DatumType::QI32(qparams),
300 _ => panic!("Can't quantize {self:?}"),
301 }
302 }
303
304 #[inline(always)]
305 pub fn zp_scale(&self) -> (i32, f32) {
306 self.qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.))
307 }
308
309 #[inline(always)]
310 pub fn with_zp_scale(&self, zero_point: i32, scale: f32) -> DatumType {
311 self.quantize(QParams::ZpScale { zero_point, scale })
312 }
313
314 pub fn unquantized(&self) -> DatumType {
315 match self {
316 DatumType::QI8(_) => DatumType::I8,
317 DatumType::QU8(_) => DatumType::U8,
318 DatumType::QI32(_) => DatumType::I32,
319 _ => *self,
320 }
321 }
322
323 pub fn integer(signed: bool, size: usize) -> Self {
324 use DatumType::*;
325 match (signed, size) {
326 (false, 8) => U8,
327 (false, 16) => U16,
328 (false, 32) => U32,
329 (false, 64) => U64,
330 (true, 8) => U8,
331 (true, 16) => U16,
332 (true, 32) => U32,
333 (true, 64) => U64,
334 _ => panic!("No integer for signed:{signed} size:{size}"),
335 }
336 }
337
338 pub fn is_integer(&self) -> bool {
339 self.is_signed() || self.is_unsigned()
340 }
341
342 #[inline]
343 pub fn size_of(&self) -> usize {
344 dispatch_datum!(std::mem::size_of(self)())
345 }
346
347 pub fn min_value(&self) -> Tensor {
348 match self {
349 DatumType::QU8(_)
350 | DatumType::U8
351 | DatumType::U16
352 | DatumType::U32
353 | DatumType::U64 => Tensor::zero_dt(*self, &[1]).unwrap(),
354 DatumType::I8 | DatumType::QI8(_) => tensor0(i8::MIN),
355 DatumType::QI32(_) => tensor0(i32::MIN),
356 DatumType::I16 => tensor0(i16::MIN),
357 DatumType::I32 => tensor0(i32::MIN),
358 DatumType::I64 => tensor0(i64::MIN),
359 DatumType::F16 => tensor0(f16::MIN),
360 DatumType::F32 => tensor0(f32::MIN),
361 DatumType::F64 => tensor0(f64::MIN),
362 _ => panic!("No min value for datum type {self:?}"),
363 }
364 }
365 pub fn max_value(&self) -> Tensor {
366 match self {
367 DatumType::U8 | DatumType::QU8(_) => tensor0(u8::MAX),
368 DatumType::U16 => tensor0(u16::MAX),
369 DatumType::U32 => tensor0(u32::MAX),
370 DatumType::U64 => tensor0(u64::MAX),
371 DatumType::I8 | DatumType::QI8(_) => tensor0(i8::MAX),
372 DatumType::I16 => tensor0(i16::MAX),
373 DatumType::I32 => tensor0(i32::MAX),
374 DatumType::I64 => tensor0(i64::MAX),
375 DatumType::QI32(_) => tensor0(i32::MAX),
376 DatumType::F16 => tensor0(f16::MAX),
377 DatumType::F32 => tensor0(f32::MAX),
378 DatumType::F64 => tensor0(f64::MAX),
379 _ => panic!("No max value for datum type {self:?}"),
380 }
381 }
382
383 pub fn is<D: Datum>(&self) -> bool {
384 *self == D::datum_type()
385 }
386}
387
388impl std::str::FromStr for DatumType {
389 type Err = TractError;
390
391 fn from_str(s: &str) -> Result<Self, Self::Err> {
392 if let Ok((z, s)) = scan_fmt!(s, "QU8(Z:{d} S:{f})", i32, f32) {
393 Ok(DatumType::QU8(QParams::ZpScale { zero_point: z, scale: s }))
394 } else if let Ok((z, s)) = scan_fmt!(s, "QI8(Z:{d} S:{f})", i32, f32) {
395 Ok(DatumType::QI8(QParams::ZpScale { zero_point: z, scale: s }))
396 } else if let Ok((z, s)) = scan_fmt!(s, "QI32(Z:{d} S:{f})", i32, f32) {
397 Ok(DatumType::QI32(QParams::ZpScale { zero_point: z, scale: s }))
398 } else {
399 match s {
400 "I8" | "i8" => Ok(DatumType::I8),
401 "I16" | "i16" => Ok(DatumType::I16),
402 "I32" | "i32" => Ok(DatumType::I32),
403 "I64" | "i64" => Ok(DatumType::I64),
404 "U8" | "u8" => Ok(DatumType::U8),
405 "U16" | "u16" => Ok(DatumType::U16),
406 "U32" | "u32" => Ok(DatumType::U32),
407 "U64" | "u64" => Ok(DatumType::U64),
408 "F16" | "f16" => Ok(DatumType::F16),
409 "F32" | "f32" => Ok(DatumType::F32),
410 "F64" | "f64" => Ok(DatumType::F64),
411 "Bool" | "bool" => Ok(DatumType::Bool),
412 "Blob" | "blob" => Ok(DatumType::Blob),
413 "String" | "string" => Ok(DatumType::String),
414 "TDim" | "tdim" => Ok(DatumType::TDim),
415 #[cfg(feature = "complex")]
416 "ComplexI16" | "complexi16" => Ok(DatumType::ComplexI16),
417 #[cfg(feature = "complex")]
418 "ComplexI32" | "complexi32" => Ok(DatumType::ComplexI32),
419 #[cfg(feature = "complex")]
420 "ComplexI64" | "complexi64" => Ok(DatumType::ComplexI64),
421 #[cfg(feature = "complex")]
422 "ComplexF16" | "complexf16" => Ok(DatumType::ComplexF16),
423 #[cfg(feature = "complex")]
424 "ComplexF32" | "complexf32" => Ok(DatumType::ComplexF32),
425 #[cfg(feature = "complex")]
426 "ComplexF64" | "complexf64" => Ok(DatumType::ComplexF64),
427 _ => bail!("Unknown type {}", s),
428 }
429 }
430 }
431}
432
433const TOINT: f32 = 1.0f32 / f32::EPSILON;
434
435pub fn round_ties_to_even(x: f32) -> f32 {
436 let u = x.to_bits();
437 let e = (u >> 23) & 0xff;
438 if e >= 0x7f + 23 {
439 return x;
440 }
441 let s = u >> 31;
442 let y = if s == 1 { x - TOINT + TOINT } else { x + TOINT - TOINT };
443 if y == 0.0 { if s == 1 { -0f32 } else { 0f32 } } else { y }
444}
445
446#[inline]
447pub fn scale_by<T: Datum + AsPrimitive<f32>>(b: T, a: f32) -> T
448where
449 f32: AsPrimitive<T>,
450{
451 let b = b.as_();
452 (round_ties_to_even(b.abs() * a) * b.signum()).as_()
453}
454
455pub trait ClampCast: PartialOrd + Copy + 'static {
456 #[inline(always)]
457 fn clamp_cast<O>(self) -> O
458 where
459 Self: AsPrimitive<O> + Datum,
460 O: AsPrimitive<Self> + num_traits::Bounded + Datum,
461 {
462 if O::min_value().as_() < O::max_value().as_() {
464 num_traits::clamp(self, O::min_value().as_(), O::max_value().as_()).as_()
465 } else {
466 self.as_()
467 }
468 }
469}
470impl<T: PartialOrd + Copy + 'static> ClampCast for T {}
471
472pub trait Datum:
473 Clone + Send + Sync + fmt::Debug + fmt::Display + Default + 'static + PartialEq
474{
475 fn name() -> &'static str;
476 fn datum_type() -> DatumType;
477 fn is<D: Datum>() -> bool;
478}
479
480macro_rules! datum {
481 ($t:ty, $v:ident) => {
482 impl From<$t> for Tensor {
483 fn from(it: $t) -> Tensor {
484 tensor0(it)
485 }
486 }
487
488 impl Datum for $t {
489 fn name() -> &'static str {
490 stringify!($t)
491 }
492
493 fn datum_type() -> DatumType {
494 DatumType::$v
495 }
496
497 fn is<D: Datum>() -> bool {
498 Self::datum_type() == D::datum_type()
499 }
500 }
501 };
502}
503
504datum!(bool, Bool);
505datum!(f16, F16);
506datum!(f32, F32);
507datum!(f64, F64);
508datum!(i8, I8);
509datum!(i16, I16);
510datum!(i32, I32);
511datum!(i64, I64);
512datum!(u8, U8);
513datum!(u16, U16);
514datum!(u32, U32);
515datum!(u64, U64);
516datum!(TDim, TDim);
517datum!(String, String);
518datum!(crate::blob::Blob, Blob);
519#[cfg(feature = "complex")]
520datum!(Complex<i16>, ComplexI16);
521#[cfg(feature = "complex")]
522datum!(Complex<i32>, ComplexI32);
523#[cfg(feature = "complex")]
524datum!(Complex<i64>, ComplexI64);
525#[cfg(feature = "complex")]
526datum!(Complex<f16>, ComplexF16);
527#[cfg(feature = "complex")]
528datum!(Complex<f32>, ComplexF32);
529#[cfg(feature = "complex")]
530datum!(Complex<f64>, ComplexF64);
531
532#[cfg(test)]
533mod tests {
534 use crate::internal::*;
535 use ndarray::arr1;
536
537 #[test]
538 fn test_array_to_tensor_to_array() {
539 let array = arr1(&[12i32, 42]);
540 let tensor = Tensor::from(array.clone());
541 let view = tensor.to_plain_array_view::<i32>().unwrap();
542 assert_eq!(array, view.into_dimensionality().unwrap());
543 }
544
545 #[test]
546 fn test_cast_dim_to_dim() {
547 let t_dim: Tensor = tensor1(&[12isize.to_dim(), 42isize.to_dim()]);
548 let t_i32 = t_dim.cast_to::<i32>().unwrap();
549 let t_dim_2 = t_i32.cast_to::<TDim>().unwrap().into_owned();
550 assert_eq!(t_dim, t_dim_2);
551 }
552
553 #[test]
554 fn test_cast_i32_to_dim() {
555 let t_i32: Tensor = tensor1(&[0i32, 12]);
556 t_i32.cast_to::<TDim>().unwrap();
557 }
558
559 #[test]
560 fn test_cast_i64_to_bool() {
561 let t_i64: Tensor = tensor1(&[0i64]);
562 t_i64.cast_to::<bool>().unwrap();
563 }
564
565 #[test]
566 fn test_parse_qu8() {
567 assert_eq!(
568 "QU8(Z:128 S:0.01)".parse::<DatumType>().unwrap(),
569 DatumType::QU8(QParams::ZpScale { zero_point: 128, scale: 0.01 })
570 );
571 }
572}