1use std::borrow::Cow;
2use std::fmt::Debug;
3use std::mem::MaybeUninit;
4use std::ops::{Index, IndexMut, Range};
5use std::sync::Arc;
6
7use crate::assume_init::AssumeInit;
8use crate::copy::{
9 copy_into, copy_into_slice, copy_into_uninit, copy_range_into_slice, map_into_slice,
10};
11use crate::errors::{DimensionError, ExpandError, FromDataError, ReshapeError, SliceError};
12use crate::iterators::{
13 AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, InnerIter, InnerIterMut, Iter, IterMut,
14 Lanes, LanesMut, for_each_mut,
15};
16use crate::layout::{
17 AsIndex, BroadcastLayout, DynLayout, IntoLayout, Layout, LayoutExt, MatrixLayout, MutLayout,
18 NdLayout, OverlapPolicy, RemoveDim, ResizeLayout, SliceWith, TrustedLayout,
19};
20use crate::overlap::may_have_internal_overlap;
21use crate::slice_range::{IntoSliceItems, SliceItem};
22use crate::storage::{
23 Alloc, CowData, GlobalAlloc, IntoStorage, Storage, StorageMut, ViewData, ViewMutData,
24};
25use crate::type_num::IndexCount;
26use crate::{Contiguous, RandomSource};
27
28pub struct TensorBase<S: Storage, L: Layout> {
38 data: S,
39
40 layout: L,
50}
51
52pub trait AsView: Layout {
70 type Elem;
72
73 type Layout: Clone + for<'a> Layout<Index<'a> = Self::Index<'a>>;
76
77 fn view(&self) -> TensorBase<ViewData<'_, Self::Elem>, Self::Layout>;
79
80 fn layout(&self) -> &Self::Layout;
82
83 fn as_cow(&self) -> TensorBase<CowData<'_, Self::Elem>, Self::Layout>
89 where
90 [Self::Elem]: ToOwned,
91 {
92 self.view().as_cow()
93 }
94
95 fn as_dyn(&self) -> TensorBase<ViewData<'_, Self::Elem>, DynLayout> {
97 self.view().as_dyn()
98 }
99
100 fn axis_chunks(&self, dim: usize, chunk_size: usize) -> AxisChunks<'_, Self::Elem, Self::Layout>
102 where
103 Self::Layout: MutLayout,
104 {
105 self.view().axis_chunks(dim, chunk_size)
106 }
107
108 fn axis_iter(&self, dim: usize) -> AxisIter<'_, Self::Elem, Self::Layout>
110 where
111 Self::Layout: MutLayout + RemoveDim,
112 {
113 self.view().axis_iter(dim)
114 }
115
116 fn broadcast<S: IntoLayout>(&self, shape: S) -> TensorBase<ViewData<'_, Self::Elem>, S::Layout>
122 where
123 Self::Layout: BroadcastLayout<S::Layout>,
124 {
125 self.view().broadcast(shape)
126 }
127
128 fn try_broadcast<S: IntoLayout>(
130 &self,
131 shape: S,
132 ) -> Result<TensorBase<ViewData<'_, Self::Elem>, S::Layout>, ExpandError>
133 where
134 Self::Layout: BroadcastLayout<S::Layout>,
135 {
136 self.view().try_broadcast(shape)
137 }
138
139 fn copy_into_slice<'a>(&self, dest: &'a mut [MaybeUninit<Self::Elem>]) -> &'a [Self::Elem]
144 where
145 Self::Elem: Copy;
146
147 fn data(&self) -> Option<&[Self::Elem]>;
149
150 fn get<I: AsIndex<Self::Layout>>(&self, index: I) -> Option<&Self::Elem>
153 where
154 Self::Layout: TrustedLayout,
155 {
156 self.view().get(index)
157 }
158
159 unsafe fn get_unchecked<I: AsIndex<Self::Layout>>(&self, index: I) -> &Self::Elem {
166 let view = self.view();
167 unsafe { view.get_unchecked(index) }
168 }
169
170 fn index_axis(
176 &self,
177 axis: usize,
178 index: usize,
179 ) -> TensorBase<ViewData<'_, Self::Elem>, <Self::Layout as RemoveDim>::Output>
180 where
181 Self::Layout: MutLayout + RemoveDim,
182 {
183 self.view().index_axis(axis, index)
184 }
185
186 fn inner_iter<const N: usize>(&self) -> InnerIter<'_, Self::Elem, NdLayout<N>> {
188 self.view().inner_iter()
189 }
190
191 fn inner_iter_dyn(&self, n: usize) -> InnerIter<'_, Self::Elem, DynLayout> {
195 self.view().inner_iter_dyn(n)
196 }
197
198 fn insert_axis(&mut self, index: usize)
200 where
201 Self::Layout: ResizeLayout;
202
203 fn remove_axis(&mut self, index: usize)
208 where
209 Self::Layout: ResizeLayout;
210
211 fn item(&self) -> Option<&Self::Elem> {
213 self.view().item()
214 }
215
216 fn iter(&self) -> Iter<'_, Self::Elem>;
218
219 fn lanes(&self, dim: usize) -> Lanes<'_, Self::Elem>
221 where
222 Self::Layout: RemoveDim,
223 {
224 self.view().lanes(dim)
225 }
226
227 fn map<F, U>(&self, f: F) -> TensorBase<Vec<U>, Self::Layout>
230 where
231 F: Fn(&Self::Elem) -> U,
232 Self::Layout: MutLayout,
233 {
234 self.view().map(f)
235 }
236
237 fn map_in<A: Alloc, F, U>(&self, alloc: A, f: F) -> TensorBase<Vec<U>, Self::Layout>
239 where
240 F: Fn(&Self::Elem) -> U,
241 Self::Layout: MutLayout,
242 {
243 self.view().map_in(alloc, f)
244 }
245
246 fn merge_axes(&mut self)
252 where
253 Self::Layout: ResizeLayout;
254
255 fn move_axis(&mut self, from: usize, to: usize)
260 where
261 Self::Layout: MutLayout;
262
263 fn nd_view<const N: usize>(&self) -> TensorBase<ViewData<'_, Self::Elem>, NdLayout<N>> {
268 self.view().nd_view()
269 }
270
271 fn permute(&mut self, order: Self::Index<'_>)
273 where
274 Self::Layout: MutLayout;
275
276 fn permuted(&self, order: Self::Index<'_>) -> TensorBase<ViewData<'_, Self::Elem>, Self::Layout>
278 where
279 Self::Layout: MutLayout,
280 {
281 self.view().permuted(order)
282 }
283
284 fn reshaped<S: Copy + IntoLayout>(
299 &self,
300 shape: S,
301 ) -> TensorBase<CowData<'_, Self::Elem>, S::Layout>
302 where
303 Self::Elem: Clone,
304 Self::Layout: MutLayout,
305 {
306 self.view().reshaped(shape)
307 }
308
309 fn reshaped_in<A: Alloc, S: Copy + IntoLayout>(
312 &self,
313 alloc: A,
314 shape: S,
315 ) -> TensorBase<CowData<'_, Self::Elem>, S::Layout>
316 where
317 Self::Elem: Clone,
318 Self::Layout: MutLayout,
319 {
320 self.view().reshaped_in(alloc, shape)
321 }
322
323 fn transpose(&mut self)
325 where
326 Self::Layout: MutLayout;
327
328 fn transposed(&self) -> TensorBase<ViewData<'_, Self::Elem>, Self::Layout>
330 where
331 Self::Layout: MutLayout,
332 {
333 self.view().transposed()
334 }
335
336 #[allow(clippy::type_complexity)]
352 fn slice<R: IntoSliceItems + IndexCount>(
353 &self,
354 range: R,
355 ) -> TensorBase<ViewData<'_, Self::Elem>, <Self::Layout as SliceWith<R, R::Count>>::Layout>
356 where
357 Self::Layout: SliceWith<R, R::Count>,
358 {
359 self.view().slice(range)
360 }
361
362 fn slice_axis(
364 &self,
365 axis: usize,
366 range: Range<usize>,
367 ) -> TensorBase<ViewData<'_, Self::Elem>, Self::Layout>
368 where
369 Self::Layout: MutLayout,
370 {
371 self.view().slice_axis(axis, range)
372 }
373
374 #[allow(clippy::type_complexity)]
377 fn try_slice<R: IntoSliceItems + IndexCount>(
378 &self,
379 range: R,
380 ) -> Result<
381 TensorBase<ViewData<'_, Self::Elem>, <Self::Layout as SliceWith<R, R::Count>>::Layout>,
382 SliceError,
383 >
384 where
385 Self::Layout: SliceWith<R, R::Count>,
386 {
387 self.view().try_slice(range)
388 }
389
390 #[allow(clippy::type_complexity)]
395 fn slice_copy<R: Clone + IntoSliceItems + IndexCount>(
396 &self,
397 range: R,
398 ) -> TensorBase<Vec<Self::Elem>, <Self::Layout as SliceWith<R, R::Count>>::Layout>
399 where
400 Self::Elem: Clone,
401 Self::Layout: SliceWith<
402 R,
403 R::Count,
404 Layout: for<'a> Layout<Index<'a>: TryFrom<&'a [usize], Error: Debug>>,
405 >,
406 {
407 self.slice_copy_in(GlobalAlloc::new(), range)
408 }
409
410 #[allow(clippy::type_complexity)]
412 fn slice_copy_in<A: Alloc, R: Clone + IntoSliceItems + IndexCount>(
413 &self,
414 pool: A,
415 range: R,
416 ) -> TensorBase<Vec<Self::Elem>, <Self::Layout as SliceWith<R, R::Count>>::Layout>
417 where
418 Self::Elem: Clone,
419 Self::Layout: SliceWith<
420 R,
421 R::Count,
422 Layout: for<'a> Layout<Index<'a>: TryFrom<&'a [usize], Error: Debug>>,
423 >,
424 {
425 if let Ok(slice_view) = self.try_slice(range.clone()) {
430 return slice_view.to_tensor_in(pool);
431 }
432
433 let items = range.into_slice_items();
434 let sliced_shape: Vec<_> = items
435 .as_ref()
436 .iter()
437 .copied()
438 .enumerate()
439 .filter_map(|(dim, item)| match item {
440 SliceItem::Index(_) => None,
441 SliceItem::Range(range) => Some(range.index_range(self.size(dim)).steps()),
442 })
443 .collect();
444 let sliced_len = sliced_shape.iter().product();
445 let mut sliced_data = pool.alloc(sliced_len);
446
447 copy_range_into_slice(
448 self.as_dyn(),
449 &mut sliced_data.spare_capacity_mut()[..sliced_len],
450 items.as_ref(),
451 );
452
453 unsafe {
455 sliced_data.set_len(sliced_len);
456 }
457
458 let sliced_shape = sliced_shape.as_slice().try_into().expect("slice failed");
459
460 TensorBase::from_data(sliced_shape, sliced_data)
461 }
462
463 fn squeezed(&self) -> TensorView<'_, Self::Elem>
465 where
466 Self::Layout: MutLayout,
467 {
468 self.view().squeezed()
469 }
470
471 fn to_vec(&self) -> Vec<Self::Elem>
474 where
475 Self::Elem: Clone;
476
477 fn to_vec_in<A: Alloc>(&self, alloc: A) -> Vec<Self::Elem>
479 where
480 Self::Elem: Clone;
481
482 fn to_contiguous(&self) -> Contiguous<TensorBase<CowData<'_, Self::Elem>, Self::Layout>>
491 where
492 Self::Elem: Clone,
493 Self::Layout: MutLayout,
494 {
495 self.view().to_contiguous()
496 }
497
498 fn to_contiguous_in<A: Alloc>(
501 &self,
502 alloc: A,
503 ) -> Contiguous<TensorBase<CowData<'_, Self::Elem>, Self::Layout>>
504 where
505 Self::Elem: Clone,
506 Self::Layout: MutLayout,
507 {
508 self.view().to_contiguous_in(alloc)
509 }
510
511 fn to_shape<S: IntoLayout>(&self, shape: S) -> TensorBase<Vec<Self::Elem>, S::Layout>
513 where
514 Self::Elem: Clone,
515 Self::Layout: MutLayout;
516
517 fn to_slice(&self) -> Cow<'_, [Self::Elem]>
524 where
525 Self::Elem: Clone,
526 {
527 self.view().to_slice()
528 }
529
530 fn to_tensor(&self) -> TensorBase<Vec<Self::Elem>, Self::Layout>
532 where
533 Self::Elem: Clone,
534 Self::Layout: MutLayout,
535 {
536 self.to_tensor_in(GlobalAlloc::new())
537 }
538
539 fn to_tensor_in<A: Alloc>(&self, alloc: A) -> TensorBase<Vec<Self::Elem>, Self::Layout>
541 where
542 Self::Elem: Clone,
543 Self::Layout: MutLayout,
544 {
545 TensorBase::from_data(self.layout().shape(), self.to_vec_in(alloc))
546 }
547
548 fn weakly_checked_view(&self) -> WeaklyCheckedView<ViewData<'_, Self::Elem>, Self::Layout> {
551 self.view().weakly_checked_view()
552 }
553}
554
555impl<S: Storage, L: Layout> TensorBase<S, L> {
556 #[track_caller]
560 pub fn from_data<D: IntoStorage<Output = S>>(shape: L::Index<'_>, data: D) -> TensorBase<S, L>
561 where
562 for<'a> L::Index<'a>: Clone,
563 L: MutLayout,
564 {
565 let data = data.into_storage();
566 let len = data.len();
567 match Self::try_from_data(shape.clone(), data) {
568 Ok(data) => data,
569 Err(_) => panic!(
570 "data length {} does not match shape {:?}",
571 len,
572 shape.as_ref()
573 ),
574 }
575 }
576
577 pub fn try_from_data<D: IntoStorage<Output = S>>(
581 shape: L::Index<'_>,
582 data: D,
583 ) -> Result<TensorBase<S, L>, FromDataError>
584 where
585 L: MutLayout,
586 {
587 let data = data.into_storage();
588 let layout = L::from_shape(shape);
589 if layout.min_data_len() != data.len() {
590 return Err(FromDataError::StorageLengthMismatch);
591 }
592 Ok(TensorBase { data, layout })
593 }
594
595 pub fn from_storage_and_layout(data: S, layout: L) -> TensorBase<S, L> {
600 assert!(data.len() >= layout.min_data_len());
601 assert!(
602 !S::MUTABLE
603 || !may_have_internal_overlap(layout.shape().as_ref(), layout.strides().as_ref())
604 );
605 TensorBase { data, layout }
606 }
607
608 pub(crate) unsafe fn from_storage_and_layout_unchecked(data: S, layout: L) -> TensorBase<S, L> {
616 debug_assert!(data.len() >= layout.min_data_len());
617 debug_assert!(
618 !S::MUTABLE
619 || !may_have_internal_overlap(layout.shape().as_ref(), layout.strides().as_ref())
620 );
621 TensorBase { data, layout }
622 }
623
624 pub fn from_data_with_strides<D: IntoStorage<Output = S>>(
632 shape: L::Index<'_>,
633 data: D,
634 strides: L::Index<'_>,
635 ) -> Result<TensorBase<S, L>, FromDataError>
636 where
637 L: MutLayout,
638 {
639 let layout = L::from_shape_and_strides(shape, strides, OverlapPolicy::DisallowOverlap)?;
640 let data = data.into_storage();
641 if layout.min_data_len() > data.len() {
642 return Err(FromDataError::StorageTooShort);
643 }
644 Ok(TensorBase { data, layout })
645 }
646
647 pub fn into_dyn(self) -> TensorBase<S, DynLayout>
650 where
651 L: Into<DynLayout>,
652 {
653 TensorBase {
654 data: self.data,
655 layout: self.layout.into(),
656 }
657 }
658
659 pub(crate) fn into_storage(self) -> S {
663 self.data
664 }
665
666 fn nd_layout<const N: usize>(&self) -> Option<NdLayout<N>> {
669 if self.ndim() != N {
670 return None;
671 }
672 let shape: [usize; N] = std::array::from_fn(|i| self.size(i));
673 let strides: [usize; N] = std::array::from_fn(|i| self.stride(i));
674 let layout = NdLayout::from_shape_and_strides(shape, strides, OverlapPolicy::AllowOverlap)
675 .expect("invalid layout");
676 Some(layout)
677 }
678
679 pub fn data_ptr(&self) -> *const S::Elem {
681 self.data.as_ptr()
682 }
683}
684
685impl<S: StorageMut, L: Clone + Layout> TensorBase<S, L> {
686 pub fn axis_iter_mut(&mut self, dim: usize) -> AxisIterMut<'_, S::Elem, L>
689 where
690 L: RemoveDim,
691 {
692 AxisIterMut::new(self.view_mut(), dim)
693 }
694
695 pub fn axis_chunks_mut(
699 &mut self,
700 dim: usize,
701 chunk_size: usize,
702 ) -> AxisChunksMut<'_, S::Elem, L>
703 where
704 L: MutLayout,
705 {
706 AxisChunksMut::new(self.view_mut(), dim, chunk_size)
707 }
708
709 pub fn apply<F: Fn(&S::Elem) -> S::Elem>(&mut self, f: F) {
712 if let Some(data) = self.data_mut() {
713 data.iter_mut().for_each(|x| *x = f(x));
715 } else {
716 for_each_mut(self.as_dyn_mut(), |x| *x = f(x));
717 }
718 }
719
720 pub fn as_dyn_mut(&mut self) -> TensorBase<ViewMutData<'_, S::Elem>, DynLayout> {
722 TensorBase {
723 layout: DynLayout::from(&self.layout),
724 data: self.data.view_mut(),
725 }
726 }
727
728 pub fn copy_from<S2: Storage<Elem = S::Elem>>(&mut self, other: &TensorBase<S2, L>)
732 where
733 S::Elem: Clone,
734 L: Clone,
735 {
736 assert!(
737 self.shape() == other.shape(),
738 "copy dest shape {:?} != src shape {:?}",
739 self.shape(),
740 other.shape()
741 );
742
743 if let Some(dest) = self.data_mut() {
744 if let Some(src) = other.data() {
745 dest.clone_from_slice(src);
746 } else {
747 let uninit_dest: &mut [MaybeUninit<S::Elem>] = unsafe { std::mem::transmute(dest) };
750 for x in &mut *uninit_dest {
751 unsafe { x.assume_init_drop() }
754 }
755
756 copy_into_slice(other.as_dyn(), uninit_dest);
758 }
759 } else {
760 copy_into(other.as_dyn(), self.as_dyn_mut());
761 }
762 }
763
764 pub fn data_mut(&mut self) -> Option<&mut [S::Elem]> {
766 let len = self.layout.min_data_len();
769 let data = self.data.slice_mut(0..len);
770
771 self.layout.is_contiguous().then(|| unsafe {
772 data.to_slice_mut()
774 })
775 }
776
777 pub fn index_axis_mut(
783 &mut self,
784 axis: usize,
785 index: usize,
786 ) -> TensorBase<ViewMutData<'_, S::Elem>, <L as RemoveDim>::Output>
787 where
788 L: MutLayout + RemoveDim,
789 {
790 let (offsets, layout) = self.layout.index_axis(axis, index);
791 TensorBase {
792 data: self.data.slice_mut(offsets),
793 layout,
794 }
795 }
796
797 pub fn storage_mut(&mut self) -> ViewMutData<'_, S::Elem> {
799 self.data.view_mut()
800 }
801
802 pub fn fill(&mut self, value: S::Elem)
804 where
805 S::Elem: Clone,
806 {
807 self.apply(|_| value.clone())
808 }
809
810 pub fn get_mut<I: AsIndex<L>>(&mut self, index: I) -> Option<&mut S::Elem>
813 where
814 L: TrustedLayout,
815 {
816 self.offset(index.as_index()).map(|offset| unsafe {
817 self.data.get_unchecked_mut(offset)
819 })
820 }
821
822 pub unsafe fn get_unchecked_mut<I: AsIndex<L>>(&mut self, index: I) -> &mut S::Elem {
829 let offset = self.layout.offset_unchecked(index.as_index());
830 unsafe { self.data.get_unchecked_mut(offset) }
831 }
832
833 pub(crate) fn mut_view_ref(&mut self) -> TensorBase<ViewMutData<'_, S::Elem>, &L> {
834 TensorBase {
835 data: self.data.view_mut(),
836 layout: &self.layout,
837 }
838 }
839
840 pub fn inner_iter_mut<const N: usize>(&mut self) -> InnerIterMut<'_, S::Elem, NdLayout<N>>
842 where
843 L: MutLayout,
844 {
845 InnerIterMut::new(self.view_mut())
846 }
847
848 pub fn inner_iter_dyn_mut(&mut self, n: usize) -> InnerIterMut<'_, S::Elem, DynLayout>
853 where
854 L: MutLayout,
855 {
856 InnerIterMut::new_dyn(self.view_mut(), n)
857 }
858
859 pub fn iter_mut(&mut self) -> IterMut<'_, S::Elem> {
862 IterMut::new(self.mut_view_ref())
863 }
864
865 pub fn lanes_mut(&mut self, dim: usize) -> LanesMut<'_, S::Elem>
868 where
869 L: RemoveDim,
870 {
871 LanesMut::new(self.mut_view_ref(), dim)
872 }
873
874 pub fn nd_view_mut<const N: usize>(
878 &mut self,
879 ) -> TensorBase<ViewMutData<'_, S::Elem>, NdLayout<N>> {
880 assert!(self.ndim() == N, "ndim {} != {}", self.ndim(), N);
881 TensorBase {
882 layout: self.nd_layout().unwrap(),
883 data: self.data.view_mut(),
884 }
885 }
886
887 pub fn permuted_mut(&mut self, order: L::Index<'_>) -> TensorBase<ViewMutData<'_, S::Elem>, L>
891 where
892 L: MutLayout,
893 {
894 TensorBase {
895 layout: self.layout.permuted(order),
896 data: self.data.view_mut(),
897 }
898 }
899
900 pub fn reshaped_mut<SH: IntoLayout>(
906 &mut self,
907 shape: SH,
908 ) -> Result<TensorBase<ViewMutData<'_, S::Elem>, SH::Layout>, ReshapeError>
909 where
910 L: MutLayout,
911 {
912 let layout = self.layout.reshaped_for_view(shape)?;
913 Ok(TensorBase {
914 layout,
915 data: self.data.view_mut(),
916 })
917 }
918
919 pub fn slice_axis_mut(
921 &mut self,
922 axis: usize,
923 range: Range<usize>,
924 ) -> TensorBase<ViewMutData<'_, S::Elem>, L>
925 where
926 L: MutLayout,
927 {
928 let (offset_range, sliced_layout) = self.layout.slice_axis(axis, range.clone()).unwrap();
929 debug_assert_eq!(sliced_layout.size(axis), range.len());
930 TensorBase {
931 data: self.data.slice_mut(offset_range),
932 layout: sliced_layout,
933 }
934 }
935
936 pub fn slice_mut<R: IntoSliceItems + IndexCount>(
941 &mut self,
942 range: R,
943 ) -> TensorBase<ViewMutData<'_, S::Elem>, <L as SliceWith<R, R::Count>>::Layout>
944 where
945 L: SliceWith<R, R::Count>,
946 {
947 self.try_slice_mut(range).expect("slice failed")
948 }
949
950 #[allow(clippy::type_complexity)]
953 pub fn try_slice_mut<R: IntoSliceItems + IndexCount>(
954 &mut self,
955 range: R,
956 ) -> Result<
957 TensorBase<ViewMutData<'_, S::Elem>, <L as SliceWith<R, R::Count>>::Layout>,
958 SliceError,
959 >
960 where
961 L: SliceWith<R, R::Count>,
962 {
963 let (offset_range, sliced_layout) = self.layout.slice_with(range)?;
964 Ok(TensorBase {
965 data: self.data.slice_mut(offset_range),
966 layout: sliced_layout,
967 })
968 }
969
970 pub fn view_mut(&mut self) -> TensorBase<ViewMutData<'_, S::Elem>, L>
972 where
973 L: Clone,
974 {
975 TensorBase {
976 data: self.data.view_mut(),
977 layout: self.layout.clone(),
978 }
979 }
980
981 pub fn weakly_checked_view_mut(&mut self) -> WeaklyCheckedView<ViewMutData<'_, S::Elem>, L> {
984 WeaklyCheckedView {
985 base: self.view_mut(),
986 }
987 }
988}
989
990impl<T, L: Clone + Layout> TensorBase<Vec<T>, L> {
991 pub fn arange(start: T, end: T, step: Option<T>) -> TensorBase<Vec<T>, L>
995 where
996 T: Copy + PartialOrd + From<bool> + std::ops::Add<Output = T>,
997 [usize; 1]: AsIndex<L>,
998 L: MutLayout,
999 {
1000 let step = step.unwrap_or((true).into());
1001 let mut data = Vec::new();
1002 let mut curr = start;
1003 while curr < end {
1004 data.push(curr);
1005 curr = curr + step;
1006 }
1007 TensorBase::from_data([data.len()].as_index(), data)
1008 }
1009
1010 pub fn append<S2: Storage<Elem = T>>(
1016 &mut self,
1017 axis: usize,
1018 other: &TensorBase<S2, L>,
1019 ) -> Result<(), ExpandError>
1020 where
1021 T: Copy,
1022 L: MutLayout,
1023 {
1024 let shape_match = self.ndim() == other.ndim()
1025 && (0..self.ndim()).all(|d| d == axis || self.size(d) == other.size(d));
1026 if !shape_match {
1027 return Err(ExpandError::ShapeMismatch);
1028 }
1029
1030 let old_size = self.size(axis);
1031 let new_size = self.size(axis) + other.size(axis);
1032
1033 let Some(new_layout) = self.expanded_layout(axis, new_size) else {
1034 return Err(ExpandError::InsufficientCapacity);
1035 };
1036
1037 let new_data_len = new_layout.min_data_len();
1038 self.layout = new_layout;
1039
1040 assert!(self.data.capacity() >= new_data_len);
1043 unsafe {
1044 self.data.set_len(new_data_len);
1045 }
1046
1047 self.slice_axis_mut(axis, old_size..new_size)
1048 .copy_from(other);
1049
1050 Ok(())
1051 }
1052
1053 pub fn from_vec(vec: Vec<T>) -> TensorBase<Vec<T>, L>
1055 where
1056 [usize; 1]: AsIndex<L>,
1057 L: MutLayout,
1058 {
1059 TensorBase::from_data([vec.len()].as_index(), vec)
1060 }
1061
1062 pub fn clip_dim(&mut self, dim: usize, range: Range<usize>)
1068 where
1069 T: Copy,
1070 L: MutLayout,
1071 {
1072 let (start, end) = (range.start, range.end);
1073
1074 assert!(start <= end, "start must be <= end");
1075 assert!(end <= self.size(dim), "end must be <= dim size");
1076
1077 self.layout.resize_dim(dim, end - start);
1078
1079 let range = if self.is_empty() {
1080 0..0
1081 } else {
1082 let start_offset = start * self.layout.stride(dim);
1083 let end_offset = start_offset + self.layout.min_data_len();
1084 start_offset..end_offset
1085 };
1086 self.data.copy_within(range.clone(), 0);
1087 self.data.truncate(range.end - range.start);
1088 }
1089
1090 pub fn has_capacity(&self, axis: usize, new_size: usize) -> bool
1093 where
1094 L: MutLayout,
1095 {
1096 self.expanded_layout(axis, new_size).is_some()
1097 }
1098
1099 fn expanded_layout(&self, axis: usize, new_size: usize) -> Option<L>
1104 where
1105 L: MutLayout,
1106 {
1107 let mut new_layout = self.layout.clone();
1108 new_layout.resize_dim(axis, new_size);
1109 let new_data_len = new_layout.min_data_len();
1110
1111 let has_capacity = new_data_len <= self.data.capacity()
1112 && !may_have_internal_overlap(
1113 new_layout.shape().as_ref(),
1114 new_layout.strides().as_ref(),
1115 );
1116
1117 has_capacity.then_some(new_layout)
1118 }
1119
1120 pub fn into_cow(self) -> TensorBase<CowData<'static, T>, L> {
1125 let TensorBase { data, layout } = self;
1126 TensorBase {
1127 layout,
1128 data: CowData::Owned(data),
1129 }
1130 }
1131
1132 pub fn into_arc(self) -> TensorBase<Arc<Vec<T>>, L> {
1137 let TensorBase { data, layout } = self;
1138 TensorBase {
1139 layout,
1140 data: Arc::new(data),
1141 }
1142 }
1143
1144 pub fn into_data(self) -> Vec<T>
1148 where
1149 T: Clone,
1150 {
1151 if self.is_contiguous() {
1152 self.into_non_contiguous_data()
1153 } else {
1154 self.to_vec()
1155 }
1156 }
1157
1158 pub fn into_non_contiguous_data(mut self) -> Vec<T> {
1161 self.data.truncate(self.layout.min_data_len());
1162 self.data
1163 }
1164
1165 #[track_caller]
1169 pub fn into_shape<S: Copy + IntoLayout>(self, shape: S) -> TensorBase<Vec<T>, S::Layout>
1170 where
1171 T: Clone,
1172 L: MutLayout,
1173 {
1174 let Ok(layout) = self.layout.reshaped_for_copy(shape) else {
1175 panic!(
1176 "element count mismatch reshaping {:?} to {:?}",
1177 self.shape(),
1178 shape
1179 );
1180 };
1181 TensorBase {
1182 layout,
1183 data: self.into_data(),
1184 }
1185 }
1186
1187 pub fn from_fn<F: FnMut(L::Index<'_>) -> T, Idx>(
1194 shape: L::Index<'_>,
1195 mut f: F,
1196 ) -> TensorBase<Vec<T>, L>
1197 where
1198 L::Indices: Iterator<Item = Idx>,
1199 Idx: AsIndex<L>,
1200 L: MutLayout,
1201 {
1202 let layout = L::from_shape(shape);
1203 let data: Vec<T> = layout.indices().map(|idx| f(idx.as_index())).collect();
1204 TensorBase { data, layout }
1205 }
1206
1207 pub fn from_simple_fn<F: FnMut() -> T>(shape: L::Index<'_>, f: F) -> TensorBase<Vec<T>, L>
1210 where
1211 L: MutLayout,
1212 {
1213 Self::from_simple_fn_in(GlobalAlloc::new(), shape, f)
1214 }
1215
1216 pub fn from_simple_fn_in<A: Alloc, F: FnMut() -> T>(
1219 alloc: A,
1220 shape: L::Index<'_>,
1221 mut f: F,
1222 ) -> TensorBase<Vec<T>, L>
1223 where
1224 L: MutLayout,
1225 {
1226 let len = shape.as_ref().iter().product();
1227 let mut data = alloc.alloc(len);
1228 data.extend(std::iter::from_fn(|| Some(f())).take(len));
1229 TensorBase::from_data(shape, data)
1230 }
1231
1232 pub fn from_scalar(value: T) -> TensorBase<Vec<T>, L>
1234 where
1235 [usize; 0]: AsIndex<L>,
1236 L: MutLayout,
1237 {
1238 TensorBase::from_data([].as_index(), vec![value])
1239 }
1240
1241 pub fn full(shape: L::Index<'_>, value: T) -> TensorBase<Vec<T>, L>
1243 where
1244 T: Clone,
1245 L: MutLayout,
1246 {
1247 Self::full_in(GlobalAlloc::new(), shape, value)
1248 }
1249
1250 pub fn full_in<A: Alloc>(alloc: A, shape: L::Index<'_>, value: T) -> TensorBase<Vec<T>, L>
1252 where
1253 T: Clone,
1254 L: MutLayout,
1255 {
1256 let len = shape.as_ref().iter().product();
1257 let mut data = alloc.alloc(len);
1258 data.resize(len, value);
1259 TensorBase::from_data(shape, data)
1260 }
1261
1262 pub fn make_contiguous(&mut self)
1269 where
1270 T: Clone,
1271 L: MutLayout,
1272 {
1273 if self.is_contiguous() {
1274 return;
1275 }
1276 self.data = self.to_vec();
1277 self.layout = L::from_shape(self.layout.shape());
1278 }
1279
1280 pub fn rand<R: RandomSource<T>>(shape: L::Index<'_>, rand_src: &mut R) -> TensorBase<Vec<T>, L>
1286 where
1287 L: MutLayout,
1288 {
1289 Self::from_simple_fn(shape, || rand_src.next())
1290 }
1291
1292 pub fn zeros(shape: L::Index<'_>) -> TensorBase<Vec<T>, L>
1295 where
1296 T: Clone + Default,
1297 L: MutLayout,
1298 {
1299 Self::zeros_in(GlobalAlloc::new(), shape)
1300 }
1301
1302 pub fn zeros_in<A: Alloc>(alloc: A, shape: L::Index<'_>) -> TensorBase<Vec<T>, L>
1304 where
1305 T: Clone + Default,
1306 L: MutLayout,
1307 {
1308 Self::full_in(alloc, shape, T::default())
1311 }
1312
1313 pub fn uninit(shape: L::Index<'_>) -> TensorBase<Vec<MaybeUninit<T>>, L>
1319 where
1320 MaybeUninit<T>: Clone,
1321 L: MutLayout,
1322 {
1323 Self::uninit_in(GlobalAlloc::new(), shape)
1324 }
1325
1326 pub fn uninit_in<A: Alloc>(alloc: A, shape: L::Index<'_>) -> TensorBase<Vec<MaybeUninit<T>>, L>
1328 where
1329 L: MutLayout,
1330 {
1331 let len = shape.as_ref().iter().product();
1332 let mut data = alloc.alloc(len);
1333
1334 unsafe { data.set_len(len) }
1337
1338 TensorBase::from_data(shape, data)
1339 }
1340
1341 pub fn with_capacity(shape: L::Index<'_>, expand_dim: usize) -> TensorBase<Vec<T>, L>
1348 where
1349 T: Copy,
1350 L: MutLayout,
1351 {
1352 Self::with_capacity_in(GlobalAlloc::new(), shape, expand_dim)
1353 }
1354
1355 pub fn with_capacity_in<A: Alloc>(
1357 alloc: A,
1358 shape: L::Index<'_>,
1359 expand_dim: usize,
1360 ) -> TensorBase<Vec<T>, L>
1361 where
1362 T: Copy,
1363 L: MutLayout,
1364 {
1365 let mut tensor = Self::uninit_in(alloc, shape);
1366 tensor.clip_dim(expand_dim, 0..0);
1367
1368 unsafe { tensor.assume_init() }
1371 }
1372}
1373
1374impl<T, L: Layout> TensorBase<CowData<'_, T>, L> {
1375 pub fn into_non_contiguous_data(self) -> Option<Vec<T>> {
1379 match self.data {
1380 CowData::Owned(mut vec) => {
1381 vec.truncate(self.layout.min_data_len());
1382 Some(vec)
1383 }
1384 CowData::Borrowed(_) => None,
1385 }
1386 }
1387}
1388
1389impl<T, S: Storage<Elem = MaybeUninit<T>> + AssumeInit, L: Layout + Clone> TensorBase<S, L>
1390where
1391 <S as AssumeInit>::Output: Storage<Elem = T>,
1392{
1393 pub unsafe fn assume_init(self) -> TensorBase<<S as AssumeInit>::Output, L> {
1403 TensorBase {
1404 layout: self.layout,
1405 data: unsafe { self.data.assume_init() },
1406 }
1407 }
1408
1409 pub fn init_from<S2: Storage<Elem = T>>(
1413 mut self,
1414 other: &TensorBase<S2, L>,
1415 ) -> TensorBase<<S as AssumeInit>::Output, L>
1416 where
1417 T: Copy,
1418 S: StorageMut<Elem = MaybeUninit<T>>,
1419 {
1420 assert_eq!(self.shape(), other.shape(), "shape mismatch");
1421
1422 match (self.data_mut(), other.data()) {
1423 (Some(self_data), Some(other_data)) => {
1425 let other_data: &[MaybeUninit<T>] = unsafe { std::mem::transmute(other_data) };
1426 self_data.clone_from_slice(other_data);
1427 }
1428 (Some(self_data), _) => {
1430 copy_into_slice(other.as_dyn(), self_data);
1431 }
1432 _ => {
1434 copy_into_uninit(other.as_dyn(), self.as_dyn_mut());
1435 }
1436 }
1437
1438 unsafe { self.assume_init() }
1439 }
1440}
1441
1442impl<'a, T, L: Clone + Layout> TensorBase<ViewData<'a, T>, L> {
1443 pub fn axis_iter(&self, dim: usize) -> AxisIter<'a, T, L>
1444 where
1445 L: MutLayout + RemoveDim,
1446 {
1447 AxisIter::new(self, dim)
1448 }
1449
1450 pub fn axis_chunks(&self, dim: usize, chunk_size: usize) -> AxisChunks<'a, T, L>
1451 where
1452 L: MutLayout,
1453 {
1454 AxisChunks::new(self, dim, chunk_size)
1455 }
1456
1457 pub fn as_dyn(&self) -> TensorBase<ViewData<'a, T>, DynLayout> {
1461 TensorBase {
1462 data: self.data,
1463 layout: DynLayout::from(&self.layout),
1464 }
1465 }
1466
1467 pub fn as_cow(&self) -> TensorBase<CowData<'a, T>, L> {
1471 TensorBase {
1472 layout: self.layout.clone(),
1473 data: CowData::Borrowed(self.data),
1474 }
1475 }
1476
1477 pub fn broadcast<S: IntoLayout>(&self, shape: S) -> TensorBase<ViewData<'a, T>, S::Layout>
1481 where
1482 L: BroadcastLayout<S::Layout>,
1483 {
1484 self.try_broadcast(shape).unwrap()
1485 }
1486
1487 pub fn try_broadcast<S: IntoLayout>(
1491 &self,
1492 shape: S,
1493 ) -> Result<TensorBase<ViewData<'a, T>, S::Layout>, ExpandError>
1494 where
1495 L: BroadcastLayout<S::Layout>,
1496 {
1497 Ok(TensorBase {
1498 layout: self.layout.broadcast(shape)?,
1499 data: self.data,
1500 })
1501 }
1502
1503 pub fn data(&self) -> Option<&'a [T]> {
1507 let len = self.layout.min_data_len();
1510 let data = self.data.slice(0..len);
1511
1512 self.layout.is_contiguous().then(|| unsafe {
1513 data.as_slice()
1515 })
1516 }
1517
1518 pub fn storage(&self) -> ViewData<'a, T> {
1520 self.data.view()
1521 }
1522
1523 pub fn get<I: AsIndex<L>>(&self, index: I) -> Option<&'a T>
1524 where
1525 L: TrustedLayout,
1526 {
1527 self.offset(index.as_index()).map(|offset|
1528 unsafe {
1533 self.data.get_unchecked(offset)
1534 })
1535 }
1536
1537 pub fn from_slice_with_strides(
1543 shape: L::Index<'_>,
1544 data: &'a [T],
1545 strides: L::Index<'_>,
1546 ) -> Result<TensorBase<ViewData<'a, T>, L>, FromDataError>
1547 where
1548 L: MutLayout,
1549 {
1550 let layout = L::from_shape_and_strides(shape, strides, OverlapPolicy::AllowOverlap)?;
1551 if layout.min_data_len() > data.as_ref().len() {
1552 return Err(FromDataError::StorageTooShort);
1553 }
1554 Ok(TensorBase {
1555 data: data.into_storage(),
1556 layout,
1557 })
1558 }
1559
1560 pub unsafe fn get_unchecked<I: AsIndex<L>>(&self, index: I) -> &'a T {
1567 let offset = self.layout.offset_unchecked(index.as_index());
1568 unsafe { self.data.get_unchecked(offset) }
1569 }
1570
1571 pub fn index_axis(
1577 &self,
1578 axis: usize,
1579 index: usize,
1580 ) -> TensorBase<ViewData<'a, T>, <L as RemoveDim>::Output>
1581 where
1582 L: MutLayout + RemoveDim,
1583 {
1584 let (offsets, layout) = self.layout.index_axis(axis, index);
1585 TensorBase {
1586 data: self.data.slice(offsets),
1587 layout,
1588 }
1589 }
1590
1591 pub fn inner_iter<const N: usize>(&self) -> InnerIter<'a, T, NdLayout<N>> {
1595 InnerIter::new(self.view())
1596 }
1597
1598 pub fn inner_iter_dyn(&self, n: usize) -> InnerIter<'a, T, DynLayout> {
1602 InnerIter::new_dyn(self.view(), n)
1603 }
1604
1605 pub fn item(&self) -> Option<&'a T> {
1607 match self.ndim() {
1608 0 => unsafe {
1609 self.data.get(0)
1611 },
1612 _ if self.len() == 1 => self.iter().next(),
1613 _ => None,
1614 }
1615 }
1616
1617 pub fn iter(&self) -> Iter<'a, T> {
1621 Iter::new(self.view_ref())
1622 }
1623
1624 pub fn lanes(&self, dim: usize) -> Lanes<'a, T>
1628 where
1629 L: RemoveDim,
1630 {
1631 assert!(dim < self.ndim());
1632 Lanes::new(self.view_ref(), dim)
1633 }
1634
1635 pub fn nd_view<const N: usize>(&self) -> TensorBase<ViewData<'a, T>, NdLayout<N>> {
1639 assert!(self.ndim() == N, "ndim {} != {}", self.ndim(), N);
1640 TensorBase {
1641 data: self.data,
1642 layout: self.nd_layout().unwrap(),
1643 }
1644 }
1645
1646 pub fn permuted(&self, order: L::Index<'_>) -> TensorBase<ViewData<'a, T>, L>
1650 where
1651 L: MutLayout,
1652 {
1653 TensorBase {
1654 data: self.data,
1655 layout: self.layout.permuted(order),
1656 }
1657 }
1658
1659 pub fn reshaped<S: Copy + IntoLayout>(&self, shape: S) -> TensorBase<CowData<'a, T>, S::Layout>
1663 where
1664 T: Clone,
1665 L: MutLayout,
1666 {
1667 self.reshaped_in(GlobalAlloc::new(), shape)
1668 }
1669
1670 pub fn reshaped_in<A: Alloc, S: Copy + IntoLayout>(
1672 &self,
1673 alloc: A,
1674 shape: S,
1675 ) -> TensorBase<CowData<'a, T>, S::Layout>
1676 where
1677 T: Clone,
1678 L: MutLayout,
1679 {
1680 if let Ok(layout) = self.layout.reshaped_for_view(shape) {
1681 TensorBase {
1682 data: CowData::Borrowed(self.data),
1683 layout,
1684 }
1685 } else {
1686 let Ok(layout) = self.layout.reshaped_for_copy(shape) else {
1687 panic!(
1688 "element count mismatch reshaping {:?} to {:?}",
1689 self.shape(),
1690 shape
1691 );
1692 };
1693
1694 TensorBase {
1695 data: CowData::Owned(self.to_vec_in(alloc)),
1696 layout,
1697 }
1698 }
1699 }
1700
1701 pub fn slice<R: IntoSliceItems + IndexCount>(
1703 &self,
1704 range: R,
1705 ) -> TensorBase<ViewData<'a, T>, <L as SliceWith<R, R::Count>>::Layout>
1706 where
1707 L: SliceWith<R, R::Count>,
1708 {
1709 self.try_slice(range).expect("slice failed")
1710 }
1711
1712 pub fn slice_axis(&self, axis: usize, range: Range<usize>) -> TensorBase<ViewData<'a, T>, L>
1714 where
1715 L: MutLayout,
1716 {
1717 let (offset_range, sliced_layout) = self.layout.slice_axis(axis, range.clone()).unwrap();
1718 debug_assert_eq!(sliced_layout.size(axis), range.len());
1719 TensorBase {
1720 data: self.data.slice(offset_range),
1721 layout: sliced_layout,
1722 }
1723 }
1724
1725 #[allow(clippy::type_complexity)]
1728 pub fn try_slice<R: IntoSliceItems + IndexCount>(
1729 &self,
1730 range: R,
1731 ) -> Result<TensorBase<ViewData<'a, T>, <L as SliceWith<R, R::Count>>::Layout>, SliceError>
1732 where
1733 L: SliceWith<R, R::Count>,
1734 {
1735 let (offset_range, sliced_layout) = self.layout.slice_with(range)?;
1736 Ok(TensorBase {
1737 data: self.data.slice(offset_range),
1738 layout: sliced_layout,
1739 })
1740 }
1741
1742 pub fn squeezed(&self) -> TensorView<'a, T>
1746 where
1747 L: MutLayout,
1748 {
1749 TensorBase {
1750 data: self.data.view(),
1751 layout: self.layout.squeezed(),
1752 }
1753 }
1754
1755 #[allow(clippy::type_complexity)]
1761 pub fn split_at(
1762 &self,
1763 axis: usize,
1764 mid: usize,
1765 ) -> (
1766 TensorBase<ViewData<'a, T>, L>,
1767 TensorBase<ViewData<'a, T>, L>,
1768 )
1769 where
1770 L: MutLayout,
1771 {
1772 let (left, right) = self.layout.split(axis, mid);
1773 let (left_offset_range, left_layout) = left;
1774 let (right_offset_range, right_layout) = right;
1775 let left_data = self.data.slice(left_offset_range.clone());
1776 let right_data = self.data.slice(right_offset_range.clone());
1777
1778 debug_assert_eq!(left_data.len(), left_layout.min_data_len());
1779 let left_view = TensorBase {
1780 data: left_data,
1781 layout: left_layout,
1782 };
1783
1784 debug_assert_eq!(right_data.len(), right_layout.min_data_len());
1785 let right_view = TensorBase {
1786 data: right_data,
1787 layout: right_layout,
1788 };
1789
1790 (left_view, right_view)
1791 }
1792
1793 pub fn to_contiguous(&self) -> Contiguous<TensorBase<CowData<'a, T>, L>>
1798 where
1799 T: Clone,
1800 L: MutLayout,
1801 {
1802 self.to_contiguous_in(GlobalAlloc::new())
1803 }
1804
1805 pub fn to_contiguous_in<A: Alloc>(&self, alloc: A) -> Contiguous<TensorBase<CowData<'a, T>, L>>
1808 where
1809 T: Clone,
1810 L: MutLayout,
1811 {
1812 let tensor = if let Some(data) = self.data() {
1813 TensorBase {
1814 data: CowData::Borrowed(data.into_storage()),
1815 layout: self.layout.clone(),
1816 }
1817 } else {
1818 let data = self.to_vec_in(alloc);
1819 TensorBase {
1820 data: CowData::Owned(data),
1821 layout: L::from_shape(self.layout.shape()),
1822 }
1823 };
1824 Contiguous::new(tensor).unwrap()
1825 }
1826
1827 pub fn to_slice(&self) -> Cow<'a, [T]>
1832 where
1833 T: Clone,
1834 {
1835 self.data()
1836 .map(Cow::Borrowed)
1837 .unwrap_or_else(|| Cow::Owned(self.to_vec()))
1838 }
1839
1840 pub fn transposed(&self) -> TensorBase<ViewData<'a, T>, L>
1842 where
1843 L: MutLayout,
1844 {
1845 TensorBase {
1846 data: self.data,
1847 layout: self.layout.transposed(),
1848 }
1849 }
1850
1851 pub fn try_slice_dyn<R: IntoSliceItems>(
1852 &self,
1853 range: R,
1854 ) -> Result<TensorView<'a, T>, SliceError>
1855 where
1856 L: MutLayout,
1857 {
1858 let (offset_range, layout) = self.layout.slice_dyn(range.into_slice_items().as_ref())?;
1859 Ok(TensorBase {
1860 data: self.data.slice(offset_range),
1861 layout,
1862 })
1863 }
1864
1865 pub fn view(&self) -> TensorBase<ViewData<'a, T>, L> {
1867 TensorBase {
1868 data: self.data,
1869 layout: self.layout.clone(),
1870 }
1871 }
1872
1873 pub(crate) fn view_ref(&self) -> TensorBase<ViewData<'a, T>, &L> {
1874 TensorBase {
1875 data: self.data,
1876 layout: &self.layout,
1877 }
1878 }
1879
1880 pub fn weakly_checked_view(&self) -> WeaklyCheckedView<ViewData<'a, T>, L> {
1881 WeaklyCheckedView { base: self.view() }
1882 }
1883}
1884
1885impl<S: Storage, L: Layout> Layout for TensorBase<S, L> {
1886 type Index<'a> = L::Index<'a>;
1887 type Indices = L::Indices;
1888
1889 fn ndim(&self) -> usize {
1890 self.layout.ndim()
1891 }
1892
1893 fn len(&self) -> usize {
1894 self.layout.len()
1895 }
1896
1897 fn is_empty(&self) -> bool {
1898 self.layout.is_empty()
1899 }
1900
1901 fn shape(&self) -> Self::Index<'_> {
1902 self.layout.shape()
1903 }
1904
1905 fn size(&self, dim: usize) -> usize {
1906 self.layout.size(dim)
1907 }
1908
1909 fn strides(&self) -> Self::Index<'_> {
1910 self.layout.strides()
1911 }
1912
1913 fn stride(&self, dim: usize) -> usize {
1914 self.layout.stride(dim)
1915 }
1916
1917 fn indices(&self) -> Self::Indices {
1918 self.layout.indices()
1919 }
1920
1921 fn offset(&self, index: Self::Index<'_>) -> Option<usize> {
1922 self.layout.offset(index)
1923 }
1924}
1925
1926impl<S: Storage, L: Layout + MatrixLayout> MatrixLayout for TensorBase<S, L> {
1927 fn rows(&self) -> usize {
1928 self.layout.rows()
1929 }
1930
1931 fn cols(&self) -> usize {
1932 self.layout.cols()
1933 }
1934
1935 fn row_stride(&self) -> usize {
1936 self.layout.row_stride()
1937 }
1938
1939 fn col_stride(&self) -> usize {
1940 self.layout.col_stride()
1941 }
1942}
1943
1944impl<T, S: Storage<Elem = T>, L: Layout + Clone> AsView for TensorBase<S, L> {
1945 type Elem = T;
1946 type Layout = L;
1947
1948 fn iter(&self) -> Iter<'_, T> {
1949 self.view().iter()
1950 }
1951
1952 fn copy_into_slice<'a>(&self, dest: &'a mut [MaybeUninit<T>]) -> &'a [T]
1953 where
1954 T: Copy,
1955 {
1956 if let Some(data) = self.data() {
1957 let src_uninit = unsafe { std::mem::transmute::<&[T], &[MaybeUninit<T>]>(data) };
1959 dest.copy_from_slice(src_uninit);
1960 unsafe { dest.assume_init() }
1963 } else {
1964 copy_into_slice(self.as_dyn(), dest)
1965 }
1966 }
1967
1968 fn data(&self) -> Option<&[Self::Elem]> {
1969 self.view().data()
1970 }
1971
1972 fn insert_axis(&mut self, index: usize)
1973 where
1974 L: ResizeLayout,
1975 {
1976 self.layout.insert_axis(index)
1977 }
1978
1979 #[track_caller]
1980 fn remove_axis(&mut self, index: usize)
1981 where
1982 L: ResizeLayout,
1983 {
1984 self.layout.remove_axis(index)
1985 }
1986
1987 fn merge_axes(&mut self)
1988 where
1989 L: ResizeLayout,
1990 {
1991 self.layout.merge_axes()
1992 }
1993
1994 fn layout(&self) -> &L {
1995 &self.layout
1996 }
1997
1998 fn map<F, U>(&self, f: F) -> TensorBase<Vec<U>, L>
1999 where
2000 F: Fn(&Self::Elem) -> U,
2001 L: MutLayout,
2002 {
2003 self.map_in(GlobalAlloc::new(), f)
2004 }
2005
2006 fn map_in<A: Alloc, F, U>(&self, alloc: A, f: F) -> TensorBase<Vec<U>, L>
2007 where
2008 F: Fn(&Self::Elem) -> U,
2009 L: MutLayout,
2010 {
2011 let len = self.len();
2012 let mut buf = alloc.alloc(len);
2013 if let Some(data) = self.data() {
2014 buf.extend(data.iter().map(f));
2016 } else {
2017 let dest = &mut buf.spare_capacity_mut()[..len];
2018 map_into_slice(self.as_dyn(), dest, f);
2019
2020 unsafe {
2022 buf.set_len(len);
2023 }
2024 };
2025 TensorBase::from_data(self.shape(), buf)
2026 }
2027
2028 fn move_axis(&mut self, from: usize, to: usize)
2029 where
2030 L: MutLayout,
2031 {
2032 self.layout.move_axis(from, to);
2033 }
2034
2035 fn view(&self) -> TensorBase<ViewData<'_, T>, L> {
2036 TensorBase {
2037 data: self.data.view(),
2038 layout: self.layout.clone(),
2039 }
2040 }
2041
2042 fn get<I: AsIndex<L>>(&self, index: I) -> Option<&Self::Elem> {
2046 self.offset(index.as_index()).map(|offset| unsafe {
2047 self.data.get_unchecked(offset)
2049 })
2050 }
2051
2052 unsafe fn get_unchecked<I: AsIndex<L>>(&self, index: I) -> &T {
2053 let offset = self.layout.offset_unchecked(index.as_index());
2054 unsafe { self.data.get_unchecked(offset) }
2055 }
2056
2057 fn permute(&mut self, order: Self::Index<'_>)
2058 where
2059 L: MutLayout,
2060 {
2061 self.layout = self.layout.permuted(order);
2062 }
2063
2064 fn to_vec(&self) -> Vec<T>
2065 where
2066 T: Clone,
2067 {
2068 self.to_vec_in(GlobalAlloc::new())
2069 }
2070
2071 fn to_vec_in<A: Alloc>(&self, alloc: A) -> Vec<T>
2072 where
2073 T: Clone,
2074 {
2075 let len = self.len();
2076 let mut buf = alloc.alloc(len);
2077
2078 if let Some(data) = self.data() {
2079 buf.extend_from_slice(data);
2080 } else {
2081 copy_into_slice(self.as_dyn(), &mut buf.spare_capacity_mut()[..len]);
2082
2083 unsafe { buf.set_len(len) }
2085 }
2086
2087 buf
2088 }
2089
2090 fn to_shape<SH: IntoLayout>(&self, shape: SH) -> TensorBase<Vec<Self::Elem>, SH::Layout>
2091 where
2092 T: Clone,
2093 L: MutLayout,
2094 {
2095 TensorBase {
2096 data: self.to_vec(),
2097 layout: self
2098 .layout
2099 .reshaped_for_copy(shape)
2100 .expect("reshape failed"),
2101 }
2102 }
2103
2104 fn transpose(&mut self)
2105 where
2106 L: MutLayout,
2107 {
2108 self.layout = self.layout.transposed();
2109 }
2110}
2111
2112impl<T, S: Storage<Elem = T>, const N: usize> TensorBase<S, NdLayout<N>> {
2113 #[inline]
2121 pub fn get_array<const M: usize>(&self, base: [usize; N], dim: usize) -> [T; M]
2122 where
2123 T: Copy + Default,
2124 {
2125 let offsets: [usize; M] = array_offsets(&self.layout, base, dim);
2126 let mut result = [T::default(); M];
2127 for i in 0..M {
2128 result[i] = unsafe { *self.data.get_unchecked(offsets[i]) };
2130 }
2131 result
2132 }
2133}
2134
2135impl<T> TensorBase<Vec<T>, DynLayout> {
2136 #[track_caller]
2139 pub fn reshape(&mut self, shape: &[usize])
2140 where
2141 T: Clone,
2142 {
2143 self.reshape_in(GlobalAlloc::new(), shape)
2144 }
2145
2146 #[track_caller]
2148 pub fn reshape_in<A: Alloc>(&mut self, alloc: A, shape: &[usize])
2149 where
2150 T: Clone,
2151 {
2152 if !self.is_contiguous() {
2153 self.data = self.to_vec_in(alloc);
2154 }
2155 let Ok(layout) = self.layout.reshaped_for_copy(shape) else {
2156 panic!(
2157 "element count mismatch reshaping {:?} to {:?}",
2158 self.shape(),
2159 shape
2160 );
2161 };
2162 self.layout = layout;
2163 }
2164}
2165
2166impl<'a, T, L: Layout> TensorBase<ViewMutData<'a, T>, L> {
2167 #[allow(clippy::type_complexity)]
2173 pub fn split_at_mut(
2174 self,
2175 axis: usize,
2176 mid: usize,
2177 ) -> (
2178 TensorBase<ViewMutData<'a, T>, L>,
2179 TensorBase<ViewMutData<'a, T>, L>,
2180 )
2181 where
2182 L: MutLayout,
2183 {
2184 let (left, right) = self.layout.split(axis, mid);
2185 let (left_offset_range, left_layout) = left;
2186 let (right_offset_range, right_layout) = right;
2187 let (left_data, right_data) = self
2188 .data
2189 .split_mut(left_offset_range.clone(), right_offset_range.clone());
2190
2191 debug_assert_eq!(left_data.len(), left_layout.min_data_len());
2192 let left_view = TensorBase {
2193 data: left_data,
2194 layout: left_layout,
2195 };
2196
2197 debug_assert_eq!(right_data.len(), right_layout.min_data_len());
2198 let right_view = TensorBase {
2199 data: right_data,
2200 layout: right_layout,
2201 };
2202
2203 (left_view, right_view)
2204 }
2205
2206 pub fn into_slice_mut(self) -> Option<&'a mut [T]> {
2209 let len = self.layout.min_data_len();
2210 self.is_contiguous().then(|| {
2211 let slice = unsafe { self.data.to_slice_mut() };
2213 &mut slice[..len]
2214 })
2215 }
2216}
2217
2218impl<T, L: MutLayout> FromIterator<T> for TensorBase<Vec<T>, L>
2219where
2220 [usize; 1]: AsIndex<L>,
2221{
2222 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> TensorBase<Vec<T>, L> {
2226 let data: Vec<T> = iter.into_iter().collect();
2227 TensorBase::from_data([data.len()].as_index(), data)
2228 }
2229}
2230
2231impl<T, L: MutLayout> From<Vec<T>> for TensorBase<Vec<T>, L>
2232where
2233 [usize; 1]: AsIndex<L>,
2234{
2235 fn from(vec: Vec<T>) -> Self {
2237 Self::from_data([vec.len()].as_index(), vec)
2238 }
2239}
2240
2241impl<'a, T, L: MutLayout> From<&'a [T]> for TensorBase<ViewData<'a, T>, L>
2242where
2243 [usize; 1]: AsIndex<L>,
2244{
2245 fn from(slice: &'a [T]) -> Self {
2247 Self::from_data([slice.len()].as_index(), slice)
2248 }
2249}
2250
2251impl<'a, T, L: MutLayout, const N: usize> From<&'a [T; N]> for TensorBase<ViewData<'a, T>, L>
2252where
2253 [usize; 1]: AsIndex<L>,
2254{
2255 fn from(slice: &'a [T; N]) -> Self {
2257 Self::from_data([slice.len()].as_index(), slice.as_slice())
2258 }
2259}
2260
2261fn array_offsets<const N: usize, const M: usize>(
2266 layout: &NdLayout<N>,
2267 base: [usize; N],
2268 dim: usize,
2269) -> [usize; M] {
2270 assert!(
2271 base[dim] < usize::MAX - M && layout.size(dim) >= base[dim] + M,
2272 "array indices invalid"
2273 );
2274
2275 let offset = layout.must_offset(base);
2276 let stride = layout.stride(dim);
2277 let mut offsets = [0; M];
2278 for i in 0..M {
2279 offsets[i] = offset + i * stride;
2280 }
2281 offsets
2282}
2283
2284impl<T, S: StorageMut<Elem = T>, const N: usize> TensorBase<S, NdLayout<N>> {
2285 #[inline]
2290 pub fn set_array<const M: usize>(&mut self, base: [usize; N], dim: usize, values: [T; M])
2291 where
2292 T: Copy,
2293 {
2294 let offsets: [usize; M] = array_offsets(&self.layout, base, dim);
2295
2296 for i in 0..M {
2297 unsafe { *self.data.get_unchecked_mut(offsets[i]) = values[i] };
2299 }
2300 }
2301}
2302
2303impl<T, S: Storage<Elem = T>> TensorBase<S, NdLayout<1>> {
2304 #[inline]
2308 pub fn to_array<const M: usize>(&self) -> [T; M]
2309 where
2310 T: Copy + Default,
2311 {
2312 self.get_array([0], 0)
2313 }
2314}
2315
2316impl<T, S: StorageMut<Elem = T>> TensorBase<S, NdLayout<1>> {
2317 #[inline]
2321 pub fn assign_array<const M: usize>(&mut self, values: [T; M])
2322 where
2323 T: Copy + Default,
2324 {
2325 self.set_array([0], 0, values)
2326 }
2327}
2328
2329pub type NdTensorView<'a, T, const N: usize> = TensorBase<ViewData<'a, T>, NdLayout<N>>;
2331
2332pub type NdTensor<T, const N: usize> = TensorBase<Vec<T>, NdLayout<N>>;
2334
2335pub type NdTensorViewMut<'a, T, const N: usize> = TensorBase<ViewMutData<'a, T>, NdLayout<N>>;
2337
2338pub type CowNdTensor<'a, T, const N: usize> = TensorBase<CowData<'a, T>, NdLayout<N>>;
2345
2346pub type Matrix<'a, T = f32> = NdTensorView<'a, T, 2>;
2348
2349pub type MatrixMut<'a, T = f32> = NdTensorViewMut<'a, T, 2>;
2351
2352pub type Tensor<T = f32> = TensorBase<Vec<T>, DynLayout>;
2354
2355pub type TensorView<'a, T = f32> = TensorBase<ViewData<'a, T>, DynLayout>;
2357
2358pub type TensorViewMut<'a, T = f32> = TensorBase<ViewMutData<'a, T>, DynLayout>;
2360
2361pub type CowTensor<'a, T> = TensorBase<CowData<'a, T>, DynLayout>;
2368
2369pub type ArcTensor<T> = TensorBase<Arc<Vec<T>>, DynLayout>;
2375
2376pub type ArcNdTensor<T, const N: usize> = TensorBase<Arc<Vec<T>>, NdLayout<N>>;
2380
2381impl<T, S: Storage<Elem = T>, L: TrustedLayout, I: AsIndex<L>> Index<I> for TensorBase<S, L> {
2382 type Output = T;
2383
2384 fn index(&self, index: I) -> &Self::Output {
2388 let offset = self.layout.must_offset(index.as_index());
2389
2390 unsafe { self.data.get_unchecked(offset) }
2393 }
2394}
2395
2396impl<T, S: StorageMut<Elem = T>, L: TrustedLayout, I: AsIndex<L>> IndexMut<I> for TensorBase<S, L> {
2397 fn index_mut(&mut self, index: I) -> &mut Self::Output {
2401 let index = index.as_index();
2402 let offset = self.layout.must_offset(index);
2403
2404 unsafe { self.data.get_unchecked_mut(offset) }
2407 }
2408}
2409
2410impl<T, S: Storage<Elem = T> + Clone, L: Layout + Clone> Clone for TensorBase<S, L> {
2411 fn clone(&self) -> TensorBase<S, L> {
2412 let data = self.data.clone();
2413 TensorBase {
2414 data,
2415 layout: self.layout.clone(),
2416 }
2417 }
2418}
2419
2420impl<T, S: Storage<Elem = T> + Copy, L: Layout + Copy> Copy for TensorBase<S, L> {}
2421
2422impl<T: PartialEq, S: Storage<Elem = T>, L: Layout + Clone, V: AsView<Elem = T>> PartialEq<V>
2423 for TensorBase<S, L>
2424{
2425 fn eq(&self, other: &V) -> bool {
2426 self.shape().as_ref() == other.shape().as_ref() && self.iter().eq(other.iter())
2427 }
2428}
2429
2430impl<T, S: Storage<Elem = T>, const N: usize> From<TensorBase<S, NdLayout<N>>>
2431 for TensorBase<S, DynLayout>
2432{
2433 fn from(tensor: TensorBase<S, NdLayout<N>>) -> Self {
2434 Self {
2435 data: tensor.data,
2436 layout: tensor.layout.into(),
2437 }
2438 }
2439}
2440
2441impl<T, S1: Storage<Elem = T>, S2: Storage<Elem = T>, const N: usize>
2442 TryFrom<TensorBase<S1, DynLayout>> for TensorBase<S2, NdLayout<N>>
2443where
2444 S1: Into<S2>,
2445{
2446 type Error = DimensionError;
2447
2448 fn try_from(value: TensorBase<S1, DynLayout>) -> Result<Self, Self::Error> {
2452 let layout: NdLayout<N> = value.layout().try_into()?;
2453 Ok(TensorBase {
2454 data: value.data.into(),
2455 layout,
2456 })
2457 }
2458}
2459
2460pub trait Scalar {}
2465
2466macro_rules! impl_scalar {
2467 ($ty:ty) => {
2468 impl Scalar for $ty {}
2469 };
2470}
2471impl_scalar!(bool);
2472impl_scalar!(u8);
2473impl_scalar!(i8);
2474impl_scalar!(u16);
2475impl_scalar!(i16);
2476impl_scalar!(u32);
2477impl_scalar!(i32);
2478impl_scalar!(u64);
2479impl_scalar!(i64);
2480impl_scalar!(usize);
2481impl_scalar!(isize);
2482impl_scalar!(f32);
2483impl_scalar!(f64);
2484impl_scalar!(String);
2485
2486impl<T: Clone + Scalar, L: MutLayout> From<T> for TensorBase<Vec<T>, L>
2491where
2492 [usize; 0]: AsIndex<L>,
2493{
2494 fn from(value: T) -> Self {
2496 Self::from_scalar(value)
2497 }
2498}
2499
2500impl<T: Clone + Scalar, L: MutLayout, const D0: usize> From<[T; D0]> for TensorBase<Vec<T>, L>
2501where
2502 [usize; 1]: AsIndex<L>,
2503{
2504 fn from(value: [T; D0]) -> Self {
2506 let data: Vec<T> = value.iter().cloned().collect();
2507 Self::from_data([D0].as_index(), data)
2508 }
2509}
2510
2511impl<T: Clone + Scalar, L: MutLayout, const D0: usize, const D1: usize> From<[[T; D1]; D0]>
2512 for TensorBase<Vec<T>, L>
2513where
2514 [usize; 2]: AsIndex<L>,
2515{
2516 fn from(value: [[T; D1]; D0]) -> Self {
2518 let data: Vec<_> = value.iter().flat_map(|y| y.iter()).cloned().collect();
2519 Self::from_data([D0, D1].as_index(), data)
2520 }
2521}
2522
2523impl<T: Clone + Scalar, L: MutLayout, const D0: usize, const D1: usize, const D2: usize>
2524 From<[[[T; D2]; D1]; D0]> for TensorBase<Vec<T>, L>
2525where
2526 [usize; 3]: AsIndex<L>,
2527{
2528 fn from(value: [[[T; D2]; D1]; D0]) -> Self {
2530 let data: Vec<_> = value
2531 .iter()
2532 .flat_map(|y| y.iter().flat_map(|z| z.iter()))
2533 .cloned()
2534 .collect();
2535 Self::from_data([D0, D1, D2].as_index(), data)
2536 }
2537}
2538
2539pub struct WeaklyCheckedView<S: Storage, L: Layout> {
2547 base: TensorBase<S, L>,
2548}
2549
2550impl<T, S: Storage<Elem = T>, L: Layout> Layout for WeaklyCheckedView<S, L> {
2551 type Index<'a> = L::Index<'a>;
2552 type Indices = L::Indices;
2553
2554 fn ndim(&self) -> usize {
2555 self.base.ndim()
2556 }
2557
2558 fn offset(&self, index: Self::Index<'_>) -> Option<usize> {
2559 self.base.offset(index)
2560 }
2561
2562 fn len(&self) -> usize {
2563 self.base.len()
2564 }
2565
2566 fn shape(&self) -> Self::Index<'_> {
2567 self.base.shape()
2568 }
2569
2570 fn strides(&self) -> Self::Index<'_> {
2571 self.base.strides()
2572 }
2573
2574 fn indices(&self) -> Self::Indices {
2575 self.base.indices()
2576 }
2577}
2578
2579impl<T, S: Storage<Elem = T>, L: Layout, I: AsIndex<L>> Index<I> for WeaklyCheckedView<S, L> {
2580 type Output = T;
2581 fn index(&self, index: I) -> &Self::Output {
2582 let offset = self.base.layout.offset_unchecked(index.as_index());
2583 unsafe {
2584 self.base.data.get(offset).expect("invalid offset")
2586 }
2587 }
2588}
2589
2590impl<T, S: StorageMut<Elem = T>, L: Layout, I: AsIndex<L>> IndexMut<I> for WeaklyCheckedView<S, L> {
2591 fn index_mut(&mut self, index: I) -> &mut Self::Output {
2592 let offset = self.base.layout.offset_unchecked(index.as_index());
2593 unsafe {
2594 self.base.data.get_mut(offset).expect("invalid offset")
2596 }
2597 }
2598}
2599
2600#[cfg(test)]
2601mod tests;