1use anyhow::{anyhow, bail, Result};
40use dry::macro_for;
41#[cfg(feature = "serde")]
42use krnl::buffer::{CowBuffer, ScalarCowBuffer};
43#[cfg(feature = "device")]
44use krnl::krnl_core::half::bf16;
45#[cfg(doc)]
46use krnl::{buffer::ArcBuffer, device::error::DeviceLost};
47use krnl::{
48 buffer::{
49 ArcBufferRepr, Buffer, BufferBase, BufferRepr, CowBufferRepr, Data, DataMut, DataOwned,
50 ScalarArcBufferRepr, ScalarBuffer, ScalarBufferBase, ScalarBufferRepr, ScalarCowBufferRepr,
51 ScalarData, ScalarDataMut, ScalarDataOwned, ScalarSlice, ScalarSliceMut,
52 ScalarSliceMutRepr, ScalarSliceRepr, Slice, SliceMut, SliceMutRepr, SliceRepr,
53 },
54 device::Device,
55 scalar::{Scalar, ScalarElem, ScalarType},
56};
57use ndarray::{
58 Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Dimension, IntoDimension, Ix0, Ix1, Ix2, Ix3,
59 Ix4, Ix5, Ix6, IxDyn, RawArrayView, RemoveAxis, ShapeError, StrideShape,
60};
61#[cfg(feature = "device")]
62use num_traits::ToPrimitive;
63use paste::paste;
64#[cfg(feature = "serde")]
65use serde::{Deserialize, Deserializer, Serialize, Serializer};
66use std::fmt::{self, Debug};
67
68mod linalg;
69mod ops;
70pub(crate) mod parallel;
71mod reduce;
72
73fn strides_from_array<S, D>(array: &ArrayBase<S, D>) -> D
74where
75 S: ndarray::RawData,
76 D: Dimension,
77{
78 let strides_slice: &[usize] = bytemuck::cast_slice(array.strides());
79 let mut strides = D::zeros(strides_slice.len());
80 for (i, s) in strides_slice.iter().copied().enumerate() {
81 strides[i] = s;
82 }
83 strides
84}
85
86fn dim_strides_from_shape<D: Dimension>(shape: impl Into<StrideShape<D>>) -> (D, D) {
87 let array = unsafe { RawArrayView::from_shape_ptr(shape, &()) };
88 let dim = array.raw_dim();
89 let strides = strides_from_array(&array);
90 (dim, strides)
91}
92
93fn into_dimensionality<D1, D2>(dim: &D1, strides: &D1) -> Result<(D2, D2), ShapeError>
94where
95 D1: Dimension,
96 D2: Dimension,
97{
98 D2::from_dimension(dim)
99 .and_then(|dim| D2::from_dimension(strides).map(|strides| (dim, strides)))
100 .ok_or(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape))
101}
102
103fn into_shape<D1, E>(dim: &D1, strides: &D1, shape: E) -> Result<(E::Dim, E::Dim), ShapeError>
104where
105 D1: Dimension,
106 E: IntoDimension,
107{
108 use ndarray::ErrorKind;
109
110 let shape = shape.into_dimension();
111 if size_of_shape_checked(&shape)? != dim.size() {
112 Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))
113 } else if is_standard_layout(dim, strides) {
114 let strides = shape.default_strides();
115 Ok((shape, strides))
116 } else if is_fortran_layout(dim, strides) {
117 let strides = shape.fortran_strides();
118 Ok((shape, strides))
119 } else {
120 Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout))
121 }
122}
123
124pub(crate) fn flatten(shape: &[usize]) -> [usize; 2] {
125 let mut iter = shape.iter().copied();
126 let rows = iter.next().unwrap_or(1);
127 let cols = iter.product();
128 [rows, cols]
129}
130
131fn is_contiguous<D: Dimension>(dim: &D, strides: &D) -> bool {
132 is_standard_layout(dim, strides) || is_fortran_layout(dim, strides)
133}
134
135fn is_standard_layout<D: Dimension>(dim: &D, strides: &D) -> bool {
136 debug_assert_eq!(dim.ndim(), strides.ndim());
137 for d in dim.slice().iter().copied() {
138 if d == 0 {
139 return true;
140 }
141 }
142 let mut acc = 1isize;
143 let strides: &[isize] = bytemuck::cast_slice(strides.slice());
144 for (d, s) in dim
145 .slice()
146 .iter()
147 .copied()
148 .zip(strides.iter().copied())
149 .rev()
150 {
151 if !(d == 1 || s == acc) {
152 return false;
153 }
154 acc *= d as isize;
155 }
156 true
157}
158
159fn is_fortran_layout<D: Dimension>(dim: &D, strides: &D) -> bool {
160 debug_assert_eq!(dim.ndim(), strides.ndim());
161 for d in dim.slice().iter().copied() {
162 if d == 0 {
163 return true;
164 }
165 }
166 let mut acc = 1;
167 for (d, s) in dim
168 .slice()
169 .iter()
170 .copied()
171 .zip(strides.slice().iter().copied())
172 {
173 if !(d == 1 || s == acc) {
174 return false;
175 }
176 acc *= d;
177 }
178 true
179}
180
181fn permuted_axes<D: Dimension>(dim: D, strides: D, axes: D) -> (D, D) {
183 let mut usage_counts = D::zeros(dim.ndim());
185 for axis in axes.slice() {
186 usage_counts[*axis] += 1;
187 }
188 for count in usage_counts.slice() {
189 assert_eq!(*count, 1, "each axis must be listed exactly once");
190 }
191 let mut new_dim = usage_counts; let mut new_strides = D::zeros(dim.ndim());
194 {
195 let dim = dim.slice();
196 let strides = strides.slice();
197 for (new_axis, &axis) in axes.slice().iter().enumerate() {
198 new_dim[new_axis] = dim[axis];
199 new_strides[new_axis] = strides[axis];
200 }
201 }
202 (new_dim, new_strides)
203}
204
205fn size_of_shape_checked<D: Dimension>(dim: &D) -> Result<usize, ShapeError> {
216 use ndarray::ErrorKind;
217 let size_nonzero = dim
218 .slice()
219 .iter()
220 .filter(|&&d| d != 0)
221 .try_fold(1usize, |acc, &d| acc.checked_mul(d))
222 .ok_or_else(|| ShapeError::from_kind(ErrorKind::Overflow))?;
223 if size_nonzero > isize::MAX as usize {
224 Err(ShapeError::from_kind(ErrorKind::Overflow))
225 } else {
226 Ok(dim.size())
227 }
228}
229
230fn broadcast<D: Dimension, E: IntoDimension>(
232 from: &D,
233 strides: &D,
234 dim: E,
235) -> Option<(E::Dim, E::Dim)> {
236 fn upcast<D: Dimension, E: Dimension>(to: &D, from: &E, stride: &E) -> Option<D> {
245 let _ = size_of_shape_checked(to).ok()?;
250
251 let mut new_stride = to.clone();
252 if to.ndim() < from.ndim() {
255 return None;
256 }
257
258 {
259 let mut new_stride_iter = new_stride.slice_mut().iter_mut().rev();
260 for ((er, es), dr) in from
261 .slice()
262 .iter()
263 .rev()
264 .zip(stride.slice().iter().rev())
265 .zip(new_stride_iter.by_ref())
266 {
267 if *dr == *er {
269 *dr = *es;
271 } else if *er == 1 {
272 *dr = 0
274 } else {
275 return None;
276 }
277 }
278
279 for dr in new_stride_iter {
281 *dr = 0;
282 }
283 }
284 Some(new_stride)
285 }
286 let dim = dim.into_dimension();
287
288 let broadcast_strides = match upcast(&dim, from, strides) {
290 Some(st) => st,
291 None => return None,
292 };
293 Some((dim, broadcast_strides))
294}
295
296fn collapse_axis<D: Dimension>(dims: &mut D, strides: &D, Axis(axis): Axis, index: usize) -> isize {
297 let dim = dims[axis];
298 assert!(index < dim);
299 dims.slice_mut()[axis] = 1;
300 index as isize * strides[axis] as isize
301}
302
303fn tensor_buffer_len(offset: usize, shape: &[usize], strides: &[isize]) -> Option<usize> {
304 if shape.iter().any(|x| *x == 0) {
305 Some(0)
306 } else if strides.iter().any(|x| *x < 0) {
307 None
308 } else {
309 let len = (shape
310 .iter()
311 .zip(strides)
312 .map(|(d, s)| (*d as isize - 1) * *s)
313 .sum::<isize>()
314 + offset as isize
315 + 1)
316 .try_into()
317 .unwrap();
318 Some(len)
319 }
320}
321
322#[derive(Clone)]
327pub struct ScalarTensorBase<S: ScalarData, D: Dimension> {
328 dim: D,
329 strides: D,
330 buffer: ScalarBufferBase<S>,
331 offset: usize,
332}
333
334pub type ScalarTensor<D> = ScalarTensorBase<ScalarBufferRepr, D>;
338pub type ScalarTensor0 = ScalarTensor<Ix0>;
340pub type ScalarTensor1 = ScalarTensor<Ix1>;
342pub type ScalarTensor2 = ScalarTensor<Ix2>;
344pub type ScalarTensor3 = ScalarTensor<Ix3>;
346pub type ScalarTensor4 = ScalarTensor<Ix4>;
348pub type ScalarTensor5 = ScalarTensor<Ix5>;
350pub type ScalarTensor6 = ScalarTensor<Ix6>;
352pub type ScalarTensorD = ScalarTensor<IxDyn>;
354
355pub type ScalarArcTensor<D> = ScalarTensorBase<ScalarArcBufferRepr, D>;
359pub type ScalarArcTensor0 = ScalarArcTensor<Ix0>;
361pub type ScalarArcTensor1 = ScalarArcTensor<Ix1>;
363pub type ScalarArcTensor2 = ScalarArcTensor<Ix2>;
365pub type ScalarArcTensor3 = ScalarArcTensor<Ix3>;
367pub type ScalarArcTensor4 = ScalarArcTensor<Ix4>;
369pub type ScalarArcTensor5 = ScalarArcTensor<Ix5>;
371pub type ScalarArcTensor6 = ScalarArcTensor<Ix6>;
373pub type ScalarArcTensorD = ScalarArcTensor<IxDyn>;
375
376pub type ScalarTensorView<'a, D> = ScalarTensorBase<ScalarSliceRepr<'a>, D>;
380pub type ScalarTensorView0<'a> = ScalarTensorView<'a, Ix0>;
382pub type ScalarTensorView1<'a> = ScalarTensorView<'a, Ix1>;
384pub type ScalarTensorView2<'a> = ScalarTensorView<'a, Ix2>;
386pub type ScalarTensorView3<'a> = ScalarTensorView<'a, Ix3>;
388pub type ScalarTensorView4<'a> = ScalarTensorView<'a, Ix4>;
390pub type ScalarTensorView5<'a> = ScalarTensorView<'a, Ix5>;
392pub type ScalarTensorView6<'a> = ScalarTensorView<'a, Ix6>;
394pub type ScalarTensorViewD<'a> = ScalarTensorView<'a, IxDyn>;
396
397pub type ScalarTensorViewMut<'a, D> = ScalarTensorBase<ScalarSliceMutRepr<'a>, D>;
401pub type ScalarTensorViewMut0<'a> = ScalarTensorViewMut<'a, Ix0>;
403pub type ScalarTensorViewMut1<'a> = ScalarTensorViewMut<'a, Ix1>;
405pub type ScalarTensorViewMut2<'a> = ScalarTensorViewMut<'a, Ix2>;
407pub type ScalarTensorViewMut3<'a> = ScalarTensorViewMut<'a, Ix3>;
409pub type ScalarTensorViewMut4<'a> = ScalarTensorViewMut<'a, Ix4>;
411pub type ScalarTensorViewMut5<'a> = ScalarTensorViewMut<'a, Ix5>;
413pub type ScalarTensorViewMut6<'a> = ScalarTensorViewMut<'a, Ix6>;
415pub type ScalarTensorViewMutD<'a> = ScalarTensorViewMut<'a, IxDyn>;
417
418pub type ScalarCowTensor<'a, D> = ScalarTensorBase<ScalarCowBufferRepr<'a>, D>;
422pub type ScalarCowTensor0<'a> = ScalarCowTensor<'a, Ix0>;
424pub type ScalarCowTensor1<'a> = ScalarCowTensor<'a, Ix1>;
426pub type ScalarCowTensor2<'a> = ScalarCowTensor<'a, Ix2>;
428pub type ScalarCowTensor3<'a> = ScalarCowTensor<'a, Ix3>;
430pub type ScalarCowTensor4<'a> = ScalarCowTensor<'a, Ix4>;
432pub type ScalarCowTensor5<'a> = ScalarCowTensor<'a, Ix5>;
434pub type ScalarCowTensor6<'a> = ScalarCowTensor<'a, Ix6>;
436pub type ScalarCowTensorD<'a> = ScalarCowTensor<'a, IxDyn>;
438
439impl<S: ScalarDataOwned, D: Dimension> ScalarTensorBase<S, D> {
440 pub unsafe fn uninit<Sh>(device: Device, shape: Sh, scalar_type: ScalarType) -> Result<Self>
449 where
450 Sh: ndarray::ShapeBuilder<Dim = D>,
451 {
452 let (dim, strides) = dim_strides_from_shape(shape.into_shape());
453 let buffer = unsafe { ScalarBufferBase::uninit(device, dim.size(), scalar_type)? };
454 Ok(Self {
455 dim,
456 strides,
457 buffer,
458 offset: 0,
459 })
460 }
461 pub fn from_elem<Sh>(device: Device, shape: Sh, elem: ScalarElem) -> Result<Self>
466 where
467 Sh: ndarray::ShapeBuilder<Dim = D>,
468 {
469 let (dim, strides) = dim_strides_from_shape(shape.into_shape());
470 let buffer = ScalarBufferBase::from_elem(device, dim.size(), elem)?;
471 Ok(Self {
472 dim,
473 strides,
474 buffer,
475 offset: 0,
476 })
477 }
478 pub fn zeros<Sh>(device: Device, shape: Sh, scalar_type: ScalarType) -> Result<Self>
483 where
484 Sh: ndarray::ShapeBuilder<Dim = D>,
485 {
486 Self::from_elem(device, shape, ScalarElem::zero(scalar_type))
487 }
488 pub fn ones<Sh>(device: Device, shape: Sh, scalar_type: ScalarType) -> Result<Self>
493 where
494 Sh: ndarray::ShapeBuilder<Dim = D>,
495 {
496 Self::from_elem(device, shape, ScalarElem::one(scalar_type))
497 }
498}
499
500impl<S: ScalarData, D: Dimension> ScalarTensorBase<S, D> {
501 pub fn device(&self) -> Device {
503 self.buffer.device()
504 }
505 pub fn scalar_type(&self) -> ScalarType {
507 self.buffer.scalar_type()
508 }
509 pub fn dim(&self) -> D::Pattern {
511 self.dim.clone().into_pattern()
512 }
513 pub fn raw_dim(&self) -> D {
515 self.dim.clone()
516 }
517 pub fn shape(&self) -> &[usize] {
519 self.dim.slice()
520 }
521 pub fn strides(&self) -> &[isize] {
523 bytemuck::cast_slice(self.strides.slice())
524 }
525 pub fn len(&self) -> usize {
527 self.dim.size()
528 }
529 pub fn is_empty(&self) -> bool {
531 self.shape().iter().any(|x| *x == 0)
532 }
533 pub fn ndim(&self) -> usize {
535 self.dim.ndim()
536 }
537 pub fn into_dimensionality<D2>(self) -> Result<ScalarTensorBase<S, D2>, ShapeError>
544 where
545 D2: Dimension,
546 {
547 let (dim, strides) = into_dimensionality(&self.dim, &self.strides)?;
548 Ok(ScalarTensorBase {
549 dim,
550 strides,
551 buffer: self.buffer,
552 offset: self.offset,
553 })
554 }
555 pub fn into_dyn(self) -> ScalarTensorBase<S, IxDyn> {
557 ScalarTensorBase {
558 dim: self.dim.into_dyn(),
559 strides: self.strides.into_dyn(),
560 buffer: self.buffer,
561 offset: self.offset,
562 }
563 }
564 pub fn into_shape<E>(self, shape: E) -> Result<ScalarTensorBase<S, E::Dim>, ShapeError>
569 where
570 E: IntoDimension,
571 {
572 let shape = shape.into_dimension();
573 let (dim, strides) = into_shape(&self.dim, &self.strides, shape)?;
574 assert_eq!(self.offset, 0);
575 Ok(ScalarTensorBase {
576 dim,
577 strides,
578 buffer: self.buffer,
579 offset: self.offset,
580 })
581 }
582 pub fn broadcast<E>(&self, dim: E) -> Option<ScalarTensorView<E::Dim>>
586 where
587 E: IntoDimension,
588 {
589 let (dim, strides) = broadcast(&self.dim, &self.strides, dim)?;
590 Some(ScalarTensorView {
591 dim,
592 strides,
593 buffer: self.buffer.as_scalar_slice(),
594 offset: self.offset,
595 })
596 }
597 pub fn view(&self) -> ScalarTensorView<D> {
599 ScalarTensorView {
600 dim: self.dim.clone(),
601 strides: self.strides.clone(),
602 buffer: self.buffer.as_scalar_slice(),
603 offset: self.offset,
604 }
605 }
606 pub fn view_mut(&mut self) -> ScalarTensorViewMut<D>
608 where
609 S: ScalarDataMut,
610 {
611 ScalarTensorViewMut {
612 dim: self.dim.clone(),
613 strides: self.strides.clone(),
614 buffer: self.buffer.as_scalar_slice_mut(),
615 offset: self.offset,
616 }
617 }
618 pub fn get_view_mut(&mut self) -> Option<ScalarTensorViewMut<D>> {
620 if self.offset == 0 && self.is_contiguous() {
621 let buffer = self.buffer.get_scalar_slice_mut()?;
622 Some(ScalarTensorViewMut {
623 dim: self.dim.clone(),
624 strides: self.strides.clone(),
625 buffer,
626 offset: 0,
627 })
628 } else {
629 None
630 }
631 }
632 pub fn make_view_mut(&mut self) -> Result<ScalarTensorViewMut<D>>
636 where
637 S: ScalarDataOwned,
638 {
639 if self.offset == 0 && self.is_contiguous() {
640 Ok(ScalarTensorViewMut {
641 dim: self.dim.clone(),
642 strides: self.strides.clone(),
643 buffer: self.buffer.make_scalar_slice_mut()?,
644 offset: 0,
645 })
646 } else {
647 let tensor = self.to_owned()?;
648 *self = Self {
649 dim: tensor.dim,
650 strides: tensor.strides,
651 buffer: ScalarBufferBase::from_scalar_buffer(tensor.buffer),
652 offset: 0,
653 };
654 Ok(ScalarTensorViewMut {
655 dim: self.dim.clone(),
656 strides: self.strides.clone(),
657 buffer: self.buffer.get_scalar_slice_mut().unwrap(),
658 offset: 0,
659 })
660 }
661 }
662 pub fn is_contiguous(&self) -> bool {
666 is_contiguous(&self.dim, &self.strides)
667 }
668 pub fn is_standard_layout(&self) -> bool {
672 is_standard_layout(&self.dim, &self.strides)
673 }
674 pub fn permuted_axes<A>(self, axes: A) -> Self
684 where
685 A: IntoDimension<Dim = D>,
686 {
687 let (dim, strides) = permuted_axes(self.dim, self.strides, axes.into_dimension());
688 Self {
689 dim,
690 strides,
691 ..self
692 }
693 }
694 pub fn reversed_axes(mut self) -> Self {
696 self.dim.slice_mut().reverse();
697 self.strides.slice_mut().reverse();
698 self
699 }
700 pub fn t(&self) -> ScalarTensorView<D> {
702 self.view().reversed_axes()
703 }
704 pub fn index_axis(&self, axis: Axis, index: usize) -> ScalarTensorView<D::Smaller>
708 where
709 D: RemoveAxis,
710 {
711 self.view().index_axis_into(axis, index)
712 }
713 pub fn index_axis_mut(&mut self, axis: Axis, index: usize) -> ScalarTensorViewMut<D::Smaller>
717 where
718 S: ScalarDataMut,
719 D: RemoveAxis,
720 {
721 self.view_mut().index_axis_into(axis, index)
722 }
723 pub fn index_axis_into(mut self, axis: Axis, index: usize) -> ScalarTensorBase<S, D::Smaller>
727 where
728 D: RemoveAxis,
729 {
730 self.collapse_axis(axis, index);
731 let dim = self.dim.remove_axis(axis);
732 let strides = self.strides.remove_axis(axis);
733 ScalarTensorBase {
734 dim,
735 strides,
736 buffer: self.buffer,
737 offset: self.offset,
738 }
739 }
740 pub fn collapse_axis(&mut self, axis: Axis, index: usize) {
747 let offset =
748 collapse_axis(&mut self.dim, &self.strides, axis, index) + self.offset as isize;
749 debug_assert!(offset >= 0);
750 self.offset = offset as usize;
751 debug_assert!(self.offset < self.buffer.len());
752 }
753 pub fn as_scalar_slice(&self) -> Option<ScalarSlice> {
755 if self.is_standard_layout() {
756 let (slice, _offset) = self.as_raw_scalar_slice_offset();
757 Some(slice)
758 } else {
759 None
760 }
761 }
762 pub fn as_scalar_slice_memory_order(&self) -> Option<ScalarSlice> {
764 if self.is_contiguous() {
765 let (slice, _offset) = self.as_raw_scalar_slice_offset();
766 Some(slice)
767 } else {
768 None
769 }
770 }
771 pub fn as_scalar_slice_mut(&mut self) -> Option<ScalarSliceMut>
773 where
774 S: ScalarDataMut,
775 {
776 if self.is_standard_layout() {
777 let (slice, _offset) = self.as_raw_scalar_slice_offset_mut();
778 Some(slice)
779 } else {
780 None
781 }
782 }
783 pub fn as_scalar_slice_memory_order_mut(&mut self) -> Option<ScalarSliceMut>
785 where
786 S: ScalarDataMut,
787 {
788 if self.is_contiguous() {
789 let (slice, _offset) = self.as_raw_scalar_slice_offset_mut();
790 Some(slice)
791 } else {
792 None
793 }
794 }
795 pub fn as_raw_scalar_slice_offset(&self) -> (ScalarSlice, usize) {
797 let strides: &[isize] = Self::strides(self);
798 if let Some(len) = tensor_buffer_len(self.offset, self.shape(), strides) {
799 let slice = self.buffer.slice(self.offset..self.offset + len).unwrap();
800 (slice, 0)
801 } else {
802 (self.buffer.as_scalar_slice(), self.offset)
803 }
804 }
805 pub fn as_raw_scalar_slice_offset_mut(&mut self) -> (ScalarSliceMut, usize)
807 where
808 S: ScalarDataMut,
809 {
810 let strides: &[isize] = Self::strides(self);
811 if let Some(len) = tensor_buffer_len(self.offset, self.shape(), strides) {
812 let slice = self
813 .buffer
814 .slice_mut(self.offset..self.offset + len)
815 .unwrap();
816 (slice, 0)
817 } else {
818 (self.buffer.as_scalar_slice_mut(), self.offset)
819 }
820 }
821 pub fn into_device(self, device: Device) -> Result<ScalarTensor<D>> {
825 if self.device() == device {
826 self.into_owned()
827 } else if let Some(slice) = self.as_scalar_slice_memory_order() {
828 let buffer = slice.to_device(device)?;
829 Ok(ScalarTensor {
830 dim: self.dim,
831 strides: self.strides,
832 buffer,
833 offset: 0,
834 })
835 } else {
836 self.into_owned()?.into_device(device)
837 }
838 }
839 pub fn to_device(&self, device: Device) -> Result<ScalarTensor<D>> {
843 if self.device() == device {
844 self.to_owned()
845 } else {
846 self.view().into_device(device)
847 }
848 }
849 pub fn to_device_mut(&mut self, device: Device) -> Result<()>
853 where
854 S: ScalarDataOwned,
855 {
856 if self.device() == device {
857 return Ok(());
858 }
859 let ScalarTensor {
860 dim,
861 strides,
862 buffer,
863 offset,
864 } = self.to_device(device)?;
865 *self = Self {
866 dim,
867 strides,
868 buffer: ScalarBufferBase::from_scalar_buffer(buffer),
869 offset,
870 };
871 Ok(())
872 }
873 pub fn into_device_shared(self, device: Device) -> Result<ScalarArcTensor<D>> {
877 if self.device() == device {
878 self.into_shared()
879 } else {
880 self.to_device(device).map(Into::into)
881 }
882 }
883 pub fn to_device_shared(&self, device: Device) -> Result<ScalarArcTensor<D>> {
887 if device == self.device() {
888 self.to_shared()
889 } else {
890 self.to_device(device).map(Into::into)
891 }
892 }
893 pub fn into_owned(self) -> Result<ScalarTensor<D>> {
895 if self.offset == 0 && self.is_contiguous() {
896 return Ok(ScalarTensorBase {
897 dim: self.dim,
898 strides: self.strides,
899 buffer: self.buffer.into_owned()?,
900 offset: 0,
901 });
902 }
903 if let Some(slice) = self.as_scalar_slice_memory_order() {
904 let buffer = slice.to_owned()?;
905 return Ok(ScalarTensorBase {
906 dim: self.dim,
907 strides: self.strides,
908 buffer,
909 offset: 0,
910 });
911 }
912 let mut output =
913 unsafe { ScalarTensor::uninit(self.device(), self.raw_dim(), self.scalar_type())? };
914 output.assign(&self)?;
915 Ok(output)
916 }
917 pub fn to_owned(&self) -> Result<ScalarTensor<D>> {
919 self.view().into_owned()
920 }
921 pub fn into_shared(self) -> Result<ScalarArcTensor<D>> {
923 if self.offset == 0 && self.is_contiguous() {
924 Ok(ScalarTensorBase {
925 dim: self.dim,
926 strides: self.strides,
927 buffer: self.buffer.into_shared()?,
928 offset: 0,
929 })
930 } else {
931 self.as_standard_layout()?.into_shared()
932 }
933 }
934 pub fn to_shared(&self) -> Result<ScalarArcTensor<D>> {
936 if !self.is_contiguous() {
937 return self.as_standard_layout()?.to_shared();
938 }
939 Ok(ScalarTensorBase {
940 dim: self.dim.clone(),
941 strides: self.strides.clone(),
942 buffer: self.buffer.to_shared()?,
943 offset: 0,
944 })
945 }
946}
947
948impl<D: Dimension> ScalarTensor<D> {
949 pub fn try_into_tensor<T: Scalar>(self) -> Result<Tensor<T, D>, Self> {
951 self.try_into()
952 }
953}
954
955impl<D: Dimension> ScalarArcTensor<D> {
956 pub fn try_into_arc_tensor<T: Scalar>(self) -> Result<ArcTensor<T, D>, Self> {
958 self.try_into()
959 }
960}
961
962impl<'a, D: Dimension> ScalarTensorView<'a, D> {
963 pub fn try_into_tensor_view<T: Scalar>(self) -> Result<TensorView<'a, T, D>, Self> {
965 self.try_into()
966 }
967}
968
969impl<'a, D: Dimension> ScalarTensorViewMut<'a, D> {
970 pub fn try_into_tensor_view_mut<T: Scalar>(self) -> Result<TensorViewMut<'a, T, D>, Self> {
972 self.try_into()
973 }
974}
975
976impl<D: Dimension> ScalarArcTensor<D> {
977 pub fn broadcast_shared<E>(&self, dim: E) -> Option<ScalarArcTensor<E::Dim>>
981 where
982 E: IntoDimension,
983 {
984 let (dim, strides) = broadcast(&self.dim, &self.strides, dim)?;
985 Some(ScalarArcTensor {
986 dim,
987 strides,
988 buffer: self.buffer.clone(),
989 offset: self.offset,
990 })
991 }
992}
993
994impl<S: ScalarDataOwned> From<ScalarBuffer> for ScalarTensorBase<S, Ix1> {
995 fn from(buffer: ScalarBuffer) -> Self {
996 let dim = buffer.len().into_dimension();
997 let strides = dim.default_strides();
998 let buffer = ScalarBufferBase::from_scalar_buffer(buffer);
999 Self {
1000 dim,
1001 strides,
1002 buffer,
1003 offset: 0,
1004 }
1005 }
1006}
1007
1008impl<S: ScalarDataOwned, T: Scalar, D: Dimension> From<Tensor<T, D>> for ScalarTensorBase<S, D> {
1009 fn from(tensor: Tensor<T, D>) -> Self {
1010 Self {
1011 dim: tensor.dim,
1012 strides: tensor.strides,
1013 buffer: tensor.buffer.into(),
1014 offset: tensor.offset,
1015 }
1016 }
1017}
1018
1019impl<D: Dimension> From<ScalarTensor<D>> for ScalarArcTensor<D> {
1020 fn from(tensor: ScalarTensor<D>) -> Self {
1021 Self {
1022 dim: tensor.dim,
1023 strides: tensor.strides,
1024 buffer: tensor.buffer.into(),
1025 offset: tensor.offset,
1026 }
1027 }
1028}
1029
1030impl<T: Scalar, D: Dimension> From<ArcTensor<T, D>> for ScalarArcTensor<D> {
1031 fn from(tensor: ArcTensor<T, D>) -> Self {
1032 Self {
1033 dim: tensor.dim,
1034 strides: tensor.strides,
1035 buffer: tensor.buffer.into(),
1036 offset: tensor.offset,
1037 }
1038 }
1039}
1040
1041impl<D: Dimension> From<ScalarTensor<D>> for ScalarCowTensor<'_, D> {
1042 fn from(tensor: ScalarTensor<D>) -> Self {
1043 Self {
1044 dim: tensor.dim,
1045 strides: tensor.strides,
1046 buffer: tensor.buffer.into(),
1047 offset: tensor.offset,
1048 }
1049 }
1050}
1051
1052impl<'a, D: Dimension> From<ScalarTensorView<'a, D>> for ScalarCowTensor<'a, D> {
1053 fn from(tensor: ScalarTensorView<'a, D>) -> Self {
1054 Self {
1055 dim: tensor.dim,
1056 strides: tensor.strides,
1057 buffer: tensor.buffer.into(),
1058 offset: tensor.offset,
1059 }
1060 }
1061}
1062
1063macro_for!($Tensor in [Tensor, ArcTensor] {
1064 paste! {
1065 impl<T: Scalar, D: Dimension> TryFrom<[<Scalar $Tensor>]<D>> for $Tensor<T, D> {
1066 type Error = [<Scalar $Tensor>]<D>;
1067 fn try_from(tensor: [<Scalar $Tensor>]<D>) -> Result<Self, Self::Error> {
1068 match tensor.buffer.try_into() {
1069 Ok(buffer) => Ok(Self {
1070 dim: tensor.dim,
1071 strides: tensor.strides,
1072 buffer,
1073 offset: tensor.offset,
1074 }),
1075 Err(buffer) => Err(Self::Error {
1076 dim: tensor.dim,
1077 strides: tensor.strides,
1078 buffer,
1079 offset: tensor.offset,
1080 })
1081 }
1082 }
1083 }
1084 }
1085});
1086
1087macro_for!($Tensor in [TensorView, TensorViewMut, CowTensor] {
1088 paste! {
1089 impl<'a, T: Scalar, D: Dimension> From<$Tensor<'a, T, D>> for [<Scalar $Tensor>]<'a, D> {
1090 fn from(tensor: $Tensor<'a, T, D>) -> Self {
1091 Self {
1092 dim: tensor.dim,
1093 strides: tensor.strides,
1094 buffer: tensor.buffer.into(),
1095 offset: tensor.offset,
1096 }
1097 }
1098 }
1099 impl<'a, T: Scalar, D: Dimension> TryFrom<[<Scalar $Tensor>]<'a, D>> for $Tensor<'a, T, D> {
1100 type Error = [<Scalar $Tensor>]<'a, D>;
1101 fn try_from(tensor: [<Scalar $Tensor>]<'a, D>) -> Result<Self, Self::Error> {
1102 match tensor.buffer.try_into() {
1103 Ok(buffer) => Ok(Self {
1104 dim: tensor.dim,
1105 strides: tensor.strides,
1106 buffer,
1107 offset: tensor.offset,
1108 }),
1109 Err(buffer) => Err(Self::Error {
1110 dim: tensor.dim,
1111 strides: tensor.strides,
1112 buffer,
1113 offset: tensor.offset,
1114 })
1115 }
1116 }
1117 }
1118 }
1119});
1120
1121impl<S: ScalarData, D: Dimension> Debug for ScalarTensorBase<S, D> {
1122 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1123 let mut builder = f.debug_struct("TensorBase");
1124 builder
1125 .field("device", &self.device())
1126 .field("scalar_type", &self.scalar_type())
1127 .field("shape", &self.shape());
1128 if self.strides != self.dim.default_strides() {
1129 builder.field("strides", &self.strides());
1130 }
1131 if self.offset > 0 {
1132 builder.field("offset", &self.offset);
1133 }
1134 builder.finish()
1135 }
1136}
1137
1138impl<S: ScalarData, D: Dimension> ScalarTensorBase<S, D> {
1140 pub fn cast_into(self, scalar_type: ScalarType) -> Result<ScalarTensor<D>> {
1144 if self.scalar_type() == scalar_type {
1145 self.into_owned()
1146 } else {
1147 self.cast(scalar_type)
1148 }
1149 }
1150 pub fn cast(&self, scalar_type: ScalarType) -> Result<ScalarTensor<D>> {
1154 if self.scalar_type() == scalar_type {
1155 self.to_owned()
1156 } else if !self.is_contiguous() {
1157 self.scaled_cast(ScalarElem::one(scalar_type))
1158 } else {
1159 Ok(ScalarTensorBase {
1160 dim: self.dim.clone(),
1161 strides: self.strides.clone(),
1162 buffer: self.buffer.cast(scalar_type)?,
1163 offset: 0,
1164 })
1165 }
1166 }
1167 pub fn cast_mut(&mut self, scalar_type: ScalarType) -> Result<()>
1171 where
1172 S: ScalarDataOwned,
1173 {
1174 if self.scalar_type() == scalar_type {
1175 return Ok(());
1176 }
1177 let ScalarTensor {
1178 dim,
1179 strides,
1180 buffer,
1181 offset,
1182 } = self.cast(scalar_type)?;
1183 *self = Self {
1184 dim,
1185 strides,
1186 buffer: ScalarBufferBase::from_scalar_buffer(buffer),
1187 offset,
1188 };
1189 Ok(())
1190 }
1191 pub fn cast_into_tensor<T: Scalar>(self) -> Result<Tensor<T, D>> {
1195 Ok(self.cast_into(T::SCALAR_TYPE)?.try_into().unwrap())
1196 }
1197}
1198
1199#[cfg(feature = "serde")]
1200#[derive(Serialize, Deserialize)]
1201#[serde(bound(
1202 serialize = "S: ScalarData, D: Dimension + Serialize",
1203 deserialize = "S: ScalarDataOwned, D: Dimension + Deserialize<'de>"
1204))]
1205#[serde(rename = "Tensor")]
1206struct ScalarTensorSerde<S: ScalarData, D: Dimension> {
1207 dim: D,
1208 buffer: ScalarBufferBase<S>,
1209}
1210
1211#[cfg(feature = "serde")]
1212impl<S1: ScalarData, D: Dimension + Serialize> Serialize for ScalarTensorBase<S1, D> {
1213 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1214 where
1215 S: Serializer,
1216 {
1217 use serde::ser::Error;
1218 let buffer = if let Some(slice) = self.as_scalar_slice() {
1219 ScalarCowBuffer::from(slice)
1220 } else {
1221 self.to_device(Device::host())
1222 .map_err(S::Error::custom)?
1223 .buffer
1224 .into()
1225 };
1226 ScalarTensorSerde {
1227 dim: self.dim.clone(),
1228 buffer,
1229 }
1230 .serialize(serializer)
1231 }
1232}
1233
1234#[cfg(feature = "serde")]
1235impl<'de, S: ScalarDataOwned, D1: Dimension + Deserialize<'de>> Deserialize<'de>
1236 for ScalarTensorBase<S, D1>
1237{
1238 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1239 where
1240 D: Deserializer<'de>,
1241 {
1242 use serde::de::Error;
1243 let ScalarTensorSerde { dim, buffer } =
1244 ScalarTensorSerde::<ScalarBufferRepr, D1>::deserialize(deserializer)?;
1245 ScalarTensorBase::from(buffer)
1246 .into_shape(dim)
1247 .map_err(D::Error::custom)
1248 }
1249}
1250
1251#[derive(Clone)]
1256pub struct TensorBase<S: Data, D: Dimension> {
1257 dim: D,
1258 strides: D,
1259 buffer: BufferBase<S>,
1260 offset: usize,
1261}
1262
1263pub type Tensor<T, D> = TensorBase<BufferRepr<T>, D>;
1267pub type Tensor0<T> = Tensor<T, Ix0>;
1269pub type Tensor1<T> = Tensor<T, Ix1>;
1271pub type Tensor2<T> = Tensor<T, Ix2>;
1273pub type Tensor3<T> = Tensor<T, Ix3>;
1275pub type Tensor4<T> = Tensor<T, Ix4>;
1277pub type Tensor5<T> = Tensor<T, Ix5>;
1279pub type Tensor6<T> = Tensor<T, Ix6>;
1281pub type TensorD<T> = Tensor<T, IxDyn>;
1283
1284pub type ArcTensor<T, D> = TensorBase<ArcBufferRepr<T>, D>;
1288pub type ArcTensor0<T> = ArcTensor<T, Ix0>;
1290pub type ArcTensor1<T> = ArcTensor<T, Ix1>;
1292pub type ArcTensor2<T> = ArcTensor<T, Ix2>;
1294pub type ArcTensor3<T> = ArcTensor<T, Ix3>;
1296pub type ArcTensor4<T> = ArcTensor<T, Ix4>;
1298pub type ArcTensor5<T> = ArcTensor<T, Ix5>;
1300pub type ArcTensor6<T> = ArcTensor<T, Ix6>;
1302pub type ArcTensorD<T> = ArcTensor<T, IxDyn>;
1304
1305pub type TensorView<'a, T, D> = TensorBase<SliceRepr<'a, T>, D>;
1309pub type TensorView0<'a, T> = TensorView<'a, T, Ix0>;
1311pub type TensorView1<'a, T> = TensorView<'a, T, Ix1>;
1313pub type TensorView2<'a, T> = TensorView<'a, T, Ix2>;
1315pub type TensorView3<'a, T> = TensorView<'a, T, Ix3>;
1317pub type TensorView4<'a, T> = TensorView<'a, T, Ix4>;
1319pub type TensorView5<'a, T> = TensorView<'a, T, Ix5>;
1321pub type TensorView6<'a, T> = TensorView<'a, T, Ix6>;
1323pub type TensorViewD<'a, T> = TensorView<'a, T, IxDyn>;
1325
1326pub type TensorViewMut<'a, T, D> = TensorBase<SliceMutRepr<'a, T>, D>;
1330pub type TensorViewMut0<'a, T> = TensorViewMut<'a, T, Ix0>;
1332pub type TensorViewMut1<'a, T> = TensorViewMut<'a, T, Ix1>;
1334pub type TensorViewMut2<'a, T> = TensorViewMut<'a, T, Ix2>;
1336pub type TensorViewMut3<'a, T> = TensorViewMut<'a, T, Ix3>;
1338pub type TensorViewMut4<'a, T> = TensorViewMut<'a, T, Ix4>;
1340pub type TensorViewMut5<'a, T> = TensorViewMut<'a, T, Ix5>;
1342pub type TensorViewMut6<'a, T> = TensorViewMut<'a, T, Ix6>;
1344pub type TensorViewMutD<'a, T> = TensorViewMut<'a, T, IxDyn>;
1346
1347pub type CowTensor<'a, T, D> = TensorBase<CowBufferRepr<'a, T>, D>;
1351pub type CowTensor0<'a, T> = CowTensor<'a, T, Ix0>;
1353pub type CowTensor1<'a, T> = CowTensor<'a, T, Ix1>;
1355pub type CowTensor2<'a, T> = CowTensor<'a, T, Ix2>;
1357pub type CowTensor3<'a, T> = CowTensor<'a, T, Ix3>;
1359pub type CowTensor4<'a, T> = CowTensor<'a, T, Ix4>;
1361pub type CowTensor5<'a, T> = CowTensor<'a, T, Ix5>;
1363pub type CowTensor6<'a, T> = CowTensor<'a, T, Ix6>;
1365pub type CowTensorD<'a, T> = CowTensor<'a, T, IxDyn>;
1367
1368impl<T: Scalar, S: DataOwned<Elem = T>, D: Dimension> TensorBase<S, D> {
1369 pub unsafe fn uninit<Sh>(device: Device, shape: Sh) -> Result<Self>
1378 where
1379 Sh: ndarray::ShapeBuilder<Dim = D>,
1380 {
1381 let (dim, strides) = dim_strides_from_shape(shape.into_shape());
1382 let buffer = unsafe { BufferBase::uninit(device, dim.size())? };
1383 Ok(Self {
1384 dim,
1385 strides,
1386 buffer,
1387 offset: 0,
1388 })
1389 }
1390 pub fn from_elem<Sh>(device: Device, shape: Sh, elem: T) -> Result<Self>
1395 where
1396 Sh: ndarray::ShapeBuilder<Dim = D>,
1397 {
1398 let (dim, strides) = dim_strides_from_shape(shape.into_shape());
1399 let buffer = BufferBase::from_elem(device, dim.size(), elem)?;
1400 Ok(Self {
1401 dim,
1402 strides,
1403 buffer,
1404 offset: 0,
1405 })
1406 }
1407 pub fn zeros<Sh>(device: Device, shape: Sh) -> Result<Self>
1412 where
1413 Sh: ndarray::ShapeBuilder<Dim = D>,
1414 {
1415 Self::from_elem(device, shape, T::default())
1416 }
1417 pub fn ones<Sh>(device: Device, shape: Sh) -> Result<Self>
1422 where
1423 Sh: ndarray::ShapeBuilder<Dim = D>,
1424 {
1425 Self::from_elem(device, shape, T::one())
1426 }
1427}
1428
1429impl<T: Scalar, S: Data<Elem = T>, D: Dimension> TensorBase<S, D> {
1430 pub fn device(&self) -> Device {
1432 self.buffer.device()
1433 }
1434 pub fn scalar_type(&self) -> ScalarType {
1436 T::SCALAR_TYPE
1437 }
1438 pub fn dim(&self) -> D::Pattern {
1440 self.dim.clone().into_pattern()
1441 }
1442 pub fn raw_dim(&self) -> D {
1444 self.dim.clone()
1445 }
1446 pub fn shape(&self) -> &[usize] {
1448 self.dim.slice()
1449 }
1450 pub fn strides(&self) -> &[isize] {
1452 bytemuck::cast_slice(self.strides.slice())
1453 }
1454 pub fn len(&self) -> usize {
1456 self.dim.size()
1457 }
1458 pub fn is_empty(&self) -> bool {
1460 self.shape().iter().any(|x| *x == 0)
1461 }
1462 pub fn ndim(&self) -> usize {
1464 self.dim.ndim()
1465 }
1466 pub fn into_dimensionality<D2>(self) -> Result<TensorBase<S, D2>, ShapeError>
1473 where
1474 D2: Dimension,
1475 {
1476 let (dim, strides) = into_dimensionality(&self.dim, &self.strides)?;
1477 Ok(TensorBase {
1478 dim,
1479 strides,
1480 buffer: self.buffer,
1481 offset: self.offset,
1482 })
1483 }
1484 pub fn into_dyn(self) -> TensorBase<S, IxDyn> {
1486 TensorBase {
1487 dim: self.dim.into_dyn(),
1488 strides: self.strides.into_dyn(),
1489 buffer: self.buffer,
1490 offset: self.offset,
1491 }
1492 }
1493 pub fn into_shape<E>(self, shape: E) -> Result<TensorBase<S, E::Dim>, ShapeError>
1498 where
1499 E: IntoDimension,
1500 {
1501 let shape = shape.into_dimension();
1502 let (dim, strides) = into_shape(&self.dim, &self.strides, shape)?;
1503 debug_assert_eq!(self.offset, 0);
1504 Ok(TensorBase {
1505 dim,
1506 strides,
1507 buffer: self.buffer,
1508 offset: self.offset,
1509 })
1510 }
1511 pub fn flatten(self) -> Result<TensorBase<S, Ix2>, ShapeError> {
1518 let dim = flatten(self.shape());
1519 self.into_shape(dim)
1520 }
1521 pub fn broadcast<E>(&self, dim: E) -> Option<TensorView<T, E::Dim>>
1525 where
1526 E: IntoDimension,
1527 {
1528 let (dim, strides) = broadcast(&self.dim, &self.strides, dim)?;
1529 Some(TensorView {
1530 dim,
1531 strides,
1532 buffer: self.buffer.as_slice(),
1533 offset: self.offset,
1534 })
1535 }
1536 pub fn view(&self) -> TensorView<T, D> {
1538 TensorView {
1539 dim: self.dim.clone(),
1540 strides: self.strides.clone(),
1541 buffer: self.buffer.as_slice(),
1542 offset: self.offset,
1543 }
1544 }
1545 pub fn view_mut(&mut self) -> TensorViewMut<T, D>
1547 where
1548 S: DataMut,
1549 {
1550 TensorViewMut {
1551 dim: self.dim.clone(),
1552 strides: self.strides.clone(),
1553 buffer: self.buffer.as_slice_mut(),
1554 offset: self.offset,
1555 }
1556 }
1557 pub fn get_view_mut(&mut self) -> Option<TensorViewMut<T, D>> {
1559 if self.offset == 0 && self.is_contiguous() {
1560 let buffer = self.buffer.get_slice_mut()?;
1561 Some(TensorViewMut {
1562 dim: self.dim.clone(),
1563 strides: self.strides.clone(),
1564 buffer,
1565 offset: 0,
1566 })
1567 } else {
1568 None
1569 }
1570 }
1571 pub fn make_view_mut(&mut self) -> Result<TensorViewMut<T, D>>
1577 where
1578 S: DataOwned,
1579 {
1580 if self.offset == 0 && self.is_contiguous() {
1581 Ok(TensorViewMut {
1582 dim: self.dim.clone(),
1583 strides: self.strides.clone(),
1584 buffer: self.buffer.make_slice_mut()?,
1585 offset: 0,
1586 })
1587 } else {
1588 let tensor = self.to_owned()?;
1589 *self = Self {
1590 dim: tensor.dim,
1591 strides: tensor.strides,
1592 buffer: BufferBase::from_buffer(tensor.buffer),
1593 offset: 0,
1594 };
1595 Ok(TensorViewMut {
1596 dim: self.dim.clone(),
1597 strides: self.strides.clone(),
1598 buffer: self.buffer.get_slice_mut().unwrap(),
1599 offset: 0,
1600 })
1601 }
1602 }
1603 pub fn is_contiguous(&self) -> bool {
1607 is_contiguous(&self.dim, &self.strides)
1608 }
1609 pub fn is_standard_layout(&self) -> bool {
1613 is_standard_layout(&self.dim, &self.strides)
1614 }
1615 pub fn permuted_axes<A>(self, axes: A) -> Self
1625 where
1626 A: IntoDimension<Dim = D>,
1627 {
1628 let (dim, strides) = permuted_axes(self.dim, self.strides, axes.into_dimension());
1629 Self {
1630 dim,
1631 strides,
1632 ..self
1633 }
1634 }
1635 pub fn reversed_axes(mut self) -> Self {
1637 self.dim.slice_mut().reverse();
1638 self.strides.slice_mut().reverse();
1639 self
1640 }
1641 pub fn t(&self) -> TensorView<T, D> {
1643 self.view().reversed_axes()
1644 }
1645 pub fn index_axis(&self, axis: Axis, index: usize) -> TensorView<T, D::Smaller>
1652 where
1653 D: RemoveAxis,
1654 {
1655 self.view().index_axis_into(axis, index)
1656 }
1657 pub fn index_axis_mut(&mut self, axis: Axis, index: usize) -> TensorViewMut<T, D::Smaller>
1664 where
1665 S: DataMut,
1666 D: RemoveAxis,
1667 {
1668 self.view_mut().index_axis_into(axis, index)
1669 }
1670 pub fn index_axis_into(mut self, axis: Axis, index: usize) -> TensorBase<S, D::Smaller>
1677 where
1678 D: RemoveAxis,
1679 {
1680 self.collapse_axis(axis, index);
1681 let dim = self.dim.remove_axis(axis);
1682 let strides = self.strides.remove_axis(axis);
1683 TensorBase {
1684 dim,
1685 strides,
1686 buffer: self.buffer,
1687 offset: self.offset,
1688 }
1689 }
1690 pub fn collapse_axis(&mut self, axis: Axis, index: usize) {
1697 let offset =
1698 collapse_axis(&mut self.dim, &self.strides, axis, index) + self.offset as isize;
1699 debug_assert!(offset >= 0);
1700 let offset = offset as usize;
1701 debug_assert!(offset < self.buffer.len());
1702 self.offset = offset;
1703 }
1704 pub fn as_slice(&self) -> Option<Slice<T>> {
1706 if self.is_standard_layout() {
1707 let (slice, _offset) = self.as_raw_slice_offset();
1708 Some(slice)
1709 } else {
1710 None
1711 }
1712 }
1713 pub fn as_slice_memory_order(&self) -> Option<Slice<T>> {
1715 if self.is_contiguous() {
1716 let (slice, _offset) = self.as_raw_slice_offset();
1717 Some(slice)
1718 } else {
1719 None
1720 }
1721 }
1722 pub fn as_slice_mut(&mut self) -> Option<SliceMut<T>>
1724 where
1725 S: DataMut,
1726 {
1727 if self.is_standard_layout() {
1728 let (slice, _offset) = self.as_raw_slice_offset_mut();
1729 Some(slice)
1730 } else {
1731 None
1732 }
1733 }
1734 pub fn as_slice_memory_order_mut(&mut self) -> Option<SliceMut<T>>
1736 where
1737 S: DataMut,
1738 {
1739 if self.is_contiguous() {
1740 let (slice, _offset) = self.as_raw_slice_offset_mut();
1741 Some(slice)
1742 } else {
1743 None
1744 }
1745 }
1746 pub fn as_raw_slice_offset(&self) -> (Slice<T>, usize) {
1748 let strides: &[isize] = Self::strides(self);
1749 if let Some(len) = tensor_buffer_len(self.offset, self.shape(), strides) {
1750 let slice = self.buffer.slice(self.offset..self.offset + len).unwrap();
1751 (slice, 0)
1752 } else {
1753 (self.buffer.as_slice(), self.offset)
1754 }
1755 }
1756 pub fn as_raw_slice_offset_mut(&mut self) -> (SliceMut<T>, usize)
1758 where
1759 S: DataMut,
1760 {
1761 let strides: &[isize] = Self::strides(self);
1762 if let Some(len) = tensor_buffer_len(self.offset, self.shape(), strides) {
1763 let slice = self
1764 .buffer
1765 .slice_mut(self.offset..self.offset + len)
1766 .unwrap();
1767 (slice, 0)
1768 } else {
1769 (self.buffer.as_slice_mut(), self.offset)
1770 }
1771 }
1772 pub fn to_device(&self, device: Device) -> Result<Tensor<T, D>> {
1776 if self.device() == device {
1777 self.to_owned()
1778 } else {
1779 self.view().into_device(device)
1780 }
1781 }
1782 pub fn to_device_shared(&self, device: Device) -> Result<ArcTensor<T, D>> {
1786 if self.device() == device {
1787 self.to_shared()
1788 } else {
1789 self.to_device(device).map(Into::into)
1790 }
1791 }
1792 pub fn to_device_mut(&mut self, device: Device) -> Result<()>
1796 where
1797 S: DataOwned,
1798 {
1799 if self.device() == device {
1800 return Ok(());
1801 }
1802 let Tensor {
1803 dim,
1804 strides,
1805 buffer,
1806 offset,
1807 } = self.to_device(device)?;
1808 *self = Self {
1809 dim,
1810 strides,
1811 buffer: BufferBase::from_buffer(buffer),
1812 offset,
1813 };
1814 Ok(())
1815 }
1816 pub fn into_device(self, device: Device) -> Result<Tensor<T, D>> {
1820 if device == self.device() {
1821 self.into_owned()
1822 } else if !self.is_contiguous() {
1823 self.as_standard_layout()?.to_device(device)
1824 } else {
1825 let buffer = self.buffer.to_device(device)?;
1826 Ok(Tensor {
1827 dim: self.dim,
1828 strides: self.strides,
1829 buffer,
1830 offset: 0,
1831 })
1832 }
1833 }
1834 pub fn into_device_shared(self, device: Device) -> Result<ArcTensor<T, D>> {
1838 if device == self.device() {
1839 self.into_shared()
1840 } else if !self.is_contiguous() {
1841 self.view()
1842 .into_standard_layout()?
1843 .into_device_shared(device)
1844 } else {
1845 let buffer = self.buffer.to_device_shared(device)?;
1846 Ok(ArcTensor {
1847 dim: self.dim,
1848 strides: self.strides,
1849 buffer,
1850 offset: 0,
1851 })
1852 }
1853 }
1854 pub fn into_owned(self) -> Result<Tensor<T, D>> {
1856 if !self.is_contiguous() {
1857 return self.into_standard_layout();
1858 }
1859 Ok(TensorBase {
1860 dim: self.dim,
1861 strides: self.strides,
1862 buffer: self.buffer.into_owned()?,
1863 offset: 0,
1864 })
1865 }
1866 pub fn to_owned(&self) -> Result<Tensor<T, D>> {
1868 self.view().into_owned()
1869 }
1870 pub fn into_shared(self) -> Result<ArcTensor<T, D>> {
1872 if self.is_contiguous() {
1873 Ok(TensorBase {
1874 dim: self.dim,
1875 strides: self.strides,
1876 buffer: self.buffer.into_shared()?,
1877 offset: self.offset,
1878 })
1879 } else {
1880 self.as_standard_layout()?.into_shared()
1881 }
1882 }
1883 pub fn to_shared(&self) -> Result<ArcTensor<T, D>> {
1887 if self.is_contiguous() {
1888 Ok(TensorBase {
1889 dim: self.dim.clone(),
1890 strides: self.strides.clone(),
1891 buffer: self.buffer.to_shared()?,
1892 offset: self.offset,
1893 })
1894 } else {
1895 self.to_owned()?.into_shared()
1896 }
1897 }
1898 pub fn fill(&mut self, elem: T) -> Result<()>
1905 where
1906 S: DataMut,
1907 {
1908 if self.is_contiguous() {
1909 self.buffer.as_slice_mut().fill(elem)
1910 } else if let Some(mut array) = self.as_array_mut() {
1911 array.fill(elem);
1912 Ok(())
1913 } else {
1914 bail!("TensorBase::fill tensor is not contiguous!")
1915 }
1916 }
1917 pub fn into_array(self) -> Result<Array<T, D>> {
1924 if self.is_contiguous() {
1925 use ndarray::ShapeBuilder;
1926
1927 let vec = self.buffer.into_vec()?;
1928 Ok(Array::from_shape_vec(self.dim.strides(self.strides), vec).unwrap())
1929 } else if let Some(array) = self.as_array() {
1930 Ok(array.into_owned())
1931 } else {
1932 bail!("TensorBase::into_array tensor is not contiguous!")
1933 }
1934 }
1935 pub fn as_array(&self) -> Option<ArrayView<T, D>> {
1937 use ndarray::ShapeBuilder;
1938
1939 self.buffer.as_host_slice().map(|host_slice| unsafe {
1940 ArrayView::from_shape_ptr(
1941 self.dim.clone().strides(self.strides.clone()),
1942 &host_slice[self.offset] as *const T,
1943 )
1944 })
1945 }
1946 pub fn as_array_mut(&mut self) -> Option<ArrayViewMut<T, D>>
1948 where
1949 S: DataMut,
1950 {
1951 use ndarray::ShapeBuilder;
1952
1953 if let Some(host_slice) = self.buffer.as_host_slice_mut() {
1954 let host_slice = unsafe {
1955 std::slice::from_raw_parts_mut(host_slice.as_mut_ptr(), host_slice.len())
1956 };
1957 Some(unsafe {
1958 ArrayViewMut::from_shape_ptr(
1959 self.dim.clone().strides(self.strides.clone()),
1960 host_slice[self.offset..].as_mut_ptr(),
1961 )
1962 })
1963 } else {
1964 None
1965 }
1966 }
1967}
1968
1969impl<T: Scalar, D: Dimension> Tensor<T, D> {
1970 pub fn into_scalar_tensor(self) -> ScalarTensor<D> {
1972 self.into()
1973 }
1974}
1975
1976impl<'a, T: Scalar, D: Dimension> CowTensor<'a, T, D> {
1977 pub fn into_scalar_cow_tensor(self) -> ScalarCowTensor<'a, D> {
1979 self.into()
1980 }
1981}
1982
1983impl<T: Scalar, D: Dimension> ArcTensor<T, D> {
1984 pub fn broadcast_shared<E>(&self, dim: E) -> Option<ArcTensor<T, E::Dim>>
1988 where
1989 E: IntoDimension,
1990 {
1991 let (dim, strides) = broadcast(&self.dim, &self.strides, dim)?;
1992 Some(ArcTensor {
1993 dim,
1994 strides,
1995 buffer: self.buffer.clone(),
1996 offset: self.offset,
1997 })
1998 }
1999}
2000
2001impl<T: Scalar, S: DataOwned<Elem = T>> From<Buffer<T>> for TensorBase<S, Ix1> {
2002 fn from(buffer: Buffer<T>) -> Self {
2003 let dim = buffer.len().into_dimension();
2004 let strides = dim.default_strides();
2005 let buffer = BufferBase::from_buffer(buffer);
2006 Self {
2007 dim,
2008 strides,
2009 buffer,
2010 offset: 0,
2011 }
2012 }
2013}
2014
2015impl<T: Scalar, S: DataOwned<Elem = T>> From<Vec<T>> for TensorBase<S, Ix1> {
2016 fn from(vec: Vec<T>) -> Self {
2017 let dim = vec.len().into_dimension();
2018 let strides = dim.default_strides();
2019 let buffer = BufferBase::from_buffer(Buffer::from(vec));
2020 Self {
2021 dim,
2022 strides,
2023 buffer,
2024 offset: 0,
2025 }
2026 }
2027}
2028
2029impl<'a, T: Scalar> From<Slice<'a, T>> for TensorView<'a, T, Ix1> {
2030 fn from(slice: Slice<'a, T>) -> Self {
2031 let dim = slice.len().into_dimension();
2032 let strides = dim.default_strides();
2033 Self {
2034 dim,
2035 strides,
2036 buffer: slice,
2037 offset: 0,
2038 }
2039 }
2040}
2041
2042impl<'a, T: Scalar> From<SliceMut<'a, T>> for TensorViewMut<'a, T, Ix1> {
2043 fn from(slice: SliceMut<'a, T>) -> Self {
2044 let dim = slice.len().into_dimension();
2045 let strides = dim.default_strides();
2046 Self {
2047 dim,
2048 strides,
2049 buffer: slice,
2050 offset: 0,
2051 }
2052 }
2053}
2054
2055impl<T: Scalar, S: DataOwned<Elem = T>, D: Dimension> From<Array<T, D>> for TensorBase<S, D> {
2056 fn from(array: Array<T, D>) -> Self {
2057 let dim = array.raw_dim();
2058 let strides = strides_from_array(&array);
2059 let buffer = BufferBase::from_vec(array.into_raw_vec());
2060 Self {
2061 dim,
2062 strides,
2063 buffer,
2064 offset: 0,
2065 }
2066 }
2067}
2068
2069impl<'a, T: Scalar, D: Dimension> From<ArrayView<'a, T, D>> for CowTensor<'a, T, D> {
2070 fn from(array: ArrayView<'a, T, D>) -> Self {
2071 if let Some(slice) = array.to_slice_memory_order() {
2072 let dim = array.raw_dim();
2073 let strides = strides_from_array(&array);
2074 let buffer = Slice::from(slice).into();
2075 Self {
2076 dim,
2077 strides,
2078 buffer,
2079 offset: 0,
2080 }
2081 } else {
2082 Self::from(array.to_owned())
2083 }
2084 }
2085}
2086
2087impl<'a, T: Scalar, D: Dimension> TryFrom<ArrayView<'a, T, D>> for TensorView<'a, T, D> {
2088 type Error = anyhow::Error;
2089 fn try_from(array: ArrayView<'a, T, D>) -> Result<Self> {
2092 let slice = array
2093 .as_slice_memory_order()
2094 .ok_or_else(|| anyhow!("Not contiguous!"))?;
2095 let slice = unsafe { std::slice::from_raw_parts(slice.as_ptr(), slice.len()) };
2097 let dim = array.raw_dim();
2098 let strides = strides_from_array(&array);
2099 Ok(Self {
2100 dim,
2101 strides,
2102 buffer: slice.into(),
2103 offset: 0,
2104 })
2105 }
2106}
2107
2108impl<'a, T: Scalar, D: Dimension> From<TensorView<'a, T, D>> for CowTensor<'a, T, D> {
2109 fn from(view: TensorView<'a, T, D>) -> Self {
2110 Self {
2111 dim: view.dim,
2112 strides: view.strides,
2113 buffer: view.buffer.into(),
2114 offset: view.offset,
2115 }
2116 }
2117}
2118
2119impl<T: Scalar, D: Dimension> From<Tensor<T, D>> for CowTensor<'_, T, D> {
2120 fn from(tensor: Tensor<T, D>) -> Self {
2121 Self {
2122 dim: tensor.dim,
2123 strides: tensor.strides,
2124 buffer: tensor.buffer.into(),
2125 offset: tensor.offset,
2126 }
2127 }
2128}
2129
2130impl<T: Scalar, D: Dimension> From<Tensor<T, D>> for ArcTensor<T, D> {
2131 fn from(tensor: Tensor<T, D>) -> Self {
2132 Self {
2133 dim: tensor.dim,
2134 strides: tensor.strides,
2135 buffer: tensor.buffer.into(),
2136 offset: tensor.offset,
2137 }
2138 }
2139}
2140
2141impl<S: Data, D: Dimension> Debug for TensorBase<S, D> {
2142 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2143 ScalarTensorView::from(self.view()).fmt(f)
2144 }
2145}
2146
2147impl<T: Scalar, S: Data<Elem = T>, D: Dimension> TensorBase<S, D> {
2149 pub fn cast_into<Y: Scalar>(self) -> Result<Tensor<Y, D>> {
2153 if T::SCALAR_TYPE == Y::SCALAR_TYPE && self.is_contiguous() {
2154 Ok(TensorBase {
2155 dim: self.dim,
2156 strides: self.strides,
2157 buffer: self.buffer.cast_into()?,
2158 offset: 0,
2159 })
2160 } else {
2161 self.cast()
2162 }
2163 }
2164 pub fn cast<Y: Scalar>(&self) -> Result<Tensor<Y, D>> {
2168 if !self.is_contiguous() {
2169 return self.scaled_cast(Y::one());
2170 }
2171 Ok(TensorBase {
2172 dim: self.dim.clone(),
2173 strides: self.strides.clone(),
2174 buffer: self.buffer.cast()?,
2175 offset: 0,
2176 })
2177 }
2178}
2179
2180#[cfg(feature = "serde")]
2181#[derive(Serialize, Deserialize)]
2182#[serde(bound(
2183 serialize = "S: Data, D: Dimension + Serialize",
2184 deserialize = "S: DataOwned, D: Dimension + Deserialize<'de>"
2185))]
2186#[serde(rename = "Tensor")]
2187struct TensorSerde<S: Data, D: Dimension> {
2188 dim: D,
2189 buffer: BufferBase<S>,
2190}
2191
2192#[cfg(feature = "serde")]
2193impl<S1: Data, D: Dimension + Serialize> Serialize for TensorBase<S1, D> {
2194 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
2195 where
2196 S: Serializer,
2197 {
2198 use serde::ser::Error;
2199 let buffer = if let Some(slice) = self.as_slice() {
2200 CowBuffer::from(slice)
2201 } else {
2202 self.to_device(Device::host())
2203 .map_err(S::Error::custom)?
2204 .buffer
2205 .into()
2206 };
2207 TensorSerde {
2208 dim: self.dim.clone(),
2209 buffer,
2210 }
2211 .serialize(serializer)
2212 }
2213}
2214
2215#[cfg(feature = "serde")]
2216impl<'de, S: DataOwned, D1: Dimension + Deserialize<'de>> Deserialize<'de> for TensorBase<S, D1> {
2217 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
2218 where
2219 D: Deserializer<'de>,
2220 {
2221 use serde::de::Error;
2222 let TensorSerde { dim, buffer } =
2223 TensorSerde::<BufferRepr<S::Elem>, D1>::deserialize(deserializer)?;
2224 TensorBase::from(buffer)
2225 .into_shape(dim)
2226 .map_err(D::Error::custom)
2227 }
2228}
2229
2230#[cfg(all(test, feature = "serde"))]
2231mod tests {
2232 use super::*;
2233 use serde_test::{assert_tokens, Token};
2234
2235 #[test]
2236 fn tensor_serde() {
2237 let data = vec![1u32, 2, 3, 4];
2238 let items: Vec<u64> = bytemuck::cast_slice(data.as_slice()).to_vec();
2239 let tensor = Tensor::from(Buffer::from(data));
2240 let tokens = [
2241 Token::Struct {
2242 name: "Tensor",
2243 len: 2,
2244 },
2245 Token::Str("dim"),
2246 Token::Tuple { len: 1 },
2247 Token::U64(4),
2248 Token::TupleEnd,
2249 Token::Str("buffer"),
2250 Token::TupleStruct {
2251 name: "Buffer",
2252 len: 3,
2253 },
2254 Token::Str("U32"),
2255 Token::U64(4),
2256 Token::Seq { len: Some(2) },
2257 Token::U64(items[0].to_be()),
2258 Token::U64(items[1].to_be()),
2259 Token::SeqEnd,
2260 Token::TupleStructEnd,
2261 Token::StructEnd,
2262 ];
2263
2264 #[derive(Debug, Serialize, Deserialize)]
2265 #[serde(transparent)]
2266 struct TensorWrap(Tensor1<u32>);
2267
2268 impl PartialEq for TensorWrap {
2269 fn eq(&self, other: &Self) -> bool {
2270 self.0.as_array().unwrap() == other.0.as_array().unwrap()
2271 }
2272 }
2273
2274 impl Eq for TensorWrap {}
2275
2276 assert_tokens(&TensorWrap(tensor), &tokens);
2277 }
2278}