1use crate::blob::Blob;
3use crate::datum::{round_ties_to_even, scale_by, ClampCast, Datum, DatumType, QParams};
4use crate::dim::TDim;
5use crate::internal::*;
6use crate::opaque::Opaque;
7use crate::TVec;
8use half::f16;
9use itertools::Itertools;
10use ndarray::prelude::*;
11#[cfg(feature = "complex")]
12use num_complex::Complex;
13use num_traits::Zero;
14use std::borrow::Cow;
15use std::fmt;
16use std::hash::Hash;
17use std::ops::Range;
18use std::sync::Arc;
19
20pub mod litteral;
21pub mod view;
22
23#[derive(Copy, Clone, Default, Debug)]
24pub enum Approximation {
25 Exact,
26 #[default]
27 Close,
28 Approximate,
29 VeryApproximate,
30 SuperApproximate,
31 UltraApproximate,
32 Custom(f32, f32, f32),
33}
34
35impl PartialEq for Approximation {
36 fn eq(&self, other: &Self) -> bool {
37 use Approximation::Custom;
38 if let (Custom(aa, ar, ao), Custom(ba, br, bo)) = (self, other) {
39 aa == ba && ar == br && bo == ao
40 } else {
41 std::mem::discriminant(self) == std::mem::discriminant(other)
42 }
43 }
44}
45
46impl Eq for Approximation {}
47
48impl From<bool> for Approximation {
49 fn from(b: bool) -> Self {
50 if b {
51 Self::Approximate
52 } else {
53 Self::Exact
54 }
55 }
56}
57
58impl Approximation {
59 fn atol_rtol_outliers(&self, dt: &DatumType) -> (f64, f64, f64) {
60 use Approximation::*;
61 match (self, dt) {
62 (Exact, _) => (0.0, 0.0, 0.0),
63 (Close, DatumType::F16) => (1e-3, 1e-3, 0.0),
64 (Approximate, DatumType::F16) => (1e-3, 5e-3, 0.0),
65 (Approximate, qp) if qp.is_quantized() => (qp.zp_scale().1 as f64, 0., 0.0),
66 (Close, _) => (1e-7, 1e-7, 0.0),
67 (Approximate, _) => (1e-4, 5e-4, 0.0),
68 (VeryApproximate, _) => (5e-2, 1e-2, 0.0),
69 (SuperApproximate, _) => (0.1, 0.05, 0.0001),
70 (UltraApproximate, _) => (0.2, 0.1, 0.0005),
71 (Custom(atol, rtol, out), _) => (*atol as _, *rtol as _, *out as _),
72 }
73 }
74}
75
76#[derive(Eq)]
78pub struct Tensor {
79 dt: DatumType,
80 shape: TVec<usize>,
81 strides: TVec<isize>,
82 len: usize,
83 data: Blob,
84}
85
86unsafe impl Send for Tensor {}
87unsafe impl Sync for Tensor {}
88
89impl Hash for Tensor {
90 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
91 use DatumType::*;
92 self.dt.hash(state);
93 self.shape.hash(state);
94 self.data.layout().align().hash(state);
95 unsafe {
96 match self.dt {
97 Bool => self.as_slice_unchecked::<bool>().hash(state),
98 I8 => self.as_slice_unchecked::<i8>().hash(state),
99 I16 => self.as_slice_unchecked::<i16>().hash(state),
100 I32 => self.as_slice_unchecked::<i32>().hash(state),
101 I64 => self.as_slice_unchecked::<i64>().hash(state),
102 U8 => self.as_slice_unchecked::<u8>().hash(state),
103 U16 => self.as_slice_unchecked::<u16>().hash(state),
104 U32 => self.as_slice_unchecked::<u32>().hash(state),
105 U64 => self.as_slice_unchecked::<u64>().hash(state),
106 F16 => self.as_slice_unchecked::<i16>().hash(state),
107 F32 => self.as_slice_unchecked::<i32>().hash(state),
108 F64 => self.as_slice_unchecked::<i64>().hash(state),
109 TDim => self.as_slice_unchecked::<crate::dim::TDim>().hash(state),
110 String => self.as_slice_unchecked::<std::string::String>().hash(state),
111 Blob => self.as_slice_unchecked::<crate::blob::Blob>().hash(state),
112 Opaque => self.as_slice_unchecked::<crate::opaque::Opaque>().hash(state),
113 QI8(_) => self.as_slice_unchecked::<i8>().hash(state),
114 QU8(_) => self.as_slice_unchecked::<u8>().hash(state),
115 QI32(_) => self.as_slice_unchecked::<i32>().hash(state),
116 #[cfg(feature = "complex")]
117 ComplexI16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
118 #[cfg(feature = "complex")]
119 ComplexI32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
120 #[cfg(feature = "complex")]
121 ComplexI64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
122 #[cfg(feature = "complex")]
123 ComplexF16 => self.as_slice_unchecked::<Complex<i16>>().hash(state),
124 #[cfg(feature = "complex")]
125 ComplexF32 => self.as_slice_unchecked::<Complex<i32>>().hash(state),
126 #[cfg(feature = "complex")]
127 ComplexF64 => self.as_slice_unchecked::<Complex<i64>>().hash(state),
128 }
129 }
130 }
131}
132
133impl Clone for Tensor {
134 fn clone(&self) -> Tensor {
135 self.deep_clone()
136 }
137}
138
139impl Default for Tensor {
140 fn default() -> Tensor {
141 litteral::tensor0(0f32)
142 }
143}
144
145impl Drop for Tensor {
146 fn drop(&mut self) {
147 macro_rules! drop_in_place {
148 ($t: ty) => {
149 if self.dt == <$t>::datum_type() {
150 unsafe {
151 self.as_slice_mut::<$t>()
152 .unwrap()
153 .iter_mut()
154 .for_each(|s| std::ptr::drop_in_place(s as *mut $t));
155 }
156 }
157 };
158 }
159 drop_in_place!(Blob);
160 drop_in_place!(String);
161 drop_in_place!(TDim);
162 drop_in_place!(Opaque);
163 }
164}
165
166#[allow(unreachable_code)]
167pub fn vector_size() -> usize {
168 #[cfg(target_arch = "x86_64")]
169 {
170 return if is_x86_feature_detected!("avx512f") { 512 / 8 } else { 256 / 8 };
171 }
172 128 / 8
173}
174
175impl Tensor {
176 #[inline]
178 pub unsafe fn uninitialized<T: Datum>(shape: &[usize]) -> TractResult<Tensor> {
179 Self::uninitialized_dt(T::datum_type(), shape)
180 }
181
182 #[inline]
184 pub unsafe fn uninitialized_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
185 Self::uninitialized_aligned_dt(dt, shape, vector_size())
186 }
187
188 #[inline]
190 pub unsafe fn uninitialized_aligned<T: Datum>(
191 shape: &[usize],
192 alignment: usize,
193 ) -> TractResult<Tensor> {
194 Self::uninitialized_aligned_dt(T::datum_type(), shape, alignment)
195 }
196
197 pub unsafe fn uninitialized_aligned_dt(
199 dt: DatumType,
200 shape: &[usize],
201 alignment: usize,
202 ) -> TractResult<Tensor> {
203 let bytes = shape.iter().cloned().product::<usize>() * dt.size_of();
204 let data = Blob::new_for_size_and_align(bytes, alignment);
205 let mut tensor = Tensor { strides: tvec!(), dt, shape: shape.into(), data, len: 0 };
206 if tensor.shape.len() == 0 {
207 tensor.len = 1;
208 } else {
209 tensor.update_strides_and_len();
210 }
211 if !tensor.data.is_empty() {
212 if dt == String::datum_type() || dt == Blob::datum_type() {
213 tensor.data.fill(0);
215 } else if dt == TDim::datum_type() {
216 tensor
217 .as_slice_mut_unchecked::<TDim>()
218 .iter_mut()
219 .for_each(|dim| std::ptr::write(dim, TDim::zero()))
220 } else if dt == Opaque::datum_type() {
221 tensor.as_slice_mut_unchecked::<Opaque>().iter_mut().for_each(|p| {
222 std::ptr::write(p, Opaque::default());
223 });
224 } else if cfg!(debug_assertions) {
225 assert!(dt.is_copy());
226 if dt == DatumType::F32 {
227 tensor.fill_t(f32::NAN).unwrap();
228 } else {
229 tensor.as_bytes_mut().iter_mut().for_each(|x| *x = (-1i8) as u8);
231 }
232 }
233 }
234 Ok(tensor)
235 }
236
237 pub fn stack_tensors(
238 axis: usize,
239 tensors: &[impl std::borrow::Borrow<Tensor>],
240 ) -> TractResult<Tensor> {
241 ensure!(tensors.len() > 0);
242 let rank = tensors[0].borrow().rank();
243 ensure!(axis < rank);
244 ensure!(tensors.iter().all(|t| t.borrow().rank() == rank));
245 let dt = tensors[0].borrow().datum_type();
246 ensure!(tensors.iter().all(|t| t.borrow().datum_type() == dt));
247 let mut shape: TVec<usize> = tensors[0].borrow().shape().into();
248 for ax in 0..rank {
249 if ax != axis {
250 ensure!(tensors.iter().all(|t| t.borrow().shape()[ax] == shape[ax]));
251 }
252 }
253 shape[axis] = tensors.iter().map(|v| v.borrow().shape()[axis]).sum();
254 unsafe {
255 let mut result = Tensor::uninitialized_dt(dt, &shape)?;
256 if dt.is_copy() && shape[..axis].iter().all(|d| *d == 1) {
257 let mut offset = 0isize;
258 for v in tensors {
259 let v = v.borrow();
260 let len = v.data.len();
261 std::ptr::copy_nonoverlapping(
262 v.data.as_ptr(),
263 result.data.as_mut_ptr().offset(offset),
264 len,
265 );
266 offset += len as isize;
267 }
268 } else {
269 let mut offset = 0;
270 for t in tensors {
271 let t = t.borrow();
272 let len = t.shape()[axis];
273 result.assign_slice_from_resolved(offset..offset + len, t, 0..len, axis);
274 offset += len;
275 }
276 }
277
278 Ok(result)
279 }
280 }
281
282 pub fn clear<T: Datum + num_traits::Zero + Clone>(&mut self) -> TractResult<()> {
283 self.fill_t(T::zero())
284 }
285
286 pub fn zero<T: Datum + num_traits::Zero>(shape: &[usize]) -> TractResult<Tensor> {
287 unsafe {
288 let mut t = Tensor::uninitialized::<T>(shape)?;
289 t.clear::<T>()?;
290 Ok(t)
291 }
292 }
293
294 pub fn zero_scalar<T: Datum + num_traits::Zero>() -> TractResult<Tensor> {
295 Tensor::zero::<T>(&[])
296 }
297
298 pub fn zero_scalar_dt(dt: DatumType) -> TractResult<Tensor> {
299 Tensor::zero_dt(dt, &[])
300 }
301
302 pub fn zero_dt(dt: DatumType, shape: &[usize]) -> TractResult<Tensor> {
303 Tensor::zero_aligned_dt(dt, shape, vector_size())
304 }
305
306 pub fn fill_t<T: Datum + Clone>(&mut self, value: T) -> TractResult<()> {
307 self.as_slice_mut::<T>()?.iter_mut().for_each(|item| *item = value.clone());
308 Ok(())
309 }
310
311 pub fn zero_aligned_dt(
312 dt: DatumType,
313 shape: &[usize],
314 alignment: usize,
315 ) -> TractResult<Tensor> {
316 if shape.iter().product::<usize>() == 0 {
317 unsafe { return Tensor::uninitialized_dt(dt, shape) };
318 }
319 if dt.is_quantized() {
320 unsafe {
321 let mut t = Tensor::uninitialized_dt(dt, shape)?;
322 let zp = dt.zp_scale().0;
323 match dt.unquantized() {
324 DatumType::I8 => {
325 t.as_slice_mut::<i8>()?.iter_mut().for_each(|item| *item = zp as _)
326 }
327 DatumType::U8 => {
328 t.as_slice_mut::<u8>()?.iter_mut().for_each(|item| *item = zp as _)
329 }
330 DatumType::I32 => {
331 t.as_slice_mut::<i32>()?.iter_mut().for_each(|item| *item = zp as _)
332 }
333 _ => unreachable!(),
334 }
335 Ok(t)
336 }
337 } else {
338 dispatch_zerolike!(Self::zero_aligned(dt)(shape, alignment))
339 }
340 }
341
342 pub fn zero_aligned<T: Datum + num_traits::Zero>(
343 shape: &[usize],
344 alignment: usize,
345 ) -> TractResult<Tensor> {
346 unsafe {
347 let mut tensor = Self::uninitialized_aligned::<T>(shape, alignment)?;
348 tensor.clear::<T>()?;
349 Ok(tensor)
350 }
351 }
352
353 pub fn from_shape<T: Datum + Copy>(shape: &[usize], data: &[T]) -> TractResult<Tensor> {
356 Self::from_shape_align(shape, data, vector_size())
357 }
358
359 pub fn from_shape_align<T: Datum + Copy>(
362 shape: &[usize],
363 data: &[T],
364 align: usize,
365 ) -> TractResult<Tensor> {
366 ensure!(
367 data.len() == shape.iter().product::<usize>(),
368 "Shape product must be equal to data length"
369 );
370 unsafe {
371 let bytes = std::slice::from_raw_parts(
372 data.as_ptr() as *const u8,
373 data.len() * T::datum_type().size_of(),
374 );
375 let dt = T::datum_type();
376 Self::from_raw_dt_align(dt, shape, bytes, align)
377 }
378 }
379
380 pub unsafe fn from_raw<T: Datum>(shape: &[usize], content: &[u8]) -> TractResult<Tensor> {
384 Tensor::from_raw_dt(T::datum_type(), shape, content)
385 }
386
387 pub unsafe fn from_raw_aligned<T: Datum>(
388 shape: &[usize],
389 content: &[u8],
390 align: usize,
391 ) -> TractResult<Tensor> {
392 Tensor::from_raw_dt_align(T::datum_type(), shape, content, align)
393 }
394
395 pub unsafe fn from_raw_dt(
396 dt: DatumType,
397 shape: &[usize],
398 content: &[u8],
399 ) -> TractResult<Tensor> {
400 Self::from_raw_dt_align(dt, shape, content, vector_size())
401 }
402
403 pub unsafe fn from_raw_dt_align(
404 dt: DatumType,
405 shape: &[usize],
406 content: &[u8],
407 align: usize,
408 ) -> TractResult<Tensor> {
409 let mut tensor = Tensor::uninitialized_aligned_dt(dt, shape, align)?;
410 tensor.as_bytes_mut().copy_from_slice(content);
411 Ok(tensor)
412 }
413
414 pub unsafe fn from_slice_align<T: Datum>(content: &[T], align: usize) -> TractResult<Tensor> {
415 let bytes = if content.len() == 0 {
416 &[]
417 } else {
418 std::slice::from_raw_parts(
419 content.as_ptr() as *const u8,
420 content.len() * T::datum_type().size_of(),
421 )
422 };
423 Self::from_raw_dt_align(T::datum_type(), &[content.len()], bytes, align)
424 }
425
426 #[inline]
428 pub fn rank(&self) -> usize {
429 self.shape.len()
430 }
431
432 #[inline]
434 pub fn shape(&self) -> &[usize] {
435 &self.shape
436 }
437
438 #[inline]
440 #[allow(clippy::len_without_is_empty)]
441 pub fn len(&self) -> usize {
442 self.len
443 }
444
445 #[inline]
447 #[allow(clippy::len_without_is_empty)]
448 pub fn volume(&self) -> usize {
449 self.len
450 }
451
452 #[inline]
454 pub fn strides(&self) -> &[isize] {
455 &self.strides
456 }
457
458 fn update_strides_and_len(&mut self) {
459 self.strides.clear();
460 if self.shape.len() == 0 {
461 self.len = 1;
462 return;
463 }
464 compute_natural_stride_to(&mut self.strides, &self.shape);
465 self.len = unsafe { *self.strides.get_unchecked(0) as usize * self.shape.get_unchecked(0) };
466 }
467
468 pub unsafe fn set_shape_unchecked(&mut self, shape: &[usize]) {
470 if shape != &*self.shape {
471 self.shape.clear();
472 self.shape.extend_from_slice(shape);
473 self.update_strides_and_len();
474 }
475 }
476
477 pub unsafe fn set_geometry_unchecked(&mut self, shape: &[usize], strides: &[isize]) {
479 self.shape.clear();
480 self.shape.extend_from_slice(shape);
481 self.strides.clear();
482 self.strides.extend_from_slice(strides);
483 }
484
485 pub fn set_shape(&mut self, shape: &[usize]) -> TractResult<()> {
487 if self.len() != shape.iter().product::<usize>() {
488 bail!("Invalid reshape {:?} to {:?}", self.shape, shape);
489 }
490 unsafe { self.set_shape_unchecked(shape) }
491 Ok(())
492 }
493
494 pub fn permute_axes(self, axes: &[usize]) -> TractResult<Tensor> {
495 ensure!(axes.iter().duplicates().next().is_none());
496 ensure!(axes.iter().all(|a| *a < self.rank()));
497 unsafe {
498 #[inline]
499 unsafe fn permute<T: Datum>(axes: &[usize], input: Tensor) -> Tensor {
500 input.into_array_unchecked::<T>().permuted_axes(axes).into_tensor()
501 }
502 let dt = self.datum_type();
503 let mut t = dispatch_datum_by_size!(permute(self.datum_type())(axes, self));
504 t.set_datum_type(dt);
505 Ok(t)
506 }
507 }
508
509 pub fn move_axis(self, from: usize, to: usize) -> TractResult<Tensor> {
510 let mut permutation: Vec<usize> = (0..self.rank()).collect();
511 permutation.remove(from);
512 permutation.insert(to, from);
513 self.permute_axes(&permutation)
514 }
515
516 pub fn collapse_axis_with_next(mut self, axis: usize) -> Tensor {
517 let removed = self.shape.remove(axis + 1);
518 self.shape[axis] *= removed;
519 self.update_strides_and_len();
520 self
521 }
522
523 pub fn split_axis(mut self, axis: usize, outer_dim: usize) -> TractResult<Tensor> {
524 if self.shape[axis] % outer_dim != 0 {
525 bail!(
526 "Invalid axis split, shape is {:?}, axis split at {}, outer {}",
527 self.shape,
528 axis,
529 outer_dim
530 );
531 }
532 self.shape.insert(axis + 1, self.shape[axis] / outer_dim);
533 self.shape[axis] = outer_dim;
534 self.update_strides_and_len();
535 Ok(self)
536 }
537
538 pub fn into_shape(mut self, shape: &[usize]) -> TractResult<Tensor> {
540 self.set_shape(shape)?;
541 Ok(self)
542 }
543
544 pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
545 self.shape.insert(axis, 1);
546 self.strides.insert(axis, self.strides.get(axis).copied().unwrap_or(1));
547 Ok(())
548 }
549
550 pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
551 ensure!(self.shape[axis] == 1, "Remove a non-1 axis: axis {} in {:?}", axis, self);
552 self.shape.remove(axis);
553 self.strides.remove(axis);
554 Ok(())
555 }
556
557 pub fn broadcast_into_rank(mut self, rank: usize) -> TractResult<Tensor> {
558 self.broadcast_to_rank(rank)?;
559 self.update_strides_and_len();
560 Ok(self)
561 }
562
563 pub fn broadcast_to_rank(&mut self, rank: usize) -> TractResult<()> {
564 if rank < self.rank() {
565 bail!("Can only broadcast to higher rank")
566 }
567 while self.shape.len() < rank {
568 self.shape.insert(0, 1)
569 }
570 self.update_strides_and_len();
571 Ok(())
572 }
573
574 pub fn broadcast_scalar_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
575 if self.rank() > 0 {
576 bail!("broadcast_scalar_to_shape called on {:?}, which is not a salar", self);
577 }
578 unsafe fn make<T: Datum>(src: &Tensor, dst: &mut Tensor) {
579 let value: &T = src.to_scalar_unchecked::<T>();
580 dst.as_slice_mut_unchecked::<T>().iter_mut().for_each(|item| *item = value.clone());
581 }
582 unsafe {
583 let mut t = Tensor::uninitialized_dt(self.datum_type(), shape)?;
584 dispatch_datum_by_size!(make(self.datum_type())(self, &mut t));
585 Ok(t)
586 }
587 }
588
589 fn broadcast_to_shape_t<T: Datum>(&self, shape: &[usize]) -> TractResult<Tensor> {
590 unsafe {
591 let view = self.to_array_view_unchecked::<T>();
592 let mut output = view
593 .broadcast(shape)
594 .with_context(|| format!("Broadcasting {view:?} to {shape:?}"))?
595 .into_owned()
596 .into_tensor();
597 output.set_datum_type(self.datum_type());
598 Ok(output)
599 }
600 }
601
602 pub fn broadcast_to_shape(&self, shape: &[usize]) -> TractResult<Tensor> {
603 dispatch_datum!(Self::broadcast_to_shape_t(self.dt)(self, shape))
604 }
605
606 pub fn broadcast_vector_to_shape(&self, shape: &[usize], axis: usize) -> TractResult<Tensor> {
607 ensure!(self.rank() == 1);
608 ensure!(shape[axis] == self.len());
609 if !self.datum_type().is_copy() {
610 let mut vec_shape = vec![1; shape.len()];
611 vec_shape[axis] = self.len();
612 return self.clone().into_shape(&vec_shape)?.broadcast_to_shape(shape);
613 }
614 unsafe {
615 let mut output = Tensor::uninitialized_dt(self.datum_type(), shape)?;
616 if output.len() == 0 {
617 return Ok(output);
618 }
619 let inner_len = shape[axis + 1..].iter().product::<usize>();
620
621 unsafe fn splat<T>(input: &Tensor, output: &mut Tensor, inner_len: usize)
622 where
623 T: Datum + Copy,
624 {
625 for ix in 0..input.len() {
626 let value: T = input.as_slice_unchecked()[ix];
627 output.as_slice_mut_unchecked::<T>()[ix * inner_len..(ix + 1) * inner_len]
628 .iter_mut()
629 .for_each(|item| *item = value);
630 }
631 }
632 dispatch_copy_by_size!(splat(self.datum_type())(&self, &mut output, inner_len));
633
634 let outer_len = shape[0..axis].iter().product::<usize>();
635 let repeat_bytes_len = inner_len * self.as_bytes().len();
636 let bytes = output.as_bytes_mut();
637 for ix in 1..outer_len {
638 bytes.copy_within(0..repeat_bytes_len, ix * repeat_bytes_len);
639 }
640
641 Ok(output)
642 }
643 }
644
645 fn clip_range_bounds(
646 &self,
647 axis: usize,
648 range: impl std::ops::RangeBounds<usize>,
649 ) -> Range<usize> {
650 use std::ops::Bound;
651 let start = match range.start_bound() {
652 Bound::Included(ix) => *ix,
653 Bound::Excluded(ix) => ix + 1,
654 Bound::Unbounded => 0,
655 };
656 let end = match range.end_bound() {
657 Bound::Included(ix) => *ix + 1,
658 Bound::Excluded(ix) => *ix,
659 Bound::Unbounded => self.shape()[axis],
660 };
661 start..end
662 }
663
664 pub fn assign_slice(
665 &mut self,
666 range: impl std::ops::RangeBounds<usize>,
667 src: &Tensor,
668 src_range: impl std::ops::RangeBounds<usize>,
669 axis: usize,
670 ) -> TractResult<()> {
671 let range = self.clip_range_bounds(axis, range);
672 let src_range = src.clip_range_bounds(axis, src_range);
673 ensure!(
674 src.datum_type() == self.datum_type(),
675 "Attempt to assign into {:?} from {:?}, datum type mismatch",
676 self.datum_type(),
677 src.datum_type()
678 );
679 ensure!(
680 src_range.len() == range.len(),
681 "Attempt to assign a range of {:?} from a range of {:?}",
682 range,
683 src_range,
684 );
685 ensure!(
686 self.rank() == src.rank()
687 && itertools::izip!(0.., self.shape(), src.shape())
688 .all(|(ix, dst, src)| ix == axis || src == dst),
689 "Attempt to assign a {}-axis range of {:?} from a range of {:?}",
690 axis,
691 self,
692 src
693 );
694 ensure!(
695 src_range.end <= src.shape()[axis],
696 "Assigning from invalid slice (axis {}, {:?}) of {:?}",
697 axis,
698 src_range,
699 src
700 );
701 ensure!(
702 range.end <= self.shape()[axis],
703 "Assigning to invalid slice (axis {}, {:?}) of {:?}",
704 axis,
705 range,
706 self
707 );
708 unsafe { self.assign_slice_from_resolved(range, src, src_range, axis) };
709 Ok(())
710 }
711
712 pub unsafe fn assign_slice_unchecked(
713 &mut self,
714 range: impl std::ops::RangeBounds<usize>,
715 src: &Tensor,
716 src_range: impl std::ops::RangeBounds<usize>,
717 axis: usize,
718 ) {
719 let range = self.clip_range_bounds(axis, range);
720 let src_range = src.clip_range_bounds(axis, src_range);
721 self.assign_slice_from_resolved(range, src, src_range, axis);
722 }
723
724 #[allow(clippy::ptr_eq)]
725 unsafe fn assign_slice_from_resolved(
726 &mut self,
727 range: std::ops::Range<usize>,
728 src: &Tensor,
729 src_range: std::ops::Range<usize>,
730 axis: usize,
731 ) {
732 use ndarray::Slice;
733 unsafe fn assign_slice_t<T: Datum>(
734 to: &mut Tensor,
735 to_range: Range<usize>,
736 from: &Tensor,
737 from_range: Range<usize>,
738 axis: usize,
739 ) {
740 to.to_array_view_mut_unchecked::<T>()
741 .slice_axis_mut(Axis(axis), Slice::from(to_range))
742 .assign(
743 &from
744 .to_array_view_unchecked::<T>()
745 .slice_axis(Axis(axis), Slice::from(from_range)),
746 )
747 }
748 if self.datum_type().is_copy() && self.shape[..axis].iter().all(|d| *d == 1) {
749 let stride = self.strides[axis] as usize * self.datum_type().size_of();
750 let dst_start = (stride * range.start) as isize;
751 let src_start = (stride * src_range.start) as isize;
752 let len = stride * range.len();
753 if len > 0 {
754 if self.data.as_ptr() != src.data.as_ptr() {
755 std::ptr::copy_nonoverlapping(
756 src.data.as_ptr().offset(src_start),
757 self.data.as_mut_ptr().offset(dst_start),
758 len,
759 );
760 } else {
761 std::ptr::copy(
762 src.data.as_ptr().offset(src_start),
763 self.data.as_mut_ptr().offset(dst_start),
764 len,
765 );
766 }
767 }
768 } else {
769 dispatch_datum!(assign_slice_t(self.datum_type())(self, range, src, src_range, axis));
770 }
771 }
772
773 #[inline]
775 pub fn datum_type(&self) -> DatumType {
776 self.dt
777 }
778
779 #[inline]
781 pub unsafe fn set_datum_type(&mut self, dt: DatumType) {
782 self.dt = dt
783 }
784
785 pub fn dump(&self, force_full: bool) -> TractResult<String> {
789 unsafe fn dump_t<D: Datum>(tensor: &Tensor, n: usize) -> String {
790 if let Some(qp) = tensor.datum_type().qparams() {
791 let integers = tensor.cast_to::<i32>().unwrap();
792 integers.as_slice_unchecked::<i32>()[0..n]
793 .iter()
794 .map(|x| format!("[{}]({})", x, qp.dq(*x)))
795 .join(", ")
796 } else {
797 tensor.as_slice_unchecked::<D>()[0..n].iter().join(", ")
798 }
799 }
800 unsafe {
801 let trunc = self.len() > 12 && !force_full;
802 let data = dispatch_datum!(dump_t(self.datum_type())(
803 self,
804 if trunc { 12 } else { self.len() }
805 ));
806 Ok(format!(
807 "{},{:?} {}{}",
808 self.shape.iter().join(","),
809 self.dt,
810 data,
811 if trunc { "..." } else { "" }
812 ))
813 }
814 }
815
816 pub fn close_enough(
818 &self,
819 other: &Self,
820 approx: impl Into<Approximation> + std::fmt::Debug,
821 ) -> TractResult<()> {
822 let approx = approx.into();
823 if self.shape() != other.shape() {
824 bail!("Shape mismatch {:?} != {:?}", self.shape(), other.shape())
825 }
826 let (atol, rtol, outliers) = approx.atol_rtol_outliers(&self.datum_type());
827 let ma = self.cast_to::<f32>()?;
828 let ma = ma.to_array_view::<f32>()?;
829 let mb = other.cast_to::<f32>()?;
830 let mb = mb.to_array_view::<f32>()?;
831 let mut first_outlier = None;
832 let mut outliers_count = 0;
833 ndarray::indices_of(&ma).into_iter().for_each(|indices| {
834 let a = ma[&indices];
835 let b = mb[&indices];
836 if !((a.is_nan() && b.is_nan())
837 || (a.is_infinite() && b.is_infinite() && a.signum() == b.signum())
838 || (a - b).abs() <= atol as f32 + rtol as f32 * b.abs())
839 {
840 if outliers_count == 0 {
841 first_outlier = Some(indices.as_array_view().to_vec());
842 }
843 outliers_count += 1;
844 }
845 });
846 if self.volume() > 0 && outliers_count as f64 / self.volume() as f64 > outliers {
847 let indices = first_outlier.unwrap();
848 let a = ma[&*indices];
849 let b = mb[&*indices];
850 bail!(
851 "Mismatch. First outlier: {:?} for {:?}) at {:?} {} != {}. Outliers: {} / {} = {:0.5} > {:0.5}.",
852 approx,
853 self.datum_type(),
854 indices,
855 a,
856 b,
857 outliers_count,
858 self.volume(),
859 outliers_count as f64 / self.volume() as f64,
860 outliers
861 );
862 }
863 Ok(())
864 }
865
866 pub fn into_array<D: Datum>(self) -> TractResult<ArrayD<D>> {
868 Ok(self.to_array_view::<D>()?.to_owned())
869 }
870
871 pub unsafe fn into_array_unchecked<D: Datum>(self) -> ArrayD<D> {
873 self.to_array_view_unchecked::<D>().to_owned()
874 }
875
876 fn check_for_access<D: Datum>(&self) -> TractResult<()> {
877 ensure!(
878 self.datum_type().unquantized() == D::datum_type().unquantized(),
879 "Tensor datum type error: tensor is {:?}, accessed as {:?}",
880 self.datum_type(),
881 D::datum_type(),
882 );
883 Ok(())
884 }
885
886 pub fn to_array_view<D: Datum>(&self) -> TractResult<ArrayViewD<D>> {
888 self.check_for_access::<D>()?;
889 unsafe { Ok(self.to_array_view_unchecked()) }
890 }
891
892 pub fn to_array_view_mut<D: Datum>(&mut self) -> TractResult<ArrayViewMutD<D>> {
894 self.check_for_access::<D>()?;
895 unsafe { Ok(self.to_array_view_mut_unchecked()) }
896 }
897
898 pub unsafe fn to_array_view_unchecked<D: Datum>(&self) -> ArrayViewD<D> {
900 if self.len() != 0 {
901 ArrayViewD::from_shape_ptr(&*self.shape, self.data.as_ptr() as *const D)
902 } else {
903 ArrayViewD::from_shape(&*self.shape, &[]).unwrap()
904 }
905 }
906
907 pub unsafe fn to_array_view_mut_unchecked<D: Datum>(&mut self) -> ArrayViewMutD<D> {
909 if self.len() != 0 {
910 ArrayViewMutD::from_shape_ptr(&*self.shape, self.data.as_mut_ptr() as *mut D)
911 } else {
912 ArrayViewMutD::from_shape(&*self.shape, &mut []).unwrap()
913 }
914 }
915
916 pub fn as_ptr<D: Datum>(&self) -> TractResult<*const D> {
918 self.check_for_access::<D>()?;
919 Ok(self.data.as_ptr() as *const D)
920 }
921
922 pub unsafe fn as_ptr_unchecked<D: Datum>(&self) -> *const D {
924 self.data.as_ptr() as *const D
925 }
926
927 pub unsafe fn as_ptr_mut_unchecked<D: Datum>(&mut self) -> *mut D {
929 self.data.as_mut_ptr() as *mut D
930 }
931
932 pub fn as_ptr_mut<D: Datum>(&mut self) -> TractResult<*mut D> {
934 self.as_ptr::<D>().map(|p| p as *mut D)
935 }
936
937 pub fn as_slice<D: Datum>(&self) -> TractResult<&[D]> {
939 let ptr: *const D = self.as_ptr()?;
940 if self.data.len() == 0 {
941 Ok(&[])
942 } else {
943 unsafe { Ok(std::slice::from_raw_parts::<D>(ptr, self.len())) }
944 }
945 }
946
947 pub fn as_slice_mut<D: Datum>(&mut self) -> TractResult<&mut [D]> {
949 let ptr: *mut D = self.as_ptr_mut()?;
950 if self.data.len() == 0 {
951 Ok(&mut [])
952 } else {
953 unsafe { Ok(std::slice::from_raw_parts_mut::<D>(ptr, self.len())) }
954 }
955 }
956
957 pub unsafe fn as_slice_unchecked<D: Datum>(&self) -> &[D] {
959 if self.data.len() == 0 {
960 &[]
961 } else {
962 std::slice::from_raw_parts::<D>(self.as_ptr_unchecked(), self.len())
963 }
964 }
965
966 pub unsafe fn as_slice_mut_unchecked<D: Datum>(&mut self) -> &mut [D] {
968 if self.data.len() == 0 {
969 &mut []
970 } else {
971 std::slice::from_raw_parts_mut::<D>(self.as_ptr_mut_unchecked(), self.len())
972 }
973 }
974
975 pub fn to_scalar<D: Datum>(&self) -> TractResult<&D> {
977 self.check_for_access::<D>()?;
978 if self.len() == 0 {
979 bail!("to_scalar called on empty tensor ({:?})", self)
980 }
981 if self.len() > 1 {
982 bail!("to_scalar called on a tensor with multiple values ({:?})", self)
983 }
984 unsafe { Ok(self.to_scalar_unchecked()) }
985 }
986
987 pub fn to_scalar_tensor(&self) -> TractResult<Tensor> {
989 fn to_scalar_tensor_t<D: Datum>(t: &Tensor) -> TractResult<Tensor> {
990 Ok(litteral::tensor0(t.to_scalar::<D>()?.clone()))
991 }
992 dispatch_datum!(to_scalar_tensor_t(self.datum_type())(self))
993 }
994
995 pub unsafe fn to_scalar_unchecked<D: Datum>(&self) -> &D {
997 &*(self.data.as_ptr() as *const D)
998 }
999
1000 pub fn to_scalar_mut<D: Datum>(&mut self) -> TractResult<&mut D> {
1002 self.check_for_access::<D>()?;
1003 if self.len() == 0 {
1004 bail!("to_scalar_mut called on empty tensor ({:?})", self)
1005 }
1006 if self.len() > 1 {
1007 bail!("to_scalar called on a tensor with multiple values ({:?})", self)
1008 }
1009 unsafe { Ok(self.to_scalar_mut_unchecked()) }
1010 }
1011
1012 pub unsafe fn to_scalar_mut_unchecked<D: Datum>(&mut self) -> &mut D {
1014 &mut *(self.data.as_mut_ptr() as *mut D)
1015 }
1016
1017 pub fn as_bytes(&self) -> &[u8] {
1018 self.data.as_bytes()
1019 }
1020
1021 pub fn as_bytes_mut(&mut self) -> &mut [u8] {
1022 self.data.as_bytes_mut()
1023 }
1024
1025 unsafe fn is_uniform_t<T: Datum>(&self) -> bool {
1026 let slice = self.as_slice_unchecked::<T>();
1027 slice[1..].iter().all(|x| x == &slice[0])
1028 }
1029
1030 pub fn is_uniform(&self) -> bool {
1031 if self.len() <= 1 {
1032 return true;
1033 }
1034 unsafe { dispatch_datum!(Tensor::is_uniform_t(self.datum_type())(self)) }
1035 }
1036
1037 unsafe fn as_uniform_t<T: Datum>(&self) -> Tensor {
1038 let v: T = self.as_slice_unchecked::<T>()[0].clone();
1039 litteral::tensor0(v)
1040 }
1041
1042 pub fn as_uniform(&self) -> Option<Tensor> {
1043 if self.len() >= 1 && self.is_uniform() {
1044 unsafe {
1045 let mut t = dispatch_datum!(Tensor::as_uniform_t(self.datum_type())(self));
1046 t.set_datum_type(self.datum_type());
1047 Some(t)
1048 }
1049 } else {
1050 None
1051 }
1052 }
1053
1054 pub fn is_all_zero(&self) -> TractResult<bool> {
1055 Ok(self.len() == 0 || self.as_uniform().map(|t| t.is_zero().unwrap()).unwrap_or(false))
1056 }
1057
1058 pub fn is_zero(&self) -> TractResult<bool> {
1059 Ok(self == &Tensor::zero_scalar_dt(self.dt)?)
1060 }
1061
1062 unsafe fn natural_cast<
1063 Source: Datum + num_traits::AsPrimitive<Target>,
1064 Target: Datum + Copy,
1065 >(
1066 &self,
1067 other: &mut Tensor,
1068 ) {
1069 self.as_slice_unchecked::<Source>()
1070 .iter()
1071 .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1072 .for_each(|(s, d)| *d = s.as_());
1073 }
1074
1075 unsafe fn cast_number_to_bool<Source: Datum + num_traits::Zero>(&self, other: &mut Tensor) {
1076 self.as_slice_unchecked::<Source>()
1077 .iter()
1078 .zip(other.as_slice_mut_unchecked::<bool>().iter_mut())
1079 .for_each(|(s, d)| *d = !s.is_zero());
1080 }
1081
1082 unsafe fn cast_from_string<Target: Datum + core::str::FromStr>(
1083 &self,
1084 other: &mut Tensor,
1085 ) -> TractResult<()> {
1086 for (s, d) in self
1087 .as_slice_unchecked::<String>()
1088 .iter()
1089 .zip(other.as_slice_mut_unchecked::<Target>().iter_mut())
1090 {
1091 *d = s
1092 .parse()
1093 .map_err(|_| format_err!("Can not parse as {:?}", Target::datum_type()))?;
1094 }
1095 Ok(())
1096 }
1097
1098 unsafe fn cast_to_string<Source: Datum>(&self, other: &mut Tensor) {
1099 for (s, d) in self
1100 .as_slice_unchecked::<Source>()
1101 .iter()
1102 .zip(other.as_slice_mut_unchecked::<String>().iter_mut())
1103 {
1104 *d = s.to_string()
1105 }
1106 }
1107
1108 pub fn cast_to<D: Datum>(&self) -> TractResult<Cow<Tensor>> {
1110 self.cast_to_dt(D::datum_type())
1111 }
1112
1113 #[allow(clippy::redundant_closure_call)]
1115 pub fn cast_to_dt(&self, dst_dt: DatumType) -> TractResult<Cow<Tensor>> {
1116 unsafe {
1117 if self.dt == dst_dt {
1118 return Ok(Cow::Borrowed(self));
1119 }
1120 if self.dt == TDim::datum_type() && (dst_dt.is_integer() || dst_dt.is_float()) {
1121 let slice = self.as_slice_unchecked::<TDim>();
1122 let mut ints = Self::uninitialized::<i64>(&self.shape)?;
1123 let ints_slice = ints.as_slice_mut_unchecked::<i64>();
1124 for i in 0..self.len() {
1125 ints_slice[i] = slice[i].to_i64()?;
1126 }
1127 return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1128 }
1129 if self.dt == bool::datum_type()
1130 && (dst_dt.is_integer() || dst_dt.is_float() || dst_dt == TDim::datum_type())
1131 {
1132 let slice = self.as_slice_unchecked::<bool>();
1133 let mut ints = Self::uninitialized::<i8>(&self.shape)?;
1134 let ints_slice = ints.as_slice_mut_unchecked::<i8>();
1135 for i in 0..self.len() {
1136 ints_slice[i] = slice[i] as usize as i8;
1137 }
1138 return Ok(Cow::Owned(ints.cast_to_dt(dst_dt)?.into_owned()));
1139 }
1140 let mut result = Self::uninitialized_dt(dst_dt, &self.shape)?;
1141 if self.dt == DatumType::String {
1142 dispatch_numbers!(Self::cast_from_string(dst_dt)(self, &mut result))?;
1143 return Ok(Cow::Owned(result));
1144 }
1145 if dst_dt == DatumType::String {
1146 dispatch_datum!(Self::cast_to_string(self.dt)(self, &mut result));
1147 return Ok(Cow::Owned(result));
1148 }
1149 macro_rules! n {
1150 ($source:ty) => {
1151 if <$source>::datum_type() == self.datum_type() {
1152 match dst_dt {
1153 DatumType::I8 => self.natural_cast::<$source, i8>(&mut result),
1154 DatumType::I16 => self.natural_cast::<$source, i16>(&mut result),
1155 DatumType::I32 => self.natural_cast::<$source, i32>(&mut result),
1156 DatumType::I64 => self.natural_cast::<$source, i64>(&mut result),
1157 DatumType::U8 => self.natural_cast::<$source, u8>(&mut result),
1158 DatumType::U16 => self.natural_cast::<$source, u16>(&mut result),
1159 DatumType::U32 => self.natural_cast::<$source, u32>(&mut result),
1160 DatumType::U64 => self.natural_cast::<$source, u64>(&mut result),
1161 DatumType::F16 => self.natural_cast::<$source, f16>(&mut result),
1162 DatumType::F32 => self.natural_cast::<$source, f32>(&mut result),
1163 DatumType::F64 => self.natural_cast::<$source, f64>(&mut result),
1164 DatumType::TDim => {
1165 let ints = self.cast_to::<i32>()?;
1166 let slice = ints.as_slice_unchecked::<i32>();
1167 let result = result.as_slice_mut_unchecked::<TDim>();
1168 for i in 0..self.len() {
1169 result[i] = slice[i].into();
1170 }
1171 }
1172 DatumType::Bool => self.cast_number_to_bool::<$source>(&mut result),
1173 _ => todo!(),
1174 }
1175 return Ok(Cow::Owned(result));
1176 };
1177 };
1178 }
1179 if !dst_dt.is_quantized() && !self.datum_type().is_quantized() {
1181 n!(u8);
1182 n!(u16);
1183 n!(u32);
1184 n!(u64);
1185 n!(i8);
1186 n!(i16);
1187 n!(i32);
1188 n!(i64);
1189 n!(f16);
1190 n!(f32);
1191 n!(f64);
1192 } else {
1193 let (s_zp, s_scale) = self.datum_type().zp_scale();
1194 let (d_zp, d_scale) = dst_dt.zp_scale();
1195 if self.datum_type().is_quantized() && dst_dt.is_float() {
1196 macro_rules! q_to_fp {
1197 ($source:ty, $dest:ty) => {
1198 if <$source>::datum_type().unquantized()
1199 == self.datum_type().unquantized()
1200 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1201 {
1202 self.as_slice_unchecked::<$source>()
1203 .iter()
1204 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1205 .for_each(|(&s, d)| {
1206 *d = (s as $dest - s_zp as $dest) * s_scale as $dest;
1207 });
1208 return Ok(Cow::Owned(result));
1209 }
1210 };
1211 }
1212 q_to_fp!(i8, f64);
1213 q_to_fp!(i8, f32);
1214 q_to_fp!(u8, f64);
1215 q_to_fp!(u8, f32);
1216 }
1217 macro_rules! q8_to_q8 {
1219 ($typ:ty) => {
1220 if dst_dt.unquantized() == <$typ>::datum_type() {
1221 self.as_slice_unchecked::<$typ>()
1222 .iter()
1223 .zip(result.as_slice_mut_unchecked::<$typ>().iter_mut())
1224 .for_each(|(&s, d)| {
1225 *d = (d_zp as i32
1226 + scale_by(s as i32 - s_zp as i32, s_scale / d_scale))
1227 .clamp_cast()
1228 });
1229 return Ok(Cow::Owned(result));
1230 }
1231 };
1232 }
1233
1234 macro_rules! q_via_f32 {
1235 ($source:ty, $dest:ty, $round:expr) => {
1236 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1237 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1238 {
1239 self.as_slice_unchecked::<$source>()
1240 .iter()
1241 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1242 .for_each(|(&s, d)| {
1243 let s_float = (s as f32 - s_zp as f32) * s_scale as f32;
1244 let d_float = s_float as f32 / d_scale as f32 + d_zp as f32;
1245 *d = $round(d_float);
1246 });
1247 return Ok(Cow::Owned(result));
1248 }
1249 };
1250 }
1251
1252 macro_rules! q_n {
1253 (clamp $source:ty, $dest:ty) => {{
1254 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1255 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1256 {
1257 self.as_slice_unchecked::<$source>()
1258 .iter()
1259 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1260 .for_each(|(&s, d)| {
1261 *d = s.clamp_cast();
1262 });
1263 return Ok(Cow::Owned(result));
1264 }
1265 }};
1266 ($source:ty, $dest:ty) => {{
1267 if <$source>::datum_type().unquantized() == self.datum_type().unquantized()
1268 && <$dest>::datum_type().unquantized() == dst_dt.unquantized()
1269 {
1270 self.as_slice_unchecked::<$source>()
1271 .iter()
1272 .zip(result.as_slice_mut_unchecked::<$dest>().iter_mut())
1273 .for_each(|(&s, d)| {
1274 *d = s as $dest;
1275 });
1276 return Ok(Cow::Owned(result));
1277 }
1278 }};
1279 }
1280
1281 if dst_dt.unquantized() == self.datum_type().unquantized()
1282 && dst_dt.is_quantized()
1283 && self.datum_type().is_quantized()
1284 {
1285 q8_to_q8!(i8);
1286 q8_to_q8!(u8);
1287 }
1288
1289 q_via_f32!(f32, i8, |f| round_ties_to_even(f).clamp_cast());
1290 q_via_f32!(f32, u8, |f| round_ties_to_even(f).clamp_cast());
1291 q_via_f32!(f32, i32, |f| round_ties_to_even(f).clamp_cast());
1292 q_via_f32!(i8, f32, |f| f);
1293 q_via_f32!(u8, f32, |f| f);
1294 q_via_f32!(i32, f32, |f| f);
1295
1296 if dst_dt.is_quantized() && self.datum_type().is_quantized() {
1297 q_via_f32!(u8, i8, |f| round_ties_to_even(f).clamp_cast());
1298 q_via_f32!(i8, u8, |f| round_ties_to_even(f).clamp_cast());
1299 q_via_f32!(i32, u8, |f| round_ties_to_even(f).clamp_cast());
1300 q_via_f32!(i32, i8, |f| round_ties_to_even(f).clamp_cast());
1301 q_via_f32!(u8, i32, |f| round_ties_to_even(f).clamp_cast());
1302 q_via_f32!(i8, i32, |f| round_ties_to_even(f).clamp_cast());
1303
1304 q_via_f32!(i8, i8, |f| round_ties_to_even(f).clamp_cast());
1306 q_via_f32!(u8, u8, |f| round_ties_to_even(f).clamp_cast());
1307 }
1308
1309 q_n!(i8, i32);
1310 q_n!(i8, u32);
1311 q_n!(u8, i32);
1312 q_n!(u8, u32);
1313 q_n!(clamp i32, i8);
1314 q_n!(clamp i32, u8);
1315 q_n!(clamp u32, i8);
1316 q_n!(clamp u32, u8);
1317 q_n!(i8, i8);
1318 q_n!(u8, u8);
1319 q_n!(i32, i32);
1320 q_n!(u32, u32);
1321 }
1322
1323 bail!("Unsupported cast from {:?} to {:?}", self.dt, dst_dt)
1324 }
1325 }
1326
1327 pub fn cast_to_scalar<D: Datum + Copy>(&self) -> TractResult<D> {
1329 let casted = self.cast_to::<D>()?;
1330 casted.to_scalar::<D>().copied()
1331 }
1332
1333 pub fn nth(&self, nth: usize) -> TractResult<Tensor> {
1335 if nth >= self.len() {
1336 bail!(
1337 "nth called with {}th element on a tensor of len {} ({:?}",
1338 nth,
1339 self.len(),
1340 self
1341 );
1342 }
1343 unsafe fn nth_t<T: Datum>(me: &Tensor, nth: usize, output: &mut Tensor) {
1344 let value = me.as_slice_unchecked::<T>()[nth].clone();
1345 output.as_slice_mut_unchecked::<T>()[0] = value;
1346 }
1347 unsafe {
1348 let mut output = Tensor::uninitialized_dt(self.datum_type(), &[])?;
1349 dispatch_datum_by_size!(nth_t(self.datum_type())(self, nth, &mut output));
1350 Ok(output)
1351 }
1352 }
1353
1354 fn eq_dt(&self, other: &Tensor) -> TractResult<bool> {
1356 unsafe fn eq_t<D: Datum>(me: &Tensor, other: &Tensor) -> bool {
1357 me.as_slice_unchecked::<D>() == other.as_slice_unchecked::<D>()
1358 }
1359
1360 unsafe {
1361 Ok(self.datum_type() == other.datum_type()
1362 && self.shape() == other.shape()
1363 && dispatch_datum!(eq_t(self.dt)(self, other)))
1364 }
1365 }
1366
1367 fn from_datum<T: Datum>(mut it: ArrayD<T>) -> Tensor {
1368 unsafe {
1369 let mut t = Self::uninitialized::<T>(it.shape()).unwrap();
1370 if let Some(slice) = it.as_slice_mut() {
1371 if t.datum_type().is_copy() {
1372 std::ptr::copy_nonoverlapping(
1373 slice.as_ptr() as *const i8,
1374 t.as_ptr_mut_unchecked(),
1375 t.data.layout().size(),
1376 );
1377 } else {
1378 t.as_slice_mut_unchecked::<T>()
1379 .iter_mut()
1380 .zip(slice.iter_mut())
1381 .for_each(|(t, s)| *t = std::mem::take(s));
1382 }
1383 return t;
1384 }
1385 if it.strides().iter().all(|&s| s > 0) && it.as_slice_memory_order().is_some() {
1386 let mut len_and_strides: TVec<(usize, usize)> = tvec!();
1387 for (len, stride) in itertools::izip!(it.shape(), it.strides(), t.strides())
1388 .sorted_by_key(|(_, src, _)| *src)
1389 .map(|(l, _, dst)| (*l as isize, *dst))
1390 {
1391 if !len_and_strides.is_empty()
1392 && len_and_strides.last().unwrap().1 * len_and_strides.last().unwrap().0
1393 == stride as usize
1394 {
1395 len_and_strides.last_mut().unwrap().0 *= len as usize;
1396 } else {
1397 len_and_strides.push((len as usize, stride as usize));
1398 }
1399 }
1400 len_and_strides.reverse();
1401 crate::scatter::scatter_contig_data(
1402 it.as_ptr(),
1403 t.as_ptr_mut_unchecked(),
1404 &len_and_strides,
1405 );
1406 return t;
1407 }
1408 t.as_slice_mut_unchecked().iter_mut().zip(it).for_each(|(t, a)| *t = a);
1410 t
1411 }
1412 }
1413
1414 pub fn deep_clone(&self) -> Tensor {
1415 unsafe {
1416 let mut tensor = Tensor::uninitialized_dt(self.datum_type(), self.shape()).unwrap();
1417 if self.len() > 0 {
1418 if self.dt.is_copy() {
1419 self.data.as_ptr().copy_to_nonoverlapping(
1420 tensor.as_bytes_mut().as_mut_ptr(),
1421 self.data.layout().size(),
1422 )
1423 } else if self.dt == DatumType::String {
1424 tensor
1425 .as_slice_mut_unchecked::<String>()
1426 .clone_from_slice(self.as_slice_unchecked());
1427 } else if self.dt == DatumType::Blob {
1428 tensor
1429 .as_slice_mut_unchecked::<Blob>()
1430 .clone_from_slice(self.as_slice_unchecked());
1431 } else if self.dt == DatumType::Opaque {
1432 tensor
1433 .as_slice_mut_unchecked::<Opaque>()
1434 .clone_from_slice(self.as_slice_unchecked());
1435 } else if self.dt == DatumType::TDim {
1436 tensor
1437 .as_slice_mut_unchecked::<TDim>()
1438 .clone_from_slice(self.as_slice_unchecked());
1439 }
1440 }
1441 tensor
1442 }
1443 }
1444
1445 pub fn slice(&self, axis: usize, start: usize, end: usize) -> TractResult<Tensor> {
1446 if axis >= self.rank() {
1447 bail!("Can not slice at axis {} tensor {:?}", axis, self);
1448 }
1449 if start > self.shape[axis] || end > self.shape[axis] || start >= end {
1450 bail!("Invalid slicing range {start}..{end} on axis {axis} for {self:?}");
1451 }
1452 fn slice_t<T: Datum>(
1453 t: &Tensor,
1454 axis: usize,
1455 start: usize,
1456 end: usize,
1457 ) -> TractResult<Tensor> {
1458 Ok(t.to_array_view::<T>()?
1459 .slice_axis(ndarray::Axis(axis), (start..end).into())
1460 .into_owned()
1461 .into_tensor())
1462 }
1463 dispatch_datum!(slice_t(self.datum_type())(self, axis, start, end))
1464 }
1465
1466 #[inline]
1467 pub fn view(&self) -> view::TensorView {
1468 unsafe { view::TensorView::view(self) }
1469 }
1470
1471 #[inline]
1472 pub fn view_at_prefix(&self, prefix: &[usize]) -> TractResult<view::TensorView> {
1473 view::TensorView::at_prefix(self, prefix)
1474 }
1475
1476 #[inline]
1477 pub fn view_offsetting(&self, coords: &[usize]) -> TractResult<view::TensorView> {
1478 view::TensorView::offsetting(self, coords)
1479 }
1480
1481 #[inline]
1482 pub unsafe fn view_offsetting_unchecked(&self, coords: &[usize]) -> view::TensorView {
1483 view::TensorView::offsetting_unchecked(self, coords)
1484 }
1485
1486 #[inline]
1487 pub fn view_mut(&mut self) -> view::TensorView {
1488 unsafe { view::TensorView::view(self) }
1489 }
1490
1491 #[inline]
1492 pub fn view_at_prefix_mut(&mut self, prefix: &[usize]) -> TractResult<view::TensorView> {
1493 view::TensorView::at_prefix(self, prefix)
1494 }
1495
1496 #[inline]
1497 pub fn view_offsetting_mut(&mut self, coords: &[usize]) -> TractResult<view::TensorView> {
1498 view::TensorView::offsetting(self, coords)
1499 }
1500
1501 pub fn offset_u8_as_i8(self: &Arc<Self>) -> Arc<Self> {
1503 let mut t = if let DatumType::U8 = self.dt.unquantized() {
1504 self.to_array_view::<u8>().unwrap().mapv(|v| v.wrapping_sub(128) as i8).into_tensor()
1505 } else {
1506 return self.clone();
1507 };
1508
1509 if let DatumType::QU8(qp) = self.dt {
1510 if let QParams::ZpScale { zero_point, scale } = qp {
1511 t.dt = DatumType::QI8(QParams::ZpScale { zero_point: zero_point - 128, scale });
1512 } else {
1513 t.dt = DatumType::QI8(qp);
1514 }
1515 }
1516
1517 t.into_arc_tensor()
1518 }
1519
1520 pub fn offset_i8_as_u8(self: &Arc<Self>) -> Arc<Self> {
1522 let mut t = if let DatumType::I8 = self.dt.unquantized() {
1523 self.to_array_view::<i8>().unwrap().mapv(|v| (v as u8).wrapping_add(128)).into_tensor()
1524 } else {
1525 return self.clone();
1526 };
1527
1528 if let DatumType::QI8(qp) = self.dt {
1529 if let QParams::ZpScale { zero_point, scale } = qp {
1530 t.dt = DatumType::QU8(QParams::ZpScale { zero_point: zero_point + 128, scale });
1531 } else {
1532 t.dt = DatumType::QU8(qp);
1533 }
1534 }
1535 t.into_arc_tensor()
1536 }
1537
1538 pub fn to_aligned_default(&self) -> TractResult<Self> {
1539 if self.dt.is_copy() {
1540 unsafe {
1541 let mut t = Self::uninitialized_dt(self.dt, &self.shape)?;
1542 t.as_bytes_mut().copy_from_slice(self.as_bytes());
1543 Ok(t)
1544 }
1545 } else {
1546 let mut t = Self::zero_dt(self.dt, &self.shape)?;
1547 if self.dt == String::datum_type() {
1548 t.as_slice_mut::<String>()?.clone_from_slice(self.as_slice()?);
1549 } else if self.dt == Blob::datum_type() {
1550 t.as_slice_mut::<Blob>()?.clone_from_slice(self.as_slice()?);
1551 } else if self.dt == TDim::datum_type() {
1552 t.as_slice_mut::<TDim>()?.clone_from_slice(self.as_slice()?);
1553 }
1554 Ok(t)
1555 }
1556 }
1557
1558 pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1559 let mut strides = tvec!();
1560 compute_natural_stride_to(&mut strides, shape);
1561 strides
1562 }
1563
1564 pub fn into_blob(mut self) -> TractResult<Blob> {
1565 ensure!(self.dt.is_copy());
1566 Ok(std::mem::take(&mut self.data))
1567 }
1568}
1569
1570impl PartialEq for Tensor {
1571 fn eq(&self, other: &Tensor) -> bool {
1572 if self.dt != other.dt || self.shape != other.shape {
1573 return false;
1574 }
1575 self.eq_dt(other).unwrap_or(false)
1576 }
1577}
1578
1579impl fmt::Debug for Tensor {
1580 fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1581 let content = self.dump(false).unwrap_or_else(|e| format!("Error : {e:?}"));
1582 write!(formatter, "{content}")
1583 }
1584}
1585
1586#[cfg(feature = "complex")]
1587pub fn reinterpret_inner_dim_as_complex(mut t: Tensor) -> TractResult<Tensor> {
1588 ensure!(
1589 t.shape().last() == Some(&2),
1590 "The last dimension in the tensor shape {:?} must be 2",
1591 t.shape()
1592 );
1593 unsafe {
1594 t.shape.pop();
1595 t.set_datum_type(t.datum_type().complexify()?);
1596 t.update_strides_and_len();
1597 Ok(t)
1598 }
1599}
1600
1601#[cfg(feature = "complex")]
1602pub fn reinterpret_complex_as_inner_dim(mut t: Tensor) -> TractResult<Tensor> {
1603 unsafe {
1604 t.shape.push(2);
1605 t.set_datum_type(t.datum_type().decomplexify()?);
1606 t.update_strides_and_len();
1607 Ok(t)
1608 }
1609}
1610
1611pub fn natural_strides(shape: &[usize]) -> TVec<isize> {
1612 let mut strides = tvec!();
1613 compute_natural_stride_to(&mut strides, shape);
1614 strides
1615}
1616
1617fn compute_natural_stride_to(strides: &mut TVec<isize>, shape: &[usize]) {
1618 match shape.len() {
1619 0 => (),
1620 1 => strides.push(1),
1621 2 => strides.extend_from_slice(&[shape[1] as isize, 1]),
1622 3 => strides.extend_from_slice(&[(shape[1] * shape[2]) as isize, shape[2] as _, 1]),
1623 4 => strides.extend_from_slice(&[
1624 (shape[1] * shape[2] * shape[3]) as isize,
1625 (shape[2] * shape[3]) as _,
1626 shape[3] as _,
1627 1,
1628 ]),
1629 _ => {
1630 strides.push(1);
1631 for dim in shape.as_ref().iter().skip(1).rev() {
1632 let previous = *strides.last().unwrap();
1633 strides.push(previous * *dim as isize)
1634 }
1635 strides.reverse();
1636 }
1637 }
1638}
1639
1640impl<D: ::ndarray::Dimension, T: Datum> From<Array<T, D>> for Tensor {
1641 fn from(it: Array<T, D>) -> Tensor {
1642 Tensor::from_datum(it.into_dyn())
1643 }
1644}
1645
1646pub trait IntoTensor: Sized {
1648 fn into_tensor(self) -> Tensor;
1652}
1653
1654pub trait IntoArcTensor: Sized {
1656 fn into_arc_tensor(self) -> Arc<Tensor>;
1660}
1661
1662impl<D: ::ndarray::Dimension, T: Datum> IntoTensor for Array<T, D> {
1663 fn into_tensor(self) -> Tensor {
1664 Tensor::from(self)
1665 }
1666}
1667
1668impl<D: ::ndarray::Dimension, T: Datum> IntoArcTensor for Array<T, D> {
1669 fn into_arc_tensor(self) -> Arc<Tensor> {
1670 Arc::new(Tensor::from(self))
1671 }
1672}
1673
1674impl IntoTensor for Tensor {
1675 fn into_tensor(self) -> Tensor {
1676 self
1677 }
1678}
1679
1680impl IntoTensor for Arc<Tensor> {
1681 fn into_tensor(self) -> Tensor {
1682 Arc::try_unwrap(self).unwrap_or_else(|t| (*t).clone())
1683 }
1684}
1685
1686impl IntoArcTensor for Tensor {
1687 fn into_arc_tensor(self) -> Arc<Tensor> {
1688 Arc::new(self)
1689 }
1690}
1691
1692impl IntoArcTensor for Arc<Tensor> {
1693 fn into_arc_tensor(self) -> Arc<Tensor> {
1694 self
1695 }
1696}
1697
1698#[cfg(test)]
1699mod tests {
1700 use crate::dim::SymbolScope;
1701 use crate::prelude::tensor1;
1702
1703 use super::*;
1704 use litteral::tensor0;
1705 use proptest::collection::vec;
1706 use proptest::prelude::*;
1707
1708 #[derive(Debug)]
1709 struct PermuteAxisProblem {
1710 shape: Vec<usize>,
1711 permutation: Vec<usize>,
1712 }
1713
1714 impl Arbitrary for PermuteAxisProblem {
1715 type Strategy = BoxedStrategy<PermuteAxisProblem>;
1716 type Parameters = ();
1717
1718 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1719 (0..8usize)
1720 .prop_flat_map(|rank| {
1721 let permute: Vec<usize> = (0..rank).collect();
1722 (proptest::collection::vec(1..5usize, rank), Just(permute).prop_shuffle())
1723 })
1724 .prop_map(|(shape, permutation)| PermuteAxisProblem { shape, permutation })
1725 .boxed()
1726 }
1727 }
1728
1729 impl PermuteAxisProblem {
1730 fn input(&self) -> ArrayD<i32> {
1731 let mut i = 0;
1732 ArrayD::from_shape_simple_fn(&*self.shape, || {
1733 i += 1;
1734 i
1735 })
1736 .permuted_axes(&*self.permutation)
1737 }
1738
1739 fn reference(&self) -> Tensor {
1740 let values: Vec<i32> = self.input().iter().copied().collect();
1741 let shape = self.permutation.iter().map(|ix| self.shape[*ix]).collect::<TVec<usize>>();
1742 super::litteral::tensor1(&values).into_shape(&shape).unwrap()
1743 }
1744
1745 fn tract(&self) -> Tensor {
1746 Tensor::from(self.input())
1747 }
1748
1749 fn check(&self) -> proptest::test_runner::TestCaseResult {
1750 prop_assert_eq!(self.tract(), self.reference());
1751 Ok(())
1752 }
1753 }
1754
1755 proptest::proptest! {
1756 #[test]
1757 fn prop(pb: PermuteAxisProblem) {
1758 pb.check().unwrap();
1759 }
1760 }
1761
1762 #[test]
1763 fn t_1_2() {
1764 PermuteAxisProblem { shape: vec![2, 1], permutation: vec![1, 0] }.check().unwrap();
1765 }
1766
1767 #[test]
1768 fn t_2_2() {
1769 PermuteAxisProblem { shape: vec![2, 2], permutation: vec![1, 0] }.check().unwrap();
1770 }
1771
1772 #[derive(Debug)]
1773 struct BroadcastVecToShape {
1774 vec: Vec<f32>,
1775 axis: usize,
1776 shape: TVec<usize>,
1777 }
1778
1779 impl BroadcastVecToShape {
1780 fn check(&self) -> proptest::test_runner::TestCaseResult {
1781 let input = tensor1(&self.vec);
1782 let mut intermediate = tvec![1usize; self.shape.len()];
1783 intermediate[self.axis] = self.vec.len();
1784 let reference = input
1785 .clone()
1786 .into_shape(&intermediate)
1787 .unwrap()
1788 .broadcast_to_shape(&self.shape)
1789 .unwrap();
1790 prop_assert_eq!(
1791 reference,
1792 input.broadcast_vector_to_shape(&self.shape, self.axis).unwrap()
1793 );
1794 Ok(())
1795 }
1796 }
1797
1798 impl Arbitrary for BroadcastVecToShape {
1799 type Strategy = BoxedStrategy<BroadcastVecToShape>;
1800 type Parameters = ();
1801
1802 fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
1803 vec(0usize..5, 0usize..4)
1804 .prop_flat_map(|shape| {
1805 (vec(-10f32..10f32, 0usize..5), Just(shape.clone()), 0..shape.len() + 1)
1806 })
1807 .prop_map(|(vec, mut shape, axis)| {
1808 shape.insert(axis, vec.len());
1809 BroadcastVecToShape { vec, shape: shape.into(), axis }
1810 })
1811 .boxed()
1812 }
1813 }
1814
1815 proptest::proptest! {
1816 #[test]
1817 fn broadcast_vector_to_shape_prop(pb: BroadcastVecToShape) {
1818 pb.check().unwrap()
1819 }
1820 }
1821
1822 #[test]
1823 #[cfg(feature = "complex")]
1824 fn test_reinterpret_inner_dim_as_complex() -> TractResult<()> {
1825 let input = crate::internal::tensor2(&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]]);
1826 let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1827 let expected = crate::internal::tensor1(&[
1828 Complex::new(1.0f32, 2.0),
1829 Complex::new(3.0, 4.0),
1830 Complex::new(5.0, 6.0),
1831 ]);
1832 assert_eq!(expected, cplx_input);
1833 Ok(())
1834 }
1835
1836 #[test]
1837 #[cfg(feature = "complex")]
1838 fn test_reinterpret_inner_dim_as_complex_2() -> TractResult<()> {
1839 let input =
1840 crate::internal::tensor3(&[[[1i32, 2], [1, 2]], [[3, 4], [3, 4]], [[5, 6], [5, 6]]]);
1841 let cplx_input = reinterpret_inner_dim_as_complex(input)?;
1842 let expected = crate::internal::tensor2(&[
1843 [Complex::new(1i32, 2), Complex::new(1, 2)],
1844 [Complex::new(3, 4), Complex::new(3, 4)],
1845 [Complex::new(5, 6), Complex::new(5, 6)],
1846 ]);
1847 assert_eq!(expected, cplx_input);
1848 Ok(())
1849 }
1850
1851 #[test]
1852 fn clone_tdim_tensor() {
1853 let symbols = SymbolScope::default();
1854 let a = symbols.sym("a");
1855 let t = tensor0(TDim::from(a));
1856 let _ = t.clone();
1857 }
1858}