1use crate::dim::TDim;
3use crate::internal::*;
4use crate::tensor::Tensor;
5use crate::TVec;
6use half::f16;
7#[cfg(feature = "complex")]
8use num_complex::Complex;
9use scan_fmt::scan_fmt;
10use std::fmt;
11use std::hash::Hash;
12
13use num_traits::AsPrimitive;
14
15#[derive(Copy, Clone, PartialEq)]
16pub enum QParams {
17 MinMax { min: f32, max: f32 },
18 ZpScale { zero_point: i32, scale: f32 },
19}
20
21impl Eq for QParams {}
22
23impl Ord for QParams {
24 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
25 use QParams::*;
26 match (self, other) {
27 (MinMax { .. }, ZpScale { .. }) => std::cmp::Ordering::Less,
28 (ZpScale { .. }, MinMax { .. }) => std::cmp::Ordering::Greater,
29 (MinMax { min: min1, max: max1 }, MinMax { min: min2, max: max2 }) => {
30 min1.total_cmp(min2).then_with(|| max1.total_cmp(max2))
31 }
32 (
33 Self::ZpScale { zero_point: zp1, scale: s1 },
34 Self::ZpScale { zero_point: zp2, scale: s2 },
35 ) => zp1.cmp(zp2).then_with(|| s1.total_cmp(s2)),
36 }
37 }
38}
39
40impl PartialOrd for QParams {
41 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
42 Some(self.cmp(other))
43 }
44}
45
46impl Default for QParams {
47 fn default() -> Self {
48 QParams::ZpScale { zero_point: 0, scale: 1. }
49 }
50}
51
52#[allow(clippy::derived_hash_with_manual_eq)]
53impl Hash for QParams {
54 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
55 match self {
56 QParams::MinMax { min, max } => {
57 0.hash(state);
58 min.to_bits().hash(state);
59 max.to_bits().hash(state);
60 }
61 QParams::ZpScale { zero_point, scale } => {
62 1.hash(state);
63 zero_point.hash(state);
64 scale.to_bits().hash(state);
65 }
66 }
67 }
68}
69
70impl QParams {
71 pub fn zp_scale(&self) -> (i32, f32) {
72 match self {
73 QParams::MinMax { min, max } => {
74 let scale = (max - min) / 255.;
75 ((-(min + max) / 2. / scale) as i32, scale)
76 }
77 QParams::ZpScale { zero_point, scale } => (*zero_point, *scale),
78 }
79 }
80
81 pub fn q(&self, f: f32) -> i32 {
82 let (zp, scale) = self.zp_scale();
83 (f / scale) as i32 + zp
84 }
85
86 pub fn dq(&self, i: i32) -> f32 {
87 let (zp, scale) = self.zp_scale();
88 (i - zp) as f32 * scale
89 }
90}
91
92impl std::fmt::Debug for QParams {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 let (zp, scale) = self.zp_scale();
95 write!(f, "Z:{zp} S:{scale}")
96 }
97}
98
99#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
100pub enum DatumType {
101 Bool,
102 U8,
103 U16,
104 U32,
105 U64,
106 I8,
107 I16,
108 I32,
109 I64,
110 F16,
111 F32,
112 F64,
113 TDim,
114 Blob,
115 String,
116 QI8(QParams),
117 QU8(QParams),
118 QI32(QParams),
119 #[cfg(feature = "complex")]
120 ComplexI16,
121 #[cfg(feature = "complex")]
122 ComplexI32,
123 #[cfg(feature = "complex")]
124 ComplexI64,
125 #[cfg(feature = "complex")]
126 ComplexF16,
127 #[cfg(feature = "complex")]
128 ComplexF32,
129 #[cfg(feature = "complex")]
130 ComplexF64,
131 Opaque,
132}
133
134impl DatumType {
135 pub fn super_types(&self) -> TVec<DatumType> {
136 use DatumType::*;
137 if *self == String || *self == TDim || *self == Blob || *self == Bool || self.is_quantized()
138 {
139 return tvec!(*self);
140 }
141 #[cfg(feature = "complex")]
142 if self.is_complex_float() {
143 return [ComplexF16, ComplexF32, ComplexF64]
144 .iter()
145 .filter(|s| s.size_of() >= self.size_of())
146 .copied()
147 .collect();
148 } else if self.is_complex_signed() {
149 return [ComplexI16, ComplexI32, ComplexI64]
150 .iter()
151 .filter(|s| s.size_of() >= self.size_of())
152 .copied()
153 .collect();
154 }
155 if self.is_float() {
156 [F16, F32, F64].iter().filter(|s| s.size_of() >= self.size_of()).copied().collect()
157 } else if self.is_signed() {
158 [I8, I16, I32, I64, TDim]
159 .iter()
160 .filter(|s| s.size_of() >= self.size_of())
161 .copied()
162 .collect()
163 } else {
164 [U8, U16, U32, U64].iter().filter(|s| s.size_of() >= self.size_of()).copied().collect()
165 }
166 }
167
168 pub fn super_type_for(
169 i: impl IntoIterator<Item = impl std::borrow::Borrow<DatumType>>,
170 ) -> Option<DatumType> {
171 let mut iter = i.into_iter();
172 let mut current = match iter.next() {
173 None => return None,
174 Some(it) => *it.borrow(),
175 };
176 for n in iter {
177 match current.common_super_type(*n.borrow()) {
178 None => return None,
179 Some(it) => current = it,
180 }
181 }
182 Some(current)
183 }
184
185 pub fn common_super_type(&self, rhs: DatumType) -> Option<DatumType> {
186 for mine in self.super_types() {
187 for theirs in rhs.super_types() {
188 if mine == theirs {
189 return Some(mine);
190 }
191 }
192 }
193 None
194 }
195
196 pub fn is_unsigned(&self) -> bool {
197 matches!(
198 self.unquantized(),
199 DatumType::U8 | DatumType::U16 | DatumType::U32 | DatumType::U64
200 )
201 }
202
203 pub fn is_signed(&self) -> bool {
204 matches!(
205 self.unquantized(),
206 DatumType::I8 | DatumType::I16 | DatumType::I32 | DatumType::I64
207 )
208 }
209
210 pub fn is_float(&self) -> bool {
211 matches!(self, DatumType::F16 | DatumType::F32 | DatumType::F64)
212 }
213
214 pub fn is_number(&self) -> bool {
215 self.is_signed() | self.is_unsigned() | self.is_float() | self.is_quantized()
216 }
217
218 pub fn is_tdim(&self) -> bool {
219 *self == DatumType::TDim
220 }
221
222 pub fn is_opaque(&self) -> bool {
223 *self == DatumType::Opaque
224 }
225
226 #[cfg(feature = "complex")]
227 pub fn is_complex(&self) -> bool {
228 self.is_complex_float() || self.is_complex_signed()
229 }
230
231 #[cfg(feature = "complex")]
232 pub fn is_complex_float(&self) -> bool {
233 matches!(self, DatumType::ComplexF16 | DatumType::ComplexF32 | DatumType::ComplexF64)
234 }
235
236 #[cfg(feature = "complex")]
237 pub fn is_complex_signed(&self) -> bool {
238 matches!(self, DatumType::ComplexI16 | DatumType::ComplexI32 | DatumType::ComplexI64)
239 }
240
241 #[cfg(feature = "complex")]
242 pub fn complexify(&self) -> TractResult<DatumType> {
243 match *self {
244 DatumType::I16 => Ok(DatumType::ComplexI16),
245 DatumType::I32 => Ok(DatumType::ComplexI32),
246 DatumType::I64 => Ok(DatumType::ComplexI64),
247 DatumType::F16 => Ok(DatumType::ComplexF16),
248 DatumType::F32 => Ok(DatumType::ComplexF32),
249 DatumType::F64 => Ok(DatumType::ComplexF64),
250 _ => bail!("No complex datum type formed on {:?}", self),
251 }
252 }
253
254 #[cfg(feature = "complex")]
255 pub fn decomplexify(&self) -> TractResult<DatumType> {
256 match *self {
257 DatumType::ComplexI16 => Ok(DatumType::I16),
258 DatumType::ComplexI32 => Ok(DatumType::I32),
259 DatumType::ComplexI64 => Ok(DatumType::I64),
260 DatumType::ComplexF16 => Ok(DatumType::F16),
261 DatumType::ComplexF32 => Ok(DatumType::F32),
262 DatumType::ComplexF64 => Ok(DatumType::F64),
263 _ => bail!("{:?} is not a complex type", self),
264 }
265 }
266
267 pub fn is_copy(&self) -> bool {
268 #[cfg(feature = "complex")]
269 if self.is_complex() {
270 return true;
271 }
272 *self == DatumType::Bool || self.is_unsigned() || self.is_signed() || self.is_float()
273 }
274
275 pub fn is_quantized(&self) -> bool {
276 self.qparams().is_some()
277 }
278
279 pub fn qparams(&self) -> Option<QParams> {
280 match self {
281 DatumType::QI8(qparams) | DatumType::QU8(qparams) | DatumType::QI32(qparams) => {
282 Some(*qparams)
283 }
284 _ => None,
285 }
286 }
287
288 pub fn with_qparams(&self, qparams: QParams) -> DatumType {
289 match self {
290 DatumType::QI8(_) => DatumType::QI8(qparams),
291 DatumType::QU8(_) => DatumType::QI8(qparams),
292 DatumType::QI32(_) => DatumType::QI32(qparams),
293 _ => *self,
294 }
295 }
296
297 pub fn quantize(&self, qparams: QParams) -> DatumType {
298 match self {
299 DatumType::I8 => DatumType::QI8(qparams),
300 DatumType::U8 => DatumType::QU8(qparams),
301 DatumType::I32 => DatumType::QI32(qparams),
302 DatumType::QI8(_) => DatumType::QI8(qparams),
303 DatumType::QU8(_) => DatumType::QU8(qparams),
304 DatumType::QI32(_) => DatumType::QI32(qparams),
305 _ => panic!("Can't quantize {self:?}"),
306 }
307 }
308
309 #[inline(always)]
310 pub fn zp_scale(&self) -> (i32, f32) {
311 self.qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.))
312 }
313
314 #[inline(always)]
315 pub fn with_zp_scale(&self, zero_point: i32, scale: f32) -> DatumType {
316 self.quantize(QParams::ZpScale { zero_point, scale })
317 }
318
319 pub fn unquantized(&self) -> DatumType {
320 match self {
321 DatumType::QI8(_) => DatumType::I8,
322 DatumType::QU8(_) => DatumType::U8,
323 DatumType::QI32(_) => DatumType::I32,
324 _ => *self,
325 }
326 }
327
328 pub fn integer(signed: bool, size: usize) -> Self {
329 use DatumType::*;
330 match (signed, size) {
331 (false, 8) => U8,
332 (false, 16) => U16,
333 (false, 32) => U32,
334 (false, 64) => U64,
335 (true, 8) => U8,
336 (true, 16) => U16,
337 (true, 32) => U32,
338 (true, 64) => U64,
339 _ => panic!("No integer for signed:{signed} size:{size}"),
340 }
341 }
342
343 pub fn is_integer(&self) -> bool {
344 self.is_signed() || self.is_unsigned()
345 }
346
347 #[inline]
348 pub fn size_of(&self) -> usize {
349 dispatch_datum!(std::mem::size_of(self)())
350 }
351
352 pub fn min_value(&self) -> Tensor {
353 match self {
354 DatumType::QU8(_)
355 | DatumType::U8
356 | DatumType::U16
357 | DatumType::U32
358 | DatumType::U64 => Tensor::zero_dt(*self, &[1]).unwrap(),
359 DatumType::I8 | DatumType::QI8(_) => tensor0(i8::MIN),
360 DatumType::QI32(_) => tensor0(i32::MIN),
361 DatumType::I16 => tensor0(i16::MIN),
362 DatumType::I32 => tensor0(i32::MIN),
363 DatumType::I64 => tensor0(i64::MIN),
364 DatumType::F16 => tensor0(f16::MIN),
365 DatumType::F32 => tensor0(f32::MIN),
366 DatumType::F64 => tensor0(f64::MIN),
367 _ => panic!("No min value for datum type {self:?}"),
368 }
369 }
370 pub fn max_value(&self) -> Tensor {
371 match self {
372 DatumType::U8 | DatumType::QU8(_) => tensor0(u8::MAX),
373 DatumType::U16 => tensor0(u16::MAX),
374 DatumType::U32 => tensor0(u32::MAX),
375 DatumType::U64 => tensor0(u64::MAX),
376 DatumType::I8 | DatumType::QI8(_) => tensor0(i8::MAX),
377 DatumType::I16 => tensor0(i16::MAX),
378 DatumType::I32 => tensor0(i32::MAX),
379 DatumType::I64 => tensor0(i64::MAX),
380 DatumType::QI32(_) => tensor0(i32::MAX),
381 DatumType::F16 => tensor0(f16::MAX),
382 DatumType::F32 => tensor0(f32::MAX),
383 DatumType::F64 => tensor0(f64::MAX),
384 _ => panic!("No max value for datum type {self:?}"),
385 }
386 }
387
388 pub fn is<D: Datum>(&self) -> bool {
389 *self == D::datum_type()
390 }
391}
392
393impl std::str::FromStr for DatumType {
394 type Err = TractError;
395
396 fn from_str(s: &str) -> Result<Self, Self::Err> {
397 if let Ok((z, s)) = scan_fmt!(s, "QU8(Z:{d} S:{f})", i32, f32) {
398 Ok(DatumType::QU8(QParams::ZpScale { zero_point: z, scale: s }))
399 } else if let Ok((z, s)) = scan_fmt!(s, "QI8(Z:{d} S:{f})", i32, f32) {
400 Ok(DatumType::QI8(QParams::ZpScale { zero_point: z, scale: s }))
401 } else if let Ok((z, s)) = scan_fmt!(s, "QI32(Z:{d} S:{f})", i32, f32) {
402 Ok(DatumType::QI32(QParams::ZpScale { zero_point: z, scale: s }))
403 } else {
404 match s {
405 "I8" | "i8" => Ok(DatumType::I8),
406 "I16" | "i16" => Ok(DatumType::I16),
407 "I32" | "i32" => Ok(DatumType::I32),
408 "I64" | "i64" => Ok(DatumType::I64),
409 "U8" | "u8" => Ok(DatumType::U8),
410 "U16" | "u16" => Ok(DatumType::U16),
411 "U32" | "u32" => Ok(DatumType::U32),
412 "U64" | "u64" => Ok(DatumType::U64),
413 "F16" | "f16" => Ok(DatumType::F16),
414 "F32" | "f32" => Ok(DatumType::F32),
415 "F64" | "f64" => Ok(DatumType::F64),
416 "Bool" | "bool" => Ok(DatumType::Bool),
417 "Blob" | "blob" => Ok(DatumType::Blob),
418 "String" | "string" => Ok(DatumType::String),
419 "TDim" | "tdim" => Ok(DatumType::TDim),
420 #[cfg(feature = "complex")]
421 "ComplexI16" | "complexi16" => Ok(DatumType::ComplexI16),
422 #[cfg(feature = "complex")]
423 "ComplexI32" | "complexi32" => Ok(DatumType::ComplexI32),
424 #[cfg(feature = "complex")]
425 "ComplexI64" | "complexi64" => Ok(DatumType::ComplexI64),
426 #[cfg(feature = "complex")]
427 "ComplexF16" | "complexf16" => Ok(DatumType::ComplexF16),
428 #[cfg(feature = "complex")]
429 "ComplexF32" | "complexf32" => Ok(DatumType::ComplexF32),
430 #[cfg(feature = "complex")]
431 "ComplexF64" | "complexf64" => Ok(DatumType::ComplexF64),
432 _ => bail!("Unknown type {}", s),
433 }
434 }
435 }
436}
437
438const TOINT: f32 = 1.0f32 / f32::EPSILON;
439
440pub fn round_ties_to_even(x: f32) -> f32 {
441 let u = x.to_bits();
442 let e = (u >> 23) & 0xff;
443 if e >= 0x7f + 23 {
444 return x;
445 }
446 let s = u >> 31;
447 let y = if s == 1 { x - TOINT + TOINT } else { x + TOINT - TOINT };
448 if y == 0.0 {
449 if s == 1 {
450 -0f32
451 } else {
452 0f32
453 }
454 } else {
455 y
456 }
457}
458
459#[inline]
460pub fn scale_by<T: Datum + AsPrimitive<f32>>(b: T, a: f32) -> T
461where
462 f32: AsPrimitive<T>,
463{
464 let b = b.as_();
465 (round_ties_to_even(b.abs() * a) * b.signum()).as_()
466}
467
468pub trait ClampCast: PartialOrd + Copy + 'static {
469 #[inline(always)]
470 fn clamp_cast<O>(self) -> O
471 where
472 Self: AsPrimitive<O> + Datum,
473 O: AsPrimitive<Self> + num_traits::Bounded + Datum,
474 {
475 if O::min_value().as_() < O::max_value().as_() {
477 num_traits::clamp(self, O::min_value().as_(), O::max_value().as_()).as_()
478 } else {
479 self.as_()
480 }
481 }
482}
483impl<T: PartialOrd + Copy + 'static> ClampCast for T {}
484
485pub trait Datum:
486 Clone + Send + Sync + fmt::Debug + fmt::Display + Default + 'static + PartialEq
487{
488 fn name() -> &'static str;
489 fn datum_type() -> DatumType;
490 fn is<D: Datum>() -> bool;
491}
492
493macro_rules! datum {
494 ($t:ty, $v:ident) => {
495 impl From<$t> for Tensor {
496 fn from(it: $t) -> Tensor {
497 tensor0(it)
498 }
499 }
500
501 impl Datum for $t {
502 fn name() -> &'static str {
503 stringify!($t)
504 }
505
506 fn datum_type() -> DatumType {
507 DatumType::$v
508 }
509
510 fn is<D: Datum>() -> bool {
511 Self::datum_type() == D::datum_type()
512 }
513 }
514 };
515}
516
517datum!(bool, Bool);
518datum!(f16, F16);
519datum!(f32, F32);
520datum!(f64, F64);
521datum!(i8, I8);
522datum!(i16, I16);
523datum!(i32, I32);
524datum!(i64, I64);
525datum!(u8, U8);
526datum!(u16, U16);
527datum!(u32, U32);
528datum!(u64, U64);
529datum!(TDim, TDim);
530datum!(String, String);
531datum!(crate::blob::Blob, Blob);
532datum!(crate::opaque::Opaque, Opaque);
533#[cfg(feature = "complex")]
534datum!(Complex<i16>, ComplexI16);
535#[cfg(feature = "complex")]
536datum!(Complex<i32>, ComplexI32);
537#[cfg(feature = "complex")]
538datum!(Complex<i64>, ComplexI64);
539#[cfg(feature = "complex")]
540datum!(Complex<f16>, ComplexF16);
541#[cfg(feature = "complex")]
542datum!(Complex<f32>, ComplexF32);
543#[cfg(feature = "complex")]
544datum!(Complex<f64>, ComplexF64);
545
546#[cfg(test)]
547mod tests {
548 use crate::internal::*;
549 use ndarray::arr1;
550
551 #[test]
552 fn test_array_to_tensor_to_array() {
553 let array = arr1(&[12i32, 42]);
554 let tensor = Tensor::from(array.clone());
555 let view = tensor.to_array_view::<i32>().unwrap();
556 assert_eq!(array, view.into_dimensionality().unwrap());
557 }
558
559 #[test]
560 fn test_cast_dim_to_dim() {
561 let t_dim: Tensor = tensor1(&[12isize.to_dim(), 42isize.to_dim()]);
562 let t_i32 = t_dim.cast_to::<i32>().unwrap();
563 let t_dim_2 = t_i32.cast_to::<TDim>().unwrap().into_owned();
564 assert_eq!(t_dim, t_dim_2);
565 }
566
567 #[test]
568 fn test_cast_i32_to_dim() {
569 let t_i32: Tensor = tensor1(&[0i32, 12]);
570 t_i32.cast_to::<TDim>().unwrap();
571 }
572
573 #[test]
574 fn test_cast_i64_to_bool() {
575 let t_i64: Tensor = tensor1(&[0i64]);
576 t_i64.cast_to::<bool>().unwrap();
577 }
578
579 #[test]
580 fn test_parse_qu8() {
581 assert_eq!(
582 "QU8(Z:128 S:0.01)".parse::<DatumType>().unwrap(),
583 DatumType::QU8(QParams::ZpScale { zero_point: 128, scale: 0.01 })
584 );
585 }
586}