autograph/
tensor.rs

1/*!
2```
3# use krnl::device::Device;
4use autograph::tensor::{Tensor, TensorView};
5use anyhow::Result;
6use ndarray::{Array, arr2, linalg::Dot};
7
8# fn main() -> Result<()> {
9# let device = Device::host();
10// Create a tensor from an array
11let a = Tensor::from(arr2(&[
12    [1f32, 2.],
13    [3., 4.],
14]))
15    // Moves to device, no copy if device is Device::host().
16    .into_device(device.clone())?;
17// Alternatively, create a tensor from an array view.
18let a = arr2(&[
19    [1f32, 2.],
20    [3., 4.],
21]);
22// This will fail if the view is not contiguous.
23let a = TensorView::try_from(a.view()).unwrap()
24    .to_device(device.clone())?;
25// Create a tensor from a vec, same as above.
26let b = Tensor::from(vec![5f32, 6., 7., 8.]).into_shape([2, 2]).unwrap()
27    .into_device(device.clone())?;
28// Compute a dot product (matrix multiplication) and move back to host.
29let c = a.dot(&b)?.into_device(Device::host())?;
30// Borrow as an array view.
31let c_view = c.as_array().unwrap();
32// Move into an array
33let c = c.into_array()?;
34
35# Ok(())
36# }
37```
38*/
39use anyhow::{anyhow, bail, Result};
40use dry::macro_for;
41#[cfg(feature = "serde")]
42use krnl::buffer::{CowBuffer, ScalarCowBuffer};
43#[cfg(feature = "device")]
44use krnl::krnl_core::half::bf16;
45#[cfg(doc)]
46use krnl::{buffer::ArcBuffer, device::error::DeviceLost};
47use krnl::{
48    buffer::{
49        ArcBufferRepr, Buffer, BufferBase, BufferRepr, CowBufferRepr, Data, DataMut, DataOwned,
50        ScalarArcBufferRepr, ScalarBuffer, ScalarBufferBase, ScalarBufferRepr, ScalarCowBufferRepr,
51        ScalarData, ScalarDataMut, ScalarDataOwned, ScalarSlice, ScalarSliceMut,
52        ScalarSliceMutRepr, ScalarSliceRepr, Slice, SliceMut, SliceMutRepr, SliceRepr,
53    },
54    device::Device,
55    scalar::{Scalar, ScalarElem, ScalarType},
56};
57use ndarray::{
58    Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Dimension, IntoDimension, Ix0, Ix1, Ix2, Ix3,
59    Ix4, Ix5, Ix6, IxDyn, RawArrayView, RemoveAxis, ShapeError, StrideShape,
60};
61#[cfg(feature = "device")]
62use num_traits::ToPrimitive;
63use paste::paste;
64#[cfg(feature = "serde")]
65use serde::{Deserialize, Deserializer, Serialize, Serializer};
66use std::fmt::{self, Debug};
67
68mod linalg;
69mod ops;
70pub(crate) mod parallel;
71mod reduce;
72
73fn strides_from_array<S, D>(array: &ArrayBase<S, D>) -> D
74where
75    S: ndarray::RawData,
76    D: Dimension,
77{
78    let strides_slice: &[usize] = bytemuck::cast_slice(array.strides());
79    let mut strides = D::zeros(strides_slice.len());
80    for (i, s) in strides_slice.iter().copied().enumerate() {
81        strides[i] = s;
82    }
83    strides
84}
85
86fn dim_strides_from_shape<D: Dimension>(shape: impl Into<StrideShape<D>>) -> (D, D) {
87    let array = unsafe { RawArrayView::from_shape_ptr(shape, &()) };
88    let dim = array.raw_dim();
89    let strides = strides_from_array(&array);
90    (dim, strides)
91}
92
93fn into_dimensionality<D1, D2>(dim: &D1, strides: &D1) -> Result<(D2, D2), ShapeError>
94where
95    D1: Dimension,
96    D2: Dimension,
97{
98    D2::from_dimension(dim)
99        .and_then(|dim| D2::from_dimension(strides).map(|strides| (dim, strides)))
100        .ok_or(ShapeError::from_kind(ndarray::ErrorKind::IncompatibleShape))
101}
102
103fn into_shape<D1, E>(dim: &D1, strides: &D1, shape: E) -> Result<(E::Dim, E::Dim), ShapeError>
104where
105    D1: Dimension,
106    E: IntoDimension,
107{
108    use ndarray::ErrorKind;
109
110    let shape = shape.into_dimension();
111    if size_of_shape_checked(&shape)? != dim.size() {
112        Err(ShapeError::from_kind(ErrorKind::IncompatibleShape))
113    } else if is_standard_layout(dim, strides) {
114        let strides = shape.default_strides();
115        Ok((shape, strides))
116    } else if is_fortran_layout(dim, strides) {
117        let strides = shape.fortran_strides();
118        Ok((shape, strides))
119    } else {
120        Err(ShapeError::from_kind(ErrorKind::IncompatibleLayout))
121    }
122}
123
124pub(crate) fn flatten(shape: &[usize]) -> [usize; 2] {
125    let mut iter = shape.iter().copied();
126    let rows = iter.next().unwrap_or(1);
127    let cols = iter.product();
128    [rows, cols]
129}
130
131fn is_contiguous<D: Dimension>(dim: &D, strides: &D) -> bool {
132    is_standard_layout(dim, strides) || is_fortran_layout(dim, strides)
133}
134
135fn is_standard_layout<D: Dimension>(dim: &D, strides: &D) -> bool {
136    debug_assert_eq!(dim.ndim(), strides.ndim());
137    for d in dim.slice().iter().copied() {
138        if d == 0 {
139            return true;
140        }
141    }
142    let mut acc = 1isize;
143    let strides: &[isize] = bytemuck::cast_slice(strides.slice());
144    for (d, s) in dim
145        .slice()
146        .iter()
147        .copied()
148        .zip(strides.iter().copied())
149        .rev()
150    {
151        if !(d == 1 || s == acc) {
152            return false;
153        }
154        acc *= d as isize;
155    }
156    true
157}
158
159fn is_fortran_layout<D: Dimension>(dim: &D, strides: &D) -> bool {
160    debug_assert_eq!(dim.ndim(), strides.ndim());
161    for d in dim.slice().iter().copied() {
162        if d == 0 {
163            return true;
164        }
165    }
166    let mut acc = 1;
167    for (d, s) in dim
168        .slice()
169        .iter()
170        .copied()
171        .zip(strides.slice().iter().copied())
172    {
173        if !(d == 1 || s == acc) {
174            return false;
175        }
176        acc *= d;
177    }
178    true
179}
180
181// adapted from https://docs.rs/ndarray/0.15.3/ndarray/struct.ArrayBase.html#method.permuted_axes
182fn permuted_axes<D: Dimension>(dim: D, strides: D, axes: D) -> (D, D) {
183    // Ensure that each axis is used exactly once.
184    let mut usage_counts = D::zeros(dim.ndim());
185    for axis in axes.slice() {
186        usage_counts[*axis] += 1;
187    }
188    for count in usage_counts.slice() {
189        assert_eq!(*count, 1, "each axis must be listed exactly once");
190    }
191    // Determine the new shape and strides.
192    let mut new_dim = usage_counts; // reuse to avoid an allocation
193    let mut new_strides = D::zeros(dim.ndim());
194    {
195        let dim = dim.slice();
196        let strides = strides.slice();
197        for (new_axis, &axis) in axes.slice().iter().enumerate() {
198            new_dim[new_axis] = dim[axis];
199            new_strides[new_axis] = strides[axis];
200        }
201    }
202    (new_dim, new_strides)
203}
204
205// adapted from https://docs.rs/crate/ndarray/0.15.6/source/src/dimension/mod.rs
206/// Returns the `size` of the `dim`, checking that the product of non-zero axis
207/// lengths does not exceed `isize::MAX`.
208///
209/// If `size_of_checked_shape(dim)` returns `Ok(size)`, the data buffer is a
210/// slice or `Vec` of length `size`, and `strides` are created with
211/// `self.default_strides()` or `self.fortran_strides()`, then the invariants
212/// are met to construct an array from the data buffer, `dim`, and `strides`.
213/// (The data buffer being a slice or `Vec` guarantees that it contains no more
214/// than `isize::MAX` bytes.)
215fn size_of_shape_checked<D: Dimension>(dim: &D) -> Result<usize, ShapeError> {
216    use ndarray::ErrorKind;
217    let size_nonzero = dim
218        .slice()
219        .iter()
220        .filter(|&&d| d != 0)
221        .try_fold(1usize, |acc, &d| acc.checked_mul(d))
222        .ok_or_else(|| ShapeError::from_kind(ErrorKind::Overflow))?;
223    if size_nonzero > isize::MAX as usize {
224        Err(ShapeError::from_kind(ErrorKind::Overflow))
225    } else {
226        Ok(dim.size())
227    }
228}
229
230// adapted from https://docs.rs/ndarray/0.15.6/ndarray/struct.ArrayBase.html#method.broadcast
231fn broadcast<D: Dimension, E: IntoDimension>(
232    from: &D,
233    strides: &D,
234    dim: E,
235) -> Option<(E::Dim, E::Dim)> {
236    /// Return new stride when trying to grow `from` into shape `to`
237    ///
238    /// Broadcasting works by returning a "fake stride" where elements
239    /// to repeat are in axes with 0 stride, so that several indexes point
240    /// to the same element.
241    ///
242    /// **Note:** Cannot be used for mutable iterators, since repeating
243    /// elements would create aliasing pointers.
244    fn upcast<D: Dimension, E: Dimension>(to: &D, from: &E, stride: &E) -> Option<D> {
245        // Make sure the product of non-zero axis lengths does not exceed
246        // `isize::MAX`. This is the only safety check we need to perform
247        // because all the other constraints of `ArrayBase` are guaranteed
248        // to be met since we're starting from a valid `ArrayBase`.
249        let _ = size_of_shape_checked(to).ok()?;
250
251        let mut new_stride = to.clone();
252        // begin at the back (the least significant dimension)
253        // size of the axis has to either agree or `from` has to be 1
254        if to.ndim() < from.ndim() {
255            return None;
256        }
257
258        {
259            let mut new_stride_iter = new_stride.slice_mut().iter_mut().rev();
260            for ((er, es), dr) in from
261                .slice()
262                .iter()
263                .rev()
264                .zip(stride.slice().iter().rev())
265                .zip(new_stride_iter.by_ref())
266            {
267                /* update strides */
268                if *dr == *er {
269                    /* keep stride */
270                    *dr = *es;
271                } else if *er == 1 {
272                    /* dead dimension, zero stride */
273                    *dr = 0
274                } else {
275                    return None;
276                }
277            }
278
279            /* set remaining strides to zero */
280            for dr in new_stride_iter {
281                *dr = 0;
282            }
283        }
284        Some(new_stride)
285    }
286    let dim = dim.into_dimension();
287
288    // Note: zero strides are safe precisely because we return an read-only view
289    let broadcast_strides = match upcast(&dim, from, strides) {
290        Some(st) => st,
291        None => return None,
292    };
293    Some((dim, broadcast_strides))
294}
295
296fn collapse_axis<D: Dimension>(dims: &mut D, strides: &D, Axis(axis): Axis, index: usize) -> isize {
297    let dim = dims[axis];
298    assert!(index < dim);
299    dims.slice_mut()[axis] = 1;
300    index as isize * strides[axis] as isize
301}
302
303fn tensor_buffer_len(offset: usize, shape: &[usize], strides: &[isize]) -> Option<usize> {
304    if shape.iter().any(|x| *x == 0) {
305        Some(0)
306    } else if strides.iter().any(|x| *x < 0) {
307        None
308    } else {
309        let len = (shape
310            .iter()
311            .zip(strides)
312            .map(|(d, s)| (*d as isize - 1) * *s)
313            .sum::<isize>()
314            + offset as isize
315            + 1)
316        .try_into()
317        .unwrap();
318        Some(len)
319    }
320}
321
322/// Dynamically typed multi-dimensional matrix.
323///
324/// Use [`TryInto`] to convert into a [`TensorBase`].
325/// Use [`From`] to convert from a [`TensorBase`].
326#[derive(Clone)]
327pub struct ScalarTensorBase<S: ScalarData, D: Dimension> {
328    dim: D,
329    strides: D,
330    buffer: ScalarBufferBase<S>,
331    offset: usize,
332}
333
334/// Owned Scalar Tensor.
335///
336/// See [`ScalarTensorBase`].
337pub type ScalarTensor<D> = ScalarTensorBase<ScalarBufferRepr, D>;
338/// ScalarTensor with 1 element.
339pub type ScalarTensor0 = ScalarTensor<Ix0>;
340/// ScalarTensor with 1 dimension.
341pub type ScalarTensor1 = ScalarTensor<Ix1>;
342/// ScalarTensor with 2 dimensions.
343pub type ScalarTensor2 = ScalarTensor<Ix2>;
344/// ScalarTensor with 3 dimensions.
345pub type ScalarTensor3 = ScalarTensor<Ix3>;
346/// ScalarTensor with 4 dimensions.
347pub type ScalarTensor4 = ScalarTensor<Ix4>;
348/// ScalarTensor with 5 dimensions.
349pub type ScalarTensor5 = ScalarTensor<Ix5>;
350/// ScalarTensor with 6 dimensions.
351pub type ScalarTensor6 = ScalarTensor<Ix6>;
352/// ScalarTensor with dynamic dimensions.
353pub type ScalarTensorD = ScalarTensor<IxDyn>;
354
355/// Shared Scalar Tensor.
356///
357/// See [`ScalarTensorBase`].
358pub type ScalarArcTensor<D> = ScalarTensorBase<ScalarArcBufferRepr, D>;
359/// ScalarArcTensor with 1 element.
360pub type ScalarArcTensor0 = ScalarArcTensor<Ix0>;
361/// ScalarArcTensor with 1 dimension.
362pub type ScalarArcTensor1 = ScalarArcTensor<Ix1>;
363/// ScalarArcTensor with 2 dimensions.
364pub type ScalarArcTensor2 = ScalarArcTensor<Ix2>;
365/// ScalarArcTensor with 3 dimensions.
366pub type ScalarArcTensor3 = ScalarArcTensor<Ix3>;
367/// ScalarArcTensor with 4 dimensions.
368pub type ScalarArcTensor4 = ScalarArcTensor<Ix4>;
369/// ScalarArcTensor with 5 dimensions.
370pub type ScalarArcTensor5 = ScalarArcTensor<Ix5>;
371/// ScalarArcTensor with 6 dimensions.
372pub type ScalarArcTensor6 = ScalarArcTensor<Ix6>;
373/// ScalarArcTensor with dynamic dimensions.
374pub type ScalarArcTensorD = ScalarArcTensor<IxDyn>;
375
376/// Borrowed Scalar Tensor.
377///
378/// See [`ScalarTensorBase`].
379pub type ScalarTensorView<'a, D> = ScalarTensorBase<ScalarSliceRepr<'a>, D>;
380/// ScalarTensorView with 1 element.
381pub type ScalarTensorView0<'a> = ScalarTensorView<'a, Ix0>;
382/// ScalarTensorView with 1 dimension.
383pub type ScalarTensorView1<'a> = ScalarTensorView<'a, Ix1>;
384/// ScalarTensorView with 2 dimensions.
385pub type ScalarTensorView2<'a> = ScalarTensorView<'a, Ix2>;
386/// ScalarTensorView with 3 dimensions.
387pub type ScalarTensorView3<'a> = ScalarTensorView<'a, Ix3>;
388/// ScalarTensorView with 4 dimensions.
389pub type ScalarTensorView4<'a> = ScalarTensorView<'a, Ix4>;
390/// ScalarTensorView with 5 dimensions.
391pub type ScalarTensorView5<'a> = ScalarTensorView<'a, Ix5>;
392/// ScalarTensorView with 6 dimensions.
393pub type ScalarTensorView6<'a> = ScalarTensorView<'a, Ix6>;
394/// ScalarTensorView with dynamic dimensions.
395pub type ScalarTensorViewD<'a> = ScalarTensorView<'a, IxDyn>;
396
397/// Mutably borrowed Scalar Tensor.
398///
399/// See [`ScalarTensorBase`].
400pub type ScalarTensorViewMut<'a, D> = ScalarTensorBase<ScalarSliceMutRepr<'a>, D>;
401/// ScalarTensorViewMut with 1 element.
402pub type ScalarTensorViewMut0<'a> = ScalarTensorViewMut<'a, Ix0>;
403/// ScalarTensorViewMut with 1 dimension.
404pub type ScalarTensorViewMut1<'a> = ScalarTensorViewMut<'a, Ix1>;
405/// ScalarTensorViewMut with 2 dimensions.
406pub type ScalarTensorViewMut2<'a> = ScalarTensorViewMut<'a, Ix2>;
407/// ScalarTensorViewMut with 3 dimensions.
408pub type ScalarTensorViewMut3<'a> = ScalarTensorViewMut<'a, Ix3>;
409/// ScalarTensorViewMut with 4 dimensions.
410pub type ScalarTensorViewMut4<'a> = ScalarTensorViewMut<'a, Ix4>;
411/// ScalarTensorViewMut with 5 dimensions.
412pub type ScalarTensorViewMut5<'a> = ScalarTensorViewMut<'a, Ix5>;
413/// ScalarTensorViewMut with 6 dimensions.
414pub type ScalarTensorViewMut6<'a> = ScalarTensorViewMut<'a, Ix6>;
415/// ScalarTensorViewMut with dynamic dimensions.
416pub type ScalarTensorViewMutD<'a> = ScalarTensorViewMut<'a, IxDyn>;
417
418/// Scalar Tensor that is either borrowed or owned.
419///
420/// See [`ScalarTensorBase`].
421pub type ScalarCowTensor<'a, D> = ScalarTensorBase<ScalarCowBufferRepr<'a>, D>;
422/// ScalarCowTensor with 1 element.
423pub type ScalarCowTensor0<'a> = ScalarCowTensor<'a, Ix0>;
424/// ScalarCowTensor with 1 dimension.
425pub type ScalarCowTensor1<'a> = ScalarCowTensor<'a, Ix1>;
426/// ScalarCowTensor with 2 dimensions.
427pub type ScalarCowTensor2<'a> = ScalarCowTensor<'a, Ix2>;
428/// ScalarCowTensor with 3 dimensions.
429pub type ScalarCowTensor3<'a> = ScalarCowTensor<'a, Ix3>;
430/// ScalarCowTensor with 4 dimensions.
431pub type ScalarCowTensor4<'a> = ScalarCowTensor<'a, Ix4>;
432/// ScalarCowTensor with 5 dimensions.
433pub type ScalarCowTensor5<'a> = ScalarCowTensor<'a, Ix5>;
434/// ScalarCowTensor with 6 dimensions.
435pub type ScalarCowTensor6<'a> = ScalarCowTensor<'a, Ix6>;
436/// ScalarCowTensor with dynamic dimensions.
437pub type ScalarCowTensorD<'a> = ScalarCowTensor<'a, IxDyn>;
438
439impl<S: ScalarDataOwned, D: Dimension> ScalarTensorBase<S, D> {
440    /// Allocates a scalar tensor on `device` with `shape`.
441    ///
442    /// # Safety
443    ///
444    /// The tensor is not initialized.
445    ///
446    /// # Errors
447    /// See [`ScalarBuffer::uninit()`].
448    pub unsafe fn uninit<Sh>(device: Device, shape: Sh, scalar_type: ScalarType) -> Result<Self>
449    where
450        Sh: ndarray::ShapeBuilder<Dim = D>,
451    {
452        let (dim, strides) = dim_strides_from_shape(shape.into_shape());
453        let buffer = unsafe { ScalarBufferBase::uninit(device, dim.size(), scalar_type)? };
454        Ok(Self {
455            dim,
456            strides,
457            buffer,
458            offset: 0,
459        })
460    }
461    /// Creates a tensor on `device` with `shape` filled with `elem`.
462    ///
463    /// # Errors
464    /// See [`ScalarBuffer::from_elem()`].
465    pub fn from_elem<Sh>(device: Device, shape: Sh, elem: ScalarElem) -> Result<Self>
466    where
467        Sh: ndarray::ShapeBuilder<Dim = D>,
468    {
469        let (dim, strides) = dim_strides_from_shape(shape.into_shape());
470        let buffer = ScalarBufferBase::from_elem(device, dim.size(), elem)?;
471        Ok(Self {
472            dim,
473            strides,
474            buffer,
475            offset: 0,
476        })
477    }
478    /// Creates a tensor on `device` with `shape` filled with 0's.
479    ///
480    /// # Errors
481    /// See [`ScalarBuffer::zeros()`].
482    pub fn zeros<Sh>(device: Device, shape: Sh, scalar_type: ScalarType) -> Result<Self>
483    where
484        Sh: ndarray::ShapeBuilder<Dim = D>,
485    {
486        Self::from_elem(device, shape, ScalarElem::zero(scalar_type))
487    }
488    /// Creates a tensor on `device` with `shape` filled with 1's.
489    ///
490    /// # Errors
491    /// See [`ScalarBuffer::ones()`].
492    pub fn ones<Sh>(device: Device, shape: Sh, scalar_type: ScalarType) -> Result<Self>
493    where
494        Sh: ndarray::ShapeBuilder<Dim = D>,
495    {
496        Self::from_elem(device, shape, ScalarElem::one(scalar_type))
497    }
498}
499
500impl<S: ScalarData, D: Dimension> ScalarTensorBase<S, D> {
501    /// The device of the tensor.
502    pub fn device(&self) -> Device {
503        self.buffer.device()
504    }
505    /// The scalar type of the tensor.
506    pub fn scalar_type(&self) -> ScalarType {
507        self.buffer.scalar_type()
508    }
509    /// The dimensions of the tensor in pattern form.
510    pub fn dim(&self) -> D::Pattern {
511        self.dim.clone().into_pattern()
512    }
513    /// The dimensions of the tensor.
514    pub fn raw_dim(&self) -> D {
515        self.dim.clone()
516    }
517    /// The dimensions of the tensor as a slice.
518    pub fn shape(&self) -> &[usize] {
519        self.dim.slice()
520    }
521    /// The strides of the tensor as a slice.
522    pub fn strides(&self) -> &[isize] {
523        bytemuck::cast_slice(self.strides.slice())
524    }
525    /// The length of the tensor.
526    pub fn len(&self) -> usize {
527        self.dim.size()
528    }
529    /// Whether the tensor is empty.
530    pub fn is_empty(&self) -> bool {
531        self.shape().iter().any(|x| *x == 0)
532    }
533    /// The dimensionality of the tensor.
534    pub fn ndim(&self) -> usize {
535        self.dim.ndim()
536    }
537    /// Converts the tensor into dimension `D2`.
538    ///
539    /// Typically this is used to downcast from [`IxDyn`](type@ndarray::IxDyn) to a static dimensionality. For conversions to [`IxDyn`](type@ndarray::IxDyn), use [`.into_dyn()`](TensorBase::into_dyn()).
540    ///
541    /// # Errors
542    /// - The number of axes of `D2` must be the same as `D`.
543    pub fn into_dimensionality<D2>(self) -> Result<ScalarTensorBase<S, D2>, ShapeError>
544    where
545        D2: Dimension,
546    {
547        let (dim, strides) = into_dimensionality(&self.dim, &self.strides)?;
548        Ok(ScalarTensorBase {
549            dim,
550            strides,
551            buffer: self.buffer,
552            offset: self.offset,
553        })
554    }
555    /// Converts the dimensionality of the tensor to [`IxDyn`](type@ndarray::IxDyn).
556    pub fn into_dyn(self) -> ScalarTensorBase<S, IxDyn> {
557        ScalarTensorBase {
558            dim: self.dim.into_dyn(),
559            strides: self.strides.into_dyn(),
560            buffer: self.buffer,
561            offset: self.offset,
562        }
563    }
564    /// Returns the tensor with dim `shape`.
565    ///
566    /// # Errors
567    /// The tensor must be contiguous, with default strides.
568    pub fn into_shape<E>(self, shape: E) -> Result<ScalarTensorBase<S, E::Dim>, ShapeError>
569    where
570        E: IntoDimension,
571    {
572        let shape = shape.into_dimension();
573        let (dim, strides) = into_shape(&self.dim, &self.strides, shape)?;
574        assert_eq!(self.offset, 0);
575        Ok(ScalarTensorBase {
576            dim,
577            strides,
578            buffer: self.buffer,
579            offset: self.offset,
580        })
581    }
582    /// Act like a larger size and/or shape array by *broadcasting* into a larger shape, if possible.
583    ///
584    /// See [`TensorBase::broadcast`].
585    pub fn broadcast<E>(&self, dim: E) -> Option<ScalarTensorView<E::Dim>>
586    where
587        E: IntoDimension,
588    {
589        let (dim, strides) = broadcast(&self.dim, &self.strides, dim)?;
590        Some(ScalarTensorView {
591            dim,
592            strides,
593            buffer: self.buffer.as_scalar_slice(),
594            offset: self.offset,
595        })
596    }
597    /// Borrows the tensor as a [`ScalarTensorView`].
598    pub fn view(&self) -> ScalarTensorView<D> {
599        ScalarTensorView {
600            dim: self.dim.clone(),
601            strides: self.strides.clone(),
602            buffer: self.buffer.as_scalar_slice(),
603            offset: self.offset,
604        }
605    }
606    /// Borrows the tensor as a [`ScalarTensorViewMut`].
607    pub fn view_mut(&mut self) -> ScalarTensorViewMut<D>
608    where
609        S: ScalarDataMut,
610    {
611        ScalarTensorViewMut {
612            dim: self.dim.clone(),
613            strides: self.strides.clone(),
614            buffer: self.buffer.as_scalar_slice_mut(),
615            offset: self.offset,
616        }
617    }
618    /// Mutably borrows the tensor as a mutable view if possible.
619    pub fn get_view_mut(&mut self) -> Option<ScalarTensorViewMut<D>> {
620        if self.offset == 0 && self.is_contiguous() {
621            let buffer = self.buffer.get_scalar_slice_mut()?;
622            Some(ScalarTensorViewMut {
623                dim: self.dim.clone(),
624                strides: self.strides.clone(),
625                buffer,
626                offset: 0,
627            })
628        } else {
629            None
630        }
631    }
632    /// Mutably borrows the tensor as a mutable view.
633    ///
634    /// See [`TensorBase::make_view_mut`].
635    pub fn make_view_mut(&mut self) -> Result<ScalarTensorViewMut<D>>
636    where
637        S: ScalarDataOwned,
638    {
639        if self.offset == 0 && self.is_contiguous() {
640            Ok(ScalarTensorViewMut {
641                dim: self.dim.clone(),
642                strides: self.strides.clone(),
643                buffer: self.buffer.make_scalar_slice_mut()?,
644                offset: 0,
645            })
646        } else {
647            let tensor = self.to_owned()?;
648            *self = Self {
649                dim: tensor.dim,
650                strides: tensor.strides,
651                buffer: ScalarBufferBase::from_scalar_buffer(tensor.buffer),
652                offset: 0,
653            };
654            Ok(ScalarTensorViewMut {
655                dim: self.dim.clone(),
656                strides: self.strides.clone(),
657                buffer: self.buffer.get_scalar_slice_mut().unwrap(),
658                offset: 0,
659            })
660        }
661    }
662    /// Whether the tensor is contiguous.
663    ///
664    /// Contiguous is either C (Standard) or Fortran layout.
665    pub fn is_contiguous(&self) -> bool {
666        is_contiguous(&self.dim, &self.strides)
667    }
668    /// Whether the tensor is standard layout.
669    ///
670    /// In standard layout, the strides increase from right to left by the product of each dimension.
671    pub fn is_standard_layout(&self) -> bool {
672        is_standard_layout(&self.dim, &self.strides)
673    }
674    /// Permute the axes of the tensor.
675    ///
676    /// Reorders the dimensions of the tensor, where for each a in `axes`, a is the index of that axis in the new tensor.
677    ///
678    /// # Note
679    /// This operation merely reorders the dimensions / strides and does not copy the data. Combine with [`.into_standard_layout()`](TensorBase::into_standard_layout()) to execute the operation, returning a tensor in standard layout.
680    ///
681    /// # Errors
682    /// Each axis 0 .. ndim must be used exactly once.
683    pub fn permuted_axes<A>(self, axes: A) -> Self
684    where
685        A: IntoDimension<Dim = D>,
686    {
687        let (dim, strides) = permuted_axes(self.dim, self.strides, axes.into_dimension());
688        Self {
689            dim,
690            strides,
691            ..self
692        }
693    }
694    /// Reverses (transposes) the axes of the tensor.
695    pub fn reversed_axes(mut self) -> Self {
696        self.dim.slice_mut().reverse();
697        self.strides.slice_mut().reverse();
698        self
699    }
700    /// Retunrs a view with reversed (transposed) axes.
701    pub fn t(&self) -> ScalarTensorView<D> {
702        self.view().reversed_axes()
703    }
704    /// Returns a view restricted to `index` along the `axis`, with the axis removed.
705    ///
706    /// See [`TensorBase::index_axis`].
707    pub fn index_axis(&self, axis: Axis, index: usize) -> ScalarTensorView<D::Smaller>
708    where
709        D: RemoveAxis,
710    {
711        self.view().index_axis_into(axis, index)
712    }
713    /// Returns a mutable view restricted to index along the `axis`, with the `axis` removed.
714    ///
715    /// See [`TensorBase::index_axis_mut`].
716    pub fn index_axis_mut(&mut self, axis: Axis, index: usize) -> ScalarTensorViewMut<D::Smaller>
717    where
718        S: ScalarDataMut,
719        D: RemoveAxis,
720    {
721        self.view_mut().index_axis_into(axis, index)
722    }
723    /// Returns a mutable view restricted to index along the `axis`, with the `axis` removed.
724    ///
725    /// See [`TensorBase::index_axis_into`].
726    pub fn index_axis_into(mut self, axis: Axis, index: usize) -> ScalarTensorBase<S, D::Smaller>
727    where
728        D: RemoveAxis,
729    {
730        self.collapse_axis(axis, index);
731        let dim = self.dim.remove_axis(axis);
732        let strides = self.strides.remove_axis(axis);
733        ScalarTensorBase {
734            dim,
735            strides,
736            buffer: self.buffer,
737            offset: self.offset,
738        }
739    }
740    /// Selects `index` along the `axis`, collapsing the axis into length one.
741    ///
742    /// # Panics
743    ///  `axis` or `index` is out of bounds.
744    ///
745    /// See [`TensorBase::collapse_axis`].
746    pub fn collapse_axis(&mut self, axis: Axis, index: usize) {
747        let offset =
748            collapse_axis(&mut self.dim, &self.strides, axis, index) + self.offset as isize;
749        debug_assert!(offset >= 0);
750        self.offset = offset as usize;
751        debug_assert!(self.offset < self.buffer.len());
752    }
753    /// Borrows the tensor as a [`ScalarSlice`] if standard layout.
754    pub fn as_scalar_slice(&self) -> Option<ScalarSlice> {
755        if self.is_standard_layout() {
756            let (slice, _offset) = self.as_raw_scalar_slice_offset();
757            Some(slice)
758        } else {
759            None
760        }
761    }
762    /// Borrows the tensor as a [`ScalarSlice`] if contiguous.
763    pub fn as_scalar_slice_memory_order(&self) -> Option<ScalarSlice> {
764        if self.is_contiguous() {
765            let (slice, _offset) = self.as_raw_scalar_slice_offset();
766            Some(slice)
767        } else {
768            None
769        }
770    }
771    /// Mutably borrows the tensor as a [`ScalarSliceMut`] if standard layout.
772    pub fn as_scalar_slice_mut(&mut self) -> Option<ScalarSliceMut>
773    where
774        S: ScalarDataMut,
775    {
776        if self.is_standard_layout() {
777            let (slice, _offset) = self.as_raw_scalar_slice_offset_mut();
778            Some(slice)
779        } else {
780            None
781        }
782    }
783    /// Mutably borrows the tensor as a [`ScalarSliceMut`] if contiguous.
784    pub fn as_scalar_slice_memory_order_mut(&mut self) -> Option<ScalarSliceMut>
785    where
786        S: ScalarDataMut,
787    {
788        if self.is_contiguous() {
789            let (slice, _offset) = self.as_raw_scalar_slice_offset_mut();
790            Some(slice)
791        } else {
792            None
793        }
794    }
795    /// Borrows the tensor as a slice and offset.
796    pub fn as_raw_scalar_slice_offset(&self) -> (ScalarSlice, usize) {
797        let strides: &[isize] = Self::strides(self);
798        if let Some(len) = tensor_buffer_len(self.offset, self.shape(), strides) {
799            let slice = self.buffer.slice(self.offset..self.offset + len).unwrap();
800            (slice, 0)
801        } else {
802            (self.buffer.as_scalar_slice(), self.offset)
803        }
804    }
805    /// Mutably borrows the tensor as a mutable slice and offset.
806    pub fn as_raw_scalar_slice_offset_mut(&mut self) -> (ScalarSliceMut, usize)
807    where
808        S: ScalarDataMut,
809    {
810        let strides: &[isize] = Self::strides(self);
811        if let Some(len) = tensor_buffer_len(self.offset, self.shape(), strides) {
812            let slice = self
813                .buffer
814                .slice_mut(self.offset..self.offset + len)
815                .unwrap();
816            (slice, 0)
817        } else {
818            (self.buffer.as_scalar_slice_mut(), self.offset)
819        }
820    }
821    /// Transfers the tensor into the `device`.
822    ///
823    /// See [`TensorBase::into_device`].
824    pub fn into_device(self, device: Device) -> Result<ScalarTensor<D>> {
825        if self.device() == device {
826            self.into_owned()
827        } else if let Some(slice) = self.as_scalar_slice_memory_order() {
828            let buffer = slice.to_device(device)?;
829            Ok(ScalarTensor {
830                dim: self.dim,
831                strides: self.strides,
832                buffer,
833                offset: 0,
834            })
835        } else {
836            self.into_owned()?.into_device(device)
837        }
838    }
839    /// Transfers the tensor to the `device`.
840    ///
841    /// See [`Tensor::to_device`].
842    pub fn to_device(&self, device: Device) -> Result<ScalarTensor<D>> {
843        if self.device() == device {
844            self.to_owned()
845        } else {
846            self.view().into_device(device)
847        }
848    }
849    /// Transfers the tensor into the `device` in place.
850    ///
851    /// See [`Tensor::to_device_mut`].
852    pub fn to_device_mut(&mut self, device: Device) -> Result<()>
853    where
854        S: ScalarDataOwned,
855    {
856        if self.device() == device {
857            return Ok(());
858        }
859        let ScalarTensor {
860            dim,
861            strides,
862            buffer,
863            offset,
864        } = self.to_device(device)?;
865        *self = Self {
866            dim,
867            strides,
868            buffer: ScalarBufferBase::from_scalar_buffer(buffer),
869            offset,
870        };
871        Ok(())
872    }
873    /// Transfers the tensor into the `device` as a scalar arc tensor.
874    ///
875    /// See [`TensorBase::into_device_shared`].
876    pub fn into_device_shared(self, device: Device) -> Result<ScalarArcTensor<D>> {
877        if self.device() == device {
878            self.into_shared()
879        } else {
880            self.to_device(device).map(Into::into)
881        }
882    }
883    /// Transfers the tensor to the `device` as a scalar arc tensor.
884    ///
885    /// See [`TensorBase::to_device_shared`].
886    pub fn to_device_shared(&self, device: Device) -> Result<ScalarArcTensor<D>> {
887        if device == self.device() {
888            self.to_shared()
889        } else {
890            self.to_device(device).map(Into::into)
891        }
892    }
893    /// Converts into a [`ScalarTensor`].
894    pub fn into_owned(self) -> Result<ScalarTensor<D>> {
895        if self.offset == 0 && self.is_contiguous() {
896            return Ok(ScalarTensorBase {
897                dim: self.dim,
898                strides: self.strides,
899                buffer: self.buffer.into_owned()?,
900                offset: 0,
901            });
902        }
903        if let Some(slice) = self.as_scalar_slice_memory_order() {
904            let buffer = slice.to_owned()?;
905            return Ok(ScalarTensorBase {
906                dim: self.dim,
907                strides: self.strides,
908                buffer,
909                offset: 0,
910            });
911        }
912        let mut output =
913            unsafe { ScalarTensor::uninit(self.device(), self.raw_dim(), self.scalar_type())? };
914        output.assign(&self)?;
915        Ok(output)
916    }
917    /// Converts to a [`ScalarTensor`].
918    pub fn to_owned(&self) -> Result<ScalarTensor<D>> {
919        self.view().into_owned()
920    }
921    /// Converts into an [`ScalarArcTensor`].
922    pub fn into_shared(self) -> Result<ScalarArcTensor<D>> {
923        if self.offset == 0 && self.is_contiguous() {
924            Ok(ScalarTensorBase {
925                dim: self.dim,
926                strides: self.strides,
927                buffer: self.buffer.into_shared()?,
928                offset: 0,
929            })
930        } else {
931            self.as_standard_layout()?.into_shared()
932        }
933    }
934    /// Converts to an [`ScalarArcTensor`].
935    pub fn to_shared(&self) -> Result<ScalarArcTensor<D>> {
936        if !self.is_contiguous() {
937            return self.as_standard_layout()?.to_shared();
938        }
939        Ok(ScalarTensorBase {
940            dim: self.dim.clone(),
941            strides: self.strides.clone(),
942            buffer: self.buffer.to_shared()?,
943            offset: 0,
944        })
945    }
946}
947
948impl<D: Dimension> ScalarTensor<D> {
949    /// Attempt to convert to a tensor.
950    pub fn try_into_tensor<T: Scalar>(self) -> Result<Tensor<T, D>, Self> {
951        self.try_into()
952    }
953}
954
955impl<D: Dimension> ScalarArcTensor<D> {
956    /// Attempt to convert to an arc tensor.
957    pub fn try_into_arc_tensor<T: Scalar>(self) -> Result<ArcTensor<T, D>, Self> {
958        self.try_into()
959    }
960}
961
962impl<'a, D: Dimension> ScalarTensorView<'a, D> {
963    /// Attempt to convert to a tensor view.
964    pub fn try_into_tensor_view<T: Scalar>(self) -> Result<TensorView<'a, T, D>, Self> {
965        self.try_into()
966    }
967}
968
969impl<'a, D: Dimension> ScalarTensorViewMut<'a, D> {
970    /// Attempt to convert to a mutable tensor view.
971    pub fn try_into_tensor_view_mut<T: Scalar>(self) -> Result<TensorViewMut<'a, T, D>, Self> {
972        self.try_into()
973    }
974}
975
976impl<D: Dimension> ScalarArcTensor<D> {
977    /// Act like a larger size and/or shape array by *broadcasting* into a larger shape, if possible.
978    ///
979    /// See [`ArcTensor::broadcast_shared()`].
980    pub fn broadcast_shared<E>(&self, dim: E) -> Option<ScalarArcTensor<E::Dim>>
981    where
982        E: IntoDimension,
983    {
984        let (dim, strides) = broadcast(&self.dim, &self.strides, dim)?;
985        Some(ScalarArcTensor {
986            dim,
987            strides,
988            buffer: self.buffer.clone(),
989            offset: self.offset,
990        })
991    }
992}
993
994impl<S: ScalarDataOwned> From<ScalarBuffer> for ScalarTensorBase<S, Ix1> {
995    fn from(buffer: ScalarBuffer) -> Self {
996        let dim = buffer.len().into_dimension();
997        let strides = dim.default_strides();
998        let buffer = ScalarBufferBase::from_scalar_buffer(buffer);
999        Self {
1000            dim,
1001            strides,
1002            buffer,
1003            offset: 0,
1004        }
1005    }
1006}
1007
1008impl<S: ScalarDataOwned, T: Scalar, D: Dimension> From<Tensor<T, D>> for ScalarTensorBase<S, D> {
1009    fn from(tensor: Tensor<T, D>) -> Self {
1010        Self {
1011            dim: tensor.dim,
1012            strides: tensor.strides,
1013            buffer: tensor.buffer.into(),
1014            offset: tensor.offset,
1015        }
1016    }
1017}
1018
1019impl<D: Dimension> From<ScalarTensor<D>> for ScalarArcTensor<D> {
1020    fn from(tensor: ScalarTensor<D>) -> Self {
1021        Self {
1022            dim: tensor.dim,
1023            strides: tensor.strides,
1024            buffer: tensor.buffer.into(),
1025            offset: tensor.offset,
1026        }
1027    }
1028}
1029
1030impl<T: Scalar, D: Dimension> From<ArcTensor<T, D>> for ScalarArcTensor<D> {
1031    fn from(tensor: ArcTensor<T, D>) -> Self {
1032        Self {
1033            dim: tensor.dim,
1034            strides: tensor.strides,
1035            buffer: tensor.buffer.into(),
1036            offset: tensor.offset,
1037        }
1038    }
1039}
1040
1041impl<D: Dimension> From<ScalarTensor<D>> for ScalarCowTensor<'_, D> {
1042    fn from(tensor: ScalarTensor<D>) -> Self {
1043        Self {
1044            dim: tensor.dim,
1045            strides: tensor.strides,
1046            buffer: tensor.buffer.into(),
1047            offset: tensor.offset,
1048        }
1049    }
1050}
1051
1052impl<'a, D: Dimension> From<ScalarTensorView<'a, D>> for ScalarCowTensor<'a, D> {
1053    fn from(tensor: ScalarTensorView<'a, D>) -> Self {
1054        Self {
1055            dim: tensor.dim,
1056            strides: tensor.strides,
1057            buffer: tensor.buffer.into(),
1058            offset: tensor.offset,
1059        }
1060    }
1061}
1062
1063macro_for!($Tensor in [Tensor, ArcTensor] {
1064    paste! {
1065        impl<T: Scalar, D: Dimension> TryFrom<[<Scalar $Tensor>]<D>> for $Tensor<T, D> {
1066            type Error = [<Scalar $Tensor>]<D>;
1067            fn try_from(tensor: [<Scalar $Tensor>]<D>) -> Result<Self, Self::Error> {
1068                match tensor.buffer.try_into() {
1069                    Ok(buffer) => Ok(Self {
1070                        dim: tensor.dim,
1071                        strides: tensor.strides,
1072                        buffer,
1073                        offset: tensor.offset,
1074                    }),
1075                    Err(buffer) => Err(Self::Error {
1076                        dim: tensor.dim,
1077                        strides: tensor.strides,
1078                        buffer,
1079                        offset: tensor.offset,
1080                    })
1081                }
1082            }
1083        }
1084    }
1085});
1086
1087macro_for!($Tensor in [TensorView, TensorViewMut, CowTensor] {
1088    paste! {
1089        impl<'a, T: Scalar, D: Dimension> From<$Tensor<'a, T, D>> for [<Scalar $Tensor>]<'a, D> {
1090            fn from(tensor: $Tensor<'a, T, D>) -> Self {
1091                Self {
1092                    dim: tensor.dim,
1093                    strides: tensor.strides,
1094                    buffer: tensor.buffer.into(),
1095                    offset: tensor.offset,
1096                }
1097            }
1098        }
1099        impl<'a, T: Scalar, D: Dimension> TryFrom<[<Scalar $Tensor>]<'a, D>> for $Tensor<'a, T, D> {
1100            type Error = [<Scalar $Tensor>]<'a, D>;
1101            fn try_from(tensor: [<Scalar $Tensor>]<'a, D>) -> Result<Self, Self::Error> {
1102                match tensor.buffer.try_into() {
1103                    Ok(buffer) => Ok(Self {
1104                        dim: tensor.dim,
1105                        strides: tensor.strides,
1106                        buffer,
1107                        offset: tensor.offset,
1108                    }),
1109                    Err(buffer) => Err(Self::Error {
1110                        dim: tensor.dim,
1111                        strides: tensor.strides,
1112                        buffer,
1113                        offset: tensor.offset,
1114                    })
1115                }
1116            }
1117        }
1118    }
1119});
1120
1121impl<S: ScalarData, D: Dimension> Debug for ScalarTensorBase<S, D> {
1122    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1123        let mut builder = f.debug_struct("TensorBase");
1124        builder
1125            .field("device", &self.device())
1126            .field("scalar_type", &self.scalar_type())
1127            .field("shape", &self.shape());
1128        if self.strides != self.dim.default_strides() {
1129            builder.field("strides", &self.strides());
1130        }
1131        if self.offset > 0 {
1132            builder.field("offset", &self.offset);
1133        }
1134        builder.finish()
1135    }
1136}
1137
1138/// Casts
1139impl<S: ScalarData, D: Dimension> ScalarTensorBase<S, D> {
1140    /// Casts the tensor into a new tensor.
1141    ///
1142    /// See [`BufferBase::cast_into()`].
1143    pub fn cast_into(self, scalar_type: ScalarType) -> Result<ScalarTensor<D>> {
1144        if self.scalar_type() == scalar_type {
1145            self.into_owned()
1146        } else {
1147            self.cast(scalar_type)
1148        }
1149    }
1150    /// Casts the tensor to a new tensor.
1151    ///
1152    /// See [`BufferBase::cast()`].
1153    pub fn cast(&self, scalar_type: ScalarType) -> Result<ScalarTensor<D>> {
1154        if self.scalar_type() == scalar_type {
1155            self.to_owned()
1156        } else if !self.is_contiguous() {
1157            self.scaled_cast(ScalarElem::one(scalar_type))
1158        } else {
1159            Ok(ScalarTensorBase {
1160                dim: self.dim.clone(),
1161                strides: self.strides.clone(),
1162                buffer: self.buffer.cast(scalar_type)?,
1163                offset: 0,
1164            })
1165        }
1166    }
1167    /// Casts the tensor in place.
1168    ///
1169    /// See [`BufferBase::cast()`].
1170    pub fn cast_mut(&mut self, scalar_type: ScalarType) -> Result<()>
1171    where
1172        S: ScalarDataOwned,
1173    {
1174        if self.scalar_type() == scalar_type {
1175            return Ok(());
1176        }
1177        let ScalarTensor {
1178            dim,
1179            strides,
1180            buffer,
1181            offset,
1182        } = self.cast(scalar_type)?;
1183        *self = Self {
1184            dim,
1185            strides,
1186            buffer: ScalarBufferBase::from_scalar_buffer(buffer),
1187            offset,
1188        };
1189        Ok(())
1190    }
1191    /// Casts the tensor into a new tensor.
1192    ///
1193    /// See [`BufferBase::cast()`].
1194    pub fn cast_into_tensor<T: Scalar>(self) -> Result<Tensor<T, D>> {
1195        Ok(self.cast_into(T::SCALAR_TYPE)?.try_into().unwrap())
1196    }
1197}
1198
1199#[cfg(feature = "serde")]
1200#[derive(Serialize, Deserialize)]
1201#[serde(bound(
1202    serialize = "S: ScalarData, D: Dimension + Serialize",
1203    deserialize = "S: ScalarDataOwned, D: Dimension + Deserialize<'de>"
1204))]
1205#[serde(rename = "Tensor")]
1206struct ScalarTensorSerde<S: ScalarData, D: Dimension> {
1207    dim: D,
1208    buffer: ScalarBufferBase<S>,
1209}
1210
1211#[cfg(feature = "serde")]
1212impl<S1: ScalarData, D: Dimension + Serialize> Serialize for ScalarTensorBase<S1, D> {
1213    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1214    where
1215        S: Serializer,
1216    {
1217        use serde::ser::Error;
1218        let buffer = if let Some(slice) = self.as_scalar_slice() {
1219            ScalarCowBuffer::from(slice)
1220        } else {
1221            self.to_device(Device::host())
1222                .map_err(S::Error::custom)?
1223                .buffer
1224                .into()
1225        };
1226        ScalarTensorSerde {
1227            dim: self.dim.clone(),
1228            buffer,
1229        }
1230        .serialize(serializer)
1231    }
1232}
1233
1234#[cfg(feature = "serde")]
1235impl<'de, S: ScalarDataOwned, D1: Dimension + Deserialize<'de>> Deserialize<'de>
1236    for ScalarTensorBase<S, D1>
1237{
1238    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1239    where
1240        D: Deserializer<'de>,
1241    {
1242        use serde::de::Error;
1243        let ScalarTensorSerde { dim, buffer } =
1244            ScalarTensorSerde::<ScalarBufferRepr, D1>::deserialize(deserializer)?;
1245        ScalarTensorBase::from(buffer)
1246            .into_shape(dim)
1247            .map_err(D::Error::custom)
1248    }
1249}
1250
1251/// Multi-dimensional matrix.
1252///
1253/// Use [`Into`] to convert to a [`ScalarTensorBase`].
1254/// Use [`TryFrom`] to convert from a [`ScalarTensorBase`].
1255#[derive(Clone)]
1256pub struct TensorBase<S: Data, D: Dimension> {
1257    dim: D,
1258    strides: D,
1259    buffer: BufferBase<S>,
1260    offset: usize,
1261}
1262
1263/// Owned Tensor.
1264///
1265/// See [`TensorBase`].
1266pub type Tensor<T, D> = TensorBase<BufferRepr<T>, D>;
1267/// Tensor with 1 element.
1268pub type Tensor0<T> = Tensor<T, Ix0>;
1269/// Tensor with 1 dimension.
1270pub type Tensor1<T> = Tensor<T, Ix1>;
1271/// Tensor with 2 dimensions.
1272pub type Tensor2<T> = Tensor<T, Ix2>;
1273/// Tensor with 3 dimensions.
1274pub type Tensor3<T> = Tensor<T, Ix3>;
1275/// Tensor with 4 dimensions.
1276pub type Tensor4<T> = Tensor<T, Ix4>;
1277/// Tensor with 5 dimensions.
1278pub type Tensor5<T> = Tensor<T, Ix5>;
1279/// Tensor with 6 dimensions.
1280pub type Tensor6<T> = Tensor<T, Ix6>;
1281/// Tensor with dynamic dimensions.
1282pub type TensorD<T> = Tensor<T, IxDyn>;
1283
1284/// Shared Tensor.
1285///
1286/// See [`TensorBase`].
1287pub type ArcTensor<T, D> = TensorBase<ArcBufferRepr<T>, D>;
1288/// ArcTensor with 1 element.
1289pub type ArcTensor0<T> = ArcTensor<T, Ix0>;
1290/// ArcTensor with 1 dimension.
1291pub type ArcTensor1<T> = ArcTensor<T, Ix1>;
1292/// ArcTensor with 2 dimensions.
1293pub type ArcTensor2<T> = ArcTensor<T, Ix2>;
1294/// ArcTensor with 3 dimensions.
1295pub type ArcTensor3<T> = ArcTensor<T, Ix3>;
1296/// ArcTensor with 4 dimensions.
1297pub type ArcTensor4<T> = ArcTensor<T, Ix4>;
1298/// ArcTensor with 5 dimensions.
1299pub type ArcTensor5<T> = ArcTensor<T, Ix5>;
1300/// ArcTensor with 6 dimensions.
1301pub type ArcTensor6<T> = ArcTensor<T, Ix6>;
1302/// ArcTensor with dynamic dimensions.
1303pub type ArcTensorD<T> = ArcTensor<T, IxDyn>;
1304
1305/// Borrowed Tensor.
1306///
1307/// See [`TensorBase`].
1308pub type TensorView<'a, T, D> = TensorBase<SliceRepr<'a, T>, D>;
1309/// TensorView with 1 element.
1310pub type TensorView0<'a, T> = TensorView<'a, T, Ix0>;
1311/// TensorView with 1 dimension.
1312pub type TensorView1<'a, T> = TensorView<'a, T, Ix1>;
1313/// TensorView with 2 dimensions.
1314pub type TensorView2<'a, T> = TensorView<'a, T, Ix2>;
1315/// TensorView with 3 dimensions.
1316pub type TensorView3<'a, T> = TensorView<'a, T, Ix3>;
1317/// TensorView with 4 dimensions.
1318pub type TensorView4<'a, T> = TensorView<'a, T, Ix4>;
1319/// TensorView with 5 dimensions.
1320pub type TensorView5<'a, T> = TensorView<'a, T, Ix5>;
1321/// TensorView with 6 dimensions.
1322pub type TensorView6<'a, T> = TensorView<'a, T, Ix6>;
1323/// TensorView with dynamic dimensions.
1324pub type TensorViewD<'a, T> = TensorView<'a, T, IxDyn>;
1325
1326/// Mutably borrowed Tensor.
1327///
1328/// See [`TensorBase`].
1329pub type TensorViewMut<'a, T, D> = TensorBase<SliceMutRepr<'a, T>, D>;
1330/// TensorViewMut with 1 element.
1331pub type TensorViewMut0<'a, T> = TensorViewMut<'a, T, Ix0>;
1332/// TensorViewMut with 1 dimension.
1333pub type TensorViewMut1<'a, T> = TensorViewMut<'a, T, Ix1>;
1334/// TensorViewMut with 2 dimensions.
1335pub type TensorViewMut2<'a, T> = TensorViewMut<'a, T, Ix2>;
1336/// TensorViewMut with 3 dimensions.
1337pub type TensorViewMut3<'a, T> = TensorViewMut<'a, T, Ix3>;
1338/// TensorViewMut with 4 dimensions.
1339pub type TensorViewMut4<'a, T> = TensorViewMut<'a, T, Ix4>;
1340/// TensorViewMut with 5 dimensions.
1341pub type TensorViewMut5<'a, T> = TensorViewMut<'a, T, Ix5>;
1342/// TensorViewMut with 6 dimensions.
1343pub type TensorViewMut6<'a, T> = TensorViewMut<'a, T, Ix6>;
1344/// TensorViewMut with dynamic dimensions.
1345pub type TensorViewMutD<'a, T> = TensorViewMut<'a, T, IxDyn>;
1346
1347/// Tensor that is either borrowed or owned.
1348///
1349/// See [`TensorBase`].
1350pub type CowTensor<'a, T, D> = TensorBase<CowBufferRepr<'a, T>, D>;
1351/// CowTensor with 1 element.
1352pub type CowTensor0<'a, T> = CowTensor<'a, T, Ix0>;
1353/// CowTensor with 1 dimension.
1354pub type CowTensor1<'a, T> = CowTensor<'a, T, Ix1>;
1355/// CowTensor with 2 dimensions.
1356pub type CowTensor2<'a, T> = CowTensor<'a, T, Ix2>;
1357/// CowTensor with 3 dimensions.
1358pub type CowTensor3<'a, T> = CowTensor<'a, T, Ix3>;
1359/// CowTensor with 4 dimensions.
1360pub type CowTensor4<'a, T> = CowTensor<'a, T, Ix4>;
1361/// CowTensor with 5 dimensions.
1362pub type CowTensor5<'a, T> = CowTensor<'a, T, Ix5>;
1363/// CowTensor with 6 dimensions.
1364pub type CowTensor6<'a, T> = CowTensor<'a, T, Ix6>;
1365/// CowTensor with dynamic dimensions.
1366pub type CowTensorD<'a, T> = CowTensor<'a, T, IxDyn>;
1367
1368impl<T: Scalar, S: DataOwned<Elem = T>, D: Dimension> TensorBase<S, D> {
1369    /// Allocates a tensor on `device` with `shape`.
1370    ///
1371    /// # Safety
1372    ///
1373    /// The tensor is not initialized.
1374    ///
1375    /// # Errors
1376    /// See [`Buffer::uninit()`].
1377    pub unsafe fn uninit<Sh>(device: Device, shape: Sh) -> Result<Self>
1378    where
1379        Sh: ndarray::ShapeBuilder<Dim = D>,
1380    {
1381        let (dim, strides) = dim_strides_from_shape(shape.into_shape());
1382        let buffer = unsafe { BufferBase::uninit(device, dim.size())? };
1383        Ok(Self {
1384            dim,
1385            strides,
1386            buffer,
1387            offset: 0,
1388        })
1389    }
1390    /// Creates a tensor on `device` with `shape` filled with `elem`.
1391    ///
1392    /// # Errors
1393    /// See [`Buffer::from_elem()`].
1394    pub fn from_elem<Sh>(device: Device, shape: Sh, elem: T) -> Result<Self>
1395    where
1396        Sh: ndarray::ShapeBuilder<Dim = D>,
1397    {
1398        let (dim, strides) = dim_strides_from_shape(shape.into_shape());
1399        let buffer = BufferBase::from_elem(device, dim.size(), elem)?;
1400        Ok(Self {
1401            dim,
1402            strides,
1403            buffer,
1404            offset: 0,
1405        })
1406    }
1407    /// Creates a tensor on `device` with `shape` filled with 0's.
1408    ///
1409    /// # Errors
1410    /// See [`Buffer::zeros()`].
1411    pub fn zeros<Sh>(device: Device, shape: Sh) -> Result<Self>
1412    where
1413        Sh: ndarray::ShapeBuilder<Dim = D>,
1414    {
1415        Self::from_elem(device, shape, T::default())
1416    }
1417    /// Creates a tensor on `device` with `shape` filled with 1's.
1418    ///
1419    /// # Errors
1420    /// See [`Buffer::ones()`].
1421    pub fn ones<Sh>(device: Device, shape: Sh) -> Result<Self>
1422    where
1423        Sh: ndarray::ShapeBuilder<Dim = D>,
1424    {
1425        Self::from_elem(device, shape, T::one())
1426    }
1427}
1428
1429impl<T: Scalar, S: Data<Elem = T>, D: Dimension> TensorBase<S, D> {
1430    /// The device of the tensor.
1431    pub fn device(&self) -> Device {
1432        self.buffer.device()
1433    }
1434    /// The scalar type of the tensor.
1435    pub fn scalar_type(&self) -> ScalarType {
1436        T::SCALAR_TYPE
1437    }
1438    /// The dimensions of the tensor in pattern form.
1439    pub fn dim(&self) -> D::Pattern {
1440        self.dim.clone().into_pattern()
1441    }
1442    /// The dimensions of the tensor.
1443    pub fn raw_dim(&self) -> D {
1444        self.dim.clone()
1445    }
1446    /// The dimensions of the tensor as a slice.
1447    pub fn shape(&self) -> &[usize] {
1448        self.dim.slice()
1449    }
1450    /// The strides of the tensor as a slice.
1451    pub fn strides(&self) -> &[isize] {
1452        bytemuck::cast_slice(self.strides.slice())
1453    }
1454    /// The length of the tensor.
1455    pub fn len(&self) -> usize {
1456        self.dim.size()
1457    }
1458    /// Whether the tensor is empty.
1459    pub fn is_empty(&self) -> bool {
1460        self.shape().iter().any(|x| *x == 0)
1461    }
1462    /// The dimensionality of the tensor.
1463    pub fn ndim(&self) -> usize {
1464        self.dim.ndim()
1465    }
1466    /// Converts the tensor into dimension `D2`.
1467    ///
1468    /// Typically this is used to downcast from [`IxDyn`](type@ndarray::IxDyn) to a static dimensionality. For conversions to [`IxDyn`](type@ndarray::IxDyn), use [`.into_dyn()`](TensorBase::into_dyn()).
1469    ///
1470    /// # Errors
1471    /// The number of axes of `D2` must be the same as `D`.
1472    pub fn into_dimensionality<D2>(self) -> Result<TensorBase<S, D2>, ShapeError>
1473    where
1474        D2: Dimension,
1475    {
1476        let (dim, strides) = into_dimensionality(&self.dim, &self.strides)?;
1477        Ok(TensorBase {
1478            dim,
1479            strides,
1480            buffer: self.buffer,
1481            offset: self.offset,
1482        })
1483    }
1484    /// Converts the dimensionality of the tensor to [`IxDyn`](type@ndarray::IxDyn).
1485    pub fn into_dyn(self) -> TensorBase<S, IxDyn> {
1486        TensorBase {
1487            dim: self.dim.into_dyn(),
1488            strides: self.strides.into_dyn(),
1489            buffer: self.buffer,
1490            offset: self.offset,
1491        }
1492    }
1493    /// Returns the tensor with dim `shape`.
1494    ///
1495    /// # Errors
1496    /// The tensor must be contiguous, with default strides.
1497    pub fn into_shape<E>(self, shape: E) -> Result<TensorBase<S, E::Dim>, ShapeError>
1498    where
1499        E: IntoDimension,
1500    {
1501        let shape = shape.into_dimension();
1502        let (dim, strides) = into_shape(&self.dim, &self.strides, shape)?;
1503        debug_assert_eq!(self.offset, 0);
1504        Ok(TensorBase {
1505            dim,
1506            strides,
1507            buffer: self.buffer,
1508            offset: self.offset,
1509        })
1510    }
1511    /// Flattens the trailing dimensions into a 2 dimensional tensor.
1512    ///
1513    /// The output has shape [d0, d1 * d2 .. * dn].
1514    ///
1515    /// # Errors
1516    /// See [`TensorBase::into_shape()`].
1517    pub fn flatten(self) -> Result<TensorBase<S, Ix2>, ShapeError> {
1518        let dim = flatten(self.shape());
1519        self.into_shape(dim)
1520    }
1521    /// Act like a larger size and/or shape array by *broadcasting* into a larger shape, if possible.
1522    ///
1523    /// See [`ArrayBase::broadcast()`].
1524    pub fn broadcast<E>(&self, dim: E) -> Option<TensorView<T, E::Dim>>
1525    where
1526        E: IntoDimension,
1527    {
1528        let (dim, strides) = broadcast(&self.dim, &self.strides, dim)?;
1529        Some(TensorView {
1530            dim,
1531            strides,
1532            buffer: self.buffer.as_slice(),
1533            offset: self.offset,
1534        })
1535    }
1536    /// Borrows the tensor as a [`TensorView`].
1537    pub fn view(&self) -> TensorView<T, D> {
1538        TensorView {
1539            dim: self.dim.clone(),
1540            strides: self.strides.clone(),
1541            buffer: self.buffer.as_slice(),
1542            offset: self.offset,
1543        }
1544    }
1545    /// Borrows the tensor as a [`TensorViewMut`].
1546    pub fn view_mut(&mut self) -> TensorViewMut<T, D>
1547    where
1548        S: DataMut,
1549    {
1550        TensorViewMut {
1551            dim: self.dim.clone(),
1552            strides: self.strides.clone(),
1553            buffer: self.buffer.as_slice_mut(),
1554            offset: self.offset,
1555        }
1556    }
1557    /// Mutably borrows the tensor as a mutable view if possible.
1558    pub fn get_view_mut(&mut self) -> Option<TensorViewMut<T, D>> {
1559        if self.offset == 0 && self.is_contiguous() {
1560            let buffer = self.buffer.get_slice_mut()?;
1561            Some(TensorViewMut {
1562                dim: self.dim.clone(),
1563                strides: self.strides.clone(),
1564                buffer,
1565                offset: 0,
1566            })
1567        } else {
1568            None
1569        }
1570    }
1571    /// Mutably borrows the tensor as a mutable view.
1572    ///
1573    /// Copies the data into a new tensor if necessary.
1574    ///
1575    /// See [`TensorBase::to_owned()`].
1576    pub fn make_view_mut(&mut self) -> Result<TensorViewMut<T, D>>
1577    where
1578        S: DataOwned,
1579    {
1580        if self.offset == 0 && self.is_contiguous() {
1581            Ok(TensorViewMut {
1582                dim: self.dim.clone(),
1583                strides: self.strides.clone(),
1584                buffer: self.buffer.make_slice_mut()?,
1585                offset: 0,
1586            })
1587        } else {
1588            let tensor = self.to_owned()?;
1589            *self = Self {
1590                dim: tensor.dim,
1591                strides: tensor.strides,
1592                buffer: BufferBase::from_buffer(tensor.buffer),
1593                offset: 0,
1594            };
1595            Ok(TensorViewMut {
1596                dim: self.dim.clone(),
1597                strides: self.strides.clone(),
1598                buffer: self.buffer.get_slice_mut().unwrap(),
1599                offset: 0,
1600            })
1601        }
1602    }
1603    /// Whether the tensor is contiguous.
1604    ///
1605    /// Contiguous is either C (Standard) or Fortran layout.
1606    pub fn is_contiguous(&self) -> bool {
1607        is_contiguous(&self.dim, &self.strides)
1608    }
1609    /// Whether the tensor is standard layout.
1610    ///
1611    /// In standard layout, the strides increase from right to left by the product of each dimension.
1612    pub fn is_standard_layout(&self) -> bool {
1613        is_standard_layout(&self.dim, &self.strides)
1614    }
1615    /// Permute the axes of the tensor.
1616    ///
1617    /// Reorders the dimensions of the tensor, where for each a in `axes`, a is the index of that axis in the new tensor.
1618    ///
1619    /// # Note
1620    /// This operation merely reorders the dimensions / strides and does not copy the data. Combine with [`.into_standard_layout()`](TensorBase::into_standard_layout()) to execute the operation, returning a tensor in standard layout.
1621    ///
1622    /// # Panics
1623    /// Each axis 0 .. ndim must be used exactly once.
1624    pub fn permuted_axes<A>(self, axes: A) -> Self
1625    where
1626        A: IntoDimension<Dim = D>,
1627    {
1628        let (dim, strides) = permuted_axes(self.dim, self.strides, axes.into_dimension());
1629        Self {
1630            dim,
1631            strides,
1632            ..self
1633        }
1634    }
1635    /// Reverses (transposes) the axes of the tensor.
1636    pub fn reversed_axes(mut self) -> Self {
1637        self.dim.slice_mut().reverse();
1638        self.strides.slice_mut().reverse();
1639        self
1640    }
1641    /// Retunrs a view with reversed (transposed) axes.
1642    pub fn t(&self) -> TensorView<T, D> {
1643        self.view().reversed_axes()
1644    }
1645    /// Returns a view restricted to index along the `axis`, with the `axis` removed.
1646    ///
1647    /// # Panics
1648    /// `axis` or `index` is out of bounds.
1649    ///
1650    /// See [`ArrayBase::index_axis()`].
1651    pub fn index_axis(&self, axis: Axis, index: usize) -> TensorView<T, D::Smaller>
1652    where
1653        D: RemoveAxis,
1654    {
1655        self.view().index_axis_into(axis, index)
1656    }
1657    /// Returns a mutable view restricted to index along the `axis`, with the `axis` removed.
1658    ///
1659    /// # Panics
1660    ///  `axis` or `index` is out of bounds.
1661    ///
1662    /// See [`ArrayBase::index_axis_mut()`].
1663    pub fn index_axis_mut(&mut self, axis: Axis, index: usize) -> TensorViewMut<T, D::Smaller>
1664    where
1665        S: DataMut,
1666        D: RemoveAxis,
1667    {
1668        self.view_mut().index_axis_into(axis, index)
1669    }
1670    /// Returns a tensor restricted to index along the `axis`, with the `axis` removed.
1671    ///
1672    /// # Panics
1673    /// `axis` or `index` is out of bounds.
1674    ///
1675    /// See [`.index_axis()`](Self::index_axis).
1676    pub fn index_axis_into(mut self, axis: Axis, index: usize) -> TensorBase<S, D::Smaller>
1677    where
1678        D: RemoveAxis,
1679    {
1680        self.collapse_axis(axis, index);
1681        let dim = self.dim.remove_axis(axis);
1682        let strides = self.strides.remove_axis(axis);
1683        TensorBase {
1684            dim,
1685            strides,
1686            buffer: self.buffer,
1687            offset: self.offset,
1688        }
1689    }
1690    /// Selects `index` along the `axis`, collapsing the axis into length one.
1691    ///
1692    /// # Panics
1693    ///  `axis` or `index` is out of bounds.
1694    ///
1695    /// See [`ArrayBase::collapse_axis()`].
1696    pub fn collapse_axis(&mut self, axis: Axis, index: usize) {
1697        let offset =
1698            collapse_axis(&mut self.dim, &self.strides, axis, index) + self.offset as isize;
1699        debug_assert!(offset >= 0);
1700        let offset = offset as usize;
1701        debug_assert!(offset < self.buffer.len());
1702        self.offset = offset;
1703    }
1704    /// Borrows the tensor as a [`Slice`] if standard layout.
1705    pub fn as_slice(&self) -> Option<Slice<T>> {
1706        if self.is_standard_layout() {
1707            let (slice, _offset) = self.as_raw_slice_offset();
1708            Some(slice)
1709        } else {
1710            None
1711        }
1712    }
1713    /// Borrows the tensor as a [`Slice`] if contiguous.
1714    pub fn as_slice_memory_order(&self) -> Option<Slice<T>> {
1715        if self.is_contiguous() {
1716            let (slice, _offset) = self.as_raw_slice_offset();
1717            Some(slice)
1718        } else {
1719            None
1720        }
1721    }
1722    /// Mutably borrows the tensor as a [`SliceMut`] if standard layout.
1723    pub fn as_slice_mut(&mut self) -> Option<SliceMut<T>>
1724    where
1725        S: DataMut,
1726    {
1727        if self.is_standard_layout() {
1728            let (slice, _offset) = self.as_raw_slice_offset_mut();
1729            Some(slice)
1730        } else {
1731            None
1732        }
1733    }
1734    /// Mutably borrows the tensor as a [`SliceMut`] if contiguous.
1735    pub fn as_slice_memory_order_mut(&mut self) -> Option<SliceMut<T>>
1736    where
1737        S: DataMut,
1738    {
1739        if self.is_contiguous() {
1740            let (slice, _offset) = self.as_raw_slice_offset_mut();
1741            Some(slice)
1742        } else {
1743            None
1744        }
1745    }
1746    /// Borrows the tensor as a slice and offset.
1747    pub fn as_raw_slice_offset(&self) -> (Slice<T>, usize) {
1748        let strides: &[isize] = Self::strides(self);
1749        if let Some(len) = tensor_buffer_len(self.offset, self.shape(), strides) {
1750            let slice = self.buffer.slice(self.offset..self.offset + len).unwrap();
1751            (slice, 0)
1752        } else {
1753            (self.buffer.as_slice(), self.offset)
1754        }
1755    }
1756    /// Mutably borrows the tensor as a mutable slice and offset.
1757    pub fn as_raw_slice_offset_mut(&mut self) -> (SliceMut<T>, usize)
1758    where
1759        S: DataMut,
1760    {
1761        let strides: &[isize] = Self::strides(self);
1762        if let Some(len) = tensor_buffer_len(self.offset, self.shape(), strides) {
1763            let slice = self
1764                .buffer
1765                .slice_mut(self.offset..self.offset + len)
1766                .unwrap();
1767            (slice, 0)
1768        } else {
1769            (self.buffer.as_slice_mut(), self.offset)
1770        }
1771    }
1772    /// Transfers the tensor to the `device`.
1773    ///
1774    /// See [`Tensor::into_device()`].
1775    pub fn to_device(&self, device: Device) -> Result<Tensor<T, D>> {
1776        if self.device() == device {
1777            self.to_owned()
1778        } else {
1779            self.view().into_device(device)
1780        }
1781    }
1782    /// Transfers the tensor to the `device`.
1783    ///
1784    /// See [`Tensor::to_device()`].
1785    pub fn to_device_shared(&self, device: Device) -> Result<ArcTensor<T, D>> {
1786        if self.device() == device {
1787            self.to_shared()
1788        } else {
1789            self.to_device(device).map(Into::into)
1790        }
1791    }
1792    /// Transfers the tensor to the `device` in place.
1793    ///
1794    /// See [`Buffer::to_device_mut()`].
1795    pub fn to_device_mut(&mut self, device: Device) -> Result<()>
1796    where
1797        S: DataOwned,
1798    {
1799        if self.device() == device {
1800            return Ok(());
1801        }
1802        let Tensor {
1803            dim,
1804            strides,
1805            buffer,
1806            offset,
1807        } = self.to_device(device)?;
1808        *self = Self {
1809            dim,
1810            strides,
1811            buffer: BufferBase::from_buffer(buffer),
1812            offset,
1813        };
1814        Ok(())
1815    }
1816    /// Transfers the tensor into the `device`.
1817    ///
1818    /// See [`Buffer::into_device()`].
1819    pub fn into_device(self, device: Device) -> Result<Tensor<T, D>> {
1820        if device == self.device() {
1821            self.into_owned()
1822        } else if !self.is_contiguous() {
1823            self.as_standard_layout()?.to_device(device)
1824        } else {
1825            let buffer = self.buffer.to_device(device)?;
1826            Ok(Tensor {
1827                dim: self.dim,
1828                strides: self.strides,
1829                buffer,
1830                offset: 0,
1831            })
1832        }
1833    }
1834    /// Transfers the tensor into the `device`.
1835    ///
1836    /// See [`ArcBuffer::into_device_shared()`].
1837    pub fn into_device_shared(self, device: Device) -> Result<ArcTensor<T, D>> {
1838        if device == self.device() {
1839            self.into_shared()
1840        } else if !self.is_contiguous() {
1841            self.view()
1842                .into_standard_layout()?
1843                .into_device_shared(device)
1844        } else {
1845            let buffer = self.buffer.to_device_shared(device)?;
1846            Ok(ArcTensor {
1847                dim: self.dim,
1848                strides: self.strides,
1849                buffer,
1850                offset: 0,
1851            })
1852        }
1853    }
1854    /// Converts into a [`Tensor`].
1855    pub fn into_owned(self) -> Result<Tensor<T, D>> {
1856        if !self.is_contiguous() {
1857            return self.into_standard_layout();
1858        }
1859        Ok(TensorBase {
1860            dim: self.dim,
1861            strides: self.strides,
1862            buffer: self.buffer.into_owned()?,
1863            offset: 0,
1864        })
1865    }
1866    /// Converts to a [`Tensor`].
1867    pub fn to_owned(&self) -> Result<Tensor<T, D>> {
1868        self.view().into_owned()
1869    }
1870    /// Converts into an [`ArcTensor`], copying if necessary.
1871    pub fn into_shared(self) -> Result<ArcTensor<T, D>> {
1872        if self.is_contiguous() {
1873            Ok(TensorBase {
1874                dim: self.dim,
1875                strides: self.strides,
1876                buffer: self.buffer.into_shared()?,
1877                offset: self.offset,
1878            })
1879        } else {
1880            self.as_standard_layout()?.into_shared()
1881        }
1882    }
1883    /// Converts to an [`ArcTensor`].
1884    ///
1885    /// Converts to an [`ArcTensor`], copying if necessary.
1886    pub fn to_shared(&self) -> Result<ArcTensor<T, D>> {
1887        if self.is_contiguous() {
1888            Ok(TensorBase {
1889                dim: self.dim.clone(),
1890                strides: self.strides.clone(),
1891                buffer: self.buffer.to_shared()?,
1892                offset: self.offset,
1893            })
1894        } else {
1895            self.to_owned()?.into_shared()
1896        }
1897    }
1898    /// Fills the tensor with `elem`.
1899    ///
1900    /// # Errors
1901    /// Device tensors must be contiguous.
1902    ///
1903    /// See [`BufferBase::fill()`].
1904    pub fn fill(&mut self, elem: T) -> Result<()>
1905    where
1906        S: DataMut,
1907    {
1908        if self.is_contiguous() {
1909            self.buffer.as_slice_mut().fill(elem)
1910        } else if let Some(mut array) = self.as_array_mut() {
1911            array.fill(elem);
1912            Ok(())
1913        } else {
1914            bail!("TensorBase::fill tensor is not contiguous!")
1915        }
1916    }
1917    /// Moves the tensor into an [`Array`].
1918    ///
1919    /// # Errors
1920    /// Device tensors must be contiguous.
1921    ///
1922    /// See [`Buffer::into_vec()`].
1923    pub fn into_array(self) -> Result<Array<T, D>> {
1924        if self.is_contiguous() {
1925            use ndarray::ShapeBuilder;
1926
1927            let vec = self.buffer.into_vec()?;
1928            Ok(Array::from_shape_vec(self.dim.strides(self.strides), vec).unwrap())
1929        } else if let Some(array) = self.as_array() {
1930            Ok(array.into_owned())
1931        } else {
1932            bail!("TensorBase::into_array tensor is not contiguous!")
1933        }
1934    }
1935    /// Borrows the tensor as an array view if on the host.
1936    pub fn as_array(&self) -> Option<ArrayView<T, D>> {
1937        use ndarray::ShapeBuilder;
1938
1939        self.buffer.as_host_slice().map(|host_slice| unsafe {
1940            ArrayView::from_shape_ptr(
1941                self.dim.clone().strides(self.strides.clone()),
1942                &host_slice[self.offset] as *const T,
1943            )
1944        })
1945    }
1946    /// Mutably borrows the tensor as an a mutable array view if on the host.
1947    pub fn as_array_mut(&mut self) -> Option<ArrayViewMut<T, D>>
1948    where
1949        S: DataMut,
1950    {
1951        use ndarray::ShapeBuilder;
1952
1953        if let Some(host_slice) = self.buffer.as_host_slice_mut() {
1954            let host_slice = unsafe {
1955                std::slice::from_raw_parts_mut(host_slice.as_mut_ptr(), host_slice.len())
1956            };
1957            Some(unsafe {
1958                ArrayViewMut::from_shape_ptr(
1959                    self.dim.clone().strides(self.strides.clone()),
1960                    host_slice[self.offset..].as_mut_ptr(),
1961                )
1962            })
1963        } else {
1964            None
1965        }
1966    }
1967}
1968
1969impl<T: Scalar, D: Dimension> Tensor<T, D> {
1970    /// Converts to a scalar tensor.
1971    pub fn into_scalar_tensor(self) -> ScalarTensor<D> {
1972        self.into()
1973    }
1974}
1975
1976impl<'a, T: Scalar, D: Dimension> CowTensor<'a, T, D> {
1977    /// Converts to a scalar cow tensor.
1978    pub fn into_scalar_cow_tensor(self) -> ScalarCowTensor<'a, D> {
1979        self.into()
1980    }
1981}
1982
1983impl<T: Scalar, D: Dimension> ArcTensor<T, D> {
1984    /// Act like a larger size and/or shape array by *broadcasting* into a larger shape, if possible.
1985    ///
1986    /// See [`TensorBase::broadcast()`].
1987    pub fn broadcast_shared<E>(&self, dim: E) -> Option<ArcTensor<T, E::Dim>>
1988    where
1989        E: IntoDimension,
1990    {
1991        let (dim, strides) = broadcast(&self.dim, &self.strides, dim)?;
1992        Some(ArcTensor {
1993            dim,
1994            strides,
1995            buffer: self.buffer.clone(),
1996            offset: self.offset,
1997        })
1998    }
1999}
2000
2001impl<T: Scalar, S: DataOwned<Elem = T>> From<Buffer<T>> for TensorBase<S, Ix1> {
2002    fn from(buffer: Buffer<T>) -> Self {
2003        let dim = buffer.len().into_dimension();
2004        let strides = dim.default_strides();
2005        let buffer = BufferBase::from_buffer(buffer);
2006        Self {
2007            dim,
2008            strides,
2009            buffer,
2010            offset: 0,
2011        }
2012    }
2013}
2014
2015impl<T: Scalar, S: DataOwned<Elem = T>> From<Vec<T>> for TensorBase<S, Ix1> {
2016    fn from(vec: Vec<T>) -> Self {
2017        let dim = vec.len().into_dimension();
2018        let strides = dim.default_strides();
2019        let buffer = BufferBase::from_buffer(Buffer::from(vec));
2020        Self {
2021            dim,
2022            strides,
2023            buffer,
2024            offset: 0,
2025        }
2026    }
2027}
2028
2029impl<'a, T: Scalar> From<Slice<'a, T>> for TensorView<'a, T, Ix1> {
2030    fn from(slice: Slice<'a, T>) -> Self {
2031        let dim = slice.len().into_dimension();
2032        let strides = dim.default_strides();
2033        Self {
2034            dim,
2035            strides,
2036            buffer: slice,
2037            offset: 0,
2038        }
2039    }
2040}
2041
2042impl<'a, T: Scalar> From<SliceMut<'a, T>> for TensorViewMut<'a, T, Ix1> {
2043    fn from(slice: SliceMut<'a, T>) -> Self {
2044        let dim = slice.len().into_dimension();
2045        let strides = dim.default_strides();
2046        Self {
2047            dim,
2048            strides,
2049            buffer: slice,
2050            offset: 0,
2051        }
2052    }
2053}
2054
2055impl<T: Scalar, S: DataOwned<Elem = T>, D: Dimension> From<Array<T, D>> for TensorBase<S, D> {
2056    fn from(array: Array<T, D>) -> Self {
2057        let dim = array.raw_dim();
2058        let strides = strides_from_array(&array);
2059        let buffer = BufferBase::from_vec(array.into_raw_vec());
2060        Self {
2061            dim,
2062            strides,
2063            buffer,
2064            offset: 0,
2065        }
2066    }
2067}
2068
2069impl<'a, T: Scalar, D: Dimension> From<ArrayView<'a, T, D>> for CowTensor<'a, T, D> {
2070    fn from(array: ArrayView<'a, T, D>) -> Self {
2071        if let Some(slice) = array.to_slice_memory_order() {
2072            let dim = array.raw_dim();
2073            let strides = strides_from_array(&array);
2074            let buffer = Slice::from(slice).into();
2075            Self {
2076                dim,
2077                strides,
2078                buffer,
2079                offset: 0,
2080            }
2081        } else {
2082            Self::from(array.to_owned())
2083        }
2084    }
2085}
2086
2087impl<'a, T: Scalar, D: Dimension> TryFrom<ArrayView<'a, T, D>> for TensorView<'a, T, D> {
2088    type Error = anyhow::Error;
2089    /// # Errors
2090    /// The `array` is not contiguous.
2091    fn try_from(array: ArrayView<'a, T, D>) -> Result<Self> {
2092        let slice = array
2093            .as_slice_memory_order()
2094            .ok_or_else(|| anyhow!("Not contiguous!"))?;
2095        // We want to return 'a, not a new borrow.
2096        let slice = unsafe { std::slice::from_raw_parts(slice.as_ptr(), slice.len()) };
2097        let dim = array.raw_dim();
2098        let strides = strides_from_array(&array);
2099        Ok(Self {
2100            dim,
2101            strides,
2102            buffer: slice.into(),
2103            offset: 0,
2104        })
2105    }
2106}
2107
2108impl<'a, T: Scalar, D: Dimension> From<TensorView<'a, T, D>> for CowTensor<'a, T, D> {
2109    fn from(view: TensorView<'a, T, D>) -> Self {
2110        Self {
2111            dim: view.dim,
2112            strides: view.strides,
2113            buffer: view.buffer.into(),
2114            offset: view.offset,
2115        }
2116    }
2117}
2118
2119impl<T: Scalar, D: Dimension> From<Tensor<T, D>> for CowTensor<'_, T, D> {
2120    fn from(tensor: Tensor<T, D>) -> Self {
2121        Self {
2122            dim: tensor.dim,
2123            strides: tensor.strides,
2124            buffer: tensor.buffer.into(),
2125            offset: tensor.offset,
2126        }
2127    }
2128}
2129
2130impl<T: Scalar, D: Dimension> From<Tensor<T, D>> for ArcTensor<T, D> {
2131    fn from(tensor: Tensor<T, D>) -> Self {
2132        Self {
2133            dim: tensor.dim,
2134            strides: tensor.strides,
2135            buffer: tensor.buffer.into(),
2136            offset: tensor.offset,
2137        }
2138    }
2139}
2140
2141impl<S: Data, D: Dimension> Debug for TensorBase<S, D> {
2142    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2143        ScalarTensorView::from(self.view()).fmt(f)
2144    }
2145}
2146
2147/// Casts
2148impl<T: Scalar, S: Data<Elem = T>, D: Dimension> TensorBase<S, D> {
2149    /// Casts the tensor into a new tensor.
2150    ///
2151    /// See [`BufferBase::cast_into()`].
2152    pub fn cast_into<Y: Scalar>(self) -> Result<Tensor<Y, D>> {
2153        if T::SCALAR_TYPE == Y::SCALAR_TYPE && self.is_contiguous() {
2154            Ok(TensorBase {
2155                dim: self.dim,
2156                strides: self.strides,
2157                buffer: self.buffer.cast_into()?,
2158                offset: 0,
2159            })
2160        } else {
2161            self.cast()
2162        }
2163    }
2164    /// Casts the tensor to a new tensor.
2165    ///
2166    /// See [`BufferBase::cast()`].
2167    pub fn cast<Y: Scalar>(&self) -> Result<Tensor<Y, D>> {
2168        if !self.is_contiguous() {
2169            return self.scaled_cast(Y::one());
2170        }
2171        Ok(TensorBase {
2172            dim: self.dim.clone(),
2173            strides: self.strides.clone(),
2174            buffer: self.buffer.cast()?,
2175            offset: 0,
2176        })
2177    }
2178}
2179
2180#[cfg(feature = "serde")]
2181#[derive(Serialize, Deserialize)]
2182#[serde(bound(
2183    serialize = "S: Data, D: Dimension + Serialize",
2184    deserialize = "S: DataOwned, D: Dimension + Deserialize<'de>"
2185))]
2186#[serde(rename = "Tensor")]
2187struct TensorSerde<S: Data, D: Dimension> {
2188    dim: D,
2189    buffer: BufferBase<S>,
2190}
2191
2192#[cfg(feature = "serde")]
2193impl<S1: Data, D: Dimension + Serialize> Serialize for TensorBase<S1, D> {
2194    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
2195    where
2196        S: Serializer,
2197    {
2198        use serde::ser::Error;
2199        let buffer = if let Some(slice) = self.as_slice() {
2200            CowBuffer::from(slice)
2201        } else {
2202            self.to_device(Device::host())
2203                .map_err(S::Error::custom)?
2204                .buffer
2205                .into()
2206        };
2207        TensorSerde {
2208            dim: self.dim.clone(),
2209            buffer,
2210        }
2211        .serialize(serializer)
2212    }
2213}
2214
2215#[cfg(feature = "serde")]
2216impl<'de, S: DataOwned, D1: Dimension + Deserialize<'de>> Deserialize<'de> for TensorBase<S, D1> {
2217    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
2218    where
2219        D: Deserializer<'de>,
2220    {
2221        use serde::de::Error;
2222        let TensorSerde { dim, buffer } =
2223            TensorSerde::<BufferRepr<S::Elem>, D1>::deserialize(deserializer)?;
2224        TensorBase::from(buffer)
2225            .into_shape(dim)
2226            .map_err(D::Error::custom)
2227    }
2228}
2229
2230#[cfg(all(test, feature = "serde"))]
2231mod tests {
2232    use super::*;
2233    use serde_test::{assert_tokens, Token};
2234
2235    #[test]
2236    fn tensor_serde() {
2237        let data = vec![1u32, 2, 3, 4];
2238        let items: Vec<u64> = bytemuck::cast_slice(data.as_slice()).to_vec();
2239        let tensor = Tensor::from(Buffer::from(data));
2240        let tokens = [
2241            Token::Struct {
2242                name: "Tensor",
2243                len: 2,
2244            },
2245            Token::Str("dim"),
2246            Token::Tuple { len: 1 },
2247            Token::U64(4),
2248            Token::TupleEnd,
2249            Token::Str("buffer"),
2250            Token::TupleStruct {
2251                name: "Buffer",
2252                len: 3,
2253            },
2254            Token::Str("U32"),
2255            Token::U64(4),
2256            Token::Seq { len: Some(2) },
2257            Token::U64(items[0].to_be()),
2258            Token::U64(items[1].to_be()),
2259            Token::SeqEnd,
2260            Token::TupleStructEnd,
2261            Token::StructEnd,
2262        ];
2263
2264        #[derive(Debug, Serialize, Deserialize)]
2265        #[serde(transparent)]
2266        struct TensorWrap(Tensor1<u32>);
2267
2268        impl PartialEq for TensorWrap {
2269            fn eq(&self, other: &Self) -> bool {
2270                self.0.as_array().unwrap() == other.0.as_array().unwrap()
2271            }
2272        }
2273
2274        impl Eq for TensorWrap {}
2275
2276        assert_tokens(&TensorWrap(tensor), &tokens);
2277    }
2278}