tract_data/
tensor.rs

1//! `Tensor`, tract main data object of interest.
2use crate::blob::Blob;
3use crate::datum::{round_ties_to_even, scale_by, ClampCast, Datum, DatumType, QParams};
4use crate::dim::TDim;
5use crate::internal::*;
6use crate::opaque::Opaque;
7use crate::TVec;
8use half::f16;
9use itertools::Itertools;
10use ndarray::prelude::*;
11#[cfg(feature = "complex")]
12use num_complex::Complex;
13use num_traits::Zero;
14use std::borrow::Cow;
15use std::fmt;
16use std::hash::Hash;
17use std::ops::Range;
18use std::sync::Arc;
19
20pub mod litteral;
21pub mod view;
22
23#[derive(Copy, Clone, Default, Debug)]
24pub enum Approximation {
25    Exact,
26    #[default]
27    Close,
28    Approximate,
29    VeryApproximate,
30    SuperApproximate,
31    UltraApproximate,
32    Custom(f32, f32, f32),
33}
34
35impl PartialEq for Approximation {
36    fn eq(&self, other: &Self) -> bool {
37        use Approximation::Custom;
38        if let (Custom(aa, ar, ao), Custom(ba, br, bo)) = (self, other) {
39            aa == ba && ar == br && bo == ao
40        } else {
41            std::mem::discriminant(self) == std::mem::discriminant(other)
42        }
43    }
44}
45
46impl Eq for Approximation {}
47
48impl From<bool> for Approximation {
49    fn from(b: bool) -> Self {
50        if b {
51            Self::Approximate
52        } else {
53            Self::Exact
54        }
55    }
56}
57
58impl Approximation {
59    fn atol_rtol_outliers(&self, dt: &DatumType) -> (f64, f64, f64) {
60        use Approximation::*;
61        match (self, dt) {
62            (Exact, _) => (0.0, 0.0, 0.0),
63            (Close, DatumType::F16) => (1e-3, 1e-3, 0.0),
64            (Approximate, DatumType::F16) => (1e-3, 5e-3, 0.0),
65            (Approximate, qp) if qp.is_quantized() => (qp.zp_scale().1 as f64, 0., 0.0),
66            (Close, _) => (1e-7, 1e-7, 0.0),
67            (Approximate, _) => (1e-4, 5e-4, 0.0),
68            (VeryApproximate, _) => (5e-2, 1e-2, 0.0),
69            (SuperApproximate, _) => (0.1, 0.05, 0.0001),
70            (UltraApproximate, _) => (0.2, 0.1, 0.0005),
71            (Custom(atol, rtol, out), _) => (*atol as _, *rtol as _, *out as _),
72        }
73    }
74}
75
76/// Tensor is a concrete tensor in tract.
77#[derive(Eq)]
78pub struct Tensor {
79    dt: DatumType,
80    shape: TVec<usize>,
81    strides: TVec<isize>,
82    len: usize,
83    data: Blob,
84}
85
86unsafe impl Send for Tensor {}
87unsafe impl Sync for Tensor {}
88
89impl Hash for Tensor {
90    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
91        use DatumType::*;
92        self.dt.hash(state);
93        self.shape.hash(state);
94        self.data.layout().align().hash(state);
95        unsafe {
96            match self.dt {
97                Bool => self.as_slice_unchecked::<bool>().hash(state),
98                I8 => self.as_slice_unchecked::<i8>().hash(state),
99                I16 => self.as_slice_unchecked::<i16>().hash(state),
100                I32 => self.as_slice_unchecked::<i32>().hash(state),
101                I64 => self.as_slice_unchecked::<i64>().hash(state),
102                U8 => self.as_slice_unchecked::<u8>().hash(state),
103                U16 => self.as_slice_unchecked::<u16>().hash(state),
104                U32 => self.as_slice_unchecked::<u32>().hash(state),
105                U64 => self.as_slice_unchecked::<u64>().hash(state),
106                F16 => self.as_slice_unchecked::<i16>().hash(state),
107                F32 => self.as_slice_unchecked::<i32>().hash(state),
108                F64 => self.as_slice_unchecked::<i64>().hash(state),
109                TDim => self.as_slice_unchecked::<crate::dim::TDim>().hash(state),
110                String => self.as_slice_unchecked::<std::string::String>().hash(state),
111                Blob => self.as_slice_unchecked::<crate::blob::Blob>().hash(state),
112                Opaque => self.as_slice_unchecked::<crate::opaque::Opaque>().hash(state),
113                QI8(_) => self.as_slice_unchecked::<i8>().hash(state),
114                QU8(_) => self.as_slice_unchecked::<u8>().hash(state),
115                QI32(_) => self.as_slice_unchecked::<i32>().hash(state),
116                #[cfg(feature = "complex")]
117                ComplexI16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
118                #[cfg(feature = "complex")]
119                ComplexI32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
120                #[cfg(feature = "complex")]
121                ComplexI64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
122                #[cfg(feature = "complex")]
123                ComplexF16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
124                #[cfg(feature = "complex")]
125                ComplexF32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
126                #[cfg(feature = "complex")]
127                ComplexF64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
128            }
129        }
130    }
131}
132
133impl Clone for Tensor {
134    fn clone(&self) -> Tensor {
135        self.deep_clone()
136    }
137}
138
139impl Default for Tensor {
140    fn default() -> Tensor {
141        litteral::tensor0(0f32)
142    }
143}
144
145impl Drop for Tensor {
146    fn drop(&mut self) {
147        macro_rules! drop_in_place {
148            ($t: ty) => {
149                if self.dt == <$t>::datum_type() {
150                    unsafe {
151                        self.as_slice_mut::<$t>()
152                            .unwrap()
153                            .iter_mut()
154                            .for_each(|s| std::ptr::drop_in_place(s as *mut $t));
155                    }
156                }
157            };
158        }
159        drop_in_place!(Blob);
160        drop_in_place!(String);
161        drop_in_place!(TDim);
162        drop_in_place!(Opaque);
163    }
164}
165
166#[allow(unreachable_code)]
167pub fn vector_size() -> usize {
168    #[cfg(target_arch = "x86_64")]
169    {
170        return if is_x86_feature_detected!("avx512f") { 512 / 8 } else { 256 / 8 };
171    }
172    128 / 8
173}
174
175impl Tensor {
176    /// Create an uninitialized tensor (dt as type paramater).
177    #[inline]
178    pub unsafe fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<Tensor> {
179        Self::uninitialized_dt(T::datum_type(), shape)
180    }
181
182    /// Create an uninitialized tensor (dt as regular parameter).
183    #[inline]
184    pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
185        Self::uninitialized_aligned_dt(dt, shape, vector_size())
186    }
187
188    /// Create an uninitialized tensor with a given alignment (in bytes).
189    #[inline]
190    pub unsafe fn uninitialized_aligned<T: Datum>(
191        shape: &[usize],
192        alignment: usize,
193    ) -> TractResult<Tensor> {
194        Self::uninitialized_aligned_dt(T::datum_type(), shape, alignment)
195    }
196
197    /// Create an uninitialized tensor with a given alignment (in bytes).
198    pub unsafe fn uninitialized_aligned_dt(
199        dt: DatumType,
200        shape: &[usize],
201        alignment: usize,
202    ) -> TractResult<Tensor> {
203        let bytes = shape.iter().cloned().product::<usize>() * dt.size_of();
204        let data = Blob::new_for_size_and_align(bytes, alignment);
205        let mut tensor = Tensor { strides: tvec!(), dt, shape: shape.into(), data, len: 0 };
206        if tensor.shape.len() == 0 {
207            tensor.len = 1;
208        } else {
209            tensor.update_strides_and_len();
210        }
211        if !tensor.data.is_empty() {
212            if dt == String::datum_type() || dt == Blob::datum_type() {
213                // assumes zero-initialized string and blob are valid
214                tensor.data.fill(0);
215            } else if dt == TDim::datum_type() {
216                tensor
217                    .as_slice_mut_unchecked::<TDim>()
218                    .iter_mut()
219                    .for_each(|dim| std::ptr::write(dim, TDim::zero()))
220            } else if dt == Opaque::datum_type() {
221                tensor.as_slice_mut_unchecked::<Opaque>().iter_mut().for_each(|p| {
222                    std::ptr::write(p, Opaque::default());
223                });
224            } else if cfg!(debug_assertions) {
225                assert!(dt.is_copy());
226                if dt == DatumType::F32 {
227                    tensor.fill_t(f32::NAN).unwrap();
228                } else {
229                    // safe, non copy types have been dealt with
230                    tensor.as_bytes_mut().iter_mut().for_each(|x| *x = (-1i8) as u8);
231                }
232            }
233        }
234        Ok(tensor)
235    }
236
237    pub fn stack_tensors(
238        axis: usize,
239        tensors: &[impl std::borrow::Borrow<Tensor>],
240    ) -> TractResult<Tensor> {
241        ensure!(tensors.len() > 0);
242        let rank = tensors[0].borrow().rank();
243        ensure!(axis < rank);
244        ensure!(tensors.iter().all(|t| t.borrow().rank() == rank));
245        let dt = tensors[0].borrow().datum_type();
246        ensure!(tensors.iter().all(|t| t.borrow().datum_type() == dt));
247        let mut shape: TVec<usize> = tensors[0].borrow().shape().into();
248        for ax in 0..rank {
249            if ax != axis {
250                ensure!(tensors.iter().all(|t| t.borrow().shape()[ax] == shape[ax]));
251            }
252        }
253        shape[axis] = tensors.iter().map(|v| v.borrow().shape()[axis]).sum();
254        unsafe {
255            let mut result = Tensor::uninitialized_dt(dt, &shape)?;
256            if dt.is_copy() && shape[..axis].iter().all(|d| *d == 1) {
257                let mut offset = 0isize;
258                for v in tensors {
259                    let v = v.borrow();
260                    let len = v.data.len();
261                    std::ptr::copy_nonoverlapping(
262                        v.data.as_ptr(),
263                        result.data.as_mut_ptr().offset(offset),
264                        len,
265                    );
266                    offset += len as isize;
267                }
268            } else {
269                let mut offset = 0;
270                for t in tensors {
271                    let t = t.borrow();
272                    let len = t.shape()[axis];
273                    result.assign_slice_from_resolved(offset..offset + len, t, 0..len, axis);
274                    offset += len;
275                }
276            }
277
278            Ok(result)
279        }
280    }
281
282    pub fn clear<T: Datum + num_traits::Zero + Clone>(&mut self) -> TractResult<()> {
283        self.fill_t(T::zero())
284    }
285
286    pub fn zero<T: Datum + num_traits::Zero>(shape: &[usize]) -> TractResult<Tensor> {
287        unsafe {
288            let mut t = Tensor::uninitialized::<T>(shape)?;
289            t.clear::<T>()?;
290            Ok(t)
291        }
292    }
293
294    pub fn zero_scalar<T: Datum + num_traits::Zero>() -> TractResult<Tensor> {
295        Tensor::zero::<T>(&[])
296    }
297
298    pub fn zero_scalar_dt(dt: DatumType) -> TractResult<Tensor> {
299        Tensor::zero_dt(dt, &[])
300    }
301
302    pub fn zero_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
303        Tensor::zero_aligned_dt(dt, shape, vector_size())
304    }
305
306    pub fn fill_t<T: Datum + Clone>(&mut self, value: T) -> TractResult<()> {
307        self.as_slice_mut::<T>()?.iter_mut().for_each(|item| *item = value.clone());
308        Ok(())
309    }
310
311    pub fn zero_aligned_dt(
312        dt: DatumType,
313        shape: &[usize],
314        alignment: usize,
315    ) -> TractResult<Tensor> {
316        if shape.iter().product::<usize>() == 0 {
317            unsafe { return Tensor::uninitialized_dt(dt, shape) };
318        }
319        if dt.is_quantized() {
320            unsafe {
321                let mut t = Tensor::uninitialized_dt(dt, shape)?;
322                let zp = dt.zp_scale().0;
323                match dt.unquantized() {
324                    DatumType::I8 => {
325                        t.as_slice_mut::<i8>()?.iter_mut().for_each(|item| *item = zp as _)
326                    }
327                    DatumType::U8 => {
328                        t.as_slice_mut::<u8>()?.iter_mut().for_each(|item| *item = zp as _)
329                    }
330                    DatumType::I32 => {
331                        t.as_slice_mut::<i32>()?.iter_mut().for_each(|item| *item = zp as _)
332                    }
333                    _ => unreachable!(),
334                }
335                Ok(t)
336            }
337        } else {
338            dispatch_zerolike!(Self::zero_aligned(dt)(shape, alignment))
339        }
340    }
341
342    pub fn zero_aligned<T: Datum + num_traits::Zero>(
343        shape: &[usize],
344        alignment: usize,
345    ) -> TractResult<Tensor> {
346        unsafe {
347            let mut tensor = Self::uninitialized_aligned::<T>(shape, alignment)?;
348            tensor.clear::<T>()?;
349            Ok(tensor)
350        }
351    }
352
353    /// Create a tensor with a given shape and a slice of elements.
354    /// The data is copied and aligned to size of T.
355    pub fn from_shape<T: Datum + Copy>(shape: &[usize], data: &[T]) -> TractResult<Tensor> {
356        Self::from_shape_align(shape, data, vector_size())
357    }
358
359    /// Create a tensor with a given shape and a slice of elements.
360    /// The data is copied and aligned to given alignment.
361    pub fn from_shape_align<T: Datum + Copy>(
362        shape: &[usize],
363        data: &[T],
364        align: usize,
365    ) -> TractResult<Tensor> {
366        ensure!(
367            data.len() == shape.iter().product::<usize>(),
368            "Shape product must be equal to data length"
369        );
370        unsafe {
371            let bytes = std::slice::from_raw_parts(
372                data.as_ptr() as *const u8,
373                data.len() * T::datum_type().size_of(),
374            );
375            let dt = T::datum_type();
376            Self::from_raw_dt_align(dt, shape, bytes, align)
377        }
378    }
379
380    /// Create a tensor from raw data.
381    ///
382    /// It copies the data, aligning it to the size of T.
383    pub unsafe fn from_raw<T: Datum>(shape: &[usize], content: &[u8]) -> TractResult<Tensor> {
384        Tensor::from_raw_dt(T::datum_type(), shape, content)
385    }
386
387    pub unsafe fn from_raw_aligned<T: Datum>(
388        shape: &[usize],
389        content: &[u8],
390        align: usize,
391    ) -> TractResult<Tensor> {
392        Tensor::from_raw_dt_align(T::datum_type(), shape, content, align)
393    }
394
395    pub unsafe fn from_raw_dt(
396        dt: DatumType,
397        shape: &[usize],
398        content: &[u8],
399    ) -> TractResult<Tensor> {
400        Self::from_raw_dt_align(dt, shape, content, vector_size())
401    }
402
403    pub unsafe fn from_raw_dt_align(
404        dt: DatumType,
405        shape: &[usize],
406        content: &[u8],
407        align: usize,
408    ) -> TractResult<Tensor> {
409        let mut tensor = Tensor::uninitialized_aligned_dt(dt, shape, align)?;
410        tensor.as_bytes_mut().copy_from_slice(content);
411        Ok(tensor)
412    }
413
414    pub unsafe fn from_slice_align<T: Datum>(content: &[T], align: usize) -> TractResult<Tensor> {
415        let bytes = if content.len() == 0 {
416            &[]
417        } else {
418            std::slice::from_raw_parts(
419                content.as_ptr() as *const u8,
420                content.len() * T::datum_type().size_of(),
421            )
422        };
423        Self::from_raw_dt_align(T::datum_type(), &[content.len()], bytes, align)
424    }
425
426    /// Get the number of dimensions (or axes) of the tensor.
427    #[inline]
428    pub fn rank(&self) -> usize {
429        self.shape.len()
430    }
431
432    /// Get the shape of the tensor.
433    #[inline]
434    pub fn shape(&self) -> &[usize] {
435        &self.shape
436    }
437
438    /// Get the number of values in the tensor.
439    #[inline]
440    #[allow(clippy::len_without_is_empty)]
441    pub fn len(&self) -> usize {
442        self.len
443    }
444
445    /// Get the number of valeus in the tensor.
446    #[inline]
447    #[allow(clippy::len_without_is_empty)]
448    pub fn volume(&self) -> usize {
449        self.len
450    }
451
452    /// Get the shape of the tensor.
453    #[inline]
454    pub fn strides(&self) -> &[isize] {
455        &self.strides
456    }
457
458    fn update_strides_and_len(&mut self) {
459        self.strides.clear();
460        if self.shape.len() == 0 {
461            self.len = 1;
462            return;
463        }
464        compute_natural_stride_to(&mut self.strides, &self.shape);
465        self.len = unsafe { *self.strides.get_unchecked(0) as usize * self.shape.get_unchecked(0) };
466    }
467
468    /// Force the tensor shape, no consistency check.
469    pub unsafe fn set_shape_unchecked(&mut self, shape: &[usize]) {
470        if shape != &*self.shape {
471            self.shape.clear();
472            self.shape.extend_from_slice(shape);
473            self.update_strides_and_len();
474        }
475    }
476
477    /// Force the tensor shape and strides, no consistency check.
478    pub unsafe fn set_geometry_unchecked(&mut self, shape: &[usize], strides: &[isize]) {
479        self.shape.clear();
480        self.shape.extend_from_slice(shape);
481        self.strides.clear();
482        self.strides.extend_from_slice(strides);
483    }
484
485    /// Force the tensor shape.
486    pub fn set_shape(&mut self, shape: &[usize]) -> TractResult<()> {
487        if self.len() != shape.iter().product::<usize>() {
488            bail!("Invalid reshape {:?} to {:?}", self.shape, shape);
489        }
490        unsafe { self.set_shape_unchecked(shape) }
491        Ok(())
492    }
493
494    pub fn permute_axes(self, axes: &[usize]) -> TractResult<Tensor> {
495        ensure!(axes.iter().duplicates().next().is_none());
496        ensure!(axes.iter().all(|a| *a < self.rank()));
497        unsafe {
498            #[inline]
499            unsafe fn permute<T: Datum>(axes: &[usize], input: Tensor) -> Tensor {
500                input.into_array_unchecked::<T>().permuted_axes(axes).into_tensor()
501            }
502            let dt = self.datum_type();
503            let mut t = dispatch_datum_by_size!(permute(self.datum_type())(axes, self));
504            t.set_datum_type(dt);
505            Ok(t)
506        }
507    }
508
509    pub fn move_axis(self, from: usize, to: usize) -> TractResult<Tensor> {
510        let mut permutation: Vec<usize> = (0..self.rank()).collect();
511        permutation.remove(from);
512        permutation.insert(to, from);
513        self.permute_axes(&permutation)
514    }
515
516    pub fn collapse_axis_with_next(mut self, axis: usize) -> Tensor {
517        let removed = self.shape.remove(axis + 1);
518        self.shape[axis] *= removed;
519        self.update_strides_and_len();
520        self
521    }
522
523    pub fn split_axis(mut self, axis: usize, outer_dim: usize) -> TractResult<Tensor> {
524        if self.shape[axis] % outer_dim != 0 {
525            bail!(
526                "Invalid axis split, shape is {:?}, axis split at {}, outer {}",
527                self.shape,
528                axis,
529                outer_dim
530            );
531        }
532        self.shape.insert(axis + 1, self.shape[axis] / outer_dim);
533        self.shape[axis] = outer_dim;
534        self.update_strides_and_len();
535        Ok(self)
536    }
537
538    /// Reshape the tensor to `shape`.
539    pub fn into_shape(mut self, shape: &[usize]) -> TractResult<Tensor> {
540        self.set_shape(shape)?;
541        Ok(self)
542    }
543
544    pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
545        self.shape.insert(axis, 1);
546        self.strides.insert(axis, self.strides.get(axis).copied().unwrap_or(1));
547        Ok(())
548    }
549
550    pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
551        ensure!(self.shape[axis] == 1, "Remove a non-1 axis: axis {} in {:?}", axis, self);
552        self.shape.remove(axis);
553        self.strides.remove(axis);
554        Ok(())
555    }
556
557    pub fn broadcast_into_rank(mut self, rank: usize) -> TractResult<Tensor> {
558        self.broadcast_to_rank(rank)?;
559        self.update_strides_and_len();
560        Ok(self)
561    }
562
563    pub fn broadcast_to_rank(&mut self, rank: usize) -> TractResult<()> {
564        if rank < self.rank() {
565            bail!("Can only broadcast to higher rank")
566        }
567        while self.shape.len() < rank {
568            self.shape.insert(0, 1)
569        }
570        self.update_strides_and_len();
571        Ok(())
572    }
573
574    pub fn broadcast_scalar_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
575        if self.rank() > 0 {
576            bail!("broadcast_scalar_to_shape called on {:?}, which is not a salar", self);
577        }
578        unsafe fn make<T: Datum>(src: &Tensor, dst: &mut Tensor) {
579            let value: &T = src.to_scalar_unchecked::<T>();
580            dst.as_slice_mut_unchecked::<T>().iter_mut().for_each(|item| *item = value.clone());
581        }
582        unsafe {
583            let mut t = Tensor::uninitialized_dt(self.datum_type(), shape)?;
584            dispatch_datum_by_size!(make(self.datum_type())(self, &mut t));
585            Ok(t)
586        }
587    }
588
589    fn broadcast_to_shape_t<T: Datum>(&self, shape: &[usize]) -> TractResult<Tensor> {
590        unsafe {
591            let view = self.to_array_view_unchecked::<T>();
592            let mut output = view
593                .broadcast(shape)
594                .with_context(|| format!("Broadcasting {view:?} to {shape:?}"))?
595                .into_owned()
596                .into_tensor();
597            output.set_datum_type(self.datum_type());
598            Ok(output)
599        }
600    }
601
602    pub fn broadcast_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
603        dispatch_datum!(Self::broadcast_to_shape_t(self.dt)(self, shape))
604    }
605
606    pub fn broadcast_vector_to_shape(&self, shape: &[usize], axis: usize) -> TractResult<Tensor> {
607        ensure!(self.rank() == 1);
608        ensure!(shape[axis] == self.len());
609        if !self.datum_type().is_copy() {
610            let mut vec_shape = vec![1; shape.len()];
611            vec_shape[axis] = self.len();
612            return self.clone().into_shape(&vec_shape)?.broadcast_to_shape(shape);
613        }
614        unsafe {
615            let mut output = Tensor::uninitialized_dt(self.datum_type(), shape)?;
616            if output.len() == 0 {
617                return Ok(output);
618            }
619            let inner_len = shape[axis + 1..].iter().product::<usize>();
620
621            unsafe fn splat<T>(input: &Tensor, output: &mut Tensor, inner_len: usize)
622            where
623                T: Datum + Copy,
624            {
625                for ix in 0..input.len() {
626                    let value: T = input.as_slice_unchecked()[ix];
627                    output.as_slice_mut_unchecked::<T>()[ix * inner_len..(ix + 1) * inner_len]
628                        .iter_mut()
629                        .for_each(|item| *item = value);
630                }
631            }
632            dispatch_copy_by_size!(splat(self.datum_type())(&self, &mut output, inner_len));
633
634            let outer_len = shape[0..axis].iter().product::<usize>();
635            let repeat_bytes_len = inner_len * self.as_bytes().len();
636            let bytes = output.as_bytes_mut();
637            for ix in 1..outer_len {
638                bytes.copy_within(0..repeat_bytes_len, ix * repeat_bytes_len);
639            }
640
641            Ok(output)
642        }
643    }
644
645    fn clip_range_bounds(
646        &self,
647        axis: usize,
648        range: impl std::ops::RangeBounds<usize>,
649    ) -> Range<usize> {
650        use std::ops::Bound;
651        let start = match range.start_bound() {
652            Bound::Included(ix) => *ix,
653            Bound::Excluded(ix) => ix + 1,
654            Bound::Unbounded => 0,
655        };
656        let end = match range.end_bound() {
657            Bound::Included(ix) => *ix + 1,
658            Bound::Excluded(ix) => *ix,
659            Bound::Unbounded => self.shape()[axis],
660        };
661        start..end
662    }
663
664    pub fn assign_slice(
665        &mut self,
666        range: impl std::ops::RangeBounds<usize>,
667        src: &Tensor,
668        src_range: impl std::ops::RangeBounds<usize>,
669        axis: usize,
670    ) -> TractResult<()> {
671        let range = self.clip_range_bounds(axis, range);
672        let src_range = src.clip_range_bounds(axis, src_range);
673        ensure!(
674            src.datum_type() == self.datum_type(),
675            "Attempt to assign into {:?} from {:?}, datum type mismatch",
676            self.datum_type(),
677            src.datum_type()
678        );
679        ensure!(
680            src_range.len() == range.len(),
681            "Attempt to assign a range of {:?} from a range of {:?}",
682            range,
683            src_range,
684        );
685        ensure!(
686            self.rank() == src.rank()
687                && itertools::izip!(0.., self.shape(), src.shape())
688                    .all(|(ix, dst, src)| ix == axis || src == dst),
689            "Attempt to assign a {}-axis range of {:?} from a range of {:?}",
690            axis,
691            self,
692            src
693        );
694        ensure!(
695            src_range.end <= src.shape()[axis],
696            "Assigning from invalid slice (axis {}, {:?}) of {:?}",
697            axis,
698            src_range,
699            src
700        );
701        ensure!(
702            range.end <= self.shape()[axis],
703            "Assigning to invalid slice (axis {}, {:?}) of {:?}",
704            axis,
705            range,
706            self
707        );
708        unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
709        Ok(())
710    }
711
712    pub unsafe fn assign_slice_unchecked(
713        &mut self,
714        range: impl std::ops::RangeBounds<usize>,
715        src: &Tensor,
716        src_range: impl std::ops::RangeBounds<usize>,
717        axis: usize,
718    ) {
719        let range = self.clip_range_bounds(axis, range);
720        let src_range = src.clip_range_bounds(axis, src_range);
721        self.assign_slice_from_resolved(range, src, src_range, axis);
722    }
723
724    #[allow(clippy::ptr_eq)]
725    unsafe fn assign_slice_from_resolved(
726        &mut self,
727        range: std::ops::Range<usize>,
728        src: &Tensor,
729        src_range: std::ops::Range<usize>,
730        axis: usize,
731    ) {
732        use ndarray::Slice;
733        unsafe fn assign_slice_t<T: Datum>(
734            to: &mut Tensor,
735            to_range: Range<usize>,
736            from: &Tensor,
737            from_range: Range<usize>,
738            axis: usize,
739        ) {
740            to.to_array_view_mut_unchecked::<T>()
741                .slice_axis_mut(Axis(axis), Slice::from(to_range))
742                .assign(
743                    &from
744                        .to_array_view_unchecked::<T>()
745                        .slice_axis(Axis(axis), Slice::from(from_range)),
746                )
747        }
748        if self.datum_type().is_copy() && self.shape[..axis].iter().all(|d| *d == 1) {
749            let stride = self.strides[axis] as usize * self.datum_type().size_of();
750            let dst_start = (stride * range.start) as isize;
751            let src_start = (stride * src_range.start) as isize;
752            let len = stride * range.len();
753            if len > 0 {
754                if self.data.as_ptr() != src.data.as_ptr() {
755                    std::ptr::copy_nonoverlapping(
756                        src.data.as_ptr().offset(src_start),
757                        self.data.as_mut_ptr().offset(dst_start),
758                        len,
759                    );
760                } else {
761                    std::ptr::copy(
762                        src.data.as_ptr().offset(src_start),
763                        self.data.as_mut_ptr().offset(dst_start),
764                        len,
765                    );
766                }
767            }
768        } else {
769            dispatch_datum!(assign_slice_t(self.datum_type())(self, range, src, src_range, axis));
770        }
771    }
772
773    /// Get the datum type of the tensor.
774    #[inline]
775    pub fn datum_type(&self) -> DatumType {
776        self.dt
777    }
778
779    /// Set the datum type of the tensor.
780    #[inline]
781    pub unsafe fn set_datum_type(&mut self, dt: DatumType) {
782        self.dt = dt
783    }
784
785    /// Dump the tensor in a human readable form.
786    ///
787    /// `force_full` will force the tensor to be dump in full even if it is big.
788    pub fn dump(&self, force_full: bool) -> TractResult<String> {
789        unsafe fn dump_t<D: Datum>(tensor: &Tensor, n: usize) -> String {
790            if let Some(qp) = tensor.datum_type().qparams() {
791                let integers = tensor.cast_to::<i32>().unwrap();
792                integers.as_slice_unchecked::<i32>()[0..n]
793                    .iter()
794                    .map(|x| format!("[{}]({})", x, qp.dq(*x)))
795                    .join(", ")
796            } else {
797                tensor.as_slice_unchecked::<D>()[0..n].iter().join(", ")
798            }
799        }
800        unsafe {
801            let trunc = self.len() > 12 && !force_full;
802            let data = dispatch_datum!(dump_t(self.datum_type())(
803                self,
804                if trunc { 12 } else { self.len() }
805            ));
806            Ok(format!(
807                "{},{:?} {}{}",
808                self.shape.iter().join(","),
809                self.dt,
810                data,
811                if trunc { "..." } else { "" }
812            ))
813        }
814    }
815
816    /// Compare two tensors, allowing for rounding errors.
817    pub fn close_enough(
818        &self,
819        other: &Self,
820        approx: impl Into<Approximation> + std::fmt::Debug,
821    ) -> TractResult<()> {
822        let approx = approx.into();
823        if self.shape() != other.shape() {
824            bail!("Shape mismatch {:?} != {:?}", self.shape(), other.shape())
825        }
826        let (atol, rtol, outliers) = approx.atol_rtol_outliers(&self.datum_type());
827        let ma = self.cast_to::<f32>()?;
828        let ma = ma.to_array_view::<f32>()?;
829        let mb = other.cast_to::<f32>()?;
830        let mb = mb.to_array_view::<f32>()?;
831        let mut first_outlier = None;
832        let mut outliers_count = 0;
833        ndarray::indices_of(&ma).into_iter().for_each(|indices| {
834            let a = ma[&indices];
835            let b = mb[&indices];
836            if !((a.is_nan() && b.is_nan())
837                || (a.is_infinite() && b.is_infinite() && a.signum() == b.signum())
838                || (a - b).abs() <= atol as f32 + rtol as f32 * b.abs())
839            {
840                if outliers_count == 0 {
841                    first_outlier = Some(indices.as_array_view().to_vec());
842                }
843                outliers_count += 1;
844            }
845        });
846        if self.volume() > 0 && outliers_count as f64 / self.volume() as f64 > outliers {
847            let indices = first_outlier.unwrap();
848            let a = ma[&*indices];
849            let b = mb[&*indices];
850            bail!(
851                "Mismatch. First outlier: {:?} for {:?}) at {:?} {} != {}. Outliers: {} / {} = {:0.5} > {:0.5}.",
852                approx,
853                self.datum_type(),
854                indices,
855                a,
856                b,
857                outliers_count,
858                self.volume(),
859                outliers_count as f64 / self.volume() as f64,
860                outliers
861            );
862        }
863        Ok(())
864    }
865
866    /// Transform the tensor into a `ndarray::Array`.
867    pub fn into_array<D: Datum>(self) -> TractResult<ArrayD<D>> {
868        Ok(self.to_array_view::<D>()?.to_owned())
869    }
870
871    /// Transform the tensor into a `ndarray::Array`.
872    pub unsafe fn into_array_unchecked<D: Datum>(self) -> ArrayD<D> {
873        self.to_array_view_unchecked::<D>().to_owned()
874    }
875
876    fn check_for_access<D: Datum>(&self) -> TractResult<()> {
877        ensure!(
878            self.datum_type().unquantized() == D::datum_type().unquantized(),
879            "Tensor datum type error: tensor is {:?}, accessed as {:?}",
880            self.datum_type(),
881            D::datum_type(),
882        );
883        Ok(())
884    }
885
886    /// Transform the data as a `ndarray::Array`.
887    pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<D>> {
888        self.check_for_access::<D>()?;
889        unsafe { Ok(self.to_array_view_unchecked()) }
890    }
891
892    /// Transform the data as a mutable `ndarray::Array`.
893    pub fn to_array_view_mut<D: Datum>(&mut self) -> TractResult<ArrayViewMutD<D>> {
894        self.check_for_access::<D>()?;
895        unsafe { Ok(self.to_array_view_mut_unchecked()) }
896    }
897
898    /// Transform the data as a `ndarray::Array`.
899    pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<D> {
900        if self.len() != 0 {
901            ArrayViewD::from_shape_ptr(&*self.shape, self.data.as_ptr() as *const D)
902        } else {
903            ArrayViewD::from_shape(&*self.shape, &[]).unwrap()
904        }
905    }
906
907    /// Transform the data as a mutable `ndarray::Array`.
908    pub unsafe fn to_array_view_mut_unchecked<D: Datum>(&mut self) -> ArrayViewMutD<D> {
909        if self.len() != 0 {
910            ArrayViewMutD::from_shape_ptr(&*self.shape, self.data.as_mut_ptr() as *mut D)
911        } else {
912            ArrayViewMutD::from_shape(&*self.shape, &mut []).unwrap()
913        }
914    }
915
916    /// Access the data as a pointer.
917    pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
918        self.check_for_access::<D>()?;
919        Ok(self.data.as_ptr() as *const D)
920    }
921
922    /// Access the data as a pointer.
923    pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
924        self.data.as_ptr() as *const D
925    }
926
927    /// Access the data as a pointer.
928    pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
929        self.data.as_mut_ptr() as *mut D
930    }
931
932    /// Access the data as a mutable pointer.
933    pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
934        self.as_ptr::<D>().map(|p| p as *mut D)
935    }
936
937    /// Access the data as a slice.
938    pub fn as_slice<D: Datum>(&self) -> TractResult<&[D]> {
939        let ptr: *const D = self.as_ptr()?;
940        if self.data.len() == 0 {
941            Ok(&[])
942        } else {
943            unsafe { Ok(std::slice::from_raw_parts::<D>(ptr, self.len())) }
944        }
945    }
946
947    /// Access the data as a mutable slice.
948    pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
949        let ptr: *mut D = self.as_ptr_mut()?;
950        if self.data.len() == 0 {
951            Ok(&mut [])
952        } else {
953            unsafe { Ok(std::slice::from_raw_parts_mut::<D>(ptr, self.len())) }
954        }
955    }
956
957    /// Access the data as a slice.
958    pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &[D] {
959        if self.data.len() == 0 {
960            &[]
961        } else {
962            std::slice::from_raw_parts::<D>(self.as_ptr_unchecked(), self.len())
963        }
964    }
965
966    /// Access the data as a mutable slice.
967    pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
968        if self.data.len() == 0 {
969            &mut []
970        } else {
971            std::slice::from_raw_parts_mut::<D>(self.as_ptr_mut_unchecked(), self.len())
972        }
973    }
974
975    /// Access the data as a scalar.
976    pub fn to_scalar<D: Datum>(&self) -> TractResult<&D> {
977        self.check_for_access::<D>()?;
978        if self.len() == 0 {
979            bail!("to_scalar called on empty tensor ({:?})", self)
980        }
981        if self.len() > 1 {
982            bail!("to_scalar called on a tensor with multiple values ({:?})", self)
983        }
984        unsafe { Ok(self.to_scalar_unchecked()) }
985    }
986
987    /// Make the tensor a scalar tensor (assumes it contains a single value).
988    pub fn to_scalar_tensor(&self) -> TractResult<Tensor> {
989        fn to_scalar_tensor_t<D: Datum>(t: &Tensor) -> TractResult<Tensor> {
990            Ok(litteral::tensor0(t.to_scalar::<D>()?.clone()))
991        }
992        dispatch_datum!(to_scalar_tensor_t(self.datum_type())(self))
993    }
994
995    /// Access the data as a scalar.
996    pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &D {
997        &*(self.data.as_ptr() as *const D)
998    }
999
1000    /// Mutable access the data as a scalar.
1001    pub fn to_scalar_mut<D: Datum>(&mut self) -> TractResult<&mut D> {
1002        self.check_for_access::<D>()?;
1003        if self.len() == 0 {
1004            bail!("to_scalar_mut called on empty tensor ({:?})", self)
1005        }
1006        if self.len() > 1 {
1007            bail!("to_scalar called on a tensor with multiple values ({:?})", self)
1008        }
1009        unsafe { Ok(self.to_scalar_mut_unchecked()) }
1010    }
1011
1012    /// Mutable access the data as a scalar.
1013    pub unsafe fn to_scalar_mut_unchecked<D: Datum>(&mut self) -> &mut D {
1014        &mut *(self.data.as_mut_ptr() as *mut D)
1015    }
1016
1017    pub fn as_bytes(&self) -> &[u8] {
1018        self.data.as_bytes()
1019    }
1020
1021    pub fn as_bytes_mut(&mut self) -> &mut [u8] {
1022        self.data.as_bytes_mut()
1023    }
1024
1025    unsafe fn is_uniform_t<T: Datum>(&self) -> bool {
1026        let slice = self.as_slice_unchecked::<T>();
1027        slice[1..].iter().all(|x| x == &slice[0])
1028    }
1029
1030    pub fn is_uniform(&self) -> bool {
1031        if self.len() <= 1 {
1032            return true;
1033        }
1034        unsafe { dispatch_datum!(Tensor::is_uniform_t(self.datum_type())(self)) }
1035    }
1036
1037    unsafe fn as_uniform_t<T: Datum>(&self) -> Tensor {
1038        let v: T = self.as_slice_unchecked::<T>()[0].clone();
1039        litteral::tensor0(v)
1040    }
1041
1042    pub fn as_uniform(&self) -> Option<Tensor> {
1043        if self.len() >= 1 && self.is_uniform() {
1044            unsafe {
1045                let mut t = dispatch_datum!(Tensor::as_uniform_t(self.datum_type())(self));
1046                t.set_datum_type(self.datum_type());
1047                Some(t)
1048            }
1049        } else {
1050            None
1051        }
1052    }
1053
1054    pub fn is_all_zero(&self) -> TractResult<bool> {
1055        Ok(self.len() == 0 || self.as_uniform().map(|t| t.is_zero().unwrap()).unwrap_or(false))
1056    }
1057
1058    pub fn is_zero(&self) -> TractResult<bool> {
1059        Ok(self == &Tensor::zero_scalar_dt(self.dt)?)
1060    }
1061
1062    unsafe fn natural_cast<
1063        Source: Datum + num_traits::AsPrimitive<Target>,
1064        Target: Datum + Copy,
1065    >(
1066        &self,
1067        other: &mut Tensor,
1068    ) {
1069        self.as_slice_unchecked::<Source>()
1070            .iter()
1071            .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1072            .for_each(|(s, d)| *d = s.as_());
1073    }
1074
1075    unsafe fn cast_number_to_bool<Source: Datum + num_traits::Zero>(&self, other: &mut Tensor) {
1076        self.as_slice_unchecked::<Source>()
1077            .iter()
1078            .zip(other.as_slice_mut_unchecked::<bool>().iter_mut())
1079            .for_each(|(s, d)| *d = !s.is_zero());
1080    }
1081
1082    unsafe fn cast_from_string<Target: Datum + core::str::FromStr>(
1083        &self,
1084        other: &mut Tensor,
1085    ) -> TractResult<()> {
1086        for (s, d) in self
1087            .as_slice_unchecked::<String>()
1088            .iter()
1089            .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1090        {
1091            *d = s
1092                .parse()
1093                .map_err(|_| format_err!("Can not parse as {:?}", Target::datum_type()))?;
1094        }
1095        Ok(())
1096    }
1097
1098    unsafe fn cast_to_string<Source: Datum>(&self, other: &mut Tensor) {
1099        for (s, d) in self
1100            .as_slice_unchecked::<Source>()
1101            .iter()
1102            .zip(other.as_slice_mut_unchecked::<String>().iter_mut())
1103        {
1104            *d = s.to_string()
1105        }
1106    }
1107
1108    /// Optionnaly convert data to a tensor for a new DatumType.
1109    pub fn cast_to<D: Datum>(&self) -> TractResult<Cow<Tensor>> {
1110        self.cast_to_dt(D::datum_type())
1111    }
1112
1113    /// Optionnaly convert data to a tensor for a new DatumType.
1114    #[allow(clippy::redundant_closure_call)]
1115    pub fn cast_to_dt(&self, dst_dt: DatumType) -> TractResult<Cow<Tensor>> {
1116        unsafe {
1117            if self.dt == dst_dt {
1118                return Ok(Cow::Borrowed(self));
1119            }
1120            if self.dt == TDim::datum_type() && (dst_dt.is_integer() || dst_dt.is_float()) {
1121                let slice = self.as_slice_unchecked::<TDim>();
1122                let mut ints = Self::uninitialized::<i64>(&self.shape)?;
1123                let ints_slice = ints.as_slice_mut_unchecked::<i64>();
1124                for i in 0..self.len() {
1125                    ints_slice[i] = slice[i].to_i64()?;
1126                }
1127                return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1128            }
1129            if self.dt == bool::datum_type()
1130                && (dst_dt.is_integer() || dst_dt.is_float() || dst_dt == TDim::datum_type())
1131            {
1132                let slice = self.as_slice_unchecked::<bool>();
1133                let mut ints = Self::uninitialized::<i8>(&self.shape)?;
1134                let ints_slice = ints.as_slice_mut_unchecked::<i8>();
1135                for i in 0..self.len() {
1136                    ints_slice[i] = slice[i] as usize as i8;
1137                }
1138                return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1139            }
1140            let mut result = Self::uninitialized_dt(dst_dt, &self.shape)?;
1141            if self.dt == DatumType::String {
1142                dispatch_numbers!(Self::cast_from_string(dst_dt)(self, &mut result))?;
1143                return Ok(Cow::Owned(result));
1144            }
1145            if dst_dt == DatumType::String {
1146                dispatch_datum!(Self::cast_to_string(self.dt)(self, &mut result));
1147                return Ok(Cow::Owned(result));
1148            }
1149            macro_rules! n {
1150                ($source:ty) => {
1151                    if <$source>::datum_type() == self.datum_type() {
1152                        match dst_dt {
1153                            DatumType::I8 => self.natural_cast::<$source, i8>(&mut result),
1154                            DatumType::I16 => self.natural_cast::<$source, i16>(&mut result),
1155                            DatumType::I32 => self.natural_cast::<$source, i32>(&mut result),
1156                            DatumType::I64 => self.natural_cast::<$source, i64>(&mut result),
1157                            DatumType::U8 => self.natural_cast::<$source, u8>(&mut result),
1158                            DatumType::U16 => self.natural_cast::<$source, u16>(&mut result),
1159                            DatumType::U32 => self.natural_cast::<$source, u32>(&mut result),
1160                            DatumType::U64 => self.natural_cast::<$source, u64>(&mut result),
1161                            DatumType::F16 => self.natural_cast::<$source, f16>(&mut result),
1162                            DatumType::F32 => self.natural_cast::<$source, f32>(&mut result),
1163                            DatumType::F64 => self.natural_cast::<$source, f64>(&mut result),
1164                            DatumType::TDim => {
1165                                let ints = self.cast_to::<i32>()?;
1166                                let slice = ints.as_slice_unchecked::<i32>();
1167                                let result = result.as_slice_mut_unchecked::<TDim>();
1168                                for i in 0..self.len() {
1169                                    result[i] = slice[i].into();
1170                                }
1171                            }
1172                            DatumType::Bool => self.cast_number_to_bool::<$source>(&mut result),
1173                            _ => todo!(),
1174                        }
1175                        return Ok(Cow::Owned(result));
1176                    };
1177                };
1178            }
1179            //If there is no quantization
1180            if !dst_dt.is_quantized() && !self.datum_type().is_quantized() {
1181                n!(u8);
1182                n!(u16);
1183                n!(u32);
1184                n!(u64);
1185                n!(i8);
1186                n!(i16);
1187                n!(i32);
1188                n!(i64);
1189                n!(f16);
1190                n!(f32);
1191                n!(f64);
1192            } else {
1193                let (s_zp, s_scale) = self.datum_type().zp_scale();
1194                let (d_zp, d_scale) = dst_dt.zp_scale();
1195                if self.datum_type().is_quantized() && dst_dt.is_float() {
1196                    macro_rules! q_to_fp {
1197                        ($source:ty, $dest:ty) => {
1198                            if <$source>::datum_type().unquantized()
1199                                == self.datum_type().unquantized()
1200                                && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1201                            {
1202                                self.as_slice_unchecked::<$source>()
1203                                    .iter()
1204                                    .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1205                                    .for_each(|(&s, d)| {
1206                                        *d = (s as $dest - s_zp as $dest) * s_scale as $dest;
1207                                    });
1208                                return Ok(Cow::Owned(result));
1209                            }
1210                        };
1211                    }
1212                    q_to_fp!(i8, f64);
1213                    q_to_fp!(i8, f32);
1214                    q_to_fp!(u8, f64);
1215                    q_to_fp!(u8, f32);
1216                }
1217                //TODO: optimize scale_by
1218                macro_rules! q8_to_q8 {
1219                    ($typ:ty) => {
1220                        if dst_dt.unquantized() == <$typ>::datum_type() {
1221                            self.as_slice_unchecked::<$typ>()
1222                                .iter()
1223                                .zip(result.as_slice_mut_unchecked::<$typ>().iter_mut())
1224                                .for_each(|(&s, d)| {
1225                                    *d = (d_zp as i32
1226                                        + scale_by(s as i32 - s_zp as i32, s_scale / d_scale))
1227                                    .clamp_cast()
1228                                });
1229                            return Ok(Cow::Owned(result));
1230                        }
1231                    };
1232                }
1233
1234                macro_rules! q_via_f32 {
1235                    ($source:ty, $dest:ty, $round:expr) => {
1236                        if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1237                            && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1238                        {
1239                            self.as_slice_unchecked::<$source>()
1240                                .iter()
1241                                .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1242                                .for_each(|(&s, d)| {
1243                                    let s_float = (s as f32 - s_zp as f32) * s_scale as f32;
1244                                    let d_float = s_float as f32 / d_scale as f32 + d_zp as f32;
1245                                    *d = $round(d_float);
1246                                });
1247                            return Ok(Cow::Owned(result));
1248                        }
1249                    };
1250                }
1251
1252                macro_rules! q_n {
1253                    (clamp $source:ty, $dest:ty) => {{
1254                        if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1255                            && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1256                        {
1257                            self.as_slice_unchecked::<$source>()
1258                                .iter()
1259                                .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1260                                .for_each(|(&s, d)| {
1261                                    *d = s.clamp_cast();
1262                                });
1263                            return Ok(Cow::Owned(result));
1264                        }
1265                    }};
1266                    ($source:ty, $dest:ty) => {{
1267                        if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1268                            && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1269                        {
1270                            self.as_slice_unchecked::<$source>()
1271                                .iter()
1272                                .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1273                                .for_each(|(&s, d)| {
1274                                    *d = s as $dest;
1275                                });
1276                            return Ok(Cow::Owned(result));
1277                        }
1278                    }};
1279                }
1280
1281                if dst_dt.unquantized() == self.datum_type().unquantized()
1282                    && dst_dt.is_quantized()
1283                    && self.datum_type().is_quantized()
1284                {
1285                    q8_to_q8!(i8);
1286                    q8_to_q8!(u8);
1287                }
1288
1289                q_via_f32!(f32, i8, |f| round_ties_to_even(f).clamp_cast());
1290                q_via_f32!(f32, u8, |f| round_ties_to_even(f).clamp_cast());
1291                q_via_f32!(f32, i32, |f| round_ties_to_even(f).clamp_cast());
1292                q_via_f32!(i8, f32, |f| f);
1293                q_via_f32!(u8, f32, |f| f);
1294                q_via_f32!(i32, f32, |f| f);
1295
1296                if dst_dt.is_quantized() && self.datum_type().is_quantized() {
1297                    q_via_f32!(u8, i8, |f| round_ties_to_even(f).clamp_cast());
1298                    q_via_f32!(i8, u8, |f| round_ties_to_even(f).clamp_cast());
1299                    q_via_f32!(i32, u8, |f| round_ties_to_even(f).clamp_cast());
1300                    q_via_f32!(i32, i8, |f| round_ties_to_even(f).clamp_cast());
1301                    q_via_f32!(u8, i32, |f| round_ties_to_even(f).clamp_cast());
1302                    q_via_f32!(i8, i32, |f| round_ties_to_even(f).clamp_cast());
1303
1304                    // ensure cast to different scale offset work
1305                    q_via_f32!(i8, i8, |f| round_ties_to_even(f).clamp_cast());
1306                    q_via_f32!(u8, u8, |f| round_ties_to_even(f).clamp_cast());
1307                }
1308
1309                q_n!(i8, i32);
1310                q_n!(i8, u32);
1311                q_n!(u8, i32);
1312                q_n!(u8, u32);
1313                q_n!(clamp i32, i8);
1314                q_n!(clamp i32, u8);
1315                q_n!(clamp u32, i8);
1316                q_n!(clamp u32, u8);
1317                q_n!(i8, i8);
1318                q_n!(u8, u8);
1319                q_n!(i32, i32);
1320                q_n!(u32, u32);
1321            }
1322
1323            bail!("Unsupported cast from {:?} to {:?}", self.dt, dst_dt)
1324        }
1325    }
1326
1327    /// Access the data as a scalar, after a cast.
1328    pub fn cast_to_scalar<D: Datum + Copy>(&self) -> TractResult<D> {
1329        let casted = self.cast_to::<D>()?;
1330        casted.to_scalar::<D>().copied()
1331    }
1332
1333    /// Access the nth element of the tensor, returned as a 0-rank Tensor
1334    pub fn nth(&self, nth: usize) -> TractResult<Tensor> {
1335        if nth >= self.len() {
1336            bail!(
1337                "nth called with {}th element on a tensor of len {} ({:?}",
1338                nth,
1339                self.len(),
1340                self
1341            );
1342        }
1343        unsafe fn nth_t<T: Datum>(me: &Tensor, nth: usize, output: &mut Tensor) {
1344            let value = me.as_slice_unchecked::<T>()[nth].clone();
1345            output.as_slice_mut_unchecked::<T>()[0] = value;
1346        }
1347        unsafe {
1348            let mut output = Tensor::uninitialized_dt(self.datum_type(), &[])?;
1349            dispatch_datum_by_size!(nth_t(self.datum_type())(self, nth, &mut output));
1350            Ok(output)
1351        }
1352    }
1353
1354    /// Strict equality test on tensors.
1355    fn eq_dt(&self, other: &Tensor) -> TractResult<bool> {
1356        unsafe fn eq_t<D: Datum>(me: &Tensor, other: &Tensor) -> bool {
1357            me.as_slice_unchecked::<D>() == other.as_slice_unchecked::<D>()
1358        }
1359
1360        unsafe {
1361            Ok(self.datum_type() == other.datum_type()
1362                && self.shape() == other.shape()
1363                && dispatch_datum!(eq_t(self.dt)(self, other)))
1364        }
1365    }
1366
1367    fn from_datum<T: Datum>(mut it: ArrayD<T>) -> Tensor {
1368        unsafe {
1369            let mut t = Self::uninitialized::<T>(it.shape()).unwrap();
1370            if let Some(slice) = it.as_slice_mut() {
1371                if t.datum_type().is_copy() {
1372                    std::ptr::copy_nonoverlapping(
1373                        slice.as_ptr() as *const i8,
1374                        t.as_ptr_mut_unchecked(),
1375                        t.data.layout().size(),
1376                    );
1377                } else {
1378                    t.as_slice_mut_unchecked::<T>()
1379                        .iter_mut()
1380                        .zip(slice.iter_mut())
1381                        .for_each(|(t, s)| *t = std::mem::take(s));
1382                }
1383                return t;
1384            }
1385            if it.strides().iter().all(|&s| s > 0) && it.as_slice_memory_order().is_some() {
1386                let mut len_and_strides: TVec<(usize, usize)> = tvec!();
1387                for (len, stride) in itertools::izip!(it.shape(), it.strides(), t.strides())
1388                    .sorted_by_key(|(_, src, _)| *src)
1389                    .map(|(l, _, dst)| (*l as isize, *dst))
1390                {
1391                    if !len_and_strides.is_empty()
1392                        && len_and_strides.last().unwrap().1 * len_and_strides.last().unwrap().0
1393                            == stride as usize
1394                    {
1395                        len_and_strides.last_mut().unwrap().0 *= len as usize;
1396                    } else {
1397                        len_and_strides.push((len as usize, stride as usize));
1398                    }
1399                }
1400                len_and_strides.reverse();
1401                crate::scatter::scatter_contig_data(
1402                    it.as_ptr(),
1403                    t.as_ptr_mut_unchecked(),
1404                    &len_and_strides,
1405                );
1406                return t;
1407            }
1408            // finally use ndarray into_iter()
1409            t.as_slice_mut_unchecked().iter_mut().zip(it).for_each(|(t, a)| *t = a);
1410            t
1411        }
1412    }
1413
1414    pub fn deep_clone(&self) -> Tensor {
1415        unsafe {
1416            let mut tensor = Tensor::uninitialized_dt(self.datum_type(), self.shape()).unwrap();
1417            if self.len() > 0 {
1418                if self.dt.is_copy() {
1419                    self.data.as_ptr().copy_to_nonoverlapping(
1420                        tensor.as_bytes_mut().as_mut_ptr(),
1421                        self.data.layout().size(),
1422                    )
1423                } else if self.dt == DatumType::String {
1424                    tensor
1425                        .as_slice_mut_unchecked::<String>()
1426                        .clone_from_slice(self.as_slice_unchecked());
1427                } else if self.dt == DatumType::Blob {
1428                    tensor
1429                        .as_slice_mut_unchecked::<Blob>()
1430                        .clone_from_slice(self.as_slice_unchecked());
1431                } else if self.dt == DatumType::Opaque {
1432                    tensor
1433                        .as_slice_mut_unchecked::<Opaque>()
1434                        .clone_from_slice(self.as_slice_unchecked());
1435                } else if self.dt == DatumType::TDim {
1436                    tensor
1437                        .as_slice_mut_unchecked::<TDim>()
1438                        .clone_from_slice(self.as_slice_unchecked());
1439                }
1440            }
1441            tensor
1442        }
1443    }
1444
1445    pub fn slice(&self, axis: usize, start: usize, end: usize) -> TractResult<Tensor> {
1446        if axis >= self.rank() {
1447            bail!("Can not slice at axis {} tensor {:?}", axis, self);
1448        }
1449        if start > self.shape[axis] || end > self.shape[axis] || start >= end {
1450            bail!("Invalid slicing range {start}..{end} on axis {axis} for {self:?}");
1451        }
1452        fn slice_t<T: Datum>(
1453            t: &Tensor,
1454            axis: usize,
1455            start: usize,
1456            end: usize,
1457        ) -> TractResult<Tensor> {
1458            Ok(t.to_array_view::<T>()?
1459                .slice_axis(ndarray::Axis(axis), (start..end).into())
1460                .into_owned()
1461                .into_tensor())
1462        }
1463        dispatch_datum!(slice_t(self.datum_type())(self, axis, start, end))
1464    }
1465
1466    #[inline]
1467    pub fn view(&self) -> view::TensorView {
1468        unsafe { view::TensorView::view(self) }
1469    }
1470
1471    #[inline]
1472    pub fn view_at_prefix(&self, prefix: &[usize]) -> TractResult<view::TensorView> {
1473        view::TensorView::at_prefix(self, prefix)
1474    }
1475
1476    #[inline]
1477    pub fn view_offsetting(&self, coords: &[usize]) -> TractResult<view::TensorView> {
1478        view::TensorView::offsetting(self, coords)
1479    }
1480
1481    #[inline]
1482    pub unsafe fn view_offsetting_unchecked(&self, coords: &[usize]) -> view::TensorView {
1483        view::TensorView::offsetting_unchecked(self, coords)
1484    }
1485
1486    #[inline]
1487    pub fn view_mut(&mut self) -> view::TensorView {
1488        unsafe { view::TensorView::view(self) }
1489    }
1490
1491    #[inline]
1492    pub fn view_at_prefix_mut(&mut self, prefix: &[usize]) -> TractResult<view::TensorView> {
1493        view::TensorView::at_prefix(self, prefix)
1494    }
1495
1496    #[inline]
1497    pub fn view_offsetting_mut(&mut self, coords: &[usize]) -> TractResult<view::TensorView> {
1498        view::TensorView::offsetting(self, coords)
1499    }
1500
1501    /// Offsets the tensor as an i8 type if it's an u8 type, otherwise passes it unchanged.
1502    pub fn offset_u8_as_i8(self: &Arc<Self>) -> Arc<Self> {
1503        let mut t = if let DatumType::U8 = self.dt.unquantized() {
1504            self.to_array_view::<u8>().unwrap().mapv(|v| v.wrapping_sub(128) as i8).into_tensor()
1505        } else {
1506            return self.clone();
1507        };
1508
1509        if let DatumType::QU8(qp) = self.dt {
1510            if let QParams::ZpScale { zero_point, scale } = qp {
1511                t.dt = DatumType::QI8(QParams::ZpScale { zero_point: zero_point - 128, scale });
1512            } else {
1513                t.dt = DatumType::QI8(qp);
1514            }
1515        }
1516
1517        t.into_arc_tensor()
1518    }
1519
1520    /// Offsets the tensor as an u8 type if it's an i8 type, otherwise passes it unchanged.
1521    pub fn offset_i8_as_u8(self: &Arc<Self>) -> Arc<Self> {
1522        let mut t = if let DatumType::I8 = self.dt.unquantized() {
1523            self.to_array_view::<i8>().unwrap().mapv(|v| (v as u8).wrapping_add(128)).into_tensor()
1524        } else {
1525            return self.clone();
1526        };
1527
1528        if let DatumType::QI8(qp) = self.dt {
1529            if let QParams::ZpScale { zero_point, scale } = qp {
1530                t.dt = DatumType::QU8(QParams::ZpScale { zero_point: zero_point + 128, scale });
1531            } else {
1532                t.dt = DatumType::QU8(qp);
1533            }
1534        }
1535        t.into_arc_tensor()
1536    }
1537
1538    pub fn to_aligned_default(&self) -> TractResult<Self> {
1539        if self.dt.is_copy() {
1540            unsafe {
1541                let mut t = Self::uninitialized_dt(self.dt, &self.shape)?;
1542                t.as_bytes_mut().copy_from_slice(self.as_bytes());
1543                Ok(t)
1544            }
1545        } else {
1546            let mut t = Self::zero_dt(self.dt, &self.shape)?;
1547            if self.dt == String::datum_type() {
1548                t.as_slice_mut::<String>()?.clone_from_slice(self.as_slice()?);
1549            } else if self.dt == Blob::datum_type() {
1550                t.as_slice_mut::<Blob>()?.clone_from_slice(self.as_slice()?);
1551            } else if self.dt == TDim::datum_type() {
1552                t.as_slice_mut::<TDim>()?.clone_from_slice(self.as_slice()?);
1553            }
1554            Ok(t)
1555        }
1556    }
1557
1558    pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1559        let mut strides = tvec!();
1560        compute_natural_stride_to(&mut strides, shape);
1561        strides
1562    }
1563
1564    pub fn into_blob(mut self) -> TractResult<Blob> {
1565        ensure!(self.dt.is_copy());
1566        Ok(std::mem::take(&mut self.data))
1567    }
1568}
1569
1570impl PartialEq for Tensor {
1571    fn eq(&self, other: &Tensor) -> bool {
1572        if self.dt != other.dt || self.shape != other.shape {
1573            return false;
1574        }
1575        self.eq_dt(other).unwrap_or(false)
1576    }
1577}
1578
1579impl fmt::Debug for Tensor {
1580    fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1581        let content = self.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
1582        write!(formatter, "{content}")
1583    }
1584}
1585
1586#[cfg(feature = "complex")]
1587pub fn reinterpret_inner_dim_as_complex(mut t: Tensor) -> TractResult<Tensor> {
1588    ensure!(
1589        t.shape().last() == Some(&2),
1590        "The last dimension in the tensor shape {:?} must be 2",
1591        t.shape()
1592    );
1593    unsafe {
1594        t.shape.pop();
1595        t.set_datum_type(t.datum_type().complexify()?);
1596        t.update_strides_and_len();
1597        Ok(t)
1598    }
1599}
1600
1601#[cfg(feature = "complex")]
1602pub fn reinterpret_complex_as_inner_dim(mut t: Tensor) -> TractResult<Tensor> {
1603    unsafe {
1604        t.shape.push(2);
1605        t.set_datum_type(t.datum_type().decomplexify()?);
1606        t.update_strides_and_len();
1607        Ok(t)
1608    }
1609}
1610
1611pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1612    let mut strides = tvec!();
1613    compute_natural_stride_to(&mut strides, shape);
1614    strides
1615}
1616
1617fn compute_natural_stride_to(strides: &mut TVec<isize>, shape: &[usize]) {
1618    match shape.len() {
1619        0 => (),
1620        1 => strides.push(1),
1621        2 => strides.extend_from_slice(&[shape[1] as isize, 1]),
1622        3 => strides.extend_from_slice(&[(shape[1] * shape[2]) as isize, shape[2] as _, 1]),
1623        4 => strides.extend_from_slice(&[
1624            (shape[1] * shape[2] * shape[3]) as isize,
1625            (shape[2] * shape[3]) as _,
1626            shape[3] as _,
1627            1,
1628        ]),
1629        _ => {
1630            strides.push(1);
1631            for dim in shape.as_ref().iter().skip(1).rev() {
1632                let previous = *strides.last().unwrap();
1633                strides.push(previous * *dim as isize)
1634            }
1635            strides.reverse();
1636        }
1637    }
1638}
1639
1640impl<D: ::ndarray::Dimension, T: Datum> From<Array<T, D>> for Tensor {
1641    fn from(it: Array<T, D>) -> Tensor {
1642        Tensor::from_datum(it.into_dyn())
1643    }
1644}
1645
1646/// Convenient conversion to Tensor.
1647pub trait IntoTensor: Sized {
1648    /// Convert Self to a Tensor.
1649    ///
1650    /// May perform a copy
1651    fn into_tensor(self) -> Tensor;
1652}
1653
1654/// Convenient conversion to Arc<Tensor>.
1655pub trait IntoArcTensor: Sized {
1656    /// Convert Self to a Arc<Tensor>.
1657    ///
1658    /// May perform a copy
1659    fn into_arc_tensor(self) -> Arc<Tensor>;
1660}
1661
1662impl<D: ::ndarray::Dimension, T: Datum> IntoTensor for Array<T, D> {
1663    fn into_tensor(self) -> Tensor {
1664        Tensor::from(self)
1665    }
1666}
1667
1668impl<D: ::ndarray::Dimension, T: Datum> IntoArcTensor for Array<T, D> {
1669    fn into_arc_tensor(self) -> Arc<Tensor> {
1670        Arc::new(Tensor::from(self))
1671    }
1672}
1673
1674impl IntoTensor for Tensor {
1675    fn into_tensor(self) -> Tensor {
1676        self
1677    }
1678}
1679
1680impl IntoTensor for Arc<Tensor> {
1681    fn into_tensor(self) -> Tensor {
1682        Arc::try_unwrap(self).unwrap_or_else(|t| (*t).clone())
1683    }
1684}
1685
1686impl IntoArcTensor for Tensor {
1687    fn into_arc_tensor(self) -> Arc<Tensor> {
1688        Arc::new(self)
1689    }
1690}
1691
1692impl IntoArcTensor for Arc<Tensor> {
1693    fn into_arc_tensor(self) -> Arc<Tensor> {
1694        self
1695    }
1696}
1697
1698#[cfg(test)]
1699mod tests {
1700    use crate::dim::SymbolScope;
1701    use crate::prelude::tensor1;
1702
1703    use super::*;
1704    use litteral::tensor0;
1705    use proptest::collection::vec;
1706    use proptest::prelude::*;
1707
1708    #[derive(Debug)]
1709    struct PermuteAxisProblem {
1710        shape: Vec<usize>,
1711        permutation: Vec<usize>,
1712    }
1713
1714    impl Arbitrary for PermuteAxisProblem {
1715        type Strategy = BoxedStrategy<PermuteAxisProblem>;
1716        type Parameters = ();
1717
1718        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1719            (0..8usize)
1720                .prop_flat_map(|rank| {
1721                    let permute: Vec<usize> = (0..rank).collect();
1722                    (proptest::collection::vec(1..5usize, rank), Just(permute).prop_shuffle())
1723                })
1724                .prop_map(|(shape, permutation)| PermuteAxisProblem { shape, permutation })
1725                .boxed()
1726        }
1727    }
1728
1729    impl PermuteAxisProblem {
1730        fn input(&self) -> ArrayD<i32> {
1731            let mut i = 0;
1732            ArrayD::from_shape_simple_fn(&*self.shape, || {
1733                i += 1;
1734                i
1735            })
1736            .permuted_axes(&*self.permutation)
1737        }
1738
1739        fn reference(&self) -> Tensor {
1740            let values: Vec<i32> = self.input().iter().copied().collect();
1741            let shape = self.permutation.iter().map(|ix| self.shape[*ix]).collect::<TVec<usize>>();
1742            super::litteral::tensor1(&values).into_shape(&shape).unwrap()
1743        }
1744
1745        fn tract(&self) -> Tensor {
1746            Tensor::from(self.input())
1747        }
1748
1749        fn check(&self) -> proptest::test_runner::TestCaseResult {
1750            prop_assert_eq!(self.tract(), self.reference());
1751            Ok(())
1752        }
1753    }
1754
1755    proptest::proptest! {
1756        #[test]
1757        fn prop(pb: PermuteAxisProblem) {
1758            pb.check().unwrap();
1759        }
1760    }
1761
1762    #[test]
1763    fn t_1_2() {
1764        PermuteAxisProblem { shape: vec![2, 1], permutation: vec![1, 0] }.check().unwrap();
1765    }
1766
1767    #[test]
1768    fn t_2_2() {
1769        PermuteAxisProblem { shape: vec![2, 2], permutation: vec![1, 0] }.check().unwrap();
1770    }
1771
1772    #[derive(Debug)]
1773    struct BroadcastVecToShape {
1774        vec: Vec<f32>,
1775        axis: usize,
1776        shape: TVec<usize>,
1777    }
1778
1779    impl BroadcastVecToShape {
1780        fn check(&self) -> proptest::test_runner::TestCaseResult {
1781            let input = tensor1(&self.vec);
1782            let mut intermediate = tvec![1usize; self.shape.len()];
1783            intermediate[self.axis] = self.vec.len();
1784            let reference = input
1785                .clone()
1786                .into_shape(&intermediate)
1787                .unwrap()
1788                .broadcast_to_shape(&self.shape)
1789                .unwrap();
1790            prop_assert_eq!(
1791                reference,
1792                input.broadcast_vector_to_shape(&self.shape, self.axis).unwrap()
1793            );
1794            Ok(())
1795        }
1796    }
1797
1798    impl Arbitrary for BroadcastVecToShape {
1799        type Strategy = BoxedStrategy<BroadcastVecToShape>;
1800        type Parameters = ();
1801
1802        fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1803            vec(0usize..5, 0usize..4)
1804                .prop_flat_map(|shape| {
1805                    (vec(-10f32..10f32, 0usize..5), Just(shape.clone()), 0..shape.len() + 1)
1806                })
1807                .prop_map(|(vec, mut shape, axis)| {
1808                    shape.insert(axis, vec.len());
1809                    BroadcastVecToShape { vec, shape: shape.into(), axis }
1810                })
1811                .boxed()
1812        }
1813    }
1814
1815    proptest::proptest! {
1816        #[test]
1817        fn broadcast_vector_to_shape_prop(pb: BroadcastVecToShape) {
1818            pb.check().unwrap()
1819        }
1820    }
1821
1822    #[test]
1823    #[cfg(feature = "complex")]
1824    fn test_reinterpret_inner_dim_as_complex() -> TractResult<()> {
1825        let input = crate::internal::tensor2(&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]]);
1826        let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1827        let expected = crate::internal::tensor1(&[
1828            Complex::new(1.0f32, 2.0),
1829            Complex::new(3.0, 4.0),
1830            Complex::new(5.0, 6.0),
1831        ]);
1832        assert_eq!(expected, cplx_input);
1833        Ok(())
1834    }
1835
1836    #[test]
1837    #[cfg(feature = "complex")]
1838    fn test_reinterpret_inner_dim_as_complex_2() -> TractResult<()> {
1839        let input =
1840            crate::internal::tensor3(&[[[1i32, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]);
1841        let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1842        let expected = crate::internal::tensor2(&[
1843            [Complex::new(1i32, 2), Complex::new(1, 2)],
1844            [Complex::new(3, 4), Complex::new(3, 4)],
1845            [Complex::new(5, 6), Complex::new(5, 6)],
1846        ]);
1847        assert_eq!(expected, cplx_input);
1848        Ok(())
1849    }
1850
1851    #[test]
1852    fn clone_tdim_tensor() {
1853        let symbols = SymbolScope::default();
1854        let a = symbols.sym("a");
1855        let t = tensor0(TDim::from(a));
1856        let _ = t.clone();
1857    }
1858}