rten_tensor/
tensor.rs

1use std::borrow::Cow;
2use std::fmt::Debug;
3use std::mem::MaybeUninit;
4use std::ops::{Index, IndexMut, Range};
5use std::sync::Arc;
6
7use crate::assume_init::AssumeInit;
8use crate::copy::{
9    copy_into, copy_into_slice, copy_into_uninit, copy_range_into_slice, map_into_slice,
10};
11use crate::errors::{DimensionError, ExpandError, FromDataError, ReshapeError, SliceError};
12use crate::iterators::{
13    AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, InnerIter, InnerIterMut, Iter, IterMut,
14    Lanes, LanesMut, for_each_mut,
15};
16use crate::layout::{
17    AsIndex, BroadcastLayout, DynLayout, IntoLayout, Layout, LayoutExt, MatrixLayout, MutLayout,
18    NdLayout, OverlapPolicy, RemoveDim, ResizeLayout, SliceWith, TrustedLayout,
19};
20use crate::overlap::may_have_internal_overlap;
21use crate::slice_range::{IntoSliceItems, SliceItem};
22use crate::storage::{
23    Alloc, CowData, GlobalAlloc, IntoStorage, Storage, StorageMut, ViewData, ViewMutData,
24};
25use crate::type_num::IndexCount;
26use crate::{Contiguous, RandomSource};
27
28/// The base type for multi-dimensional arrays. This consists of storage for
29/// elements, plus a _layout_ which maps from a multi-dimensional array index
30/// to a storage offset. This base type is not normally used directly but
31/// instead through a type alias which selects the storage type and layout.
32///
33/// The storage can be owned (like a `Vec<T>`), borrowed (like `&[T]`) or
34/// mutably borrowed (like `&mut [T]`). The layout can have a dimension count
35/// that is determined statically (ie. forms part of the tensor's type), see
36/// [`NdLayout`] or is only known at runtime, see [`DynLayout`].
37pub struct TensorBase<S: Storage, L: Layout> {
38    data: S,
39
40    // Layout mapping N-dimensional indices to offsets in `data`.
41    //
42    // Constructors must ensure:
43    //
44    // - Every index that is valid for `layout` must map to an offset that is
45    //   less than `data.len()`. The minimum length for a layout is given by
46    //   `Layout::min_data_len`.
47    // - If `S` is a mutable storage type, no two indices of `layout` can map to
48    //   the same offset. See the `may_have_internal_overlap` function.
49    layout: L,
50}
51
52/// Trait implemented by all variants of [`TensorBase`], which provides a
53/// `view` method to get an immutable view of the tensor, plus methods which
54/// forward to such a view.
55///
56/// The purpose of this trait is to allow methods to be specialized for
57/// immutable views by preserving the lifetime of the underlying data in
58/// return types (eg. `iter` returns `&[T]` in the trait, but `&'a [T]` in
59/// the view). This allows for chaining operations on views together (eg.
60/// `tensor.slice(...).transpose()`) without needing to separate each step
61/// into separate statements.
62///
63/// This trait is conceptually similar to the way [`std::ops::Deref`] in the Rust
64/// standard library allows a `Vec<T>` to have all the methods of an `&[T]`.
65///
66/// If stable Rust gains support for specialization or a `Deref` trait that can
67/// return non-references (see <https://github.com/rust-lang/rfcs/issues/997>)
68/// this will become unnecessary.
69pub trait AsView: Layout {
70    /// Type of element stored in this tensor.
71    type Elem;
72
73    /// The underlying layout of this tensor. It must have the same index
74    /// type (eg. `[usize; N]` or `&[usize]`) as this view.
75    type Layout: Clone + for<'a> Layout<Index<'a> = Self::Index<'a>>;
76
77    /// Return a borrowed view of this tensor.
78    fn view(&self) -> TensorBase<ViewData<'_, Self::Elem>, Self::Layout>;
79
80    /// Return the layout of this tensor.
81    fn layout(&self) -> &Self::Layout;
82
83    /// Return a view of this tensor using a borrowed [`CowData`] for storage.
84    ///
85    /// Together with [`into_cow`](TensorBase::into_cow), this is useful where
86    /// code needs to conditionally copy or create a new tensor, and get either
87    /// the borrowed or owned tensor into the same type.
88    fn as_cow(&self) -> TensorBase<CowData<'_, Self::Elem>, Self::Layout>
89    where
90        [Self::Elem]: ToOwned,
91    {
92        self.view().as_cow()
93    }
94
95    /// Return a view of this tensor with a dynamic rank.
96    fn as_dyn(&self) -> TensorBase<ViewData<'_, Self::Elem>, DynLayout> {
97        self.view().as_dyn()
98    }
99
100    /// Return an iterator over slices of this tensor along a given axis.
101    fn axis_chunks(&self, dim: usize, chunk_size: usize) -> AxisChunks<'_, Self::Elem, Self::Layout>
102    where
103        Self::Layout: MutLayout,
104    {
105        self.view().axis_chunks(dim, chunk_size)
106    }
107
108    /// Return an iterator over slices of this tensor along a given axis.
109    fn axis_iter(&self, dim: usize) -> AxisIter<'_, Self::Elem, Self::Layout>
110    where
111        Self::Layout: MutLayout + RemoveDim,
112    {
113        self.view().axis_iter(dim)
114    }
115
116    /// Broadcast this view to another shape.
117    ///
118    /// If `shape` is an array (`[usize; N]`), the result will have a
119    /// static-rank layout with `N` dims. If `shape` is a slice, the result will
120    /// have a dynamic-rank layout.
121    fn broadcast<S: IntoLayout>(&self, shape: S) -> TensorBase<ViewData<'_, Self::Elem>, S::Layout>
122    where
123        Self::Layout: BroadcastLayout<S::Layout>,
124    {
125        self.view().broadcast(shape)
126    }
127
128    /// Fallible variant of [`broadcast`](AsView::broadcast).
129    fn try_broadcast<S: IntoLayout>(
130        &self,
131        shape: S,
132    ) -> Result<TensorBase<ViewData<'_, Self::Elem>, S::Layout>, ExpandError>
133    where
134        Self::Layout: BroadcastLayout<S::Layout>,
135    {
136        self.view().try_broadcast(shape)
137    }
138
139    /// Copy elements from this tensor into `dest` in logical order.
140    ///
141    /// Returns the initialized slice. Panics if the length of `dest` does
142    /// not match the number of elements in `self`.
143    fn copy_into_slice<'a>(&self, dest: &'a mut [MaybeUninit<Self::Elem>]) -> &'a [Self::Elem]
144    where
145        Self::Elem: Copy;
146
147    /// Return the layout of this tensor as a slice, if it is contiguous.
148    fn data(&self) -> Option<&[Self::Elem]>;
149
150    /// Return a reference to the element at a given index, or `None` if the
151    /// index is invalid.
152    fn get<I: AsIndex<Self::Layout>>(&self, index: I) -> Option<&Self::Elem>
153    where
154        Self::Layout: TrustedLayout,
155    {
156        self.view().get(index)
157    }
158
159    /// Return a reference to the element at a given index, without performing
160    /// bounds checks.
161    ///
162    /// # Safety
163    ///
164    /// The caller must ensure that the index is valid for the tensor's shape.
165    unsafe fn get_unchecked<I: AsIndex<Self::Layout>>(&self, index: I) -> &Self::Elem {
166        let view = self.view();
167        unsafe { view.get_unchecked(index) }
168    }
169
170    /// Index the tensor along a given axis.
171    ///
172    /// Returns a view with one dimension removed.
173    ///
174    /// Panics if `axis >= self.ndim()` or `index >= self.size(axis)`.
175    fn index_axis(
176        &self,
177        axis: usize,
178        index: usize,
179    ) -> TensorBase<ViewData<'_, Self::Elem>, <Self::Layout as RemoveDim>::Output>
180    where
181        Self::Layout: MutLayout + RemoveDim,
182    {
183        self.view().index_axis(axis, index)
184    }
185
186    /// Return an iterator over the innermost N dimensions.
187    fn inner_iter<const N: usize>(&self) -> InnerIter<'_, Self::Elem, NdLayout<N>> {
188        self.view().inner_iter()
189    }
190
191    /// Return an iterator over the innermost `n` dimensions.
192    ///
193    /// Prefer [`inner_iter`](AsView::inner_iter) if `N` is known at compile time.
194    fn inner_iter_dyn(&self, n: usize) -> InnerIter<'_, Self::Elem, DynLayout> {
195        self.view().inner_iter_dyn(n)
196    }
197
198    /// Insert a size-1 axis at the given index.
199    fn insert_axis(&mut self, index: usize)
200    where
201        Self::Layout: ResizeLayout;
202
203    /// Remove a size-1 axis at the given index.
204    ///
205    /// This will panic if the index is out of bounds or the size of the index
206    /// is not 1.
207    fn remove_axis(&mut self, index: usize)
208    where
209        Self::Layout: ResizeLayout;
210
211    /// Return the scalar value in this tensor if it has 0 dimensions.
212    fn item(&self) -> Option<&Self::Elem> {
213        self.view().item()
214    }
215
216    /// Return an iterator over elements in this tensor in their logical order.
217    fn iter(&self) -> Iter<'_, Self::Elem>;
218
219    /// Return an iterator over 1D slices of this tensor along a given axis.
220    fn lanes(&self, dim: usize) -> Lanes<'_, Self::Elem>
221    where
222        Self::Layout: RemoveDim,
223    {
224        self.view().lanes(dim)
225    }
226
227    /// Return a new tensor with the same shape, formed by applying `f` to each
228    /// element in this tensor.
229    fn map<F, U>(&self, f: F) -> TensorBase<Vec<U>, Self::Layout>
230    where
231        F: Fn(&Self::Elem) -> U,
232        Self::Layout: MutLayout,
233    {
234        self.view().map(f)
235    }
236
237    /// Variant of [`map`](AsView::map) which takes an allocator.
238    fn map_in<A: Alloc, F, U>(&self, alloc: A, f: F) -> TensorBase<Vec<U>, Self::Layout>
239    where
240        F: Fn(&Self::Elem) -> U,
241        Self::Layout: MutLayout,
242    {
243        self.view().map_in(alloc, f)
244    }
245
246    /// Merge consecutive dimensions to the extent possible without copying
247    /// data or changing the iteration order.
248    ///
249    /// If the tensor is contiguous, this has the effect of flattening the
250    /// tensor into a vector.
251    fn merge_axes(&mut self)
252    where
253        Self::Layout: ResizeLayout;
254
255    /// Re-order the axes of this tensor to move the axis at index `from` to
256    /// `to`.
257    ///
258    /// Panics if `from` or `to` is >= `self.ndim()`.
259    fn move_axis(&mut self, from: usize, to: usize)
260    where
261        Self::Layout: MutLayout;
262
263    /// Convert this tensor to one with the same shape but a static dimension
264    /// count.
265    ///
266    /// Panics if `self.ndim() != N`.
267    fn nd_view<const N: usize>(&self) -> TensorBase<ViewData<'_, Self::Elem>, NdLayout<N>> {
268        self.view().nd_view()
269    }
270
271    /// Permute the dimensions of this tensor.
272    fn permute(&mut self, order: Self::Index<'_>)
273    where
274        Self::Layout: MutLayout;
275
276    /// Return a view with dimensions permuted in the order given by `dims`.
277    fn permuted(&self, order: Self::Index<'_>) -> TensorBase<ViewData<'_, Self::Elem>, Self::Layout>
278    where
279        Self::Layout: MutLayout,
280    {
281        self.view().permuted(order)
282    }
283
284    /// Return either a view or a copy of `self` with the given shape.
285    ///
286    /// The new shape must have the same number of elments as the current
287    /// shape. The result will have a static rank if `shape` is an array or
288    /// a dynamic rank if it is a slice.
289    ///
290    /// If `self` is contiguous this will return a view, as changing the shape
291    /// can be done without moving data. Otherwise it will copy elements into
292    /// a new tensor.
293    ///
294    /// # Panics
295    ///
296    /// Panics if the number of elements in the new shape does not match the
297    /// current shape.
298    fn reshaped<S: Copy + IntoLayout>(
299        &self,
300        shape: S,
301    ) -> TensorBase<CowData<'_, Self::Elem>, S::Layout>
302    where
303        Self::Elem: Clone,
304        Self::Layout: MutLayout,
305    {
306        self.view().reshaped(shape)
307    }
308
309    /// A variant of [`reshaped`](AsView::reshaped) that allows specifying the
310    /// allocator to use if a copy is needed.
311    fn reshaped_in<A: Alloc, S: Copy + IntoLayout>(
312        &self,
313        alloc: A,
314        shape: S,
315    ) -> TensorBase<CowData<'_, Self::Elem>, S::Layout>
316    where
317        Self::Elem: Clone,
318        Self::Layout: MutLayout,
319    {
320        self.view().reshaped_in(alloc, shape)
321    }
322
323    /// Reverse the order of dimensions in this tensor.
324    fn transpose(&mut self)
325    where
326        Self::Layout: MutLayout;
327
328    /// Return a view with the order of dimensions reversed.
329    fn transposed(&self) -> TensorBase<ViewData<'_, Self::Elem>, Self::Layout>
330    where
331        Self::Layout: MutLayout,
332    {
333        self.view().transposed()
334    }
335
336    /// Slice this tensor and return a view.
337    ///
338    /// If both this tensor's layout and the range have a statically-known
339    /// number of index terms, the result will have a static rank. Otherwise it
340    /// will have a dynamic rank.
341    ///
342    /// ```
343    /// use rten_tensor::prelude::*;
344    /// use rten_tensor::NdTensor;
345    ///
346    /// let x = NdTensor::from([[1, 2], [3, 4]]);
347    /// let col = x.slice((.., 1)); // `col` is an `NdTensorView<i32, 1>`
348    /// assert_eq!(col.shape(), [2usize]);
349    /// assert_eq!(col.to_vec(), [2, 4]);
350    /// ```
351    #[allow(clippy::type_complexity)]
352    fn slice<R: IntoSliceItems + IndexCount>(
353        &self,
354        range: R,
355    ) -> TensorBase<ViewData<'_, Self::Elem>, <Self::Layout as SliceWith<R, R::Count>>::Layout>
356    where
357        Self::Layout: SliceWith<R, R::Count>,
358    {
359        self.view().slice(range)
360    }
361
362    /// Slice this tensor along a given axis.
363    fn slice_axis(
364        &self,
365        axis: usize,
366        range: Range<usize>,
367    ) -> TensorBase<ViewData<'_, Self::Elem>, Self::Layout>
368    where
369        Self::Layout: MutLayout,
370    {
371        self.view().slice_axis(axis, range)
372    }
373
374    /// A variant of [`slice`](Self::slice) that returns a result
375    /// instead of panicking.
376    #[allow(clippy::type_complexity)]
377    fn try_slice<R: IntoSliceItems + IndexCount>(
378        &self,
379        range: R,
380    ) -> Result<
381        TensorBase<ViewData<'_, Self::Elem>, <Self::Layout as SliceWith<R, R::Count>>::Layout>,
382        SliceError,
383    >
384    where
385        Self::Layout: SliceWith<R, R::Count>,
386    {
387        self.view().try_slice(range)
388    }
389
390    /// Return a slice of this tensor as an owned tensor.
391    ///
392    /// This is more expensive than [`slice`](AsView::slice) as it copies the
393    /// data, but is more flexible as it supports ranges with negative steps.
394    #[allow(clippy::type_complexity)]
395    fn slice_copy<R: Clone + IntoSliceItems + IndexCount>(
396        &self,
397        range: R,
398    ) -> TensorBase<Vec<Self::Elem>, <Self::Layout as SliceWith<R, R::Count>>::Layout>
399    where
400        Self::Elem: Clone,
401        Self::Layout: SliceWith<
402                R,
403                R::Count,
404                Layout: for<'a> Layout<Index<'a>: TryFrom<&'a [usize], Error: Debug>>,
405            >,
406    {
407        self.slice_copy_in(GlobalAlloc::new(), range)
408    }
409
410    /// Variant of [`slice_copy`](AsView::slice_copy) which takes an allocator.
411    #[allow(clippy::type_complexity)]
412    fn slice_copy_in<A: Alloc, R: Clone + IntoSliceItems + IndexCount>(
413        &self,
414        pool: A,
415        range: R,
416    ) -> TensorBase<Vec<Self::Elem>, <Self::Layout as SliceWith<R, R::Count>>::Layout>
417    where
418        Self::Elem: Clone,
419        Self::Layout: SliceWith<
420                R,
421                R::Count,
422                Layout: for<'a> Layout<Index<'a>: TryFrom<&'a [usize], Error: Debug>>,
423            >,
424    {
425        // Fast path for slice ranges supported by `Tensor::slice`. This includes
426        // all ranges except those with a negative step. This benefits from
427        // optimizations that `Tensor::to_tensor` has for slices that are already
428        // contiguous or have a small number of dims.
429        if let Ok(slice_view) = self.try_slice(range.clone()) {
430            return slice_view.to_tensor_in(pool);
431        }
432
433        let items = range.into_slice_items();
434        let sliced_shape: Vec<_> = items
435            .as_ref()
436            .iter()
437            .copied()
438            .enumerate()
439            .filter_map(|(dim, item)| match item {
440                SliceItem::Index(_) => None,
441                SliceItem::Range(range) => Some(range.index_range(self.size(dim)).steps()),
442            })
443            .collect();
444        let sliced_len = sliced_shape.iter().product();
445        let mut sliced_data = pool.alloc(sliced_len);
446
447        copy_range_into_slice(
448            self.as_dyn(),
449            &mut sliced_data.spare_capacity_mut()[..sliced_len],
450            items.as_ref(),
451        );
452
453        // Safety: `copy_range_into_slice` initialized `sliced_len` elements.
454        unsafe {
455            sliced_data.set_len(sliced_len);
456        }
457
458        let sliced_shape = sliced_shape.as_slice().try_into().expect("slice failed");
459
460        TensorBase::from_data(sliced_shape, sliced_data)
461    }
462
463    /// Return a view of this tensor with all dimensions of size 1 removed.
464    fn squeezed(&self) -> TensorView<'_, Self::Elem>
465    where
466        Self::Layout: MutLayout,
467    {
468        self.view().squeezed()
469    }
470
471    /// Return a vector containing the elements of this tensor in their logical
472    /// order, ie. as if the tensor were flattened into one dimension.
473    fn to_vec(&self) -> Vec<Self::Elem>
474    where
475        Self::Elem: Clone;
476
477    /// Variant of [`to_vec`](AsView::to_vec) which takes an allocator.
478    fn to_vec_in<A: Alloc>(&self, alloc: A) -> Vec<Self::Elem>
479    where
480        Self::Elem: Clone;
481
482    /// Return a tensor with the same shape as this tensor/view but with the
483    /// data contiguous in memory and arranged in the same order as the
484    /// logical/iteration order (used by `iter`).
485    ///
486    /// This will return a view if the data is already contiguous or copy
487    /// data into a new buffer otherwise.
488    ///
489    /// Certain operations require or are faster with contiguous tensors.
490    fn to_contiguous(&self) -> Contiguous<TensorBase<CowData<'_, Self::Elem>, Self::Layout>>
491    where
492        Self::Elem: Clone,
493        Self::Layout: MutLayout,
494    {
495        self.view().to_contiguous()
496    }
497
498    /// Variant of [`to_contiguous`](AsView::to_contiguous) which takes an
499    /// allocator.
500    fn to_contiguous_in<A: Alloc>(
501        &self,
502        alloc: A,
503    ) -> Contiguous<TensorBase<CowData<'_, Self::Elem>, Self::Layout>>
504    where
505        Self::Elem: Clone,
506        Self::Layout: MutLayout,
507    {
508        self.view().to_contiguous_in(alloc)
509    }
510
511    /// Return a copy of this tensor with a given shape.
512    fn to_shape<S: IntoLayout>(&self, shape: S) -> TensorBase<Vec<Self::Elem>, S::Layout>
513    where
514        Self::Elem: Clone,
515        Self::Layout: MutLayout;
516
517    /// Return a slice containing the elements of this tensor in their logical
518    /// order, ie. as if the tensor were flattened into one dimension.
519    ///
520    /// Unlike [`data`](AsView::data) this will copy the elements if they are
521    /// not contiguous. Unlike [`to_vec`](AsView::to_vec) this will not copy
522    /// the elements if the tensor is already contiguous.
523    fn to_slice(&self) -> Cow<'_, [Self::Elem]>
524    where
525        Self::Elem: Clone,
526    {
527        self.view().to_slice()
528    }
529
530    /// Return a copy of this tensor/view which uniquely owns its elements.
531    fn to_tensor(&self) -> TensorBase<Vec<Self::Elem>, Self::Layout>
532    where
533        Self::Elem: Clone,
534        Self::Layout: MutLayout,
535    {
536        self.to_tensor_in(GlobalAlloc::new())
537    }
538
539    /// Variant of [`to_tensor`](AsView::to_tensor) which takes an allocator.
540    fn to_tensor_in<A: Alloc>(&self, alloc: A) -> TensorBase<Vec<Self::Elem>, Self::Layout>
541    where
542        Self::Elem: Clone,
543        Self::Layout: MutLayout,
544    {
545        TensorBase::from_data(self.layout().shape(), self.to_vec_in(alloc))
546    }
547
548    /// Return a view which performs "weak" checking when indexing via
549    /// `view[<index>]`. See [`WeaklyCheckedView`] for an explanation.
550    fn weakly_checked_view(&self) -> WeaklyCheckedView<ViewData<'_, Self::Elem>, Self::Layout> {
551        self.view().weakly_checked_view()
552    }
553}
554
555impl<S: Storage, L: Layout> TensorBase<S, L> {
556    /// Construct a new tensor from a given shape and storage.
557    ///
558    /// Panics if the data length does not match the product of `shape`.
559    #[track_caller]
560    pub fn from_data<D: IntoStorage<Output = S>>(shape: L::Index<'_>, data: D) -> TensorBase<S, L>
561    where
562        for<'a> L::Index<'a>: Clone,
563        L: MutLayout,
564    {
565        let data = data.into_storage();
566        let len = data.len();
567        match Self::try_from_data(shape.clone(), data) {
568            Ok(data) => data,
569            Err(_) => panic!(
570                "data length {} does not match shape {:?}",
571                len,
572                shape.as_ref()
573            ),
574        }
575    }
576
577    /// Construct a new tensor from a given shape and storage.
578    ///
579    /// This will fail if the data length does not match the product of `shape`.
580    pub fn try_from_data<D: IntoStorage<Output = S>>(
581        shape: L::Index<'_>,
582        data: D,
583    ) -> Result<TensorBase<S, L>, FromDataError>
584    where
585        L: MutLayout,
586    {
587        let data = data.into_storage();
588        let layout = L::from_shape(shape);
589        if layout.min_data_len() != data.len() {
590            return Err(FromDataError::StorageLengthMismatch);
591        }
592        Ok(TensorBase { data, layout })
593    }
594
595    /// Create a tensor from a pre-created storage and layout.
596    ///
597    /// Panics if the storage length is too short for the layout, or the storage
598    /// is mutable and the layout may map multiple indices to the same offset.
599    pub fn from_storage_and_layout(data: S, layout: L) -> TensorBase<S, L> {
600        assert!(data.len() >= layout.min_data_len());
601        assert!(
602            !S::MUTABLE
603                || !may_have_internal_overlap(layout.shape().as_ref(), layout.strides().as_ref())
604        );
605        TensorBase { data, layout }
606    }
607
608    /// Create a tensor from a pre-created storage and layout.
609    ///
610    /// # Safety
611    ///
612    /// Caller must ensure storage length is sufficient for the layout, and
613    /// that, if the storage is mutable, no two indices in the layout map to the
614    /// same offset.
615    pub(crate) unsafe fn from_storage_and_layout_unchecked(data: S, layout: L) -> TensorBase<S, L> {
616        debug_assert!(data.len() >= layout.min_data_len());
617        debug_assert!(
618            !S::MUTABLE
619                || !may_have_internal_overlap(layout.shape().as_ref(), layout.strides().as_ref())
620        );
621        TensorBase { data, layout }
622    }
623
624    /// Construct a new tensor from a given shape and storage, and custom
625    /// strides.
626    ///
627    /// This will fail if the data length is incorrect for the shape and stride
628    /// combination, or if the strides lead to overlap (see [`OverlapPolicy`]).
629    /// See also [`TensorBase::from_slice_with_strides`] which is a similar method
630    /// for immutable views that does allow overlapping strides.
631    pub fn from_data_with_strides<D: IntoStorage<Output = S>>(
632        shape: L::Index<'_>,
633        data: D,
634        strides: L::Index<'_>,
635    ) -> Result<TensorBase<S, L>, FromDataError>
636    where
637        L: MutLayout,
638    {
639        let layout = L::from_shape_and_strides(shape, strides, OverlapPolicy::DisallowOverlap)?;
640        let data = data.into_storage();
641        if layout.min_data_len() > data.len() {
642            return Err(FromDataError::StorageTooShort);
643        }
644        Ok(TensorBase { data, layout })
645    }
646
647    /// Convert the current tensor into a dynamic rank tensor without copying
648    /// any data.
649    pub fn into_dyn(self) -> TensorBase<S, DynLayout>
650    where
651        L: Into<DynLayout>,
652    {
653        TensorBase {
654            data: self.data,
655            layout: self.layout.into(),
656        }
657    }
658
659    /// Consume this tensor and return the underlying storage.
660    ///
661    /// Be aware that the underlying elements are not guaranteed to be contiguous.
662    pub(crate) fn into_storage(self) -> S {
663        self.data
664    }
665
666    /// Attempt to convert this tensor's layout to a static-rank layout with `N`
667    /// dimensions.
668    fn nd_layout<const N: usize>(&self) -> Option<NdLayout<N>> {
669        if self.ndim() != N {
670            return None;
671        }
672        let shape: [usize; N] = std::array::from_fn(|i| self.size(i));
673        let strides: [usize; N] = std::array::from_fn(|i| self.stride(i));
674        let layout = NdLayout::from_shape_and_strides(shape, strides, OverlapPolicy::AllowOverlap)
675            .expect("invalid layout");
676        Some(layout)
677    }
678
679    /// Return a raw pointer to the tensor's underlying data.
680    pub fn data_ptr(&self) -> *const S::Elem {
681        self.data.as_ptr()
682    }
683}
684
685impl<S: StorageMut, L: Clone + Layout> TensorBase<S, L> {
686    /// Return an iterator over mutable slices of this tensor along a given
687    /// axis. Each view yielded has one dimension fewer than the current layout.
688    pub fn axis_iter_mut(&mut self, dim: usize) -> AxisIterMut<'_, S::Elem, L>
689    where
690        L: RemoveDim,
691    {
692        AxisIterMut::new(self.view_mut(), dim)
693    }
694
695    /// Return an iterator over mutable slices of this tensor along a given
696    /// axis. Each view yielded has the same rank as this tensor, but the
697    /// dimension `dim` will only have `chunk_size` entries.
698    pub fn axis_chunks_mut(
699        &mut self,
700        dim: usize,
701        chunk_size: usize,
702    ) -> AxisChunksMut<'_, S::Elem, L>
703    where
704        L: MutLayout,
705    {
706        AxisChunksMut::new(self.view_mut(), dim, chunk_size)
707    }
708
709    /// Replace each element in this tensor with the result of applying `f` to
710    /// the element.
711    pub fn apply<F: Fn(&S::Elem) -> S::Elem>(&mut self, f: F) {
712        if let Some(data) = self.data_mut() {
713            // Fast path for contiguous tensors.
714            data.iter_mut().for_each(|x| *x = f(x));
715        } else {
716            for_each_mut(self.as_dyn_mut(), |x| *x = f(x));
717        }
718    }
719
720    /// Return a mutable view of this tensor with a dynamic dimension count.
721    pub fn as_dyn_mut(&mut self) -> TensorBase<ViewMutData<'_, S::Elem>, DynLayout> {
722        TensorBase {
723            layout: DynLayout::from(&self.layout),
724            data: self.data.view_mut(),
725        }
726    }
727
728    /// Copy elements from another tensor into this tensor.
729    ///
730    /// This tensor and `other` must have the same shape.
731    pub fn copy_from<S2: Storage<Elem = S::Elem>>(&mut self, other: &TensorBase<S2, L>)
732    where
733        S::Elem: Clone,
734        L: Clone,
735    {
736        assert!(
737            self.shape() == other.shape(),
738            "copy dest shape {:?} != src shape {:?}",
739            self.shape(),
740            other.shape()
741        );
742
743        if let Some(dest) = self.data_mut() {
744            if let Some(src) = other.data() {
745                dest.clone_from_slice(src);
746            } else {
747                // Drop all the existing values. This should be compiled away for
748                // `Copy` types.
749                let uninit_dest: &mut [MaybeUninit<S::Elem>] = unsafe { std::mem::transmute(dest) };
750                for x in &mut *uninit_dest {
751                    // Safety: All elements were initialized at the start of this
752                    // block, and we haven't written to the slice yet.
753                    unsafe { x.assume_init_drop() }
754                }
755
756                // Copy source into destination in contiguous order.
757                copy_into_slice(other.as_dyn(), uninit_dest);
758            }
759        } else {
760            copy_into(other.as_dyn(), self.as_dyn_mut());
761        }
762    }
763
764    /// Return the data in this tensor as a slice if it is contiguous.
765    pub fn data_mut(&mut self) -> Option<&mut [S::Elem]> {
766        // The length of `self.data` must be at least the minimum required by
767        // the layout, but it may be larger.
768        let len = self.layout.min_data_len();
769        let data = self.data.slice_mut(0..len);
770
771        self.layout.is_contiguous().then(|| unsafe {
772            // Safety: We verified the layout is contiguous.
773            data.to_slice_mut()
774        })
775    }
776
777    /// Index the tensor along a given axis.
778    ///
779    /// Returns a mutable view with one dimension removed.
780    ///
781    /// Panics if `axis >= self.ndim()` or `index >= self.size(axis)`.
782    pub fn index_axis_mut(
783        &mut self,
784        axis: usize,
785        index: usize,
786    ) -> TensorBase<ViewMutData<'_, S::Elem>, <L as RemoveDim>::Output>
787    where
788        L: MutLayout + RemoveDim,
789    {
790        let (offsets, layout) = self.layout.index_axis(axis, index);
791        TensorBase {
792            data: self.data.slice_mut(offsets),
793            layout,
794        }
795    }
796
797    /// Return a mutable view of the tensor's underlying storage.
798    pub fn storage_mut(&mut self) -> ViewMutData<'_, S::Elem> {
799        self.data.view_mut()
800    }
801
802    /// Replace all elements of this tensor with `value`.
803    pub fn fill(&mut self, value: S::Elem)
804    where
805        S::Elem: Clone,
806    {
807        self.apply(|_| value.clone())
808    }
809
810    /// Return a mutable reference to the element at `index`, or `None` if the
811    /// index is invalid.
812    pub fn get_mut<I: AsIndex<L>>(&mut self, index: I) -> Option<&mut S::Elem>
813    where
814        L: TrustedLayout,
815    {
816        self.offset(index.as_index()).map(|offset| unsafe {
817            // Safety: We verified the offset is in-bounds.
818            self.data.get_unchecked_mut(offset)
819        })
820    }
821
822    /// Return the element at a given index, without performing any bounds-
823    /// checking.
824    ///
825    /// # Safety
826    ///
827    /// The caller must ensure that the index is valid for the tensor's shape.
828    pub unsafe fn get_unchecked_mut<I: AsIndex<L>>(&mut self, index: I) -> &mut S::Elem {
829        let offset = self.layout.offset_unchecked(index.as_index());
830        unsafe { self.data.get_unchecked_mut(offset) }
831    }
832
833    pub(crate) fn mut_view_ref(&mut self) -> TensorBase<ViewMutData<'_, S::Elem>, &L> {
834        TensorBase {
835            data: self.data.view_mut(),
836            layout: &self.layout,
837        }
838    }
839
840    /// Return a mutable iterator over the N innermost dimensions of this tensor.
841    pub fn inner_iter_mut<const N: usize>(&mut self) -> InnerIterMut<'_, S::Elem, NdLayout<N>>
842    where
843        L: MutLayout,
844    {
845        InnerIterMut::new(self.view_mut())
846    }
847
848    /// Return a mutable iterator over the n innermost dimensions of this tensor.
849    ///
850    /// Prefer [`inner_iter_mut`](TensorBase::inner_iter_mut) if `N` is known
851    /// at compile time.
852    pub fn inner_iter_dyn_mut(&mut self, n: usize) -> InnerIterMut<'_, S::Elem, DynLayout>
853    where
854        L: MutLayout,
855    {
856        InnerIterMut::new_dyn(self.view_mut(), n)
857    }
858
859    /// Return a mutable iterator over the elements of this tensor, in their
860    /// logical order.
861    pub fn iter_mut(&mut self) -> IterMut<'_, S::Elem> {
862        IterMut::new(self.mut_view_ref())
863    }
864
865    /// Return an iterator over mutable 1D slices of this tensor along a given
866    /// dimension.
867    pub fn lanes_mut(&mut self, dim: usize) -> LanesMut<'_, S::Elem>
868    where
869        L: RemoveDim,
870    {
871        LanesMut::new(self.mut_view_ref(), dim)
872    }
873
874    /// Return a view of this tensor with a static dimension count.
875    ///
876    /// Panics if `self.ndim() != N`.
877    pub fn nd_view_mut<const N: usize>(
878        &mut self,
879    ) -> TensorBase<ViewMutData<'_, S::Elem>, NdLayout<N>> {
880        assert!(self.ndim() == N, "ndim {} != {}", self.ndim(), N);
881        TensorBase {
882            layout: self.nd_layout().unwrap(),
883            data: self.data.view_mut(),
884        }
885    }
886
887    /// Permute the order of dimensions according to the given order.
888    ///
889    /// See [`AsView::permuted`].
890    pub fn permuted_mut(&mut self, order: L::Index<'_>) -> TensorBase<ViewMutData<'_, S::Elem>, L>
891    where
892        L: MutLayout,
893    {
894        TensorBase {
895            layout: self.layout.permuted(order),
896            data: self.data.view_mut(),
897        }
898    }
899
900    /// Change the layout of the tensor without moving any data.
901    ///
902    /// This will return an error if the view is not contiguous.
903    ///
904    /// See also [`AsView::reshaped`].
905    pub fn reshaped_mut<SH: IntoLayout>(
906        &mut self,
907        shape: SH,
908    ) -> Result<TensorBase<ViewMutData<'_, S::Elem>, SH::Layout>, ReshapeError>
909    where
910        L: MutLayout,
911    {
912        let layout = self.layout.reshaped_for_view(shape)?;
913        Ok(TensorBase {
914            layout,
915            data: self.data.view_mut(),
916        })
917    }
918
919    /// Slice this tensor along a given axis.
920    pub fn slice_axis_mut(
921        &mut self,
922        axis: usize,
923        range: Range<usize>,
924    ) -> TensorBase<ViewMutData<'_, S::Elem>, L>
925    where
926        L: MutLayout,
927    {
928        let (offset_range, sliced_layout) = self.layout.slice_axis(axis, range.clone()).unwrap();
929        debug_assert_eq!(sliced_layout.size(axis), range.len());
930        TensorBase {
931            data: self.data.slice_mut(offset_range),
932            layout: sliced_layout,
933        }
934    }
935
936    /// Slice this tensor and return a mutable view.
937    ///
938    /// See [`slice`](AsView::slice) for notes on the layout of the returned
939    /// view.
940    pub fn slice_mut<R: IntoSliceItems + IndexCount>(
941        &mut self,
942        range: R,
943    ) -> TensorBase<ViewMutData<'_, S::Elem>, <L as SliceWith<R, R::Count>>::Layout>
944    where
945        L: SliceWith<R, R::Count>,
946    {
947        self.try_slice_mut(range).expect("slice failed")
948    }
949
950    /// A variant of [`slice_mut`](Self::slice_mut) that returns a
951    /// result instead of panicking.
952    #[allow(clippy::type_complexity)]
953    pub fn try_slice_mut<R: IntoSliceItems + IndexCount>(
954        &mut self,
955        range: R,
956    ) -> Result<
957        TensorBase<ViewMutData<'_, S::Elem>, <L as SliceWith<R, R::Count>>::Layout>,
958        SliceError,
959    >
960    where
961        L: SliceWith<R, R::Count>,
962    {
963        let (offset_range, sliced_layout) = self.layout.slice_with(range)?;
964        Ok(TensorBase {
965            data: self.data.slice_mut(offset_range),
966            layout: sliced_layout,
967        })
968    }
969
970    /// Return a mutable view of this tensor.
971    pub fn view_mut(&mut self) -> TensorBase<ViewMutData<'_, S::Elem>, L>
972    where
973        L: Clone,
974    {
975        TensorBase {
976            data: self.data.view_mut(),
977            layout: self.layout.clone(),
978        }
979    }
980
981    /// Return a mutable view that performs only "weak" checking when indexing,
982    /// this is faster but can hide bugs. See [`WeaklyCheckedView`].
983    pub fn weakly_checked_view_mut(&mut self) -> WeaklyCheckedView<ViewMutData<'_, S::Elem>, L> {
984        WeaklyCheckedView {
985            base: self.view_mut(),
986        }
987    }
988}
989
990impl<T, L: Clone + Layout> TensorBase<Vec<T>, L> {
991    /// Create a new 1D tensor filled with an arithmetic sequence of values
992    /// in the range `[start, end)` separated by `step`. If `step` is omitted,
993    /// it defaults to 1.
994    pub fn arange(start: T, end: T, step: Option<T>) -> TensorBase<Vec<T>, L>
995    where
996        T: Copy + PartialOrd + From<bool> + std::ops::Add<Output = T>,
997        [usize; 1]: AsIndex<L>,
998        L: MutLayout,
999    {
1000        let step = step.unwrap_or((true).into());
1001        let mut data = Vec::new();
1002        let mut curr = start;
1003        while curr < end {
1004            data.push(curr);
1005            curr = curr + step;
1006        }
1007        TensorBase::from_data([data.len()].as_index(), data)
1008    }
1009
1010    /// Append elements from `other` to this tensor along a given axis.
1011    ///
1012    /// This will fail if the shapes of `self` and `other` do not match along
1013    /// dimensions other than `axis`, or if the current tensor has
1014    /// insufficient capacity to expand without re-allocating.
1015    pub fn append<S2: Storage<Elem = T>>(
1016        &mut self,
1017        axis: usize,
1018        other: &TensorBase<S2, L>,
1019    ) -> Result<(), ExpandError>
1020    where
1021        T: Copy,
1022        L: MutLayout,
1023    {
1024        let shape_match = self.ndim() == other.ndim()
1025            && (0..self.ndim()).all(|d| d == axis || self.size(d) == other.size(d));
1026        if !shape_match {
1027            return Err(ExpandError::ShapeMismatch);
1028        }
1029
1030        let old_size = self.size(axis);
1031        let new_size = self.size(axis) + other.size(axis);
1032
1033        let Some(new_layout) = self.expanded_layout(axis, new_size) else {
1034            return Err(ExpandError::InsufficientCapacity);
1035        };
1036
1037        let new_data_len = new_layout.min_data_len();
1038        self.layout = new_layout;
1039
1040        // Safety: The `copy_from` call below will initialize all elements
1041        // added to the tensor.
1042        assert!(self.data.capacity() >= new_data_len);
1043        unsafe {
1044            self.data.set_len(new_data_len);
1045        }
1046
1047        self.slice_axis_mut(axis, old_size..new_size)
1048            .copy_from(other);
1049
1050        Ok(())
1051    }
1052
1053    /// Create a new 1D tensor from a `Vec<T>`.
1054    pub fn from_vec(vec: Vec<T>) -> TensorBase<Vec<T>, L>
1055    where
1056        [usize; 1]: AsIndex<L>,
1057        L: MutLayout,
1058    {
1059        TensorBase::from_data([vec.len()].as_index(), vec)
1060    }
1061
1062    /// Clip dimension `dim` to `[range.start, range.end)`. The new size for
1063    /// the dimension must be <= the old size.
1064    ///
1065    /// This currently requires `T: Copy` to support efficiently moving data
1066    /// from the new start offset to the beginning of the element buffer.
1067    pub fn clip_dim(&mut self, dim: usize, range: Range<usize>)
1068    where
1069        T: Copy,
1070        L: MutLayout,
1071    {
1072        let (start, end) = (range.start, range.end);
1073
1074        assert!(start <= end, "start must be <= end");
1075        assert!(end <= self.size(dim), "end must be <= dim size");
1076
1077        self.layout.resize_dim(dim, end - start);
1078
1079        let range = if self.is_empty() {
1080            0..0
1081        } else {
1082            let start_offset = start * self.layout.stride(dim);
1083            let end_offset = start_offset + self.layout.min_data_len();
1084            start_offset..end_offset
1085        };
1086        self.data.copy_within(range.clone(), 0);
1087        self.data.truncate(range.end - range.start);
1088    }
1089
1090    /// Return true if this tensor can be expanded along a given axis to a
1091    /// new size without re-allocating.
1092    pub fn has_capacity(&self, axis: usize, new_size: usize) -> bool
1093    where
1094        L: MutLayout,
1095    {
1096        self.expanded_layout(axis, new_size).is_some()
1097    }
1098
1099    /// Return the layout this tensor would have if the size of `axis` were
1100    /// expanded to `new_size`.
1101    ///
1102    /// Returns `None` if the tensor does not have capacity for the new size.
1103    fn expanded_layout(&self, axis: usize, new_size: usize) -> Option<L>
1104    where
1105        L: MutLayout,
1106    {
1107        let mut new_layout = self.layout.clone();
1108        new_layout.resize_dim(axis, new_size);
1109        let new_data_len = new_layout.min_data_len();
1110
1111        let has_capacity = new_data_len <= self.data.capacity()
1112            && !may_have_internal_overlap(
1113                new_layout.shape().as_ref(),
1114                new_layout.strides().as_ref(),
1115            );
1116
1117        has_capacity.then_some(new_layout)
1118    }
1119
1120    /// Convert the storage of this tensor into an owned [`CowData`].
1121    ///
1122    /// This is useful in contexts where code needs to conditionally copy or
1123    /// create a new tensor. See [`AsView::as_cow`].
1124    pub fn into_cow(self) -> TensorBase<CowData<'static, T>, L> {
1125        let TensorBase { data, layout } = self;
1126        TensorBase {
1127            layout,
1128            data: CowData::Owned(data),
1129        }
1130    }
1131
1132    /// Convert the storage of this tensor to be reference counted.
1133    ///
1134    /// This is a (relatively) cheap operation that does not copy the tensor
1135    /// data.
1136    pub fn into_arc(self) -> TensorBase<Arc<Vec<T>>, L> {
1137        let TensorBase { data, layout } = self;
1138        TensorBase {
1139            layout,
1140            data: Arc::new(data),
1141        }
1142    }
1143
1144    /// Consume self and return the underlying data as a contiguous tensor.
1145    ///
1146    /// See also [`TensorBase::to_vec`].
1147    pub fn into_data(self) -> Vec<T>
1148    where
1149        T: Clone,
1150    {
1151        if self.is_contiguous() {
1152            self.into_non_contiguous_data()
1153        } else {
1154            self.to_vec()
1155        }
1156    }
1157
1158    /// Consume self and return the underlying data in whatever order the
1159    /// elements are currently stored.
1160    pub fn into_non_contiguous_data(mut self) -> Vec<T> {
1161        self.data.truncate(self.layout.min_data_len());
1162        self.data
1163    }
1164
1165    /// Consume self and return a new contiguous tensor with the given shape.
1166    ///
1167    /// This avoids copying the data if it is already contiguous.
1168    #[track_caller]
1169    pub fn into_shape<S: Copy + IntoLayout>(self, shape: S) -> TensorBase<Vec<T>, S::Layout>
1170    where
1171        T: Clone,
1172        L: MutLayout,
1173    {
1174        let Ok(layout) = self.layout.reshaped_for_copy(shape) else {
1175            panic!(
1176                "element count mismatch reshaping {:?} to {:?}",
1177                self.shape(),
1178                shape
1179            );
1180        };
1181        TensorBase {
1182            layout,
1183            data: self.into_data(),
1184        }
1185    }
1186
1187    /// Create a new tensor with a given shape and values generated by calling
1188    /// `f` repeatedly.
1189    ///
1190    /// Each call to `f` will receive an element index and should return the
1191    /// corresponding value. If the function does not need this index, use
1192    /// [`from_simple_fn`](TensorBase::from_simple_fn) instead, as it is faster.
1193    pub fn from_fn<F: FnMut(L::Index<'_>) -> T, Idx>(
1194        shape: L::Index<'_>,
1195        mut f: F,
1196    ) -> TensorBase<Vec<T>, L>
1197    where
1198        L::Indices: Iterator<Item = Idx>,
1199        Idx: AsIndex<L>,
1200        L: MutLayout,
1201    {
1202        let layout = L::from_shape(shape);
1203        let data: Vec<T> = layout.indices().map(|idx| f(idx.as_index())).collect();
1204        TensorBase { data, layout }
1205    }
1206
1207    /// Create a new tensor with a given shape and values generated by calling
1208    /// `f` repeatedly.
1209    pub fn from_simple_fn<F: FnMut() -> T>(shape: L::Index<'_>, f: F) -> TensorBase<Vec<T>, L>
1210    where
1211        L: MutLayout,
1212    {
1213        Self::from_simple_fn_in(GlobalAlloc::new(), shape, f)
1214    }
1215
1216    /// Variant of [`from_simple_fn`](TensorBase::from_simple_fn) that takes
1217    /// an allocator.
1218    pub fn from_simple_fn_in<A: Alloc, F: FnMut() -> T>(
1219        alloc: A,
1220        shape: L::Index<'_>,
1221        mut f: F,
1222    ) -> TensorBase<Vec<T>, L>
1223    where
1224        L: MutLayout,
1225    {
1226        let len = shape.as_ref().iter().product();
1227        let mut data = alloc.alloc(len);
1228        data.extend(std::iter::from_fn(|| Some(f())).take(len));
1229        TensorBase::from_data(shape, data)
1230    }
1231
1232    /// Create a new 0D tensor from a scalar value.
1233    pub fn from_scalar(value: T) -> TensorBase<Vec<T>, L>
1234    where
1235        [usize; 0]: AsIndex<L>,
1236        L: MutLayout,
1237    {
1238        TensorBase::from_data([].as_index(), vec![value])
1239    }
1240
1241    /// Create a new tensor with a given shape and all elements set to `value`.
1242    pub fn full(shape: L::Index<'_>, value: T) -> TensorBase<Vec<T>, L>
1243    where
1244        T: Clone,
1245        L: MutLayout,
1246    {
1247        Self::full_in(GlobalAlloc::new(), shape, value)
1248    }
1249
1250    /// Variant of [`full`](TensorBase::full) which takes an allocator.
1251    pub fn full_in<A: Alloc>(alloc: A, shape: L::Index<'_>, value: T) -> TensorBase<Vec<T>, L>
1252    where
1253        T: Clone,
1254        L: MutLayout,
1255    {
1256        let len = shape.as_ref().iter().product();
1257        let mut data = alloc.alloc(len);
1258        data.resize(len, value);
1259        TensorBase::from_data(shape, data)
1260    }
1261
1262    /// Make the underlying data in this tensor contiguous.
1263    ///
1264    /// This means that after calling `make_contiguous`, the elements are
1265    /// guaranteed to be stored in the same order as the logical order in
1266    /// which `iter` yields elements. This method is cheap if the storage is
1267    /// already contiguous.
1268    pub fn make_contiguous(&mut self)
1269    where
1270        T: Clone,
1271        L: MutLayout,
1272    {
1273        if self.is_contiguous() {
1274            return;
1275        }
1276        self.data = self.to_vec();
1277        self.layout = L::from_shape(self.layout.shape());
1278    }
1279
1280    /// Create a new tensor with a given shape and elements populated using
1281    /// numbers generated by `rand_src`.
1282    ///
1283    /// A more general version of this method that generates values using any
1284    /// function is [`from_simple_fn`](Self::from_simple_fn).
1285    pub fn rand<R: RandomSource<T>>(shape: L::Index<'_>, rand_src: &mut R) -> TensorBase<Vec<T>, L>
1286    where
1287        L: MutLayout,
1288    {
1289        Self::from_simple_fn(shape, || rand_src.next())
1290    }
1291
1292    /// Create a new tensor with a given shape, with all elements set to their
1293    /// default value (ie. zero for numeric types).
1294    pub fn zeros(shape: L::Index<'_>) -> TensorBase<Vec<T>, L>
1295    where
1296        T: Clone + Default,
1297        L: MutLayout,
1298    {
1299        Self::zeros_in(GlobalAlloc::new(), shape)
1300    }
1301
1302    /// Variant of [`zeros`](TensorBase::zeros) which takes an allocator.
1303    pub fn zeros_in<A: Alloc>(alloc: A, shape: L::Index<'_>) -> TensorBase<Vec<T>, L>
1304    where
1305        T: Clone + Default,
1306        L: MutLayout,
1307    {
1308        // We delegate to `full_in` here and rely on compiler optimizations to
1309        // take advantage of the value being statically known to be zero.
1310        Self::full_in(alloc, shape, T::default())
1311    }
1312
1313    /// Return a new tensor containing uninitialized elements.
1314    ///
1315    /// The caller must initialize elements and then call
1316    /// [`assume_init`](TensorBase::assume_init) to convert to an initialized
1317    /// `Tensor<T>`.
1318    pub fn uninit(shape: L::Index<'_>) -> TensorBase<Vec<MaybeUninit<T>>, L>
1319    where
1320        MaybeUninit<T>: Clone,
1321        L: MutLayout,
1322    {
1323        Self::uninit_in(GlobalAlloc::new(), shape)
1324    }
1325
1326    /// Variant of [`uninit`](TensorBase::uninit) which takes an allocator.
1327    pub fn uninit_in<A: Alloc>(alloc: A, shape: L::Index<'_>) -> TensorBase<Vec<MaybeUninit<T>>, L>
1328    where
1329        L: MutLayout,
1330    {
1331        let len = shape.as_ref().iter().product();
1332        let mut data = alloc.alloc(len);
1333
1334        // Safety: Since the contents of the `Vec` are `MaybeUninit`, we don't
1335        // need to initialize them.
1336        unsafe { data.set_len(len) }
1337
1338        TensorBase::from_data(shape, data)
1339    }
1340
1341    /// Create a tensor which initially has zero elements, but can be expanded
1342    /// along a given dimension without reallocating.
1343    ///
1344    /// `shape` specifies the maximum shape that the tensor can be expanded to
1345    /// without reallocating. The initial shape will be the same, except for
1346    /// the dimension specified by `expand_dim`, which will be zero.
1347    pub fn with_capacity(shape: L::Index<'_>, expand_dim: usize) -> TensorBase<Vec<T>, L>
1348    where
1349        T: Copy,
1350        L: MutLayout,
1351    {
1352        Self::with_capacity_in(GlobalAlloc::new(), shape, expand_dim)
1353    }
1354
1355    /// Variant of [`with_capacity`](Self::with_capacity) which takes an allocator.
1356    pub fn with_capacity_in<A: Alloc>(
1357        alloc: A,
1358        shape: L::Index<'_>,
1359        expand_dim: usize,
1360    ) -> TensorBase<Vec<T>, L>
1361    where
1362        T: Copy,
1363        L: MutLayout,
1364    {
1365        let mut tensor = Self::uninit_in(alloc, shape);
1366        tensor.clip_dim(expand_dim, 0..0);
1367
1368        // Safety: Since at least one dimension has a size of zero, the tensor
1369        // has no elements and thus is fully initialized.
1370        unsafe { tensor.assume_init() }
1371    }
1372}
1373
1374impl<T, L: Layout> TensorBase<CowData<'_, T>, L> {
1375    /// Consume self and return the underlying data in whatever order the
1376    /// elements are currently stored, if the storage is owned, or `None` if
1377    /// it is borrowed.
1378    pub fn into_non_contiguous_data(self) -> Option<Vec<T>> {
1379        match self.data {
1380            CowData::Owned(mut vec) => {
1381                vec.truncate(self.layout.min_data_len());
1382                Some(vec)
1383            }
1384            CowData::Borrowed(_) => None,
1385        }
1386    }
1387}
1388
1389impl<T, S: Storage<Elem = MaybeUninit<T>> + AssumeInit, L: Layout + Clone> TensorBase<S, L>
1390where
1391    <S as AssumeInit>::Output: Storage<Elem = T>,
1392{
1393    /// Convert a tensor of potentially uninitialized elements to one of
1394    /// initialized elements.
1395    ///
1396    /// See also [`MaybeUninit::assume_init`].
1397    ///
1398    /// # Safety
1399    ///
1400    /// The caller must guarantee that all elements in this tensor have been
1401    /// initialized before calling `assume_init`.
1402    pub unsafe fn assume_init(self) -> TensorBase<<S as AssumeInit>::Output, L> {
1403        TensorBase {
1404            layout: self.layout,
1405            data: unsafe { self.data.assume_init() },
1406        }
1407    }
1408
1409    /// Initialize this tensor with data from another view.
1410    ///
1411    /// This tensor and `other` must have the same shape.
1412    pub fn init_from<S2: Storage<Elem = T>>(
1413        mut self,
1414        other: &TensorBase<S2, L>,
1415    ) -> TensorBase<<S as AssumeInit>::Output, L>
1416    where
1417        T: Copy,
1418        S: StorageMut<Elem = MaybeUninit<T>>,
1419    {
1420        assert_eq!(self.shape(), other.shape(), "shape mismatch");
1421
1422        match (self.data_mut(), other.data()) {
1423            // Source and dest are contiguous. Use a memcpy.
1424            (Some(self_data), Some(other_data)) => {
1425                let other_data: &[MaybeUninit<T>] = unsafe { std::mem::transmute(other_data) };
1426                self_data.clone_from_slice(other_data);
1427            }
1428            // Dest is contiguous.
1429            (Some(self_data), _) => {
1430                copy_into_slice(other.as_dyn(), self_data);
1431            }
1432            // Neither are contiguous.
1433            _ => {
1434                copy_into_uninit(other.as_dyn(), self.as_dyn_mut());
1435            }
1436        }
1437
1438        unsafe { self.assume_init() }
1439    }
1440}
1441
1442impl<'a, T, L: Clone + Layout> TensorBase<ViewData<'a, T>, L> {
1443    pub fn axis_iter(&self, dim: usize) -> AxisIter<'a, T, L>
1444    where
1445        L: MutLayout + RemoveDim,
1446    {
1447        AxisIter::new(self, dim)
1448    }
1449
1450    pub fn axis_chunks(&self, dim: usize, chunk_size: usize) -> AxisChunks<'a, T, L>
1451    where
1452        L: MutLayout,
1453    {
1454        AxisChunks::new(self, dim, chunk_size)
1455    }
1456
1457    /// Return a view of this tensor with a dynamic dimension count.
1458    ///
1459    /// See [`AsView::as_dyn`].
1460    pub fn as_dyn(&self) -> TensorBase<ViewData<'a, T>, DynLayout> {
1461        TensorBase {
1462            data: self.data,
1463            layout: DynLayout::from(&self.layout),
1464        }
1465    }
1466
1467    /// Convert the storage of this view to a borrowed [`CowData`].
1468    ///
1469    /// See [`AsView::as_cow`].
1470    pub fn as_cow(&self) -> TensorBase<CowData<'a, T>, L> {
1471        TensorBase {
1472            layout: self.layout.clone(),
1473            data: CowData::Borrowed(self.data),
1474        }
1475    }
1476
1477    /// Broadcast this view to another shape.
1478    ///
1479    /// See [`AsView::broadcast`].
1480    pub fn broadcast<S: IntoLayout>(&self, shape: S) -> TensorBase<ViewData<'a, T>, S::Layout>
1481    where
1482        L: BroadcastLayout<S::Layout>,
1483    {
1484        self.try_broadcast(shape).unwrap()
1485    }
1486
1487    /// Broadcast this view to another shape.
1488    ///
1489    /// See [`AsView::broadcast`].
1490    pub fn try_broadcast<S: IntoLayout>(
1491        &self,
1492        shape: S,
1493    ) -> Result<TensorBase<ViewData<'a, T>, S::Layout>, ExpandError>
1494    where
1495        L: BroadcastLayout<S::Layout>,
1496    {
1497        Ok(TensorBase {
1498            layout: self.layout.broadcast(shape)?,
1499            data: self.data,
1500        })
1501    }
1502
1503    /// Return the data in this tensor as a slice if it is contiguous, ie.
1504    /// the order of elements in the slice is the same as the logical order
1505    /// yielded by `iter`, and there are no gaps.
1506    pub fn data(&self) -> Option<&'a [T]> {
1507        // The length of `self.data` must be at least the minimum required by
1508        // the layout, but it may be larger.
1509        let len = self.layout.min_data_len();
1510        let data = self.data.slice(0..len);
1511
1512        self.layout.is_contiguous().then(|| unsafe {
1513            // Safety: Storage is contigous
1514            data.as_slice()
1515        })
1516    }
1517
1518    /// Return an immutable view of the tensor's underlying storage.
1519    pub fn storage(&self) -> ViewData<'a, T> {
1520        self.data.view()
1521    }
1522
1523    pub fn get<I: AsIndex<L>>(&self, index: I) -> Option<&'a T>
1524    where
1525        L: TrustedLayout,
1526    {
1527        self.offset(index.as_index()).map(|offset|
1528                // Safety:
1529                // - No logically overlapping mutable view exist.
1530                // - For trusted layouts, offset is promised to be less than
1531                //   the storage length
1532                unsafe {
1533                self.data.get_unchecked(offset)
1534            })
1535    }
1536
1537    /// Create a new view with a given shape and data slice, and custom strides.
1538    ///
1539    /// If you do not need to specify custom strides, use [`TensorBase::from_data`]
1540    /// instead. This method is similar to [`TensorBase::from_data_with_strides`],
1541    /// but allows strides that lead to internal overlap (see [`OverlapPolicy`]).
1542    pub fn from_slice_with_strides(
1543        shape: L::Index<'_>,
1544        data: &'a [T],
1545        strides: L::Index<'_>,
1546    ) -> Result<TensorBase<ViewData<'a, T>, L>, FromDataError>
1547    where
1548        L: MutLayout,
1549    {
1550        let layout = L::from_shape_and_strides(shape, strides, OverlapPolicy::AllowOverlap)?;
1551        if layout.min_data_len() > data.as_ref().len() {
1552            return Err(FromDataError::StorageTooShort);
1553        }
1554        Ok(TensorBase {
1555            data: data.into_storage(),
1556            layout,
1557        })
1558    }
1559
1560    /// Return the element at a given index, without performing any bounds-
1561    /// checking.
1562    ///
1563    /// # Safety
1564    ///
1565    /// The caller must ensure that the index is valid for the tensor's shape.
1566    pub unsafe fn get_unchecked<I: AsIndex<L>>(&self, index: I) -> &'a T {
1567        let offset = self.layout.offset_unchecked(index.as_index());
1568        unsafe { self.data.get_unchecked(offset) }
1569    }
1570
1571    /// Index the tensor along a given axis.
1572    ///
1573    /// Returns a view with one dimension removed.
1574    ///
1575    /// Panics if `axis >= self.ndim()` or `index >= self.size(axis)`.
1576    pub fn index_axis(
1577        &self,
1578        axis: usize,
1579        index: usize,
1580    ) -> TensorBase<ViewData<'a, T>, <L as RemoveDim>::Output>
1581    where
1582        L: MutLayout + RemoveDim,
1583    {
1584        let (offsets, layout) = self.layout.index_axis(axis, index);
1585        TensorBase {
1586            data: self.data.slice(offsets),
1587            layout,
1588        }
1589    }
1590
1591    /// Return an iterator over the inner `N` dimensions of this tensor.
1592    ///
1593    /// See [`AsView::inner_iter`].
1594    pub fn inner_iter<const N: usize>(&self) -> InnerIter<'a, T, NdLayout<N>> {
1595        InnerIter::new(self.view())
1596    }
1597
1598    /// Return an iterator over the inner `n` dimensions of this tensor.
1599    ///
1600    /// See [`AsView::inner_iter_dyn`].
1601    pub fn inner_iter_dyn(&self, n: usize) -> InnerIter<'a, T, DynLayout> {
1602        InnerIter::new_dyn(self.view(), n)
1603    }
1604
1605    /// Return the scalar value in this tensor if it has one element.
1606    pub fn item(&self) -> Option<&'a T> {
1607        match self.ndim() {
1608            0 => unsafe {
1609                // Safety: No logically overlapping mutable views exist.
1610                self.data.get(0)
1611            },
1612            _ if self.len() == 1 => self.iter().next(),
1613            _ => None,
1614        }
1615    }
1616
1617    /// Return an iterator over elements of this tensor in their logical order.
1618    ///
1619    /// See [`AsView::iter`].
1620    pub fn iter(&self) -> Iter<'a, T> {
1621        Iter::new(self.view_ref())
1622    }
1623
1624    /// Return an iterator over 1D slices of this tensor along a given dimension.
1625    ///
1626    /// See [`AsView::lanes`].
1627    pub fn lanes(&self, dim: usize) -> Lanes<'a, T>
1628    where
1629        L: RemoveDim,
1630    {
1631        assert!(dim < self.ndim());
1632        Lanes::new(self.view_ref(), dim)
1633    }
1634
1635    /// Return a view of this tensor with a static dimension count.
1636    ///
1637    /// Panics if `self.ndim() != N`.
1638    pub fn nd_view<const N: usize>(&self) -> TensorBase<ViewData<'a, T>, NdLayout<N>> {
1639        assert!(self.ndim() == N, "ndim {} != {}", self.ndim(), N);
1640        TensorBase {
1641            data: self.data,
1642            layout: self.nd_layout().unwrap(),
1643        }
1644    }
1645
1646    /// Permute the axes of this tensor according to `order`.
1647    ///
1648    /// See [`AsView::permuted`].
1649    pub fn permuted(&self, order: L::Index<'_>) -> TensorBase<ViewData<'a, T>, L>
1650    where
1651        L: MutLayout,
1652    {
1653        TensorBase {
1654            data: self.data,
1655            layout: self.layout.permuted(order),
1656        }
1657    }
1658
1659    /// Return a view or owned tensor that has the given shape.
1660    ///
1661    /// See [`AsView::reshaped`].
1662    pub fn reshaped<S: Copy + IntoLayout>(&self, shape: S) -> TensorBase<CowData<'a, T>, S::Layout>
1663    where
1664        T: Clone,
1665        L: MutLayout,
1666    {
1667        self.reshaped_in(GlobalAlloc::new(), shape)
1668    }
1669
1670    /// Variant of [`reshaped`](Self::reshaped) that takes an allocator.
1671    pub fn reshaped_in<A: Alloc, S: Copy + IntoLayout>(
1672        &self,
1673        alloc: A,
1674        shape: S,
1675    ) -> TensorBase<CowData<'a, T>, S::Layout>
1676    where
1677        T: Clone,
1678        L: MutLayout,
1679    {
1680        if let Ok(layout) = self.layout.reshaped_for_view(shape) {
1681            TensorBase {
1682                data: CowData::Borrowed(self.data),
1683                layout,
1684            }
1685        } else {
1686            let Ok(layout) = self.layout.reshaped_for_copy(shape) else {
1687                panic!(
1688                    "element count mismatch reshaping {:?} to {:?}",
1689                    self.shape(),
1690                    shape
1691                );
1692            };
1693
1694            TensorBase {
1695                data: CowData::Owned(self.to_vec_in(alloc)),
1696                layout,
1697            }
1698        }
1699    }
1700
1701    /// Slice this tensor and return a view. See [`AsView::slice`].
1702    pub fn slice<R: IntoSliceItems + IndexCount>(
1703        &self,
1704        range: R,
1705    ) -> TensorBase<ViewData<'a, T>, <L as SliceWith<R, R::Count>>::Layout>
1706    where
1707        L: SliceWith<R, R::Count>,
1708    {
1709        self.try_slice(range).expect("slice failed")
1710    }
1711
1712    /// Slice this tensor along a given axis.
1713    pub fn slice_axis(&self, axis: usize, range: Range<usize>) -> TensorBase<ViewData<'a, T>, L>
1714    where
1715        L: MutLayout,
1716    {
1717        let (offset_range, sliced_layout) = self.layout.slice_axis(axis, range.clone()).unwrap();
1718        debug_assert_eq!(sliced_layout.size(axis), range.len());
1719        TensorBase {
1720            data: self.data.slice(offset_range),
1721            layout: sliced_layout,
1722        }
1723    }
1724
1725    /// A variant of [`slice`](Self::slice) that returns a result
1726    /// instead of panicking.
1727    #[allow(clippy::type_complexity)]
1728    pub fn try_slice<R: IntoSliceItems + IndexCount>(
1729        &self,
1730        range: R,
1731    ) -> Result<TensorBase<ViewData<'a, T>, <L as SliceWith<R, R::Count>>::Layout>, SliceError>
1732    where
1733        L: SliceWith<R, R::Count>,
1734    {
1735        let (offset_range, sliced_layout) = self.layout.slice_with(range)?;
1736        Ok(TensorBase {
1737            data: self.data.slice(offset_range),
1738            layout: sliced_layout,
1739        })
1740    }
1741
1742    /// Remove all size-one dimensions from this tensor.
1743    ///
1744    /// See [`AsView::squeezed`].
1745    pub fn squeezed(&self) -> TensorView<'a, T>
1746    where
1747        L: MutLayout,
1748    {
1749        TensorBase {
1750            data: self.data.view(),
1751            layout: self.layout.squeezed(),
1752        }
1753    }
1754
1755    /// Divide this tensor into two views along a given axis.
1756    ///
1757    /// Returns a `(left, right)` tuple of views, where the `left` view
1758    /// contains the slice from `[0, mid)` along `axis` and the `right`
1759    /// view contains the slice from `[mid, end)` along `axis`.
1760    #[allow(clippy::type_complexity)]
1761    pub fn split_at(
1762        &self,
1763        axis: usize,
1764        mid: usize,
1765    ) -> (
1766        TensorBase<ViewData<'a, T>, L>,
1767        TensorBase<ViewData<'a, T>, L>,
1768    )
1769    where
1770        L: MutLayout,
1771    {
1772        let (left, right) = self.layout.split(axis, mid);
1773        let (left_offset_range, left_layout) = left;
1774        let (right_offset_range, right_layout) = right;
1775        let left_data = self.data.slice(left_offset_range.clone());
1776        let right_data = self.data.slice(right_offset_range.clone());
1777
1778        debug_assert_eq!(left_data.len(), left_layout.min_data_len());
1779        let left_view = TensorBase {
1780            data: left_data,
1781            layout: left_layout,
1782        };
1783
1784        debug_assert_eq!(right_data.len(), right_layout.min_data_len());
1785        let right_view = TensorBase {
1786            data: right_data,
1787            layout: right_layout,
1788        };
1789
1790        (left_view, right_view)
1791    }
1792
1793    /// Return a view of this tensor with elements stored in contiguous order.
1794    ///
1795    /// If the data is already contiguous, no copy is made, otherwise the
1796    /// elements are copied into a new buffer in contiguous order.
1797    pub fn to_contiguous(&self) -> Contiguous<TensorBase<CowData<'a, T>, L>>
1798    where
1799        T: Clone,
1800        L: MutLayout,
1801    {
1802        self.to_contiguous_in(GlobalAlloc::new())
1803    }
1804
1805    /// Variant of [`to_contiguous`](TensorBase::to_contiguous) which takes
1806    /// an allocator.
1807    pub fn to_contiguous_in<A: Alloc>(&self, alloc: A) -> Contiguous<TensorBase<CowData<'a, T>, L>>
1808    where
1809        T: Clone,
1810        L: MutLayout,
1811    {
1812        let tensor = if let Some(data) = self.data() {
1813            TensorBase {
1814                data: CowData::Borrowed(data.into_storage()),
1815                layout: self.layout.clone(),
1816            }
1817        } else {
1818            let data = self.to_vec_in(alloc);
1819            TensorBase {
1820                data: CowData::Owned(data),
1821                layout: L::from_shape(self.layout.shape()),
1822            }
1823        };
1824        Contiguous::new(tensor).unwrap()
1825    }
1826
1827    /// Return the underlying data as a flat slice if the tensor is contiguous,
1828    /// or a copy of the data as a flat slice otherwise.
1829    ///
1830    /// See [`AsView::to_slice`].
1831    pub fn to_slice(&self) -> Cow<'a, [T]>
1832    where
1833        T: Clone,
1834    {
1835        self.data()
1836            .map(Cow::Borrowed)
1837            .unwrap_or_else(|| Cow::Owned(self.to_vec()))
1838    }
1839
1840    /// Reverse the order of dimensions in this tensor. See [`AsView::transposed`].
1841    pub fn transposed(&self) -> TensorBase<ViewData<'a, T>, L>
1842    where
1843        L: MutLayout,
1844    {
1845        TensorBase {
1846            data: self.data,
1847            layout: self.layout.transposed(),
1848        }
1849    }
1850
1851    pub fn try_slice_dyn<R: IntoSliceItems>(
1852        &self,
1853        range: R,
1854    ) -> Result<TensorView<'a, T>, SliceError>
1855    where
1856        L: MutLayout,
1857    {
1858        let (offset_range, layout) = self.layout.slice_dyn(range.into_slice_items().as_ref())?;
1859        Ok(TensorBase {
1860            data: self.data.slice(offset_range),
1861            layout,
1862        })
1863    }
1864
1865    /// Return a read-only view of this tensor. See [`AsView::view`].
1866    pub fn view(&self) -> TensorBase<ViewData<'a, T>, L> {
1867        TensorBase {
1868            data: self.data,
1869            layout: self.layout.clone(),
1870        }
1871    }
1872
1873    pub(crate) fn view_ref(&self) -> TensorBase<ViewData<'a, T>, &L> {
1874        TensorBase {
1875            data: self.data,
1876            layout: &self.layout,
1877        }
1878    }
1879
1880    pub fn weakly_checked_view(&self) -> WeaklyCheckedView<ViewData<'a, T>, L> {
1881        WeaklyCheckedView { base: self.view() }
1882    }
1883}
1884
1885impl<S: Storage, L: Layout> Layout for TensorBase<S, L> {
1886    type Index<'a> = L::Index<'a>;
1887    type Indices = L::Indices;
1888
1889    fn ndim(&self) -> usize {
1890        self.layout.ndim()
1891    }
1892
1893    fn len(&self) -> usize {
1894        self.layout.len()
1895    }
1896
1897    fn is_empty(&self) -> bool {
1898        self.layout.is_empty()
1899    }
1900
1901    fn shape(&self) -> Self::Index<'_> {
1902        self.layout.shape()
1903    }
1904
1905    fn size(&self, dim: usize) -> usize {
1906        self.layout.size(dim)
1907    }
1908
1909    fn strides(&self) -> Self::Index<'_> {
1910        self.layout.strides()
1911    }
1912
1913    fn stride(&self, dim: usize) -> usize {
1914        self.layout.stride(dim)
1915    }
1916
1917    fn indices(&self) -> Self::Indices {
1918        self.layout.indices()
1919    }
1920
1921    fn offset(&self, index: Self::Index<'_>) -> Option<usize> {
1922        self.layout.offset(index)
1923    }
1924}
1925
1926impl<S: Storage, L: Layout + MatrixLayout> MatrixLayout for TensorBase<S, L> {
1927    fn rows(&self) -> usize {
1928        self.layout.rows()
1929    }
1930
1931    fn cols(&self) -> usize {
1932        self.layout.cols()
1933    }
1934
1935    fn row_stride(&self) -> usize {
1936        self.layout.row_stride()
1937    }
1938
1939    fn col_stride(&self) -> usize {
1940        self.layout.col_stride()
1941    }
1942}
1943
1944impl<T, S: Storage<Elem = T>, L: Layout + Clone> AsView for TensorBase<S, L> {
1945    type Elem = T;
1946    type Layout = L;
1947
1948    fn iter(&self) -> Iter<'_, T> {
1949        self.view().iter()
1950    }
1951
1952    fn copy_into_slice<'a>(&self, dest: &'a mut [MaybeUninit<T>]) -> &'a [T]
1953    where
1954        T: Copy,
1955    {
1956        if let Some(data) = self.data() {
1957            // Safety: `[T]` and `[MaybeUninit<T>]` have same layout.
1958            let src_uninit = unsafe { std::mem::transmute::<&[T], &[MaybeUninit<T>]>(data) };
1959            dest.copy_from_slice(src_uninit);
1960            // Safety: `copy_from_slice` initializes the whole slice or panics
1961            // if there is a length mismatch.
1962            unsafe { dest.assume_init() }
1963        } else {
1964            copy_into_slice(self.as_dyn(), dest)
1965        }
1966    }
1967
1968    fn data(&self) -> Option<&[Self::Elem]> {
1969        self.view().data()
1970    }
1971
1972    fn insert_axis(&mut self, index: usize)
1973    where
1974        L: ResizeLayout,
1975    {
1976        self.layout.insert_axis(index)
1977    }
1978
1979    #[track_caller]
1980    fn remove_axis(&mut self, index: usize)
1981    where
1982        L: ResizeLayout,
1983    {
1984        self.layout.remove_axis(index)
1985    }
1986
1987    fn merge_axes(&mut self)
1988    where
1989        L: ResizeLayout,
1990    {
1991        self.layout.merge_axes()
1992    }
1993
1994    fn layout(&self) -> &L {
1995        &self.layout
1996    }
1997
1998    fn map<F, U>(&self, f: F) -> TensorBase<Vec<U>, L>
1999    where
2000        F: Fn(&Self::Elem) -> U,
2001        L: MutLayout,
2002    {
2003        self.map_in(GlobalAlloc::new(), f)
2004    }
2005
2006    fn map_in<A: Alloc, F, U>(&self, alloc: A, f: F) -> TensorBase<Vec<U>, L>
2007    where
2008        F: Fn(&Self::Elem) -> U,
2009        L: MutLayout,
2010    {
2011        let len = self.len();
2012        let mut buf = alloc.alloc(len);
2013        if let Some(data) = self.data() {
2014            // Fast path for contiguous tensors.
2015            buf.extend(data.iter().map(f));
2016        } else {
2017            let dest = &mut buf.spare_capacity_mut()[..len];
2018            map_into_slice(self.as_dyn(), dest, f);
2019
2020            // Safety: `map_into` initialized all elements of `dest`.
2021            unsafe {
2022                buf.set_len(len);
2023            }
2024        };
2025        TensorBase::from_data(self.shape(), buf)
2026    }
2027
2028    fn move_axis(&mut self, from: usize, to: usize)
2029    where
2030        L: MutLayout,
2031    {
2032        self.layout.move_axis(from, to);
2033    }
2034
2035    fn view(&self) -> TensorBase<ViewData<'_, T>, L> {
2036        TensorBase {
2037            data: self.data.view(),
2038            layout: self.layout.clone(),
2039        }
2040    }
2041
2042    // For `get` and `get_unchecked` we override the default implementation in
2043    // the trait to skip view creation.
2044
2045    fn get<I: AsIndex<L>>(&self, index: I) -> Option<&Self::Elem> {
2046        self.offset(index.as_index()).map(|offset| unsafe {
2047            // Safety: We verified the offset is in-bounds
2048            self.data.get_unchecked(offset)
2049        })
2050    }
2051
2052    unsafe fn get_unchecked<I: AsIndex<L>>(&self, index: I) -> &T {
2053        let offset = self.layout.offset_unchecked(index.as_index());
2054        unsafe { self.data.get_unchecked(offset) }
2055    }
2056
2057    fn permute(&mut self, order: Self::Index<'_>)
2058    where
2059        L: MutLayout,
2060    {
2061        self.layout = self.layout.permuted(order);
2062    }
2063
2064    fn to_vec(&self) -> Vec<T>
2065    where
2066        T: Clone,
2067    {
2068        self.to_vec_in(GlobalAlloc::new())
2069    }
2070
2071    fn to_vec_in<A: Alloc>(&self, alloc: A) -> Vec<T>
2072    where
2073        T: Clone,
2074    {
2075        let len = self.len();
2076        let mut buf = alloc.alloc(len);
2077
2078        if let Some(data) = self.data() {
2079            buf.extend_from_slice(data);
2080        } else {
2081            copy_into_slice(self.as_dyn(), &mut buf.spare_capacity_mut()[..len]);
2082
2083            // Safety: We initialized `len` elements.
2084            unsafe { buf.set_len(len) }
2085        }
2086
2087        buf
2088    }
2089
2090    fn to_shape<SH: IntoLayout>(&self, shape: SH) -> TensorBase<Vec<Self::Elem>, SH::Layout>
2091    where
2092        T: Clone,
2093        L: MutLayout,
2094    {
2095        TensorBase {
2096            data: self.to_vec(),
2097            layout: self
2098                .layout
2099                .reshaped_for_copy(shape)
2100                .expect("reshape failed"),
2101        }
2102    }
2103
2104    fn transpose(&mut self)
2105    where
2106        L: MutLayout,
2107    {
2108        self.layout = self.layout.transposed();
2109    }
2110}
2111
2112impl<T, S: Storage<Elem = T>, const N: usize> TensorBase<S, NdLayout<N>> {
2113    /// Load an array of `M` elements from successive entries of a tensor along
2114    /// the `dim` axis.
2115    ///
2116    /// eg. If `base` is `[0, 1, 2]`, dim=0 and `M` = 4 this will return an
2117    /// array with values from indices `[0, 1, 2]`, `[1, 1, 2]` ... `[3, 1, 2]`.
2118    ///
2119    /// Panics if any of the array indices are out of bounds.
2120    #[inline]
2121    pub fn get_array<const M: usize>(&self, base: [usize; N], dim: usize) -> [T; M]
2122    where
2123        T: Copy + Default,
2124    {
2125        let offsets: [usize; M] = array_offsets(&self.layout, base, dim);
2126        let mut result = [T::default(); M];
2127        for i in 0..M {
2128            // Safety: `array_offsets` returns valid offsets
2129            result[i] = unsafe { *self.data.get_unchecked(offsets[i]) };
2130        }
2131        result
2132    }
2133}
2134
2135impl<T> TensorBase<Vec<T>, DynLayout> {
2136    /// Reshape this tensor in place. This is cheap if the tensor is contiguous,
2137    /// as only the layout will be changed, but requires copying data otherwise.
2138    #[track_caller]
2139    pub fn reshape(&mut self, shape: &[usize])
2140    where
2141        T: Clone,
2142    {
2143        self.reshape_in(GlobalAlloc::new(), shape)
2144    }
2145
2146    /// Variant of [`reshape`](TensorBase::reshape) which takes an allocator.
2147    #[track_caller]
2148    pub fn reshape_in<A: Alloc>(&mut self, alloc: A, shape: &[usize])
2149    where
2150        T: Clone,
2151    {
2152        if !self.is_contiguous() {
2153            self.data = self.to_vec_in(alloc);
2154        }
2155        let Ok(layout) = self.layout.reshaped_for_copy(shape) else {
2156            panic!(
2157                "element count mismatch reshaping {:?} to {:?}",
2158                self.shape(),
2159                shape
2160            );
2161        };
2162        self.layout = layout;
2163    }
2164}
2165
2166impl<'a, T, L: Layout> TensorBase<ViewMutData<'a, T>, L> {
2167    /// Divide this tensor into two mutable views along a given axis.
2168    ///
2169    /// Returns a `(left, right)` tuple of views, where the `left` view
2170    /// contains the slice from `[0, mid)` along `axis` and the `right`
2171    /// view contains the slice from `[mid, end)` along `axis`.
2172    #[allow(clippy::type_complexity)]
2173    pub fn split_at_mut(
2174        self,
2175        axis: usize,
2176        mid: usize,
2177    ) -> (
2178        TensorBase<ViewMutData<'a, T>, L>,
2179        TensorBase<ViewMutData<'a, T>, L>,
2180    )
2181    where
2182        L: MutLayout,
2183    {
2184        let (left, right) = self.layout.split(axis, mid);
2185        let (left_offset_range, left_layout) = left;
2186        let (right_offset_range, right_layout) = right;
2187        let (left_data, right_data) = self
2188            .data
2189            .split_mut(left_offset_range.clone(), right_offset_range.clone());
2190
2191        debug_assert_eq!(left_data.len(), left_layout.min_data_len());
2192        let left_view = TensorBase {
2193            data: left_data,
2194            layout: left_layout,
2195        };
2196
2197        debug_assert_eq!(right_data.len(), right_layout.min_data_len());
2198        let right_view = TensorBase {
2199            data: right_data,
2200            layout: right_layout,
2201        };
2202
2203        (left_view, right_view)
2204    }
2205
2206    /// Consume this view and return a mutable slice, if the tensor is
2207    /// contiguous.
2208    pub fn into_slice_mut(self) -> Option<&'a mut [T]> {
2209        let len = self.layout.min_data_len();
2210        self.is_contiguous().then(|| {
2211            // Safety: We verified that the slice is contiguous.
2212            let slice = unsafe { self.data.to_slice_mut() };
2213            &mut slice[..len]
2214        })
2215    }
2216}
2217
2218impl<T, L: MutLayout> FromIterator<T> for TensorBase<Vec<T>, L>
2219where
2220    [usize; 1]: AsIndex<L>,
2221{
2222    /// Create a new 1D tensor filled with an arithmetic sequence of values
2223    /// in the range `[start, end)` separated by `step`. If `step` is omitted,
2224    /// it defaults to 1.
2225    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> TensorBase<Vec<T>, L> {
2226        let data: Vec<T> = iter.into_iter().collect();
2227        TensorBase::from_data([data.len()].as_index(), data)
2228    }
2229}
2230
2231impl<T, L: MutLayout> From<Vec<T>> for TensorBase<Vec<T>, L>
2232where
2233    [usize; 1]: AsIndex<L>,
2234{
2235    /// Create a 1D tensor from a vector.
2236    fn from(vec: Vec<T>) -> Self {
2237        Self::from_data([vec.len()].as_index(), vec)
2238    }
2239}
2240
2241impl<'a, T, L: MutLayout> From<&'a [T]> for TensorBase<ViewData<'a, T>, L>
2242where
2243    [usize; 1]: AsIndex<L>,
2244{
2245    /// Create a 1D view from a slice.
2246    fn from(slice: &'a [T]) -> Self {
2247        Self::from_data([slice.len()].as_index(), slice)
2248    }
2249}
2250
2251impl<'a, T, L: MutLayout, const N: usize> From<&'a [T; N]> for TensorBase<ViewData<'a, T>, L>
2252where
2253    [usize; 1]: AsIndex<L>,
2254{
2255    /// Create a 1D view from a slice of known length.
2256    fn from(slice: &'a [T; N]) -> Self {
2257        Self::from_data([slice.len()].as_index(), slice.as_slice())
2258    }
2259}
2260
2261/// Return the offsets of `M` successive elements along the `dim` axis, starting
2262/// at index `base`.
2263///
2264/// Panics if any of the M element indices are out of bounds.
2265fn array_offsets<const N: usize, const M: usize>(
2266    layout: &NdLayout<N>,
2267    base: [usize; N],
2268    dim: usize,
2269) -> [usize; M] {
2270    assert!(
2271        base[dim] < usize::MAX - M && layout.size(dim) >= base[dim] + M,
2272        "array indices invalid"
2273    );
2274
2275    let offset = layout.must_offset(base);
2276    let stride = layout.stride(dim);
2277    let mut offsets = [0; M];
2278    for i in 0..M {
2279        offsets[i] = offset + i * stride;
2280    }
2281    offsets
2282}
2283
2284impl<T, S: StorageMut<Elem = T>, const N: usize> TensorBase<S, NdLayout<N>> {
2285    /// Store an array of `M` elements into successive entries of a tensor along
2286    /// the `dim` axis.
2287    ///
2288    /// See [`TensorBase::get_array`] for more details.
2289    #[inline]
2290    pub fn set_array<const M: usize>(&mut self, base: [usize; N], dim: usize, values: [T; M])
2291    where
2292        T: Copy,
2293    {
2294        let offsets: [usize; M] = array_offsets(&self.layout, base, dim);
2295
2296        for i in 0..M {
2297            // Safety: `array_offsets` returns valid offsets.
2298            unsafe { *self.data.get_unchecked_mut(offsets[i]) = values[i] };
2299        }
2300    }
2301}
2302
2303impl<T, S: Storage<Elem = T>> TensorBase<S, NdLayout<1>> {
2304    /// Convert this vector to a static array of length `M`.
2305    ///
2306    /// Panics if the length of this vector is not M.
2307    #[inline]
2308    pub fn to_array<const M: usize>(&self) -> [T; M]
2309    where
2310        T: Copy + Default,
2311    {
2312        self.get_array([0], 0)
2313    }
2314}
2315
2316impl<T, S: StorageMut<Elem = T>> TensorBase<S, NdLayout<1>> {
2317    /// Fill this vector with values from a static array of length `M`.
2318    ///
2319    /// Panics if the length of this vector is not M.
2320    #[inline]
2321    pub fn assign_array<const M: usize>(&mut self, values: [T; M])
2322    where
2323        T: Copy + Default,
2324    {
2325        self.set_array([0], 0, values)
2326    }
2327}
2328
2329/// View of a tensor with N dimensions.
2330pub type NdTensorView<'a, T, const N: usize> = TensorBase<ViewData<'a, T>, NdLayout<N>>;
2331
2332/// Owned tensor with N dimensions.
2333pub type NdTensor<T, const N: usize> = TensorBase<Vec<T>, NdLayout<N>>;
2334
2335/// Mutable view of a tensor with N dimensions.
2336pub type NdTensorViewMut<'a, T, const N: usize> = TensorBase<ViewMutData<'a, T>, NdLayout<N>>;
2337
2338/// Owned or borrowed tensor with N dimensions.
2339///
2340/// `CowNdTensor`s can be created using [`as_cow`](TensorBase::as_cow) (to
2341/// borrow) or [`into_cow`](TensorBase::into_cow).
2342///
2343/// The name comes from [`std::borrow::Cow`].
2344pub type CowNdTensor<'a, T, const N: usize> = TensorBase<CowData<'a, T>, NdLayout<N>>;
2345
2346/// View of a 2D tensor.
2347pub type Matrix<'a, T = f32> = NdTensorView<'a, T, 2>;
2348
2349/// Mutable view of a 2D tensor.
2350pub type MatrixMut<'a, T = f32> = NdTensorViewMut<'a, T, 2>;
2351
2352/// Owned tensor with a dynamic dimension count.
2353pub type Tensor<T = f32> = TensorBase<Vec<T>, DynLayout>;
2354
2355/// View of a tensor with a dynamic dimension count.
2356pub type TensorView<'a, T = f32> = TensorBase<ViewData<'a, T>, DynLayout>;
2357
2358/// Mutable view of a tensor with a dynamic dimension count.
2359pub type TensorViewMut<'a, T = f32> = TensorBase<ViewMutData<'a, T>, DynLayout>;
2360
2361/// Owned or borrowed tensor with a dynamic dimension count.
2362///
2363/// `CowTensor`s can be created using [`as_cow`](TensorBase::as_cow) (to
2364/// borrow) or [`into_cow`](TensorBase::into_cow).
2365///
2366/// The name comes from [`std::borrow::Cow`].
2367pub type CowTensor<'a, T> = TensorBase<CowData<'a, T>, DynLayout>;
2368
2369/// Reference-counted tensor with a dynamic dimension count.
2370///
2371/// This uses `Arc<Vec<T>>` rather than `Arc<[T]>` as the backing storage. This
2372/// adds an extra indirection when accessing the data, but it enables cheap
2373/// conversion between owned and reference-counted tensors.
2374pub type ArcTensor<T> = TensorBase<Arc<Vec<T>>, DynLayout>;
2375
2376/// Reference-counted tensor with N dimensions.
2377///
2378/// See also the notes for [`ArcTensor`].
2379pub type ArcNdTensor<T, const N: usize> = TensorBase<Arc<Vec<T>>, NdLayout<N>>;
2380
2381impl<T, S: Storage<Elem = T>, L: TrustedLayout, I: AsIndex<L>> Index<I> for TensorBase<S, L> {
2382    type Output = T;
2383
2384    /// Return the element at a given index.
2385    ///
2386    /// Panics if the index is out of bounds along any dimension.
2387    fn index(&self, index: I) -> &Self::Output {
2388        let offset = self.layout.must_offset(index.as_index());
2389
2390        // Safety: `TrustedLayout` guarantees offsets are < `min_data_len`.
2391        // TensorBase guarantees storage length is >= `min_data_len`.
2392        unsafe { self.data.get_unchecked(offset) }
2393    }
2394}
2395
2396impl<T, S: StorageMut<Elem = T>, L: TrustedLayout, I: AsIndex<L>> IndexMut<I> for TensorBase<S, L> {
2397    /// Return the element at a given index.
2398    ///
2399    /// Panics if the index is out of bounds along any dimension.
2400    fn index_mut(&mut self, index: I) -> &mut Self::Output {
2401        let index = index.as_index();
2402        let offset = self.layout.must_offset(index);
2403
2404        // Safety: `TrustedLayout` guarantees offsets are < `min_data_len`.
2405        // TensorBase guarantees storage length is >= `min_data_len`.
2406        unsafe { self.data.get_unchecked_mut(offset) }
2407    }
2408}
2409
2410impl<T, S: Storage<Elem = T> + Clone, L: Layout + Clone> Clone for TensorBase<S, L> {
2411    fn clone(&self) -> TensorBase<S, L> {
2412        let data = self.data.clone();
2413        TensorBase {
2414            data,
2415            layout: self.layout.clone(),
2416        }
2417    }
2418}
2419
2420impl<T, S: Storage<Elem = T> + Copy, L: Layout + Copy> Copy for TensorBase<S, L> {}
2421
2422impl<T: PartialEq, S: Storage<Elem = T>, L: Layout + Clone, V: AsView<Elem = T>> PartialEq<V>
2423    for TensorBase<S, L>
2424{
2425    fn eq(&self, other: &V) -> bool {
2426        self.shape().as_ref() == other.shape().as_ref() && self.iter().eq(other.iter())
2427    }
2428}
2429
2430impl<T, S: Storage<Elem = T>, const N: usize> From<TensorBase<S, NdLayout<N>>>
2431    for TensorBase<S, DynLayout>
2432{
2433    fn from(tensor: TensorBase<S, NdLayout<N>>) -> Self {
2434        Self {
2435            data: tensor.data,
2436            layout: tensor.layout.into(),
2437        }
2438    }
2439}
2440
2441impl<T, S1: Storage<Elem = T>, S2: Storage<Elem = T>, const N: usize>
2442    TryFrom<TensorBase<S1, DynLayout>> for TensorBase<S2, NdLayout<N>>
2443where
2444    S1: Into<S2>,
2445{
2446    type Error = DimensionError;
2447
2448    /// Convert a tensor or view with dynamic rank into a static rank one.
2449    ///
2450    /// Fails if `value` does not have `N` dimensions.
2451    fn try_from(value: TensorBase<S1, DynLayout>) -> Result<Self, Self::Error> {
2452        let layout: NdLayout<N> = value.layout().try_into()?;
2453        Ok(TensorBase {
2454            data: value.data.into(),
2455            layout,
2456        })
2457    }
2458}
2459
2460/// Trait for scalar (ie. non-array) values.
2461///
2462/// This is used to prevent generic types from being inferred as array types
2463/// in [`TensorBase::from`].
2464pub trait Scalar {}
2465
2466macro_rules! impl_scalar {
2467    ($ty:ty) => {
2468        impl Scalar for $ty {}
2469    };
2470}
2471impl_scalar!(bool);
2472impl_scalar!(u8);
2473impl_scalar!(i8);
2474impl_scalar!(u16);
2475impl_scalar!(i16);
2476impl_scalar!(u32);
2477impl_scalar!(i32);
2478impl_scalar!(u64);
2479impl_scalar!(i64);
2480impl_scalar!(usize);
2481impl_scalar!(isize);
2482impl_scalar!(f32);
2483impl_scalar!(f64);
2484impl_scalar!(String);
2485
2486// The `T: Scalar` bound avoids ambiguity when choosing a `Tensor::from`
2487// impl for a nested array literal, as it prevents `T` from matching an array
2488// type.
2489
2490impl<T: Clone + Scalar, L: MutLayout> From<T> for TensorBase<Vec<T>, L>
2491where
2492    [usize; 0]: AsIndex<L>,
2493{
2494    /// Construct a scalar tensor from a scalar value.
2495    fn from(value: T) -> Self {
2496        Self::from_scalar(value)
2497    }
2498}
2499
2500impl<T: Clone + Scalar, L: MutLayout, const D0: usize> From<[T; D0]> for TensorBase<Vec<T>, L>
2501where
2502    [usize; 1]: AsIndex<L>,
2503{
2504    /// Construct a 1D tensor from a 1D array.
2505    fn from(value: [T; D0]) -> Self {
2506        let data: Vec<T> = value.iter().cloned().collect();
2507        Self::from_data([D0].as_index(), data)
2508    }
2509}
2510
2511impl<T: Clone + Scalar, L: MutLayout, const D0: usize, const D1: usize> From<[[T; D1]; D0]>
2512    for TensorBase<Vec<T>, L>
2513where
2514    [usize; 2]: AsIndex<L>,
2515{
2516    /// Construct a 2D tensor from a nested array.
2517    fn from(value: [[T; D1]; D0]) -> Self {
2518        let data: Vec<_> = value.iter().flat_map(|y| y.iter()).cloned().collect();
2519        Self::from_data([D0, D1].as_index(), data)
2520    }
2521}
2522
2523impl<T: Clone + Scalar, L: MutLayout, const D0: usize, const D1: usize, const D2: usize>
2524    From<[[[T; D2]; D1]; D0]> for TensorBase<Vec<T>, L>
2525where
2526    [usize; 3]: AsIndex<L>,
2527{
2528    /// Construct a 3D tensor from a nested array.
2529    fn from(value: [[[T; D2]; D1]; D0]) -> Self {
2530        let data: Vec<_> = value
2531            .iter()
2532            .flat_map(|y| y.iter().flat_map(|z| z.iter()))
2533            .cloned()
2534            .collect();
2535        Self::from_data([D0, D1, D2].as_index(), data)
2536    }
2537}
2538
2539/// A view of a tensor which does "weak" checking when indexing via
2540/// `view[<index>]`. This means that it does not bounds-check individual
2541/// dimensions, but does bounds-check the computed offset.
2542///
2543/// This offers a middle-ground between regular indexing, which bounds-checks
2544/// each index element, and unchecked indexing, which does no bounds-checking
2545/// at all and is thus unsafe.
2546pub struct WeaklyCheckedView<S: Storage, L: Layout> {
2547    base: TensorBase<S, L>,
2548}
2549
2550impl<T, S: Storage<Elem = T>, L: Layout> Layout for WeaklyCheckedView<S, L> {
2551    type Index<'a> = L::Index<'a>;
2552    type Indices = L::Indices;
2553
2554    fn ndim(&self) -> usize {
2555        self.base.ndim()
2556    }
2557
2558    fn offset(&self, index: Self::Index<'_>) -> Option<usize> {
2559        self.base.offset(index)
2560    }
2561
2562    fn len(&self) -> usize {
2563        self.base.len()
2564    }
2565
2566    fn shape(&self) -> Self::Index<'_> {
2567        self.base.shape()
2568    }
2569
2570    fn strides(&self) -> Self::Index<'_> {
2571        self.base.strides()
2572    }
2573
2574    fn indices(&self) -> Self::Indices {
2575        self.base.indices()
2576    }
2577}
2578
2579impl<T, S: Storage<Elem = T>, L: Layout, I: AsIndex<L>> Index<I> for WeaklyCheckedView<S, L> {
2580    type Output = T;
2581    fn index(&self, index: I) -> &Self::Output {
2582        let offset = self.base.layout.offset_unchecked(index.as_index());
2583        unsafe {
2584            // Safety: See comments in [Storage] trait.
2585            self.base.data.get(offset).expect("invalid offset")
2586        }
2587    }
2588}
2589
2590impl<T, S: StorageMut<Elem = T>, L: Layout, I: AsIndex<L>> IndexMut<I> for WeaklyCheckedView<S, L> {
2591    fn index_mut(&mut self, index: I) -> &mut Self::Output {
2592        let offset = self.base.layout.offset_unchecked(index.as_index());
2593        unsafe {
2594            // Safety: See comments in [Storage] trait.
2595            self.base.data.get_mut(offset).expect("invalid offset")
2596        }
2597    }
2598}
2599
2600#[cfg(test)]
2601mod tests;